Fork of the espurna firmware for `mhsw` switches
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

247 lines
7.3 KiB

  1. // -----------------------------------------------------------------------------
  2. // WiFiClientSecure validation helpers
  3. // -----------------------------------------------------------------------------
  4. #pragma once
  5. #if SECURE_CLIENT != SECURE_CLIENT_NONE
  6. #if SECURE_CLIENT == SECURE_CLIENT_BEARSSL
  7. #include <WiFiClientSecureBearSSL.h>
  8. #elif SECURE_CLIENT == SECURE_CLIENT_AXTLS
  9. #include <WiFiClientSecureAxTLS.h>
  10. #endif
  11. namespace SecureClientHelpers {
  12. using host_callback_f = std::function<String()>;
  13. using check_callback_f = std::function<int()>;
  14. using fp_callback_f = std::function<String()>;
  15. using cert_callback_f = std::function<const char*()>;
  16. using mfln_callback_f = std::function<uint16_t()>;
  17. const char * _secureClientCheckAsString(int check) {
  18. switch (check) {
  19. case SECURE_CLIENT_CHECK_NONE: return "no validation";
  20. case SECURE_CLIENT_CHECK_FINGERPRINT: return "fingerprint validation";
  21. case SECURE_CLIENT_CHECK_CA: return "CA validation";
  22. default: return "unknown";
  23. }
  24. }
  25. #if SECURE_CLIENT == SECURE_CLIENT_AXTLS
  26. using SecureClientClass = axTLS::WiFiClientSecure;
  27. struct SecureClientConfig {
  28. SecureClientConfig(const char* tag, host_callback_f host_cb, check_callback_f check_cb, fp_callback_f fp_cb, bool debug = false) :
  29. tag(tag),
  30. on_host(host_cb),
  31. on_check(check_cb),
  32. on_fingerprint(fp_cb),
  33. debug(debug)
  34. {}
  35. String tag;
  36. host_callback_f on_host;
  37. check_callback_f on_check;
  38. fp_callback_f on_fingerprint;
  39. bool debug;
  40. };
  41. struct SecureClientChecks {
  42. SecureClientChecks(SecureClientConfig& config) :
  43. config(config)
  44. {}
  45. int getCheck() {
  46. return (config.on_check) ? config.on_check() : (SECURE_CLIENT_CHECK);
  47. }
  48. bool beforeConnected(SecureClientClass& client) {
  49. return true;
  50. }
  51. // Special condition for legacy client!
  52. // Otherwise, we are required to connect twice. And it is deemed broken & deprecated anyways...
  53. bool afterConnected(SecureClientClass& client) {
  54. bool result = false;
  55. int check = getCheck();
  56. if(config.debug) {
  57. DEBUG_MSG_P(PSTR("[%s] Using SSL check type: %s\n"), config.tag.c_str(), _secureClientCheckAsString(check));
  58. }
  59. if (check == SECURE_CLIENT_CHECK_NONE) {
  60. if (config.debug) DEBUG_MSG_P(PSTR("[%s] !!! Secure connection will not be validated !!!\n"), config.tag.c_str());
  61. result = true;
  62. } else if (check == SECURE_CLIENT_CHECK_FINGERPRINT) {
  63. if (config.on_fingerprint) {
  64. char _buffer[60] = {0};
  65. if (config.on_fingerprint && config.on_host && sslFingerPrintChar(config.on_fingerprint().c_str(), _buffer)) {
  66. result = client.verify(_buffer, config.on_host().c_str());
  67. }
  68. if (!result) DEBUG_MSG_P(PSTR("[%s] Wrong fingerprint, cannot connect\n"), config.tag.c_str());
  69. }
  70. } else if (check == SECURE_CLIENT_CHECK_CA) {
  71. if (config.debug) DEBUG_MSG_P(PSTR("[%s] CA verification is not supported with axTLS client\n"), config.tag.c_str());
  72. }
  73. return result;
  74. }
  75. SecureClientConfig& config;
  76. bool debug;
  77. };
  78. #endif // SECURE_CLIENT_AXTLS
  79. #if SECURE_CLIENT == SECURE_CLIENT_BEARSSL
  80. using SecureClientClass = BearSSL::WiFiClientSecure;
  81. struct SecureClientConfig {
  82. SecureClientConfig(const char* tag, check_callback_f check_cb, cert_callback_f cert_cb, fp_callback_f fp_cb, mfln_callback_f mfln_cb, bool debug = false) :
  83. tag(tag),
  84. on_check(check_cb),
  85. on_certificate(cert_cb),
  86. on_fingerprint(fp_cb),
  87. on_mfln(mfln_cb),
  88. debug(debug)
  89. {}
  90. String tag;
  91. check_callback_f on_check;
  92. cert_callback_f on_certificate;
  93. fp_callback_f on_fingerprint;
  94. mfln_callback_f on_mfln;
  95. bool debug;
  96. };
  97. struct SecureClientChecks {
  98. SecureClientChecks(SecureClientConfig& config) :
  99. config(config)
  100. {}
  101. int getCheck() {
  102. return (config.on_check) ? config.on_check() : (SECURE_CLIENT_CHECK);
  103. }
  104. bool prepareMFLN(SecureClientClass& client) {
  105. const uint16_t requested_mfln = (config.on_mfln) ? config.on_mfln() : (SECURE_CLIENT_MFLN);
  106. bool result = false;
  107. switch (requested_mfln) {
  108. // default, do nothing
  109. case 0:
  110. result = true;
  111. break;
  112. // match valid sizes only
  113. case 512:
  114. case 1024:
  115. case 2048:
  116. case 4096:
  117. {
  118. client.setBufferSizes(requested_mfln, requested_mfln);
  119. result = true;
  120. if (config.debug) {
  121. DEBUG_MSG_P(PSTR("[%s] MFLN buffer size set to %u\n"), config.tag.c_str(), requested_mfln);
  122. }
  123. break;
  124. }
  125. default:
  126. {
  127. if (config.debug) {
  128. DEBUG_MSG_P(PSTR("[%s] Warning: MFLN buffer size must be one of 512, 1024, 2048 or 4096\n"), config.tag.c_str());
  129. }
  130. }
  131. }
  132. return result;
  133. }
  134. bool beforeConnected(SecureClientClass& client) {
  135. int check = getCheck();
  136. bool settime = (check == SECURE_CLIENT_CHECK_CA);
  137. if(config.debug) {
  138. DEBUG_MSG_P(PSTR("[%s] Using SSL check type: %s\n"), config.tag.c_str(), _secureClientCheckAsString(check));
  139. }
  140. if (!ntpSynced() && settime) {
  141. if (config.debug) DEBUG_MSG_P(PSTR("[%s] Time not synced! Cannot use CA validation\n"), config.tag.c_str());
  142. return false;
  143. }
  144. prepareMFLN(client);
  145. if (check == SECURE_CLIENT_CHECK_NONE) {
  146. if (config.debug) DEBUG_MSG_P(PSTR("[%s] !!! Secure connection will not be validated !!!\n"), config.tag.c_str());
  147. client.setInsecure();
  148. } else if (check == SECURE_CLIENT_CHECK_FINGERPRINT) {
  149. uint8_t _buffer[20] = {0};
  150. if (config.on_fingerprint && sslFingerPrintArray(config.on_fingerprint().c_str(), _buffer)) {
  151. client.setFingerprint(_buffer);
  152. }
  153. } else if (check == SECURE_CLIENT_CHECK_CA) {
  154. client.setX509Time(ntpLocal2UTC(now()));
  155. if (!certs.getCount()) {
  156. if (config.on_certificate) certs.append(config.on_certificate());
  157. }
  158. client.setTrustAnchors(&certs);
  159. }
  160. return true;
  161. }
  162. bool afterConnected(SecureClientClass&) {
  163. return true;
  164. }
  165. bool debug;
  166. SecureClientConfig& config;
  167. BearSSL::X509List certs;
  168. };
  169. #endif // SECURE_CLIENT_BEARSSL
  170. class SecureClient {
  171. public:
  172. SecureClient(SecureClientConfig& config) :
  173. _config(config),
  174. _checks(_config),
  175. _client(std::make_unique<SecureClientClass>())
  176. {}
  177. bool afterConnected() {
  178. return _checks.afterConnected(get());
  179. }
  180. bool beforeConnected() {
  181. return _checks.beforeConnected(get());
  182. }
  183. SecureClientClass& get() {
  184. return *_client.get();
  185. }
  186. private:
  187. SecureClientConfig _config;
  188. SecureClientChecks _checks;
  189. std::unique_ptr<SecureClientClass> _client;
  190. };
  191. };
  192. using SecureClientConfig = SecureClientHelpers::SecureClientConfig;
  193. using SecureClientChecks = SecureClientHelpers::SecureClientChecks;
  194. using SecureClient = SecureClientHelpers::SecureClient;
  195. #endif // SECURE_CLIENT != SECURE_CLIENT_NONE