Fork of the espurna firmware for `mhsw` switches
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

295 lines
7.7 KiB

7 years ago
  1. /*
  2. OTA MODULE
  3. Copyright (C) 2016-2018 by Xose Pérez <xose dot perez at gmail dot com>
  4. */
  5. #include "ArduinoOTA.h"
  6. // -----------------------------------------------------------------------------
  7. // Arduino OTA
  8. // -----------------------------------------------------------------------------
  9. void _otaConfigure() {
  10. ArduinoOTA.setPort(OTA_PORT);
  11. ArduinoOTA.setHostname(getSetting("hostname").c_str());
  12. #if USE_PASSWORD
  13. ArduinoOTA.setPassword(getAdminPass().c_str());
  14. #endif
  15. }
  16. void _otaLoop() {
  17. ArduinoOTA.handle();
  18. }
  19. // -----------------------------------------------------------------------------
  20. // Terminal OTA
  21. // -----------------------------------------------------------------------------
  22. #if TERMINAL_SUPPORT || OTA_MQTT_SUPPORT
  23. #include <ESPAsyncTCP.h>
  24. AsyncClient * _ota_client;
  25. char * _ota_host;
  26. char * _ota_url;
  27. unsigned int _ota_port = 80;
  28. unsigned long _ota_size = 0;
  29. const char OTA_REQUEST_TEMPLATE[] PROGMEM =
  30. "GET %s HTTP/1.1\r\n"
  31. "Host: %s\r\n"
  32. "User-Agent: ESPurna\r\n"
  33. "Connection: close\r\n"
  34. "Content-Type: application/x-www-form-urlencoded\r\n"
  35. "Content-Length: 0\r\n\r\n\r\n";
  36. void _otaFrom(const char * host, unsigned int port, const char * url) {
  37. if (_ota_host) free(_ota_host);
  38. if (_ota_url) free(_ota_url);
  39. _ota_host = strdup(host);
  40. _ota_url = strdup(url);
  41. _ota_port = port;
  42. _ota_size = 0;
  43. if (_ota_client == NULL) {
  44. _ota_client = new AsyncClient();
  45. }
  46. _ota_client->onDisconnect([](void *s, AsyncClient *c) {
  47. DEBUG_MSG_P(PSTR("\n"));
  48. if (Update.end(true)){
  49. DEBUG_MSG_P(PSTR("[OTA] Success: %u bytes\n"), _ota_size);
  50. deferredReset(100, CUSTOM_RESET_OTA);
  51. } else {
  52. #ifdef DEBUG_PORT
  53. Update.printError(DEBUG_PORT);
  54. #endif
  55. eepromRotate(true);
  56. }
  57. DEBUG_MSG_P(PSTR("[OTA] Disconnected\n"));
  58. _ota_client->free();
  59. delete _ota_client;
  60. _ota_client = NULL;
  61. free(_ota_host);
  62. _ota_host = NULL;
  63. free(_ota_url);
  64. _ota_url = NULL;
  65. }, 0);
  66. _ota_client->onTimeout([](void *s, AsyncClient *c, uint32_t time) {
  67. _ota_client->close(true);
  68. }, 0);
  69. _ota_client->onData([](void * arg, AsyncClient * c, void * data, size_t len) {
  70. char * p = (char *) data;
  71. if (_ota_size == 0) {
  72. Update.runAsync(true);
  73. if (!Update.begin((ESP.getFreeSketchSpace() - 0x1000) & 0xFFFFF000)) {
  74. #ifdef DEBUG_PORT
  75. Update.printError(DEBUG_PORT);
  76. #endif
  77. }
  78. p = strstr((char *)data, "\r\n\r\n") + 4;
  79. len = len - (p - (char *) data);
  80. }
  81. if (!Update.hasError()) {
  82. if (Update.write((uint8_t *) p, len) != len) {
  83. #ifdef DEBUG_PORT
  84. Update.printError(DEBUG_PORT);
  85. #endif
  86. }
  87. }
  88. _ota_size += len;
  89. DEBUG_MSG_P(PSTR("[OTA] Progress: %u bytes\r"), _ota_size);
  90. delay(0);
  91. }, NULL);
  92. _ota_client->onConnect([](void * arg, AsyncClient * client) {
  93. #if ASYNC_TCP_SSL_ENABLED
  94. if (443 == _ota_port) {
  95. uint8_t fp[20] = {0};
  96. sslFingerPrintArray(getSetting("otafp", OTA_GITHUB_FP).c_str(), fp);
  97. SSL * ssl = _ota_client->getSSL();
  98. if (ssl_match_fingerprint(ssl, fp) != SSL_OK) {
  99. DEBUG_MSG_P(PSTR("[OTA] Warning: certificate doesn't match\n"));
  100. }
  101. }
  102. #endif
  103. // Disabling EEPROM rotation to prevent writing to EEPROM after the upgrade
  104. eepromRotate(false);
  105. DEBUG_MSG_P(PSTR("[OTA] Downloading %s\n"), _ota_url);
  106. char buffer[strlen_P(OTA_REQUEST_TEMPLATE) + strlen(_ota_url) + strlen(_ota_host)];
  107. snprintf_P(buffer, sizeof(buffer), OTA_REQUEST_TEMPLATE, _ota_url, _ota_host);
  108. client->write(buffer);
  109. }, NULL);
  110. #if ASYNC_TCP_SSL_ENABLED
  111. bool connected = _ota_client->connect(host, port, 443 == port);
  112. #else
  113. bool connected = _ota_client->connect(host, port);
  114. #endif
  115. if (!connected) {
  116. DEBUG_MSG_P(PSTR("[OTA] Connection failed\n"));
  117. _ota_client->close(true);
  118. }
  119. }
  120. void _otaFrom(String url) {
  121. if (!url.startsWith("http://") && !url.startsWith("https://")) {
  122. DEBUG_MSG_P(PSTR("[OTA] Incorrect URL specified\n"));
  123. return;
  124. }
  125. // Port from protocol
  126. unsigned int port = 80;
  127. if (url.startsWith("https://")) port = 443;
  128. url = url.substring(url.indexOf("/") + 2);
  129. // Get host
  130. String host = url.substring(0, url.indexOf("/"));
  131. // Explicit port
  132. int p = host.indexOf(":");
  133. if (p > 0) {
  134. port = host.substring(p + 1).toInt();
  135. host = host.substring(0, p);
  136. }
  137. // Get URL
  138. String uri = url.substring(url.indexOf("/"));
  139. _otaFrom(host.c_str(), port, uri.c_str());
  140. }
  141. #endif // TERMINAL_SUPPORT || OTA_MQTT_SUPPORT
  142. #if TERMINAL_SUPPORT
  143. void _otaInitCommands() {
  144. terminalRegisterCommand(F("OTA"), [](Embedis* e) {
  145. if (e->argc < 2) {
  146. terminalError(F("Wrong arguments"));
  147. } else {
  148. terminalOK();
  149. String url = String(e->argv[1]);
  150. _otaFrom(url);
  151. }
  152. });
  153. }
  154. #endif // TERMINAL_SUPPORT
  155. #if OTA_MQTT_SUPPORT
  156. void _otaMQTTCallback(unsigned int type, const char * topic, const char * payload) {
  157. if (type == MQTT_CONNECT_EVENT) {
  158. mqttSubscribe(MQTT_TOPIC_OTA);
  159. }
  160. if (type == MQTT_MESSAGE_EVENT) {
  161. // Match topic
  162. String t = mqttMagnitude((char *) topic);
  163. if (t.equals(MQTT_TOPIC_OTA)) {
  164. DEBUG_MSG_P(PSTR("[OTA] Initiating from URL: %s\n"), payload);
  165. _otaFrom(payload);
  166. }
  167. }
  168. }
  169. #endif // OTA_MQTT_SUPPORT
  170. // -----------------------------------------------------------------------------
  171. void otaSetup() {
  172. _otaConfigure();
  173. #if TERMINAL_SUPPORT
  174. _otaInitCommands();
  175. #endif
  176. #if OTA_MQTT_SUPPORT
  177. mqttRegister(_otaMQTTCallback);
  178. #endif
  179. // Main callbacks
  180. espurnaRegisterLoop(_otaLoop);
  181. espurnaRegisterReload(_otaConfigure);
  182. // -------------------------------------------------------------------------
  183. ArduinoOTA.onStart([]() {
  184. // Disabling EEPROM rotation to prevent writing to EEPROM after the upgrade
  185. eepromRotate(false);
  186. DEBUG_MSG_P(PSTR("[OTA] Start\n"));
  187. #if WEB_SUPPORT
  188. wsSend_P(PSTR("{\"message\": 2}"));
  189. #endif
  190. });
  191. ArduinoOTA.onEnd([]() {
  192. DEBUG_MSG_P(PSTR("\n"));
  193. DEBUG_MSG_P(PSTR("[OTA] Done, restarting...\n"));
  194. #if WEB_SUPPORT
  195. wsSend_P(PSTR("{\"action\": \"reload\"}"));
  196. #endif
  197. deferredReset(100, CUSTOM_RESET_OTA);
  198. });
  199. ArduinoOTA.onProgress([](unsigned int progress, unsigned int total) {
  200. static unsigned int _progOld;
  201. unsigned int _prog = (progress / (total / 100));
  202. if (_prog != _progOld) {
  203. DEBUG_MSG_P(PSTR("[OTA] Progress: %u%%\r"), _prog);
  204. _progOld = _prog;
  205. }
  206. });
  207. ArduinoOTA.onError([](ota_error_t error) {
  208. #if DEBUG_SUPPORT
  209. DEBUG_MSG_P(PSTR("\n[OTA] Error #%u: "), error);
  210. if (error == OTA_AUTH_ERROR) DEBUG_MSG_P(PSTR("Auth Failed\n"));
  211. else if (error == OTA_BEGIN_ERROR) DEBUG_MSG_P(PSTR("Begin Failed\n"));
  212. else if (error == OTA_CONNECT_ERROR) DEBUG_MSG_P(PSTR("Connect Failed\n"));
  213. else if (error == OTA_RECEIVE_ERROR) DEBUG_MSG_P(PSTR("Receive Failed\n"));
  214. else if (error == OTA_END_ERROR) DEBUG_MSG_P(PSTR("End Failed\n"));
  215. #endif
  216. eepromRotate(true);
  217. });
  218. ArduinoOTA.begin();
  219. }