From 67335e62a9156ec5638eef1e434c771df1845fed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xose=20P=C3=A9rez?= Date: Thu, 12 Oct 2017 09:22:12 +0200 Subject: [PATCH] =?UTF-8?q?Safer=20buffer=20handling=20for=20websocket=20d?= =?UTF-8?q?ata=20(thanks=20to=20Hermann=20Kraus=20&=20Bj=C3=B6rn=20Bergman?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- code/espurna/web.h | 83 ++++++++++++++++++++++++++++++++++++++++++++ code/espurna/web.ino | 34 +++++++++--------- code/platformio.ini | 4 +-- 3 files changed, 101 insertions(+), 20 deletions(-) create mode 100644 code/espurna/web.h diff --git a/code/espurna/web.h b/code/espurna/web.h new file mode 100644 index 00000000..3762a623 --- /dev/null +++ b/code/espurna/web.h @@ -0,0 +1,83 @@ +/* + +WebSocketIncommingBuffer + +Code by Hermann Kraus (https://bitbucket.org/hermr2d2/) +and slightly modified. + +https://bitbucket.org/xoseperez/espurna/pull-requests/30/safer-buffer-handling-for-websocket-data + +*/ + +#pragma once + +#define MAX_WS_MSG_SIZE 4000 +typedef std::function AwsMessageHandler; + +class WebSocketIncommingBuffer { + + public: + WebSocketIncommingBuffer(AwsMessageHandler cb, bool terminate_string = true, bool cb_on_fragments = false) : + _cb(cb), + _cb_on_fragments(cb_on_fragments), + _terminate_string(terminate_string), + _buffer(0) + {} + + ~WebSocketIncommingBuffer() { + if (_buffer) delete _buffer; + } + + void data_event(AsyncWebSocketClient *client, AwsFrameInfo *info, uint8_t *data, size_t len) { + + if((info->final || _cb_on_fragments) && + !_terminate_string && info->index == 0 && info->len == len) { + /* The whole message is in a single frame and we got all of it's + data therefore we can parse it without copying the data first.*/ + _cb(client, data, len); + } else { + if (info->len > MAX_WS_MSG_SIZE) return; + /* Check if previous fragment was discarded because it was too long. */ + if (!_cb_on_fragments && info->num > 0 && !_buffer) return; + + if (!_buffer) { + _buffer = new std::vector(); + } + if (info->index == 0) { + //New frame => preallocate memory + if (_cb_on_fragments) { + _buffer->reserve(info->len + 1); + } else { + /* The current fragment would lead to a message which is + too long. So discard everything received so far. */ + if (info->len + _buffer->size() > MAX_WS_MSG_SIZE) { + delete _buffer; + _buffer = 0; + return; + } else { + _buffer->reserve(info->len + _buffer->size() + 1); + } + } + } + //assert(_buffer->size() == info->index); + _buffer->insert(_buffer->end(), data, data+len); + if (info->index + len == info->len && + (info->final || _cb_on_fragments)) { + // Frame/message complete + if (_terminate_string) { + _buffer->push_back(0); + } + _cb(client, _buffer->data(), _buffer->size()); + _buffer->clear(); + } + } + } + + private: + + AwsMessageHandler _cb; + bool _cb_on_fragments; + bool _terminate_string; + std::vector *_buffer; + +}; diff --git a/code/espurna/web.ino b/code/espurna/web.ino index 8e603246..47afc731 100644 --- a/code/espurna/web.ino +++ b/code/espurna/web.ino @@ -16,6 +16,7 @@ Copyright (C) 2016-2017 by Xose PĂ©rez #include #include #include +#include "web.h" #if WEB_EMBEDDED #include "static/index.html.gz.h" @@ -67,7 +68,10 @@ void _wsMQTTCallback(unsigned int type, const char * topic, const char * payload } -void _wsParse(uint32_t client_id, uint8_t * payload, size_t length) { +void _wsParse(AsyncWebSocketClient *client, uint8_t * payload, size_t length) { + + // Get client ID + uint32_t client_id = client->id(); // Parse JSON input DynamicJsonBuffer jsonBuffer; @@ -689,8 +693,6 @@ bool _wsAuth(AsyncWebSocketClient * client) { void _wsEvent(AsyncWebSocket * server, AsyncWebSocketClient * client, AwsEventType type, void * arg, uint8_t *data, size_t len){ - static uint8_t * message; - // Authorize #ifndef NOWSAUTH if (!_wsAuth(client)) return; @@ -700,32 +702,28 @@ void _wsEvent(AsyncWebSocket * server, AsyncWebSocketClient * client, AwsEventTy IPAddress ip = client->remoteIP(); DEBUG_MSG_P(PSTR("[WEBSOCKET] #%u connected, ip: %d.%d.%d.%d, url: %s\n"), client->id(), ip[0], ip[1], ip[2], ip[3], server->url()); _wsStart(client->id()); + client->_tempObject = new WebSocketIncommingBuffer(&_wsParse, true); + } else if(type == WS_EVT_DISCONNECT) { DEBUG_MSG_P(PSTR("[WEBSOCKET] #%u disconnected\n"), client->id()); + if (client->_tempObject) { + delete (WebSocketIncommingBuffer *) client->_tempObject; + } + } else if(type == WS_EVT_ERROR) { DEBUG_MSG_P(PSTR("[WEBSOCKET] #%u error(%u): %s\n"), client->id(), *((uint16_t*)arg), (char*)data); + } else if(type == WS_EVT_PONG) { DEBUG_MSG_P(PSTR("[WEBSOCKET] #%u pong(%u): %s\n"), client->id(), len, len ? (char*) data : ""); - } else if(type == WS_EVT_DATA) { + } else if(type == WS_EVT_DATA) { + WebSocketIncommingBuffer *buffer = (WebSocketIncommingBuffer *)client->_tempObject; AwsFrameInfo * info = (AwsFrameInfo*)arg; - - // First packet - if (info->index == 0) { - message = (uint8_t*) malloc(info->len); - } - - // Store data - memcpy(message + info->index, data, len); - - // Last packet - if (info->index + len == info->len) { - _wsParse(client->id(), message, info->len); - free(message); - } + buffer->data_event(client, info, data, len); } + } // ----------------------------------------------------------------------------- diff --git a/code/platformio.ini b/code/platformio.ini index d3d1165c..33fb1457 100644 --- a/code/platformio.ini +++ b/code/platformio.ini @@ -13,8 +13,8 @@ lib_deps = Adafruit Unified Sensor https://github.com/xoseperez/Time ArduinoJson - https://github.com/me-no-dev/ESPAsyncTCP#991f855 - https://github.com/me-no-dev/ESPAsyncWebServer#a94265d + https://github.com/me-no-dev/ESPAsyncTCP#9b0cc37 + https://github.com/me-no-dev/ESPAsyncWebServer#313f337 https://github.com/marvinroger/async-mqtt-client#v0.8.1 PubSubClient Embedis