Fork of the espurna firmware for `mhsw` switches
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

272 lines
7.1 KiB

  1. /*
  2. ASYNC CLIENT OTA MODULE
  3. Copyright (C) 2016-2019 by Xose Pérez <xose dot perez at gmail dot com>
  4. */
  5. #include "ota.h"
  6. #if OTA_CLIENT == OTA_CLIENT_ASYNCTCP
  7. // -----------------------------------------------------------------------------
  8. // Terminal and MQTT OTA command handlers
  9. // -----------------------------------------------------------------------------
  10. #include <Arduino.h>
  11. #include "espurna.h"
  12. #if TERMINAL_SUPPORT || OTA_MQTT_SUPPORT
  13. #include <Schedule.h>
  14. #include <ESPAsyncTCP.h>
  15. #include "mqtt.h"
  16. #include "system.h"
  17. #include "settings.h"
  18. #include "terminal.h"
  19. #include "libs/URL.h"
  20. const char OTA_REQUEST_TEMPLATE[] PROGMEM =
  21. "GET %s HTTP/1.1\r\n"
  22. "Host: %s\r\n"
  23. "User-Agent: ESPurna\r\n"
  24. "Connection: close\r\n\r\n";
  25. struct ota_client_t {
  26. enum state_t {
  27. HEADERS,
  28. DATA,
  29. END
  30. };
  31. ota_client_t() = delete;
  32. ota_client_t(const ota_client_t&) = delete;
  33. ota_client_t(URL&& url);
  34. bool connect();
  35. state_t state = HEADERS;
  36. size_t size = 0;
  37. const URL url;
  38. AsyncClient client;
  39. };
  40. std::unique_ptr<ota_client_t> _ota_client = nullptr;
  41. // -----------------------------------------------------------------------------
  42. void _otaClientDisconnect() {
  43. DEBUG_MSG_P(PSTR("[OTA] Disconnected\n"));
  44. _ota_client = nullptr;
  45. }
  46. void _otaClientOnDisconnect(void* arg, AsyncClient* client) {
  47. DEBUG_MSG_P(PSTR("\n"));
  48. otaFinalize(reinterpret_cast<ota_client_t*>(arg)->size, CUSTOM_RESET_OTA, true);
  49. schedule_function(_otaClientDisconnect);
  50. }
  51. void _otaClientOnTimeout(void*, AsyncClient * client, uint32_t) {
  52. client->close(true);
  53. }
  54. void _otaClientOnError(void*, AsyncClient* client, err_t error) {
  55. DEBUG_MSG_P(PSTR("[OTA] ERROR: %s\n"), client->errorToString(error));
  56. }
  57. void _otaClientOnData(void* arg, AsyncClient* client, void* data, size_t len) {
  58. ota_client_t* ota_client = reinterpret_cast<ota_client_t*>(arg);
  59. auto* ptr = (char *) data;
  60. // TODO: check status
  61. // TODO: check for 3xx, discover `Location:` header and schedule
  62. // another _otaClientFrom(location_header_url)
  63. if (_ota_client->state == ota_client_t::HEADERS) {
  64. ptr = (char *) strnstr((char *) data, "\r\n\r\n", len);
  65. if (!ptr) {
  66. return;
  67. }
  68. auto diff = ptr - ((char *) data);
  69. _ota_client->state = ota_client_t::DATA;
  70. len -= diff + 4;
  71. if (!len) {
  72. return;
  73. }
  74. ptr += 4;
  75. }
  76. if (ota_client->state == ota_client_t::DATA) {
  77. if (!ota_client->size) {
  78. // Check header before anything is written to the flash
  79. if (!otaVerifyHeader((uint8_t *) ptr, len)) {
  80. DEBUG_MSG_P(PSTR("[OTA] ERROR: No magic byte / invalid flash config"));
  81. client->close(true);
  82. ota_client->state = ota_client_t::END;
  83. return;
  84. }
  85. // XXX: In case of non-chunked response, really parse headers and specify size via content-length value
  86. Update.runAsync(true);
  87. if (!Update.begin((ESP.getFreeSketchSpace() - 0x1000) & 0xFFFFF000)) {
  88. otaPrintError();
  89. client->close(true);
  90. return;
  91. }
  92. }
  93. // We can enter this callback even after client->close()
  94. if (!Update.isRunning()) {
  95. return;
  96. }
  97. if (Update.write((uint8_t *) ptr, len) != len) {
  98. otaPrintError();
  99. client->close(true);
  100. ota_client->state = ota_client_t::END;
  101. return;
  102. }
  103. ota_client->size += len;
  104. otaProgress(ota_client->size);
  105. delay(0);
  106. }
  107. }
  108. void _otaClientOnConnect(void* arg, AsyncClient* client) {
  109. ota_client_t* ota_client = reinterpret_cast<ota_client_t*>(arg);
  110. #if ASYNC_TCP_SSL_ENABLED
  111. const auto check = getSetting("otaScCheck", OTA_SECURE_CLIENT_CHECK);
  112. if ((check == SECURE_CLIENT_CHECK_FINGERPRINT) && (443 == ota_client->url.port)) {
  113. uint8_t fp[20] = {0};
  114. sslFingerPrintArray(getSetting("otaFP", OTA_FINGERPRINT).c_str(), fp);
  115. SSL * ssl = client->getSSL();
  116. if (ssl_match_fingerprint(ssl, fp) != SSL_OK) {
  117. DEBUG_MSG_P(PSTR("[OTA] Warning: certificate fingerpint doesn't match\n"));
  118. client->close(true);
  119. return;
  120. }
  121. }
  122. #endif
  123. // Disabling EEPROM rotation to prevent writing to EEPROM after the upgrade
  124. eepromRotate(false);
  125. DEBUG_MSG_P(PSTR("[OTA] Downloading %s\n"), ota_client->url.path.c_str());
  126. char buffer[strlen_P(OTA_REQUEST_TEMPLATE) + ota_client->url.path.length() + ota_client->url.host.length()];
  127. snprintf_P(buffer, sizeof(buffer), OTA_REQUEST_TEMPLATE, ota_client->url.path.c_str(), ota_client->url.host.c_str());
  128. client->write(buffer);
  129. }
  130. ota_client_t::ota_client_t(URL&& url) :
  131. url(std::move(url))
  132. {
  133. client.setRxTimeout(5);
  134. client.onError(_otaClientOnError, nullptr);
  135. client.onTimeout(_otaClientOnTimeout, nullptr);
  136. client.onDisconnect(_otaClientOnDisconnect, this);
  137. client.onData(_otaClientOnData, this);
  138. client.onConnect(_otaClientOnConnect, this);
  139. }
  140. bool ota_client_t::connect() {
  141. #if ASYNC_TCP_SSL_ENABLED
  142. return client.connect(url.host.c_str(), url.port, 443 == url.port);
  143. #else
  144. return client.connect(url.host.c_str(), url.port);
  145. #endif
  146. }
  147. // -----------------------------------------------------------------------------
  148. void _otaClientFrom(const String& url) {
  149. if (_ota_client) {
  150. DEBUG_MSG_P(PSTR("[OTA] Already connected\n"));
  151. return;
  152. }
  153. URL _url(url);
  154. if (!_url.protocol.equals("http") && !_url.protocol.equals("https")) {
  155. DEBUG_MSG_P(PSTR("[OTA] Incorrect URL specified\n"));
  156. return;
  157. }
  158. _ota_client = std::make_unique<ota_client_t>(std::move(_url));
  159. if (!_ota_client->connect()) {
  160. DEBUG_MSG_P(PSTR("[OTA] Connection failed\n"));
  161. }
  162. }
  163. #endif // TERMINAL_SUPPORT || OTA_MQTT_SUPPORT
  164. #if TERMINAL_SUPPORT
  165. void _otaClientInitCommands() {
  166. terminalRegisterCommand(F("OTA"), [](Embedis* e) {
  167. if (e->argc < 2) {
  168. terminalError(F("OTA <url>"));
  169. } else {
  170. _otaClientFrom(String(e->argv[1]));
  171. terminalOK();
  172. }
  173. });
  174. }
  175. #endif // TERMINAL_SUPPORT
  176. #if OTA_MQTT_SUPPORT
  177. void _otaClientMqttCallback(unsigned int type, const char * topic, const char * payload) {
  178. if (type == MQTT_CONNECT_EVENT) {
  179. mqttSubscribe(MQTT_TOPIC_OTA);
  180. }
  181. if (type == MQTT_MESSAGE_EVENT) {
  182. String t = mqttMagnitude((char *) topic);
  183. if (t.equals(MQTT_TOPIC_OTA)) {
  184. DEBUG_MSG_P(PSTR("[OTA] Initiating from URL: %s\n"), payload);
  185. _otaClientFrom(payload);
  186. }
  187. }
  188. }
  189. #endif // OTA_MQTT_SUPPORT
  190. // -----------------------------------------------------------------------------
  191. void otaClientSetup() {
  192. // Backwards compatibility
  193. moveSetting("otafp", "otaFP");
  194. #if TERMINAL_SUPPORT
  195. _otaClientInitCommands();
  196. #endif
  197. #if (MQTT_SUPPORT && OTA_MQTT_SUPPORT)
  198. mqttRegister(_otaClientMqttCallback);
  199. #endif
  200. }
  201. #endif // OTA_CLIENT == OTA_CLIENT_ASYNCTCP