/* ASYNC CLIENT OTA MODULE Copyright (C) 2016-2019 by Xose PĂ©rez */ #include "ota.h" #if OTA_CLIENT == OTA_CLIENT_ASYNCTCP // ----------------------------------------------------------------------------- // Terminal and MQTT OTA command handlers // ----------------------------------------------------------------------------- #include #include "espurna.h" #if TERMINAL_SUPPORT || OTA_MQTT_SUPPORT #include #include "mqtt.h" #include "system.h" #include "settings.h" #include "terminal.h" #include "libs/URL.h" #include #include const char OTA_REQUEST_TEMPLATE[] PROGMEM = "GET %s HTTP/1.1\r\n" "Host: %s\r\n" "User-Agent: ESPurna\r\n" "Connection: close\r\n\r\n"; struct ota_client_t { enum state_t { HEADERS, DATA, END }; ota_client_t() = delete; ota_client_t(const ota_client_t&) = delete; ota_client_t(URL&& url); bool connect(); state_t state = HEADERS; size_t size = 0; const URL url; AsyncClient client; }; std::unique_ptr _ota_client = nullptr; // ----------------------------------------------------------------------------- void _otaClientDisconnect() { DEBUG_MSG_P(PSTR("[OTA] Disconnected\n")); _ota_client = nullptr; } void _otaClientOnDisconnect(void* arg, AsyncClient* client) { DEBUG_MSG_P(PSTR("\n")); otaFinalize(reinterpret_cast(arg)->size, CUSTOM_RESET_OTA, true); schedule_function(_otaClientDisconnect); } void _otaClientOnTimeout(void*, AsyncClient * client, uint32_t) { client->close(true); } void _otaClientOnError(void*, AsyncClient* client, err_t error) { DEBUG_MSG_P(PSTR("[OTA] ERROR: %s\n"), client->errorToString(error)); } void _otaClientOnData(void* arg, AsyncClient* client, void* data, size_t len) { ota_client_t* ota_client = reinterpret_cast(arg); auto* ptr = (char *) data; // TODO: check status // TODO: check for 3xx, discover `Location:` header and schedule // another _otaClientFrom(location_header_url) if (_ota_client->state == ota_client_t::HEADERS) { ptr = (char *) strnstr((char *) data, "\r\n\r\n", len); if (!ptr) { return; } auto diff = ptr - ((char *) data); _ota_client->state = ota_client_t::DATA; len -= diff + 4; if (!len) { return; } ptr += 4; } if (ota_client->state == ota_client_t::DATA) { if (!ota_client->size) { // Check header before anything is written to the flash if (!otaVerifyHeader((uint8_t *) ptr, len)) { DEBUG_MSG_P(PSTR("[OTA] ERROR: No magic byte / invalid flash config")); client->close(true); ota_client->state = ota_client_t::END; return; } // XXX: In case of non-chunked response, really parse headers and specify size via content-length value Update.runAsync(true); if (!Update.begin((ESP.getFreeSketchSpace() - 0x1000) & 0xFFFFF000)) { otaPrintError(); client->close(true); return; } } // We can enter this callback even after client->close() if (!Update.isRunning()) { return; } if (Update.write((uint8_t *) ptr, len) != len) { otaPrintError(); client->close(true); ota_client->state = ota_client_t::END; return; } ota_client->size += len; otaProgress(ota_client->size); delay(0); } } void _otaClientOnConnect(void* arg, AsyncClient* client) { ota_client_t* ota_client = reinterpret_cast(arg); #if ASYNC_TCP_SSL_ENABLED const auto check = getSetting("otaScCheck", OTA_SECURE_CLIENT_CHECK); if ((check == SECURE_CLIENT_CHECK_FINGERPRINT) && (443 == ota_client->url.port)) { uint8_t fp[20] = {0}; sslFingerPrintArray(getSetting("otaFP", OTA_FINGERPRINT).c_str(), fp); SSL * ssl = client->getSSL(); if (ssl_match_fingerprint(ssl, fp) != SSL_OK) { DEBUG_MSG_P(PSTR("[OTA] Warning: certificate fingerpint doesn't match\n")); client->close(true); return; } } #endif // Disabling EEPROM rotation to prevent writing to EEPROM after the upgrade eepromRotate(false); DEBUG_MSG_P(PSTR("[OTA] Downloading %s\n"), ota_client->url.path.c_str()); char buffer[strlen_P(OTA_REQUEST_TEMPLATE) + ota_client->url.path.length() + ota_client->url.host.length()]; snprintf_P(buffer, sizeof(buffer), OTA_REQUEST_TEMPLATE, ota_client->url.path.c_str(), ota_client->url.host.c_str()); client->write(buffer); } ota_client_t::ota_client_t(URL&& url) : url(std::move(url)) { client.setRxTimeout(5); client.onError(_otaClientOnError, nullptr); client.onTimeout(_otaClientOnTimeout, nullptr); client.onDisconnect(_otaClientOnDisconnect, this); client.onData(_otaClientOnData, this); client.onConnect(_otaClientOnConnect, this); } bool ota_client_t::connect() { #if ASYNC_TCP_SSL_ENABLED return client.connect(url.host.c_str(), url.port, 443 == url.port); #else return client.connect(url.host.c_str(), url.port); #endif } // ----------------------------------------------------------------------------- void _otaClientFrom(const String& url) { if (_ota_client) { DEBUG_MSG_P(PSTR("[OTA] Already connected\n")); return; } URL _url(url); if (!_url.protocol.equals("http") && !_url.protocol.equals("https")) { DEBUG_MSG_P(PSTR("[OTA] Incorrect URL specified\n")); return; } _ota_client = std::make_unique(std::move(_url)); if (!_ota_client->connect()) { DEBUG_MSG_P(PSTR("[OTA] Connection failed\n")); } } #endif // TERMINAL_SUPPORT || OTA_MQTT_SUPPORT #if TERMINAL_SUPPORT void _otaClientInitCommands() { terminalRegisterCommand(F("OTA"), [](const terminal::CommandContext& ctx) { if (ctx.argc < 2) { terminalError(F("OTA ")); } else { _otaClientFrom(ctx.argv[1]); terminalOK(); } }); } #endif // TERMINAL_SUPPORT #if OTA_MQTT_SUPPORT void _otaClientMqttCallback(unsigned int type, const char * topic, const char * payload) { if (type == MQTT_CONNECT_EVENT) { mqttSubscribe(MQTT_TOPIC_OTA); } if (type == MQTT_MESSAGE_EVENT) { String t = mqttMagnitude((char *) topic); if (t.equals(MQTT_TOPIC_OTA)) { DEBUG_MSG_P(PSTR("[OTA] Initiating from URL: %s\n"), payload); _otaClientFrom(payload); } } } #endif // OTA_MQTT_SUPPORT // ----------------------------------------------------------------------------- void otaClientSetup() { // Backwards compatibility moveSetting("otafp", "otaFP"); #if TERMINAL_SUPPORT _otaClientInitCommands(); #endif #if (MQTT_SUPPORT && OTA_MQTT_SUPPORT) mqttRegister(_otaClientMqttCallback); #endif } #endif // OTA_CLIENT == OTA_CLIENT_ASYNCTCP