Fork of the espurna firmware for `mhsw` switches
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

265 lines
7.0 KiB

  1. /*
  2. ASYNC CLIENT OTA MODULE
  3. Copyright (C) 2016-2019 by Xose Pérez <xose dot perez at gmail dot com>
  4. */
  5. #if OTA_CLIENT == OTA_CLIENT_ASYNCTCP
  6. // -----------------------------------------------------------------------------
  7. // Terminal and MQTT OTA command handlers
  8. // -----------------------------------------------------------------------------
  9. #if TERMINAL_SUPPORT || OTA_MQTT_SUPPORT
  10. #include <Schedule.h>
  11. #include <ESPAsyncTCP.h>
  12. #include "system.h"
  13. #include "ota.h"
  14. #include "libs/URL.h"
  15. const char OTA_REQUEST_TEMPLATE[] PROGMEM =
  16. "GET %s HTTP/1.1\r\n"
  17. "Host: %s\r\n"
  18. "User-Agent: ESPurna\r\n"
  19. "Connection: close\r\n\r\n";
  20. struct ota_client_t {
  21. enum state_t {
  22. HEADERS,
  23. DATA,
  24. END
  25. };
  26. ota_client_t() = delete;
  27. ota_client_t(const ota_client_t&) = delete;
  28. ota_client_t(URL&& url);
  29. bool connect();
  30. state_t state = HEADERS;
  31. size_t size = 0;
  32. const URL url;
  33. AsyncClient client;
  34. };
  35. std::unique_ptr<ota_client_t> _ota_client = nullptr;
  36. // -----------------------------------------------------------------------------
  37. void _otaClientDisconnect() {
  38. DEBUG_MSG_P(PSTR("[OTA] Disconnected\n"));
  39. _ota_client = nullptr;
  40. }
  41. void _otaClientOnDisconnect(void* arg, AsyncClient* client) {
  42. DEBUG_MSG_P(PSTR("\n"));
  43. otaFinalize(reinterpret_cast<ota_client_t*>(arg)->size, CUSTOM_RESET_OTA, true);
  44. schedule_function(_otaClientDisconnect);
  45. }
  46. void _otaClientOnTimeout(void*, AsyncClient * client, uint32_t) {
  47. client->close(true);
  48. }
  49. void _otaClientOnError(void*, AsyncClient* client, err_t error) {
  50. DEBUG_MSG_P(PSTR("[OTA] ERROR: %s\n"), client->errorToString(error));
  51. }
  52. void _otaClientOnData(void* arg, AsyncClient* client, void* data, size_t len) {
  53. ota_client_t* ota_client = reinterpret_cast<ota_client_t*>(arg);
  54. auto* ptr = (char *) data;
  55. // TODO: check status
  56. // TODO: check for 3xx, discover `Location:` header and schedule
  57. // another _otaClientFrom(location_header_url)
  58. if (_ota_client->state == ota_client_t::HEADERS) {
  59. ptr = (char *) strnstr((char *) data, "\r\n\r\n", len);
  60. if (!ptr) {
  61. return;
  62. }
  63. auto diff = ptr - ((char *) data);
  64. _ota_client->state = ota_client_t::DATA;
  65. len -= diff + 4;
  66. if (!len) {
  67. return;
  68. }
  69. ptr += 4;
  70. }
  71. if (ota_client->state == ota_client_t::DATA) {
  72. if (!ota_client->size) {
  73. // Check header before anything is written to the flash
  74. if (!otaVerifyHeader((uint8_t *) ptr, len)) {
  75. DEBUG_MSG_P(PSTR("[OTA] ERROR: No magic byte / invalid flash config"));
  76. client->close(true);
  77. ota_client->state = ota_client_t::END;
  78. return;
  79. }
  80. // XXX: In case of non-chunked response, really parse headers and specify size via content-length value
  81. Update.runAsync(true);
  82. if (!Update.begin((ESP.getFreeSketchSpace() - 0x1000) & 0xFFFFF000)) {
  83. otaPrintError();
  84. client->close(true);
  85. return;
  86. }
  87. }
  88. // We can enter this callback even after client->close()
  89. if (!Update.isRunning()) {
  90. return;
  91. }
  92. if (Update.write((uint8_t *) ptr, len) != len) {
  93. otaPrintError();
  94. client->close(true);
  95. ota_client->state = ota_client_t::END;
  96. return;
  97. }
  98. ota_client->size += len;
  99. otaProgress(ota_client->size);
  100. delay(0);
  101. }
  102. }
  103. void _otaClientOnConnect(void* arg, AsyncClient* client) {
  104. ota_client_t* ota_client = reinterpret_cast<ota_client_t*>(arg);
  105. #if ASYNC_TCP_SSL_ENABLED
  106. int check = getSetting("otaScCheck", OTA_SECURE_CLIENT_CHECK).toInt();
  107. if ((check == SECURE_CLIENT_CHECK_FINGERPRINT) && (443 == _ota_url->port)) {
  108. uint8_t fp[20] = {0};
  109. sslFingerPrintArray(getSetting("otaFP", OTA_FINGERPRINT).c_str(), fp);
  110. SSL * ssl = client->getSSL();
  111. if (ssl_match_fingerprint(ssl, fp) != SSL_OK) {
  112. DEBUG_MSG_P(PSTR("[OTA] Warning: certificate fingerpint doesn't match\n"));
  113. client->close(true);
  114. return;
  115. }
  116. }
  117. #endif
  118. // Disabling EEPROM rotation to prevent writing to EEPROM after the upgrade
  119. eepromRotate(false);
  120. DEBUG_MSG_P(PSTR("[OTA] Downloading %s\n"), ota_client->url.path.c_str());
  121. char buffer[strlen_P(OTA_REQUEST_TEMPLATE) + ota_client->url.path.length() + ota_client->url.host.length()];
  122. snprintf_P(buffer, sizeof(buffer), OTA_REQUEST_TEMPLATE, ota_client->url.path.c_str(), ota_client->url.host.c_str());
  123. client->write(buffer);
  124. }
  125. ota_client_t::ota_client_t(URL&& url) :
  126. url(std::move(url))
  127. {
  128. client.setRxTimeout(5);
  129. client.onError(_otaClientOnError, nullptr);
  130. client.onTimeout(_otaClientOnTimeout, nullptr);
  131. client.onDisconnect(_otaClientOnDisconnect, this);
  132. client.onData(_otaClientOnData, this);
  133. client.onConnect(_otaClientOnConnect, this);
  134. }
  135. bool ota_client_t::connect() {
  136. #if ASYNC_TCP_SSL_ENABLED
  137. return client.connect(url.host.c_str(), url.port, 443 == url.port);
  138. #else
  139. return client.connect(url.host.c_str(), url.port);
  140. #endif
  141. }
  142. // -----------------------------------------------------------------------------
  143. void _otaClientFrom(const String& url) {
  144. if (_ota_client) {
  145. DEBUG_MSG_P(PSTR("[OTA] Already connected\n"));
  146. return;
  147. }
  148. URL _url(url);
  149. if (!_url.protocol.equals("http") && !_url.protocol.equals("https")) {
  150. DEBUG_MSG_P(PSTR("[OTA] Incorrect URL specified\n"));
  151. return;
  152. }
  153. _ota_client = std::make_unique<ota_client_t>(std::move(_url));
  154. if (!_ota_client->connect()) {
  155. DEBUG_MSG_P(PSTR("[OTA] Connection failed\n"));
  156. }
  157. }
  158. #endif // TERMINAL_SUPPORT || OTA_MQTT_SUPPORT
  159. #if TERMINAL_SUPPORT
  160. void _otaClientInitCommands() {
  161. terminalRegisterCommand(F("OTA"), [](Embedis* e) {
  162. if (e->argc < 2) {
  163. terminalError(F("OTA <url>"));
  164. } else {
  165. _otaClientFrom(String(e->argv[1]));
  166. terminalOK();
  167. }
  168. });
  169. }
  170. #endif // TERMINAL_SUPPORT
  171. #if OTA_MQTT_SUPPORT
  172. void _otaClientMqttCallback(unsigned int type, const char * topic, const char * payload) {
  173. if (type == MQTT_CONNECT_EVENT) {
  174. mqttSubscribe(MQTT_TOPIC_OTA);
  175. }
  176. if (type == MQTT_MESSAGE_EVENT) {
  177. String t = mqttMagnitude((char *) topic);
  178. if (t.equals(MQTT_TOPIC_OTA)) {
  179. DEBUG_MSG_P(PSTR("[OTA] Initiating from URL: %s\n"), payload);
  180. _otaClientFrom(payload);
  181. }
  182. }
  183. }
  184. #endif // OTA_MQTT_SUPPORT
  185. // -----------------------------------------------------------------------------
  186. void otaClientSetup() {
  187. // Backwards compatibility
  188. moveSetting("otafp", "otaFP");
  189. #if TERMINAL_SUPPORT
  190. _otaClientInitCommands();
  191. #endif
  192. #if (MQTT_SUPPORT && OTA_MQTT_SUPPORT)
  193. mqttRegister(_otaClientMqttCallback);
  194. #endif
  195. }
  196. #endif // OTA_CLIENT == OTA_CLIENT_ASYNCTCP