Fork of the espurna firmware for `mhsw` switches
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

267 lines
7.0 KiB

  1. /*
  2. ASYNC CLIENT OTA MODULE
  3. Copyright (C) 2016-2019 by Xose Pérez <xose dot perez at gmail dot com>
  4. */
  5. #if OTA_CLIENT == OTA_CLIENT_ASYNCTCP
  6. // -----------------------------------------------------------------------------
  7. // Terminal and MQTT OTA command handlers
  8. // -----------------------------------------------------------------------------
  9. #if TERMINAL_SUPPORT || OTA_MQTT_SUPPORT
  10. #include <Schedule.h>
  11. #include <ESPAsyncTCP.h>
  12. #include "mqtt.h"
  13. #include "ota.h"
  14. #include "system.h"
  15. #include "terminal.h"
  16. #include "libs/URL.h"
  17. const char OTA_REQUEST_TEMPLATE[] PROGMEM =
  18. "GET %s HTTP/1.1\r\n"
  19. "Host: %s\r\n"
  20. "User-Agent: ESPurna\r\n"
  21. "Connection: close\r\n\r\n";
  22. struct ota_client_t {
  23. enum state_t {
  24. HEADERS,
  25. DATA,
  26. END
  27. };
  28. ota_client_t() = delete;
  29. ota_client_t(const ota_client_t&) = delete;
  30. ota_client_t(URL&& url);
  31. bool connect();
  32. state_t state = HEADERS;
  33. size_t size = 0;
  34. const URL url;
  35. AsyncClient client;
  36. };
  37. std::unique_ptr<ota_client_t> _ota_client = nullptr;
  38. // -----------------------------------------------------------------------------
  39. void _otaClientDisconnect() {
  40. DEBUG_MSG_P(PSTR("[OTA] Disconnected\n"));
  41. _ota_client = nullptr;
  42. }
  43. void _otaClientOnDisconnect(void* arg, AsyncClient* client) {
  44. DEBUG_MSG_P(PSTR("\n"));
  45. otaFinalize(reinterpret_cast<ota_client_t*>(arg)->size, CUSTOM_RESET_OTA, true);
  46. schedule_function(_otaClientDisconnect);
  47. }
  48. void _otaClientOnTimeout(void*, AsyncClient * client, uint32_t) {
  49. client->close(true);
  50. }
  51. void _otaClientOnError(void*, AsyncClient* client, err_t error) {
  52. DEBUG_MSG_P(PSTR("[OTA] ERROR: %s\n"), client->errorToString(error));
  53. }
  54. void _otaClientOnData(void* arg, AsyncClient* client, void* data, size_t len) {
  55. ota_client_t* ota_client = reinterpret_cast<ota_client_t*>(arg);
  56. auto* ptr = (char *) data;
  57. // TODO: check status
  58. // TODO: check for 3xx, discover `Location:` header and schedule
  59. // another _otaClientFrom(location_header_url)
  60. if (_ota_client->state == ota_client_t::HEADERS) {
  61. ptr = (char *) strnstr((char *) data, "\r\n\r\n", len);
  62. if (!ptr) {
  63. return;
  64. }
  65. auto diff = ptr - ((char *) data);
  66. _ota_client->state = ota_client_t::DATA;
  67. len -= diff + 4;
  68. if (!len) {
  69. return;
  70. }
  71. ptr += 4;
  72. }
  73. if (ota_client->state == ota_client_t::DATA) {
  74. if (!ota_client->size) {
  75. // Check header before anything is written to the flash
  76. if (!otaVerifyHeader((uint8_t *) ptr, len)) {
  77. DEBUG_MSG_P(PSTR("[OTA] ERROR: No magic byte / invalid flash config"));
  78. client->close(true);
  79. ota_client->state = ota_client_t::END;
  80. return;
  81. }
  82. // XXX: In case of non-chunked response, really parse headers and specify size via content-length value
  83. Update.runAsync(true);
  84. if (!Update.begin((ESP.getFreeSketchSpace() - 0x1000) & 0xFFFFF000)) {
  85. otaPrintError();
  86. client->close(true);
  87. return;
  88. }
  89. }
  90. // We can enter this callback even after client->close()
  91. if (!Update.isRunning()) {
  92. return;
  93. }
  94. if (Update.write((uint8_t *) ptr, len) != len) {
  95. otaPrintError();
  96. client->close(true);
  97. ota_client->state = ota_client_t::END;
  98. return;
  99. }
  100. ota_client->size += len;
  101. otaProgress(ota_client->size);
  102. delay(0);
  103. }
  104. }
  105. void _otaClientOnConnect(void* arg, AsyncClient* client) {
  106. ota_client_t* ota_client = reinterpret_cast<ota_client_t*>(arg);
  107. #if ASYNC_TCP_SSL_ENABLED
  108. const auto check = getSetting("otaScCheck", OTA_SECURE_CLIENT_CHECK);
  109. if ((check == SECURE_CLIENT_CHECK_FINGERPRINT) && (443 == ota_client->url.port)) {
  110. uint8_t fp[20] = {0};
  111. sslFingerPrintArray(getSetting("otaFP", OTA_FINGERPRINT).c_str(), fp);
  112. SSL * ssl = client->getSSL();
  113. if (ssl_match_fingerprint(ssl, fp) != SSL_OK) {
  114. DEBUG_MSG_P(PSTR("[OTA] Warning: certificate fingerpint doesn't match\n"));
  115. client->close(true);
  116. return;
  117. }
  118. }
  119. #endif
  120. // Disabling EEPROM rotation to prevent writing to EEPROM after the upgrade
  121. eepromRotate(false);
  122. DEBUG_MSG_P(PSTR("[OTA] Downloading %s\n"), ota_client->url.path.c_str());
  123. char buffer[strlen_P(OTA_REQUEST_TEMPLATE) + ota_client->url.path.length() + ota_client->url.host.length()];
  124. snprintf_P(buffer, sizeof(buffer), OTA_REQUEST_TEMPLATE, ota_client->url.path.c_str(), ota_client->url.host.c_str());
  125. client->write(buffer);
  126. }
  127. ota_client_t::ota_client_t(URL&& url) :
  128. url(std::move(url))
  129. {
  130. client.setRxTimeout(5);
  131. client.onError(_otaClientOnError, nullptr);
  132. client.onTimeout(_otaClientOnTimeout, nullptr);
  133. client.onDisconnect(_otaClientOnDisconnect, this);
  134. client.onData(_otaClientOnData, this);
  135. client.onConnect(_otaClientOnConnect, this);
  136. }
  137. bool ota_client_t::connect() {
  138. #if ASYNC_TCP_SSL_ENABLED
  139. return client.connect(url.host.c_str(), url.port, 443 == url.port);
  140. #else
  141. return client.connect(url.host.c_str(), url.port);
  142. #endif
  143. }
  144. // -----------------------------------------------------------------------------
  145. void _otaClientFrom(const String& url) {
  146. if (_ota_client) {
  147. DEBUG_MSG_P(PSTR("[OTA] Already connected\n"));
  148. return;
  149. }
  150. URL _url(url);
  151. if (!_url.protocol.equals("http") && !_url.protocol.equals("https")) {
  152. DEBUG_MSG_P(PSTR("[OTA] Incorrect URL specified\n"));
  153. return;
  154. }
  155. _ota_client = std::make_unique<ota_client_t>(std::move(_url));
  156. if (!_ota_client->connect()) {
  157. DEBUG_MSG_P(PSTR("[OTA] Connection failed\n"));
  158. }
  159. }
  160. #endif // TERMINAL_SUPPORT || OTA_MQTT_SUPPORT
  161. #if TERMINAL_SUPPORT
  162. void _otaClientInitCommands() {
  163. terminalRegisterCommand(F("OTA"), [](Embedis* e) {
  164. if (e->argc < 2) {
  165. terminalError(F("OTA <url>"));
  166. } else {
  167. _otaClientFrom(String(e->argv[1]));
  168. terminalOK();
  169. }
  170. });
  171. }
  172. #endif // TERMINAL_SUPPORT
  173. #if OTA_MQTT_SUPPORT
  174. void _otaClientMqttCallback(unsigned int type, const char * topic, const char * payload) {
  175. if (type == MQTT_CONNECT_EVENT) {
  176. mqttSubscribe(MQTT_TOPIC_OTA);
  177. }
  178. if (type == MQTT_MESSAGE_EVENT) {
  179. String t = mqttMagnitude((char *) topic);
  180. if (t.equals(MQTT_TOPIC_OTA)) {
  181. DEBUG_MSG_P(PSTR("[OTA] Initiating from URL: %s\n"), payload);
  182. _otaClientFrom(payload);
  183. }
  184. }
  185. }
  186. #endif // OTA_MQTT_SUPPORT
  187. // -----------------------------------------------------------------------------
  188. void otaClientSetup() {
  189. // Backwards compatibility
  190. moveSetting("otafp", "otaFP");
  191. #if TERMINAL_SUPPORT
  192. _otaClientInitCommands();
  193. #endif
  194. #if (MQTT_SUPPORT && OTA_MQTT_SUPPORT)
  195. mqttRegister(_otaClientMqttCallback);
  196. #endif
  197. }
  198. #endif // OTA_CLIENT == OTA_CLIENT_ASYNCTCP