Fork of the espurna firmware for `mhsw` switches
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

230 lines
6.7 KiB

  1. // -----------------------------------------------------------------------------
  2. // WiFiClientSecure validation helpers
  3. // -----------------------------------------------------------------------------
  4. #pragma once
  5. #if SECURE_CLIENT != SECURE_CLIENT_NONE
  6. #if SECURE_CLIENT == SECURE_CLIENT_BEARSSL
  7. #include <WiFiClientSecureBearSSL.h>
  8. #elif SECURE_CLIENT == SECURE_CLIENT_AXTLS
  9. #include <WiFiClientSecureAxTLS.h>
  10. #endif
  11. namespace SecureClientHelpers {
  12. using host_callback_f = std::function<String()>;
  13. using check_callback_f = std::function<bool()>;
  14. using fp_callback_f = std::function<String()>;
  15. using cert_callback_f = std::function<const char*()>;
  16. using mfln_callback_f = std::function<uint16_t()>;
  17. #if SECURE_CLIENT == SECURE_CLIENT_AXTLS
  18. using SecureClientClass = axTLS::WiFiClientSecure;
  19. struct SecureClientConfig {
  20. SecureClientConfig(const char* tag, host_callback_f host_cb, check_callback_f check_cb, fp_callback_f fp_cb, bool debug = false) :
  21. tag(tag),
  22. on_host(host_cb),
  23. on_check(check_cb),
  24. on_fingerprint(fp_cb),
  25. debug(debug)
  26. {}
  27. String tag;
  28. host_callback_f on_host;
  29. check_callback_f on_check;
  30. fp_callback_f on_fingerprint;
  31. bool debug;
  32. };
  33. struct SecureClientChecks {
  34. SecureClientChecks(SecureClientConfig& config) :
  35. config(config)
  36. {}
  37. int getCheck() {
  38. return (config.on_check) ? config.on_check() : (SECURE_CLIENT_CHECK);
  39. }
  40. bool beforeConnected(SecureClientClass& client) {
  41. return true;
  42. }
  43. // Special condition for legacy client!
  44. // Otherwise, we are required to connect twice. And it is deemed broken & deprecated anyways...
  45. bool afterConnected(SecureClientClass& client) {
  46. bool result = false;
  47. int check = getCheck();
  48. if (check == SECURE_CLIENT_CHECK_NONE) {
  49. if (config.debug) DEBUG_MSG_P(PSTR("[%s] !!! Secure connection will not be validated !!!\n"), config.tag.c_str());
  50. result = true;
  51. } else if (check == SECURE_CLIENT_CHECK_FINGERPRINT) {
  52. if (config.on_fingerprint) {
  53. char _buffer[60] = {0};
  54. if (config.on_fingerprint && config.on_host && sslFingerPrintChar(config.on_fingerprint().c_str(), _buffer)) {
  55. result = client.verify(_buffer, config.on_host().c_str());
  56. }
  57. if (!result) DEBUG_MSG_P(PSTR("[%s] Wrong fingerprint, cannot connect\n"), config.tag.c_str());
  58. }
  59. } else if (check == SECURE_CLIENT_CHECK_CA) {
  60. if (config.debug) DEBUG_MSG_P(PSTR("[%s] CA verification is not supported with axTLS client\n"), config.tag.c_str());
  61. }
  62. return result;
  63. }
  64. SecureClientConfig& config;
  65. bool debug;
  66. };
  67. #endif // SECURE_CLIENT_AXTLS
  68. #if SECURE_CLIENT == SECURE_CLIENT_BEARSSL
  69. using SecureClientClass = BearSSL::WiFiClientSecure;
  70. struct SecureClientConfig {
  71. SecureClientConfig(const char* tag, check_callback_f check_cb, cert_callback_f cert_cb, fp_callback_f fp_cb, mfln_callback_f mfln_cb, bool debug = false) :
  72. tag(tag),
  73. on_check(check_cb),
  74. on_certificate(cert_cb),
  75. on_fingerprint(fp_cb),
  76. on_mfln(mfln_cb),
  77. debug(debug)
  78. {}
  79. String tag;
  80. check_callback_f on_check;
  81. cert_callback_f on_certificate;
  82. fp_callback_f on_fingerprint;
  83. mfln_callback_f on_mfln;
  84. bool debug;
  85. };
  86. struct SecureClientChecks {
  87. SecureClientChecks(SecureClientConfig& config) :
  88. config(config)
  89. {}
  90. int getCheck() {
  91. return (config.on_check) ? config.on_check() : (SECURE_CLIENT_CHECK);
  92. }
  93. bool prepareMFLN(SecureClientClass& client) {
  94. const uint16_t requested_mfln = (config.on_mfln) ? config.on_mfln() : (SECURE_CLIENT_MFLN);
  95. bool result = false;
  96. switch (requested_mfln) {
  97. // default, do nothing
  98. case 0:
  99. result = true;
  100. break;
  101. // match valid sizes only
  102. case 512:
  103. case 1024:
  104. case 2048:
  105. case 4096:
  106. {
  107. client.setBufferSizes(requested_mfln, requested_mfln);
  108. result = true;
  109. if (config.debug) {
  110. DEBUG_MSG_P(PSTR("[%s] MFLN buffer size set to %u\n"), config.tag.c_str(), requested_mfln);
  111. }
  112. break;
  113. }
  114. default:
  115. {
  116. if (config.debug) {
  117. DEBUG_MSG_P(PSTR("[%s] Warning: MFLN buffer size must be one of 512, 1024, 2048 or 4096\n"), config.tag.c_str());
  118. }
  119. }
  120. }
  121. return result;
  122. }
  123. bool beforeConnected(SecureClientClass& client) {
  124. int check = getCheck();
  125. bool settime = (check == SECURE_CLIENT_CHECK_CA);
  126. if (!ntpSynced() && settime) {
  127. if (config.debug) DEBUG_MSG_P(PSTR("[%s] Time not synced! Cannot use CA validation\n"), config.tag.c_str());
  128. return false;
  129. }
  130. prepareMFLN(client);
  131. if (check == SECURE_CLIENT_CHECK_NONE) {
  132. if (config.debug) DEBUG_MSG_P(PSTR("[%s] !!! Secure connection will not be validated !!!\n"), config.tag.c_str());
  133. client.setInsecure();
  134. } else if (check == SECURE_CLIENT_CHECK_FINGERPRINT) {
  135. uint8_t _buffer[20] = {0};
  136. if (config.on_fingerprint && sslFingerPrintArray(config.on_fingerprint().c_str(), _buffer)) {
  137. client.setFingerprint(_buffer);
  138. }
  139. } else if (check == SECURE_CLIENT_CHECK_CA) {
  140. client.setX509Time(ntpLocal2UTC(now()));
  141. if (!certs.getCount()) {
  142. if (config.on_certificate) certs.append(config.on_certificate());
  143. }
  144. client.setTrustAnchors(&certs);
  145. }
  146. return true;
  147. }
  148. bool afterConnected(SecureClientClass&) {
  149. return true;
  150. }
  151. bool debug;
  152. SecureClientConfig& config;
  153. BearSSL::X509List certs;
  154. };
  155. #endif // SECURE_CLIENT_BEARSSL
  156. class SecureClient {
  157. public:
  158. SecureClient(SecureClientConfig& config) :
  159. _config(config),
  160. _checks(_config),
  161. _client(std::make_unique<SecureClientClass>())
  162. {}
  163. bool afterConnected() {
  164. return _checks.afterConnected(get());
  165. }
  166. bool beforeConnected() {
  167. return _checks.beforeConnected(get());
  168. }
  169. SecureClientClass& get() {
  170. return *_client.get();
  171. }
  172. private:
  173. SecureClientConfig _config;
  174. SecureClientChecks _checks;
  175. std::unique_ptr<SecureClientClass> _client;
  176. };
  177. };
  178. using SecureClientConfig = SecureClientHelpers::SecureClientConfig;
  179. using SecureClientChecks = SecureClientHelpers::SecureClientChecks;
  180. using SecureClient = SecureClientHelpers::SecureClient;
  181. #endif // SECURE_CLIENT != SECURE_CLIENT_NONE