Fork of the espurna firmware for `mhsw` switches
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

275 lines
7.1 KiB

7 years ago
  1. /*
  2. OTA MODULE
  3. Copyright (C) 2016-2018 by Xose Pérez <xose dot perez at gmail dot com>
  4. Module key prefix: ota
  5. */
  6. #include "ArduinoOTA.h"
  7. // -----------------------------------------------------------------------------
  8. // Arduino OTA
  9. // -----------------------------------------------------------------------------
  10. void _otaConfigure() {
  11. ArduinoOTA.setPort(OTA_PORT);
  12. ArduinoOTA.setHostname(getHostname().c_str());
  13. #if USE_PASSWORD
  14. ArduinoOTA.setPassword(getPassword().c_str());
  15. #endif
  16. }
  17. void _otaLoop() {
  18. ArduinoOTA.handle();
  19. }
  20. // -----------------------------------------------------------------------------
  21. // Terminal OTA
  22. // -----------------------------------------------------------------------------
  23. #if TERMINAL_SUPPORT
  24. #include <ESPAsyncTCP.h>
  25. AsyncClient * _ota_client;
  26. char * _ota_host;
  27. char * _ota_url;
  28. unsigned int _ota_port = 80;
  29. unsigned long _ota_size = 0;
  30. const char OTA_REQUEST_TEMPLATE[] PROGMEM =
  31. "GET %s HTTP/1.1\r\n"
  32. "Host: %s\r\n"
  33. "User-Agent: ESPurna\r\n"
  34. "Connection: close\r\n"
  35. "Content-Type: application/x-www-form-urlencoded\r\n"
  36. "Content-Length: 0\r\n\r\n\r\n";
  37. void _otaFrom(const char * host, unsigned int port, const char * url) {
  38. if (_ota_host) free(_ota_host);
  39. if (_ota_url) free(_ota_url);
  40. _ota_host = strdup(host);
  41. _ota_url = strdup(url);
  42. _ota_port = port;
  43. _ota_size = 0;
  44. if (_ota_client == NULL) {
  45. _ota_client = new AsyncClient();
  46. }
  47. _ota_client->onDisconnect([](void *s, AsyncClient *c) {
  48. DEBUG_MSG_P(PSTR("\n"));
  49. if (Update.end(true)){
  50. DEBUG_MSG_P(PSTR("[OTA] Success: %u bytes\n"), _ota_size);
  51. deferredReset(100, CUSTOM_RESET_OTA);
  52. } else {
  53. #ifdef DEBUG_PORT
  54. Update.printError(DEBUG_PORT);
  55. #endif
  56. eepromRotate(true);
  57. }
  58. DEBUG_MSG_P(PSTR("[OTA] Disconnected\n"));
  59. _ota_client->free();
  60. delete _ota_client;
  61. _ota_client = NULL;
  62. free(_ota_host);
  63. _ota_host = NULL;
  64. free(_ota_url);
  65. _ota_url = NULL;
  66. }, 0);
  67. _ota_client->onTimeout([](void *s, AsyncClient *c, uint32_t time) {
  68. _ota_client->close(true);
  69. }, 0);
  70. _ota_client->onData([](void * arg, AsyncClient * c, void * data, size_t len) {
  71. char * p = (char *) data;
  72. if (_ota_size == 0) {
  73. Update.runAsync(true);
  74. if (!Update.begin((ESP.getFreeSketchSpace() - 0x1000) & 0xFFFFF000)) {
  75. #ifdef DEBUG_PORT
  76. Update.printError(DEBUG_PORT);
  77. #endif
  78. }
  79. p = strstr((char *)data, "\r\n\r\n") + 4;
  80. len = len - (p - (char *) data);
  81. }
  82. if (!Update.hasError()) {
  83. if (Update.write((uint8_t *) p, len) != len) {
  84. #ifdef DEBUG_PORT
  85. Update.printError(DEBUG_PORT);
  86. #endif
  87. }
  88. }
  89. _ota_size += len;
  90. DEBUG_MSG_P(PSTR("[OTA] Progress: %u bytes\r"), _ota_size);
  91. }, NULL);
  92. _ota_client->onConnect([](void * arg, AsyncClient * client) {
  93. #if ASYNC_TCP_SSL_ENABLED
  94. if (443 == _ota_port) {
  95. uint8_t fp[20] = {0};
  96. sslFingerPrintArray(getSetting("otaFP", OTA_GITHUB_FP).c_str(), fp);
  97. SSL * ssl = _ota_client->getSSL();
  98. if (ssl_match_fingerprint(ssl, fp) != SSL_OK) {
  99. DEBUG_MSG_P(PSTR("[OTA] Warning: certificate doesn't match\n"));
  100. }
  101. }
  102. #endif
  103. // Disabling EEPROM rotation to prevent writing to EEPROM after the upgrade
  104. eepromRotate(false);
  105. DEBUG_MSG_P(PSTR("[OTA] Downloading %s\n"), _ota_url);
  106. char buffer[strlen_P(OTA_REQUEST_TEMPLATE) + strlen(_ota_url) + strlen(_ota_host)];
  107. snprintf_P(buffer, sizeof(buffer), OTA_REQUEST_TEMPLATE, _ota_url, _ota_host);
  108. client->write(buffer);
  109. }, NULL);
  110. #if ASYNC_TCP_SSL_ENABLED
  111. bool connected = _ota_client->connect(host, port, 443 == port);
  112. #else
  113. bool connected = _ota_client->connect(host, port);
  114. #endif
  115. if (!connected) {
  116. DEBUG_MSG_P(PSTR("[OTA] Connection failed\n"));
  117. _ota_client->close(true);
  118. }
  119. }
  120. void _otaFrom(String url) {
  121. // Port from protocol
  122. unsigned int port = 80;
  123. if (url.startsWith("https://")) port = 443;
  124. url = url.substring(url.indexOf("/") + 2);
  125. // Get host
  126. String host = url.substring(0, url.indexOf("/"));
  127. // Explicit port
  128. int p = host.indexOf(":");
  129. if (p > 0) {
  130. port = host.substring(p + 1).toInt();
  131. host = host.substring(0, p);
  132. }
  133. // Get URL
  134. String uri = url.substring(url.indexOf("/"));
  135. _otaFrom(host.c_str(), port, uri.c_str());
  136. }
  137. void _otaInitCommands() {
  138. settingsRegisterCommand(F("OTA"), [](Embedis* e) {
  139. if (e->argc < 2) {
  140. DEBUG_MSG_P(PSTR("-ERROR: Wrong arguments\n"));
  141. } else {
  142. DEBUG_MSG_P(PSTR("+OK\n"));
  143. String url = String(e->argv[1]);
  144. _otaFrom(url);
  145. }
  146. });
  147. }
  148. #endif // TERMINAL_SUPPORT
  149. #if WEB_SUPPORT
  150. bool _otaWebSocketOnReceive(const char * key, JsonVariant& value) {
  151. return (strncmp(key, "ota", 3) == 0);
  152. }
  153. #endif // WEB_SUPPORT
  154. void _otaBackwards() {
  155. moveSetting("otafs", "otaFS");
  156. }
  157. // -----------------------------------------------------------------------------
  158. void otaSetup() {
  159. _otaBackwards();
  160. _otaConfigure();
  161. #if WEB_SUPPORT
  162. wsOnAfterParseRegister(_otaConfigure);
  163. wsOnReceiveRegister(_otaWebSocketOnReceive);
  164. #endif
  165. #if TERMINAL_SUPPORT
  166. _otaInitCommands();
  167. #endif
  168. // Register loop
  169. espurnaRegisterLoop(_otaLoop);
  170. // -------------------------------------------------------------------------
  171. ArduinoOTA.onStart([]() {
  172. // Disabling EEPROM rotation to prevent writing to EEPROM after the upgrade
  173. eepromRotate(false);
  174. DEBUG_MSG_P(PSTR("[OTA] Start\n"));
  175. #if WEB_SUPPORT
  176. wsSend_P(PSTR("{\"message\": 2}"));
  177. #endif
  178. });
  179. ArduinoOTA.onEnd([]() {
  180. DEBUG_MSG_P(PSTR("\n"));
  181. DEBUG_MSG_P(PSTR("[OTA] Done, restarting...\n"));
  182. #if WEB_SUPPORT
  183. wsSend_P(PSTR("{\"action\": \"reload\"}"));
  184. #endif
  185. deferredReset(100, CUSTOM_RESET_OTA);
  186. });
  187. ArduinoOTA.onProgress([](unsigned int progress, unsigned int total) {
  188. DEBUG_MSG_P(PSTR("[OTA] Progress: %u%%\r"), (progress / (total / 100)));
  189. });
  190. ArduinoOTA.onError([](ota_error_t error) {
  191. #if DEBUG_SUPPORT
  192. DEBUG_MSG_P(PSTR("\n[OTA] Error #%u: "), error);
  193. if (error == OTA_AUTH_ERROR) DEBUG_MSG_P(PSTR("Auth Failed\n"));
  194. else if (error == OTA_BEGIN_ERROR) DEBUG_MSG_P(PSTR("Begin Failed\n"));
  195. else if (error == OTA_CONNECT_ERROR) DEBUG_MSG_P(PSTR("Connect Failed\n"));
  196. else if (error == OTA_RECEIVE_ERROR) DEBUG_MSG_P(PSTR("Receive Failed\n"));
  197. else if (error == OTA_END_ERROR) DEBUG_MSG_P(PSTR("End Failed\n"));
  198. #endif
  199. eepromRotate(true);
  200. });
  201. ArduinoOTA.begin();
  202. }