Fork of the espurna firmware for `mhsw` switches
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

253 lines
7.5 KiB

  1. // -----------------------------------------------------------------------------
  2. // WiFiClientSecure validation helpers
  3. // -----------------------------------------------------------------------------
  4. #pragma once
  5. #include "../espurna.h"
  6. #if SECURE_CLIENT != SECURE_CLIENT_NONE
  7. #include "../ntp.h"
  8. #if SECURE_CLIENT == SECURE_CLIENT_BEARSSL
  9. #include "ntp_timelib.h"
  10. #include <WiFiClientSecureBearSSL.h>
  11. #elif SECURE_CLIENT == SECURE_CLIENT_AXTLS
  12. #include <WiFiClientSecureAxTLS.h>
  13. #endif
  14. namespace SecureClientHelpers {
  15. using host_callback_f = std::function<String()>;
  16. using check_callback_f = std::function<int()>;
  17. using fp_callback_f = std::function<String()>;
  18. using cert_callback_f = std::function<const char*()>;
  19. using mfln_callback_f = std::function<uint16_t()>;
  20. // TODO: workaround for `multiple definition of `SecureClientHelpers::_secureClientCheckAsString(int);'
  21. inline const char * _secureClientCheckAsString(int check) {
  22. switch (check) {
  23. case SECURE_CLIENT_CHECK_NONE: return "no validation";
  24. case SECURE_CLIENT_CHECK_FINGERPRINT: return "fingerprint validation";
  25. case SECURE_CLIENT_CHECK_CA: return "CA validation";
  26. default: return "unknown";
  27. }
  28. }
  29. #if SECURE_CLIENT == SECURE_CLIENT_AXTLS
  30. using SecureClientClass = axTLS::WiFiClientSecure;
  31. struct SecureClientConfig {
  32. SecureClientConfig(const char* tag, host_callback_f host_cb, check_callback_f check_cb, fp_callback_f fp_cb, bool debug = false) :
  33. tag(tag),
  34. on_host(host_cb),
  35. on_check(check_cb),
  36. on_fingerprint(fp_cb),
  37. debug(debug)
  38. {}
  39. String tag;
  40. host_callback_f on_host;
  41. check_callback_f on_check;
  42. fp_callback_f on_fingerprint;
  43. bool debug;
  44. };
  45. struct SecureClientChecks {
  46. SecureClientChecks(SecureClientConfig& config) :
  47. config(config)
  48. {}
  49. int getCheck() {
  50. return (config.on_check) ? config.on_check() : (SECURE_CLIENT_CHECK);
  51. }
  52. bool beforeConnected(SecureClientClass& client) {
  53. return true;
  54. }
  55. // Special condition for legacy client!
  56. // Otherwise, we are required to connect twice. And it is deemed broken & deprecated anyways...
  57. bool afterConnected(SecureClientClass& client) {
  58. bool result = false;
  59. int check = getCheck();
  60. if(config.debug) {
  61. DEBUG_MSG_P(PSTR("[%s] Using SSL check type: %s\n"), config.tag.c_str(), _secureClientCheckAsString(check));
  62. }
  63. if (check == SECURE_CLIENT_CHECK_NONE) {
  64. if (config.debug) DEBUG_MSG_P(PSTR("[%s] !!! Secure connection will not be validated !!!\n"), config.tag.c_str());
  65. result = true;
  66. } else if (check == SECURE_CLIENT_CHECK_FINGERPRINT) {
  67. if (config.on_fingerprint) {
  68. char _buffer[60] = {0};
  69. if (config.on_fingerprint && config.on_host && sslFingerPrintChar(config.on_fingerprint().c_str(), _buffer)) {
  70. result = client.verify(_buffer, config.on_host().c_str());
  71. }
  72. if (!result) DEBUG_MSG_P(PSTR("[%s] Wrong fingerprint, cannot connect\n"), config.tag.c_str());
  73. }
  74. } else if (check == SECURE_CLIENT_CHECK_CA) {
  75. if (config.debug) DEBUG_MSG_P(PSTR("[%s] CA verification is not supported with axTLS client\n"), config.tag.c_str());
  76. }
  77. return result;
  78. }
  79. SecureClientConfig& config;
  80. bool debug;
  81. };
  82. #endif // SECURE_CLIENT_AXTLS
  83. #if SECURE_CLIENT == SECURE_CLIENT_BEARSSL
  84. using SecureClientClass = BearSSL::WiFiClientSecure;
  85. struct SecureClientConfig {
  86. SecureClientConfig(const char* tag, check_callback_f check_cb, cert_callback_f cert_cb, fp_callback_f fp_cb, mfln_callback_f mfln_cb, bool debug = false) :
  87. tag(tag),
  88. on_check(check_cb),
  89. on_certificate(cert_cb),
  90. on_fingerprint(fp_cb),
  91. on_mfln(mfln_cb),
  92. debug(debug)
  93. {}
  94. String tag;
  95. check_callback_f on_check;
  96. cert_callback_f on_certificate;
  97. fp_callback_f on_fingerprint;
  98. mfln_callback_f on_mfln;
  99. bool debug;
  100. };
  101. struct SecureClientChecks {
  102. SecureClientChecks(SecureClientConfig& config) :
  103. config(config)
  104. {}
  105. int getCheck() {
  106. return (config.on_check) ? config.on_check() : (SECURE_CLIENT_CHECK);
  107. }
  108. bool prepareMFLN(SecureClientClass& client) {
  109. const uint16_t requested_mfln = (config.on_mfln) ? config.on_mfln() : (SECURE_CLIENT_MFLN);
  110. bool result = false;
  111. switch (requested_mfln) {
  112. // default, do nothing
  113. case 0:
  114. result = true;
  115. break;
  116. // match valid sizes only
  117. case 512:
  118. case 1024:
  119. case 2048:
  120. case 4096:
  121. {
  122. client.setBufferSizes(requested_mfln, requested_mfln);
  123. result = true;
  124. if (config.debug) {
  125. DEBUG_MSG_P(PSTR("[%s] MFLN buffer size set to %u\n"), config.tag.c_str(), requested_mfln);
  126. }
  127. break;
  128. }
  129. default:
  130. {
  131. if (config.debug) {
  132. DEBUG_MSG_P(PSTR("[%s] Warning: MFLN buffer size must be one of 512, 1024, 2048 or 4096\n"), config.tag.c_str());
  133. }
  134. }
  135. }
  136. return result;
  137. }
  138. bool beforeConnected(SecureClientClass& client) {
  139. int check = getCheck();
  140. bool settime = (check == SECURE_CLIENT_CHECK_CA);
  141. if(config.debug) {
  142. DEBUG_MSG_P(PSTR("[%s] Using SSL check type: %s\n"), config.tag.c_str(), _secureClientCheckAsString(check));
  143. }
  144. if (!ntpSynced() && settime) {
  145. if (config.debug) DEBUG_MSG_P(PSTR("[%s] Time not synced! Cannot use CA validation\n"), config.tag.c_str());
  146. return false;
  147. }
  148. prepareMFLN(client);
  149. if (check == SECURE_CLIENT_CHECK_NONE) {
  150. if (config.debug) DEBUG_MSG_P(PSTR("[%s] !!! Secure connection will not be validated !!!\n"), config.tag.c_str());
  151. client.setInsecure();
  152. } else if (check == SECURE_CLIENT_CHECK_FINGERPRINT) {
  153. uint8_t _buffer[20] = {0};
  154. if (config.on_fingerprint && sslFingerPrintArray(config.on_fingerprint().c_str(), _buffer)) {
  155. client.setFingerprint(_buffer);
  156. }
  157. } else if (check == SECURE_CLIENT_CHECK_CA) {
  158. client.setX509Time(now());
  159. if (!certs.getCount()) {
  160. if (config.on_certificate) certs.append(config.on_certificate());
  161. }
  162. client.setTrustAnchors(&certs);
  163. }
  164. return true;
  165. }
  166. bool afterConnected(SecureClientClass&) {
  167. return true;
  168. }
  169. bool debug;
  170. SecureClientConfig& config;
  171. BearSSL::X509List certs;
  172. };
  173. #endif // SECURE_CLIENT_BEARSSL
  174. class SecureClient {
  175. public:
  176. SecureClient(SecureClientConfig& config) :
  177. _config(config),
  178. _checks(_config),
  179. _client(std::make_unique<SecureClientClass>())
  180. {}
  181. bool afterConnected() {
  182. return _checks.afterConnected(get());
  183. }
  184. bool beforeConnected() {
  185. return _checks.beforeConnected(get());
  186. }
  187. SecureClientClass& get() {
  188. return *_client.get();
  189. }
  190. private:
  191. SecureClientConfig _config;
  192. SecureClientChecks _checks;
  193. std::unique_ptr<SecureClientClass> _client;
  194. };
  195. };
  196. using SecureClientConfig = SecureClientHelpers::SecureClientConfig;
  197. using SecureClientChecks = SecureClientHelpers::SecureClientChecks;
  198. using SecureClient = SecureClientHelpers::SecureClient;
  199. #endif // SECURE_CLIENT != SECURE_CLIENT_NONE