Fork of the espurna firmware for `mhsw` switches
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

273 lines
7.1 KiB

7 years ago
  1. /*
  2. OTA MODULE
  3. Copyright (C) 2016-2018 by Xose Pérez <xose dot perez at gmail dot com>
  4. Module key prefix: ota
  5. */
  6. #include "ArduinoOTA.h"
  7. // -----------------------------------------------------------------------------
  8. // Arduino OTA
  9. // -----------------------------------------------------------------------------
  10. void _otaConfigure() {
  11. ArduinoOTA.setPort(OTA_PORT);
  12. ArduinoOTA.setHostname(getHostname().c_str());
  13. #if USE_PASSWORD
  14. ArduinoOTA.setPassword(getPassword().c_str());
  15. #endif
  16. }
  17. void _otaLoop() {
  18. ArduinoOTA.handle();
  19. }
  20. // -----------------------------------------------------------------------------
  21. // Terminal OTA
  22. // -----------------------------------------------------------------------------
  23. #if TERMINAL_SUPPORT
  24. #include <ESPAsyncTCP.h>
  25. AsyncClient * _ota_client;
  26. char * _ota_host;
  27. char * _ota_url;
  28. unsigned int _ota_port = 80;
  29. unsigned long _ota_size = 0;
  30. const char OTA_REQUEST_TEMPLATE[] PROGMEM =
  31. "GET %s HTTP/1.1\r\n"
  32. "Host: %s\r\n"
  33. "User-Agent: ESPurna\r\n"
  34. "Connection: close\r\n"
  35. "Content-Type: application/x-www-form-urlencoded\r\n"
  36. "Content-Length: 0\r\n\r\n\r\n";
  37. void _otaFrom(const char * host, unsigned int port, const char * url) {
  38. if (_ota_host) free(_ota_host);
  39. if (_ota_url) free(_ota_url);
  40. _ota_host = strdup(host);
  41. _ota_url = strdup(url);
  42. _ota_port = port;
  43. _ota_size = 0;
  44. if (_ota_client == NULL) {
  45. _ota_client = new AsyncClient();
  46. }
  47. _ota_client->onDisconnect([](void *s, AsyncClient *c) {
  48. DEBUG_MSG_P(PSTR("\n"));
  49. if (Update.end(true)){
  50. DEBUG_MSG_P(PSTR("[OTA] Success: %u bytes\n"), _ota_size);
  51. deferredReset(100, CUSTOM_RESET_OTA);
  52. } else {
  53. DEBUG_MSG_P(PSTR("[OTA] Error #%u\n"), Update.getError());
  54. eepromRotate(true);
  55. }
  56. DEBUG_MSG_P(PSTR("[OTA] Disconnected\n"));
  57. _ota_client->free();
  58. delete _ota_client;
  59. _ota_client = NULL;
  60. free(_ota_host);
  61. _ota_host = NULL;
  62. free(_ota_url);
  63. _ota_url = NULL;
  64. }, 0);
  65. _ota_client->onTimeout([](void *s, AsyncClient *c, uint32_t time) {
  66. _ota_client->close(true);
  67. }, 0);
  68. _ota_client->onData([](void * arg, AsyncClient * c, void * data, size_t len) {
  69. char * p = (char *) data;
  70. if (_ota_size == 0) {
  71. Update.runAsync(true);
  72. if (!Update.begin((ESP.getFreeSketchSpace() - 0x1000) & 0xFFFFF000)) {
  73. DEBUG_MSG_P(PSTR("[OTA] Error #%u\n"), Update.getError());
  74. }
  75. p = strstr((char *)data, "\r\n\r\n") + 4;
  76. len = len - (p - (char *) data);
  77. }
  78. if (!Update.hasError()) {
  79. if (Update.write((uint8_t *) p, len) != len) {
  80. DEBUG_MSG_P(PSTR("[OTA] Error #%u\n"), Update.getError());
  81. }
  82. }
  83. _ota_size += len;
  84. DEBUG_MSG_P(PSTR("[OTA] Progress: %u bytes\r"), _ota_size);
  85. }, NULL);
  86. _ota_client->onConnect([](void * arg, AsyncClient * client) {
  87. #if ASYNC_TCP_SSL_ENABLED
  88. if (443 == _ota_port) {
  89. uint8_t fp[20] = {0};
  90. sslFingerPrintArray(getSetting("otaFP", OTA_GITHUB_FP).c_str(), fp);
  91. SSL * ssl = _ota_client->getSSL();
  92. if (ssl_match_fingerprint(ssl, fp) != SSL_OK) {
  93. DEBUG_MSG_P(PSTR("[OTA] Warning: certificate doesn't match\n"));
  94. }
  95. }
  96. #endif
  97. // Disabling EEPROM rotation to prevent writing to EEPROM after the upgrade
  98. eepromRotate(false);
  99. DEBUG_MSG_P(PSTR("[OTA] Downloading %s\n"), _ota_url);
  100. char buffer[strlen_P(OTA_REQUEST_TEMPLATE) + strlen(_ota_url) + strlen(_ota_host)];
  101. snprintf_P(buffer, sizeof(buffer), OTA_REQUEST_TEMPLATE, _ota_url, _ota_host);
  102. client->write(buffer);
  103. }, NULL);
  104. #if ASYNC_TCP_SSL_ENABLED
  105. bool connected = _ota_client->connect(host, port, 443 == port);
  106. #else
  107. bool connected = _ota_client->connect(host, port);
  108. #endif
  109. if (!connected) {
  110. DEBUG_MSG_P(PSTR("[OTA] Connection failed\n"));
  111. _ota_client->close(true);
  112. }
  113. }
  114. void _otaFrom(String url) {
  115. // Port from protocol
  116. unsigned int port = 80;
  117. if (url.startsWith("https://")) port = 443;
  118. url = url.substring(url.indexOf("/") + 2);
  119. // Get host
  120. String host = url.substring(0, url.indexOf("/"));
  121. // Explicit port
  122. int p = host.indexOf(":");
  123. if (p > 0) {
  124. port = host.substring(p + 1).toInt();
  125. host = host.substring(0, p);
  126. }
  127. // Get URL
  128. String uri = url.substring(url.indexOf("/"));
  129. _otaFrom(host.c_str(), port, uri.c_str());
  130. }
  131. void _otaInitCommands() {
  132. settingsRegisterCommand(F("OTA"), [](Embedis* e) {
  133. if (e->argc < 2) {
  134. DEBUG_MSG_P(PSTR("-ERROR: Wrong arguments\n"));
  135. } else {
  136. DEBUG_MSG_P(PSTR("+OK\n"));
  137. String url = String(e->argv[1]);
  138. _otaFrom(url);
  139. }
  140. });
  141. }
  142. #endif // TERMINAL_SUPPORT
  143. bool _otaKeyCheck(const char * key) {
  144. return (strncmp(key, "ota", 3) == 0);
  145. }
  146. void _otaBackwards() {
  147. moveSetting("otafs", "otaFS");
  148. }
  149. // -----------------------------------------------------------------------------
  150. void otaSetup() {
  151. _otaBackwards();
  152. _otaConfigure();
  153. #if WEB_SUPPORT
  154. wsOnAfterParseRegister(_otaConfigure);
  155. #endif
  156. #if TERMINAL_SUPPORT
  157. _otaInitCommands();
  158. #endif
  159. // Register settings key check
  160. settingsRegisterKeyCheck(_otaKeyCheck);
  161. // Register loop
  162. espurnaRegisterLoop(_otaLoop);
  163. // -------------------------------------------------------------------------
  164. ArduinoOTA.onStart([]() {
  165. // Disabling EEPROM rotation to prevent writing to EEPROM after the upgrade
  166. eepromRotate(false);
  167. DEBUG_MSG_P(PSTR("[OTA] Start\n"));
  168. #if WEB_SUPPORT
  169. wsSend_P(PSTR("{\"message\": 2}"));
  170. #endif
  171. });
  172. ArduinoOTA.onEnd([]() {
  173. DEBUG_MSG_P(PSTR("\n"));
  174. DEBUG_MSG_P(PSTR("[OTA] Done, restarting...\n"));
  175. #if WEB_SUPPORT
  176. wsSend_P(PSTR("{\"action\": \"reload\"}"));
  177. #endif
  178. deferredReset(100, CUSTOM_RESET_OTA);
  179. });
  180. ArduinoOTA.onProgress([](unsigned int progress, unsigned int total) {
  181. static unsigned int _progOld;
  182. unsigned int _prog = (progress / (total / 100));
  183. if (_prog != _progOld) {
  184. DEBUG_MSG_P(PSTR("[OTA] Progress: %u%%\r"), _prog);
  185. _progOld = _prog;
  186. }
  187. });
  188. ArduinoOTA.onError([](ota_error_t error) {
  189. #if DEBUG_SUPPORT
  190. DEBUG_MSG_P(PSTR("\n[OTA] Error #%u: "), error);
  191. if (error == OTA_AUTH_ERROR) DEBUG_MSG_P(PSTR("Auth Failed\n"));
  192. else if (error == OTA_BEGIN_ERROR) DEBUG_MSG_P(PSTR("Begin Failed\n"));
  193. else if (error == OTA_CONNECT_ERROR) DEBUG_MSG_P(PSTR("Connect Failed\n"));
  194. else if (error == OTA_RECEIVE_ERROR) DEBUG_MSG_P(PSTR("Receive Failed\n"));
  195. else if (error == OTA_END_ERROR) DEBUG_MSG_P(PSTR("End Failed\n"));
  196. #endif
  197. eepromRotate(true);
  198. });
  199. ArduinoOTA.begin();
  200. }