Fork of the espurna firmware for `mhsw` switches
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

325 lines
7.0 KiB

  1. #ifndef LWIP_OPEN_SRC
  2. #define LWIP_OPEN_SRC
  3. #endif
  4. #include <functional>
  5. #include <WiFiUdp.h>
  6. #include "ArduinoOTA.h"
  7. #include "MD5Builder.h"
  8. extern "C" {
  9. #include "osapi.h"
  10. #include "ets_sys.h"
  11. #include "user_interface.h"
  12. }
  13. #include "lwip/opt.h"
  14. #include "lwip/udp.h"
  15. #include "lwip/inet.h"
  16. #include "lwip/igmp.h"
  17. #include "lwip/mem.h"
  18. #include "include/UdpContext.h"
  19. #include <ESP8266mDNS.h>
  20. #ifdef DEBUG_ESP_OTA
  21. #ifdef DEBUG_ESP_PORT
  22. #define OTA_DEBUG DEBUG_ESP_PORT
  23. #endif
  24. #endif
  25. ArduinoOTAClass::ArduinoOTAClass()
  26. : _port(0)
  27. , _udp_ota(0)
  28. , _initialized(false)
  29. , _state(OTA_IDLE)
  30. , _size(0)
  31. , _cmd(0)
  32. , _ota_port(0)
  33. , _start_callback(NULL)
  34. , _end_callback(NULL)
  35. , _error_callback(NULL)
  36. , _progress_callback(NULL)
  37. {
  38. }
  39. ArduinoOTAClass::~ArduinoOTAClass(){
  40. if(_udp_ota){
  41. _udp_ota->unref();
  42. _udp_ota = 0;
  43. }
  44. }
  45. void ArduinoOTAClass::onStart(OTA_CALLBACK(fn)) {
  46. _start_callback = fn;
  47. }
  48. void ArduinoOTAClass::onEnd(OTA_CALLBACK(fn)) {
  49. _end_callback = fn;
  50. }
  51. void ArduinoOTAClass::onProgress(OTA_CALLBACK_PROGRESS(fn)) {
  52. _progress_callback = fn;
  53. }
  54. void ArduinoOTAClass::onError(OTA_CALLBACK_ERROR(fn)) {
  55. _error_callback = fn;
  56. }
  57. void ArduinoOTAClass::setPort(uint16_t port) {
  58. if (!_initialized && !_port && port) {
  59. _port = port;
  60. }
  61. }
  62. void ArduinoOTAClass::setHostname(const char * hostname) {
  63. if (!_initialized && !_hostname.length() && hostname) {
  64. _hostname = hostname;
  65. }
  66. }
  67. void ArduinoOTAClass::setPassword(const char * password) {
  68. if (!_initialized && !_password.length() && password) {
  69. _password = password;
  70. }
  71. }
  72. void ArduinoOTAClass::begin() {
  73. if (_initialized)
  74. return;
  75. if (!_hostname.length()) {
  76. char tmp[15];
  77. sprintf(tmp, "esp8266-%06x", ESP.getChipId());
  78. _hostname = tmp;
  79. }
  80. if (!_port) {
  81. _port = 8266;
  82. }
  83. if(_udp_ota){
  84. _udp_ota->unref();
  85. _udp_ota = 0;
  86. }
  87. _udp_ota = new UdpContext;
  88. _udp_ota->ref();
  89. if(!_udp_ota->listen(*IP_ADDR_ANY, _port))
  90. return;
  91. _udp_ota->onRx(std::bind(&ArduinoOTAClass::_onRx, this));
  92. MDNS.begin(_hostname.c_str());
  93. if (_password.length()) {
  94. MDNS.enableArduino(_port, true);
  95. } else {
  96. MDNS.enableArduino(_port);
  97. }
  98. _initialized = true;
  99. _state = OTA_IDLE;
  100. #ifdef OTA_DEBUG
  101. OTA_DEBUG.printf("OTA server at: %s.local:%u\n", _hostname.c_str(), _port);
  102. #endif
  103. }
  104. int ArduinoOTAClass::parseInt(){
  105. char data[16];
  106. uint8_t index = 0;
  107. char value;
  108. while(_udp_ota->peek() == ' ') _udp_ota->read();
  109. while(true){
  110. value = _udp_ota->peek();
  111. if(value < '0' || value > '9'){
  112. data[index++] = '\0';
  113. return atoi(data);
  114. }
  115. data[index++] = _udp_ota->read();
  116. }
  117. return 0;
  118. }
  119. String ArduinoOTAClass::readStringUntil(char end){
  120. String res = "";
  121. char value;
  122. while(true){
  123. value = _udp_ota->read();
  124. if(value == '\0' || value == end){
  125. return res;
  126. }
  127. res += value;
  128. }
  129. return res;
  130. }
  131. void ArduinoOTAClass::_onRx(){
  132. if(!_udp_ota->next()) return;
  133. ip_addr_t ota_ip;
  134. if (_state == OTA_IDLE) {
  135. int cmd = parseInt();
  136. if (cmd != U_FLASH && cmd != U_SPIFFS)
  137. return;
  138. _ota_ip = _udp_ota->getRemoteAddress();
  139. _cmd = cmd;
  140. _ota_port = parseInt();
  141. _size = parseInt();
  142. _udp_ota->read();
  143. _md5 = readStringUntil('\n');
  144. _md5.trim();
  145. if(_md5.length() != 32)
  146. return;
  147. ota_ip.addr = (uint32_t)_ota_ip;
  148. if (_password.length()){
  149. MD5Builder nonce_md5;
  150. nonce_md5.begin();
  151. nonce_md5.add(String(micros()));
  152. nonce_md5.calculate();
  153. _nonce = nonce_md5.toString();
  154. char auth_req[38];
  155. sprintf(auth_req, "AUTH %s", _nonce.c_str());
  156. _udp_ota->append((const char *)auth_req, strlen(auth_req));
  157. _udp_ota->send(&ota_ip, _udp_ota->getRemotePort());
  158. _state = OTA_WAITAUTH;
  159. return;
  160. } else {
  161. _udp_ota->append("OK", 2);
  162. _udp_ota->send(&ota_ip, _udp_ota->getRemotePort());
  163. _state = OTA_RUNUPDATE;
  164. }
  165. } else if (_state == OTA_WAITAUTH) {
  166. int cmd = parseInt();
  167. if (cmd != U_AUTH) {
  168. _state = OTA_IDLE;
  169. return;
  170. }
  171. _udp_ota->read();
  172. String cnonce = readStringUntil(' ');
  173. String response = readStringUntil('\n');
  174. if (cnonce.length() != 32 || response.length() != 32) {
  175. _state = OTA_IDLE;
  176. return;
  177. }
  178. MD5Builder _passmd5;
  179. _passmd5.begin();
  180. _passmd5.add(_password);
  181. _passmd5.calculate();
  182. String passmd5 = _passmd5.toString();
  183. String challenge = passmd5 + ":" + String(_nonce) + ":" + cnonce;
  184. MD5Builder _challengemd5;
  185. _challengemd5.begin();
  186. _challengemd5.add(challenge);
  187. _challengemd5.calculate();
  188. String result = _challengemd5.toString();
  189. ota_ip.addr = (uint32_t)_ota_ip;
  190. if(result.equals(response)){
  191. _udp_ota->append("OK", 2);
  192. _udp_ota->send(&ota_ip, _udp_ota->getRemotePort());
  193. _state = OTA_RUNUPDATE;
  194. } else {
  195. _udp_ota->append("Authentication Failed", 21);
  196. _udp_ota->send(&ota_ip, _udp_ota->getRemotePort());
  197. if (_error_callback) _error_callback(OTA_AUTH_ERROR);
  198. _state = OTA_IDLE;
  199. }
  200. }
  201. while(_udp_ota->next()) _udp_ota->flush();
  202. }
  203. void ArduinoOTAClass::_runUpdate() {
  204. if (!Update.begin(_size, _cmd)) {
  205. #ifdef OTA_DEBUG
  206. OTA_DEBUG.println("Update Begin Error");
  207. #endif
  208. if (_error_callback) {
  209. _error_callback(OTA_BEGIN_ERROR);
  210. }
  211. _udp_ota->listen(*IP_ADDR_ANY, _port);
  212. _state = OTA_IDLE;
  213. return;
  214. }
  215. Update.setMD5(_md5.c_str());
  216. WiFiUDP::stopAll();
  217. WiFiClient::stopAll();
  218. if (_start_callback) {
  219. _start_callback();
  220. }
  221. if (_progress_callback) {
  222. _progress_callback(0, _size);
  223. }
  224. WiFiClient client;
  225. if (!client.connect(_ota_ip, _ota_port)) {
  226. #ifdef OTA_DEBUG
  227. OTA_DEBUG.printf("Connect Failed\n");
  228. #endif
  229. _udp_ota->listen(*IP_ADDR_ANY, _port);
  230. if (_error_callback) {
  231. _error_callback(OTA_CONNECT_ERROR);
  232. }
  233. _state = OTA_IDLE;
  234. }
  235. uint32_t written, total = 0;
  236. while (!Update.isFinished() && client.connected()) {
  237. int waited = 1000;
  238. while (!client.available() && waited--)
  239. delay(1);
  240. if (!waited){
  241. #ifdef OTA_DEBUG
  242. OTA_DEBUG.printf("Receive Failed\n");
  243. #endif
  244. _udp_ota->listen(*IP_ADDR_ANY, _port);
  245. if (_error_callback) {
  246. _error_callback(OTA_RECEIVE_ERROR);
  247. }
  248. _state = OTA_IDLE;
  249. }
  250. written = Update.write(client);
  251. if (written > 0) {
  252. client.print(written, DEC);
  253. total += written;
  254. if(_progress_callback) {
  255. _progress_callback(total, _size);
  256. }
  257. }
  258. }
  259. if (Update.end()) {
  260. client.print("OK");
  261. client.stop();
  262. delay(10);
  263. #ifdef OTA_DEBUG
  264. OTA_DEBUG.printf("Update Success\nRebooting...\n");
  265. #endif
  266. if (_end_callback) {
  267. _end_callback();
  268. }
  269. ESP.restart();
  270. } else {
  271. _udp_ota->listen(*IP_ADDR_ANY, _port);
  272. if (_error_callback) {
  273. _error_callback(OTA_END_ERROR);
  274. }
  275. Update.printError(client);
  276. #ifdef OTA_DEBUG
  277. Update.printError(OTA_DEBUG);
  278. #endif
  279. _state = OTA_IDLE;
  280. }
  281. }
  282. void ArduinoOTAClass::handle() {
  283. if (_state == OTA_RUNUPDATE) {
  284. _runUpdate();
  285. _state = OTA_IDLE;
  286. }
  287. }
  288. ArduinoOTAClass ArduinoOTA;