From 34fecbcd18894684090be950282fdc6bd4beda17 Mon Sep 17 00:00:00 2001 From: Maxim Prokhorov Date: Thu, 11 May 2023 22:45:53 +0300 Subject: [PATCH] draft --- code/espurna/network.cpp | 1164 ++++++++++++++++++++++++++++++++++++++ code/espurna/network.h | 609 ++++++++++++++++++++ code/espurna/types.h | 294 +++++++++- 3 files changed, 2062 insertions(+), 5 deletions(-) diff --git a/code/espurna/network.cpp b/code/espurna/network.cpp index aa16cbc0..3e293530 100644 --- a/code/espurna/network.cpp +++ b/code/espurna/network.cpp @@ -12,10 +12,13 @@ Copyright (C) 2022 by Maxim Prokhorov #include #include #include +#include #include +#include #include +#include "network.h" #include "libs/URL.h" // not yet CONNECTING or LISTENING @@ -27,6 +30,1112 @@ extern "C" struct tcp_pcb *tcp_tw_pcbs; namespace espurna { namespace network { + +Packet::~Packet() { + pbuf_free(_pbuf); +} + +Packet::Packet(Packet&& other) noexcept : + _pbuf(other._pbuf) +{ + other._pbuf = nullptr; +} + +Packet& Packet::operator=(Packet&& other) noexcept { + _pbuf = other._pbuf; + other._pbuf = nullptr; + return *this; +} + +void Packet::append(pbuf* pb) { + if (!_pbuf) { + _pbuf = pb; + } else { + pbuf_cat(_pbuf, pb); + } +} + +void Packet::consume(size_t size) { + if (_pbuf) { + _pbuf = pbuf_free_header(_pbuf, size); + } +} + +size_t Packet::size() const { + if (_pbuf) { + return _pbuf->tot_len; + } + + return 0; +} + +size_t Packet::size_chunk() const { + if (_pbuf) { + return _pbuf->len; + } + + return 0; +} + +ConstData Packet::peek_chunk(size_t size) { + if (_pbuf) { + const auto available = size_t{ _pbuf->len }; + const auto* payload = reinterpret_cast(_pbuf->payload); + return {payload, std::min({size, available, RecvMax})}; + } + + return {}; +} + +size_t Packet::peek(Data out) { + if (_pbuf) { + return pbuf_copy_partial(_pbuf, out.data(), out.size(), 0); + } + + return 0; +} + +std::vector Packet::peek(size_t size) { + std::vector out; + if (size) { + size = std::min(size, RecvMax); + out.resize(size); + peek(Data{out.data(), size}); + } + + return out; +} + +size_t Packet::read(Data out) { + const auto received = peek(out); + if (received) { + consume(received); + } + + return 0; +} + +std::vector Packet::read(size_t size) { + auto received = peek(size); + consume(received.size()); + return received; +} + +namespace tcp { + +Address remote_address(tcp_pcb* pcb) { + Address out; + if (pcb) { + ip_addr_copy(out.ip, pcb->remote_ip); + out.port = pcb->remote_port; + } + + return out; +} + +Address remote_address(const Control& control) { + return remote_address(control.get()); +} + +Address local_address(tcp_pcb* pcb) { + Address out; + if (pcb) { + ip_addr_copy(out.ip, pcb->local_ip); + out.port = pcb->local_port; + } + + return out; +} + +Address local_address(const Control& control) { + return local_address(control.get()); +} + +template +std::shared_ptr make_io_completion(size_t size, Result&& result) { + return std::make_shared(size, std::forward(result)); +} + +template +static bool poll(Timeout timeout, Interval interval, T&& blocked) { + return time::blockingDelay( + std::chrono::duration_cast(timeout), + std::chrono::duration_cast(interval), + std::forward(blocked)); +} + +template +static void poll(Interval interval, T&& blocked) { + for (;;) { + const auto result = blocked(); + if (!result) { + break; + } + + time::delay(interval); + } +} + +std::errc from_lwip_error(err_t err) { + switch (err) { + case ERR_OK: + break; + + case ERR_MEM: + return std::errc::not_enough_memory; + + case ERR_BUF: + return std::errc::no_buffer_space; + + case ERR_TIMEOUT: + return std::errc::timed_out; + + case ERR_RTE: + return std::errc::network_unreachable; + + case ERR_INPROGRESS: + return std::errc::operation_in_progress; + + case ERR_VAL: + return std::errc::invalid_argument; + + case ERR_WOULDBLOCK: + return std::errc::operation_would_block; + + case ERR_USE: + return std::errc::address_in_use; + + case ERR_ALREADY: + return std::errc::already_connected; + + case ERR_ISCONN: + return std::errc::already_connected; + + case ERR_CONN: + return std::errc::not_connected; + + case ERR_IF: + return std::errc::network_down; + + case ERR_ABRT: + return std::errc::operation_canceled; + + case ERR_RST: + return std::errc::connection_reset; + + case ERR_CLSD: + return std::errc::connection_aborted; + + case ERR_ARG: + return std::errc::invalid_argument; + + } + + return std::errc{}; +} + +void Control::set_nagle() { + if (_pcb) { + tcp_nagle_enable(_pcb); + } +} + +void Control::set_nodelay() { + if (_pcb) { + tcp_nagle_disable(_pcb); + } +} + +err_t Control::connect(Address address, tcp_connected_fn handler) { + if (_pcb) { + return ERR_VAL; + } + + if (_pcb->state == ESTABLISHED) { + return ERR_ALREADY; + } + + return tcp_connect(_pcb, &address.ip, address.port, handler); +} + +err_t Control::attach(Address address, ServerHandler handler) { + if (_pcb) { + return ERR_VAL; + } + + auto* pcb = tcp_new(); + if (!pcb) { + return ERR_MEM; + } + + pcb->so_options |= SOF_REUSEADDR; + pcb->prio = TCP_PRIO_MIN; + + const auto err = tcp_bind(pcb, &address.ip, address.port); + if (err != ERR_OK) { + return err; + } + + auto* listen = tcp_listen(pcb); + if (!listen) { + tcp_abort(pcb); + return ERR_VAL; + } + + tcp_arg(pcb, handler.arg); + tcp_accept(pcb, handler.on_accept); + + _pcb = pcb; + + return ERR_OK; +} + + +err_t Control::attach(ClientHandler handler) { + tcp_arg(_pcb, handler.arg); + tcp_err(_pcb, handler.on_error); + tcp_poll(_pcb, handler.on_poll, 1); + tcp_recv(_pcb, handler.on_recv); + tcp_sent(_pcb, handler.on_sent); + + return ERR_OK; +} + +err_t Control::attach(tcp_pcb* other, ClientHandler handler) { + if (!_pcb) { + _pcb = other; + attach(handler); + return ERR_OK; + } + + return ERR_VAL; +} + +size_t Control::write_buffer() const { + if (_pcb) { + return tcp_sndbuf(_pcb); + } + + return 0; +} + +void Control::recv_consume(size_t size) { + if (_pcb) { + tcp_recved(_pcb, size); + } +} + +err_t Control::write(const uint8_t* data, size_t size, uint8_t flags) { + return tcp_write(_pcb, data, size, flags); +} + +err_t Control::try_send() { + return tcp_output(_pcb); +} + +void Control::detach() noexcept { + if (_pcb) { + tcp_arg(_pcb, nullptr); + tcp_recv(_pcb, nullptr); + tcp_err(_pcb, nullptr); + tcp_sent(_pcb, nullptr); + tcp_poll(_pcb, nullptr, 0); + } +} + +err_t Control::close() noexcept { + err_t err = ERR_OK; + if (_pcb) { + detach(); + err = tcp_close(_pcb); + if (err != ERR_OK) { + err = abort(); + } + + _pcb = nullptr; + } + + return err; +} + +err_t Control::abort(err_t err) { + tcp_abort(_pcb); + + _pcb = nullptr; + _last_err = err; + + return ERR_ABRT; +} + +Client::Client(Control&& control) noexcept : + _remote(remote_address(control.get())), + _control(std::move(control)) +{ + _control.attach( + Control::ClientHandler{ + .arg = this, + .on_error = s_on_tcp_error, + .on_poll = s_on_tcp_poll, + .on_recv = s_on_tcp_recv, + .on_sent = s_on_tcp_sent, + }); + _control.set_nagle(); +} + +Client::Client(tcp_pcb* pcb) noexcept : + Client(Control(pcb)) +{} + +Client::Client(Client&& other) noexcept : + _remote(other._remote), + _control(std::move(other._control)), + _packet(std::move(other._packet)), + _execution(std::move(other._execution)) +{} + +Client& Client::operator=(Client&& other) noexcept { + abort(); + + _control = std::move(other._control); + _execution = std::move(other._execution); + + _remote = other._remote; + + return *this; +} + +Address Client::local() const { + return local_address(_control); +} + +std::errc Client::connect(Address address) { + if (_control && !_execution.connection) { + return from_lwip_error( + _control.connect(address, s_on_tcp_connect)); + } + + if (_execution.connection) { + return std::errc::connection_already_in_progress; + } + + return std::errc::invalid_argument; +} + +err_t Client::on_connect(err_t err) { + if (!_execution.connection) { + return ERR_OK; + } + + decltype(_execution.connection) connection; + std::swap(_execution.connection, connection); + + if (err != ERR_OK) { + connection->set_error(from_lwip_error(err)); + } else { + connection->set_done(); + } + + return _control.error(); +} + +err_t Client::on_poll() { + try_readers(); + try_writers(); + return _control.error(); +} + +err_t Client::on_sent(uint16_t) { + try_writers(); + return _control.error(); +} + +void Client::stop_connection(CompletionPtr ptr) { + if (_execution.connection == ptr) { + decltype(_execution.connection) connection; + std::swap(_execution.connection, connection); + connection->set_error(std::errc::operation_canceled); + close(); + } +} + +bool ClientCancelation::set_timeout(Client& client, TimeSource::duration timeout) { + auto ptr = _ptr.lock(); + if (!ptr) { + return false; + } + + client._execution.timeouts.push_front( + Client::Timeout{ + .timeout = timeout, + .start = TimeSource::now(), + .type = _type, + .completion = ptr, + }); + + return true; +} + +bool ClientCancelation::try_completion_once(Client& client) { + switch (_type) { + case Type::Connection: + break; + case Type::Writing: + client.try_writers(); + break; + case Type::Reading: + client.try_readers(); + break; + } + + return _ptr.expired(); +} + +void ClientCancelation::try_completion(Client& client, TimeSource::duration interval) { + for (;;) { + if (try_completion_once(client)) { + return; + } + + time::delay(std::chrono::duration_cast(interval)); + } +} + +void ClientCancelation::wait_for(Client& client, TimeSource::duration timeout, TimeSource::duration interval) { + const auto result = set_timeout(client, timeout); + if (!result) { + return; + } + + try_completion(client, interval); +} + +void ClientCancelation::wait_for(Client& client, TimeSource::duration timeout) { + wait_for(client, timeout, duration::Milliseconds{ 10 }); +} + +void ClientCancelation::wait(Client& client) { + try_completion(client, duration::Milliseconds{ 10 }); +} + +void ClientCancelation::cancel(Client& client) { + auto ptr = _ptr.lock(); + if (!ptr) { + return; + } + + client.stop_completion(_type, ptr); +} + +size_t Client::available() const { + return _packet.size(); +} + +size_t Client::available_chunk() const { + return _packet.size_chunk(); +} + +bool Client::connected() const { + return _control.state() == ESTABLISHED; +} + +bool Client::closed() const { + switch (_control.state()) { + case tcp_state::CLOSED: + case tcp_state::CLOSE_WAIT: + return true; + + default: + break; + } + + return false; +} + +void Client::consume(size_t size) { + _packet.consume(size); + _control.recv_consume(size); +} + + +size_t Client::peek(Data out) { + return _packet.peek(out); +} + +std::vector Client::peek(size_t size) { + return _packet.peek(size); +} + +size_t Client::read(Data out) { + const auto received = peek(out); + if (received) { + consume(received); + } + + return 0; +} + +std::vector Client::read(size_t size) { + auto received = peek(size); + consume(received.size()); + return received; +} + +void Client::stop_completion(ClientCancelation::Type type, CompletionPtr ptr) { + switch (type) { + case ClientCancelation::Type::Connection: + stop_connection(ptr); + break; + + case ClientCancelation::Type::Reading: + stop_reader(ptr); + break; + + case ClientCancelation::Type::Writing: + stop_writer(ptr); + break; + } +} + +template +void try_timeouts_impl(Timeouts& timeouts, T&& handler) { + timeouts.remove_if( + [](const typename Timeouts::value_type& timeout) { + return timeout.completion.expired(); + }); + + if (timeouts.empty()) { + return; + } + + const auto now = TimeSource::now(); + for (auto& timeout : timeouts) { + if (now - timeout.start < timeout.timeout) { + continue; + } + + auto ptr = timeout.completion.lock(); + if (!ptr) { + continue; + } + + ptr->set_error(std::errc::timed_out); + handler(timeout, ptr); + } +} + +void Client::try_timeouts() { + try_timeouts_impl( + _execution.timeouts, + [&](Timeout& timeout, CompletionPtr ptr) { + stop_completion(timeout.type, ptr); + }); +} + +template +void Client::try_requests(Requests& requests, Handler&& handler) { + try_timeouts(); + for (;;) { + const auto it = requests.begin(); + if (it == requests.end()) { + return; + } + + switch (handler(*it)) { + // still a pending request, but cannot continue just this moment + case std::errc::operation_in_progress: + return; + + // (possibly) interrupted by external means, cannot use `it` + case std::errc::interrupted: + return; + + // everything else should discard the request + default: + break; + } + + requests.erase(it); + } +} + +template +void Client::stop_requests(Requests& requests, CompletionPtr ptr) { + auto it = std::find_if( + requests.begin(), + requests.end(), + [&](const Request& request) { + return request.ptr == ptr; + }); + + if (it != requests.end()) { + IoCompletionPtr ptr; + std::swap((*it).ptr, ptr); + requests.erase(it); + ptr->set_error(std::errc::operation_canceled); + } +} + +std::errc Client::read_some(ReadRequest& request) { + const size_t pending = _packet.size(); + if (!pending) { + return std::errc::operation_in_progress; + } + + auto completion = request.ptr; + + const auto chunk = std::min(request.size, pending); + if (request.at_most && chunk != 0) { + completion->notify_partial(pending); + return std::errc::operation_in_progress; + } + + auto err = std::errc::operation_in_progress; + if (chunk == request.size) { + err = std::errc{}; + } + + if (chunk > 0) { + completion->notify_total(chunk); + } + + if (completion->is_done()) { + err = std::errc::interrupted; + } + + return err; +} + +void Client::try_readers() { + try_requests(_execution.readers, + [&](ReadRequest& request) -> std::errc { + return read_some(request); + }); +} + +void Client::stop_reader(CompletionPtr ptr) { + stop_requests(_execution.readers, ptr); +} + +void Client::try_writers() { + try_requests(_execution.writers, + [&](WriteRequest& request) -> std::errc { + return write_some(request); + }); +} + +void Client::stop_writer(CompletionPtr ptr) { + stop_requests(_execution.writers, ptr); +} + +void Client::start_reader(IoCompletionPtr ptr, size_t size, bool at_most) { + _execution.readers.push_back(ReadRequest{ + .ptr = std::move(ptr), + .size = size, + .at_most = at_most, + }); +} + +std::errc Client::write_some(WriteRequest& request) { + for (;;) { + if (request.current == request.data.end()) { + return std::errc{}; + } + + if (!_control) { + return std::errc::not_connected; + } + + const size_t available = _control.write_buffer(); + auto& current = *request.current; + + auto size = std::min(current.size(), available); + if (!size) { + return std::errc::operation_in_progress; + } + + auto flags = TCP_WRITE_FLAG_COPY; + if ((request.data.end() != (request.current + 1)) + || (size != current.size())) + { + flags |= TCP_WRITE_FLAG_MORE; + } + + const auto* ptr = current.data(); + + if (pointerInFlash(ptr)) { + size = std::min(size, RecvBufferSize); + request.copy_buffer.resize(size); + memcpy_P( + request.copy_buffer.data(), ptr, + request.copy_buffer.size()); + ptr = request.copy_buffer.data(); + } + + auto completion = request.ptr; + + auto err = _control.write(ptr, size, flags); + if (err != ERR_OK) { + completion->set_error(from_lwip_error(err)); + return std::errc::interrupted; + } + + if (current.size() - size != 0) { + current = current.subspan(size); + } else { + std::advance(request.current, 1); + } + + err = _control.try_send(); + if (err != ERR_OK) { + completion->set_error(from_lwip_error(err)); + return std::errc::interrupted; + } + + completion->notify_progress(size); + if (completion->is_done()) { + return std::errc::interrupted; + } + } + + return std::errc{}; +} + +void Client::start_writer(IoCompletionPtr ptr, ConstDataSequence data) { + _execution.writers.push_back(WriteRequest{ + .ptr = std::move(ptr), + .data = std::move(data), + .current = data.begin(), + .start = TimeSource::now(), + }); +} + +std::shared_ptr make_basic_completion(BasicCompletion::Result&& result) { + return std::make_shared(std::move(result)); +} + +std::errc connect(Client& client, Address address) { + auto out = std::errc::operation_in_progress; + + auto cancelation = connect_async( + client, + address, + [&](std::errc err) { + out = err; + }); + cancelation.wait(client); + + return out; +} + +ClientCancelation connect_async(Client& client, Address address, BasicCompletion::Result&& result) { + ClientCancelation out; + if (client._execution.connection) { + result(std::errc::operation_in_progress); + return out; + } + + const auto err = client.connect(address); + if (err != std::errc{}) { + result(err); + return out; + } + + auto completion = make_basic_completion(std::move(result)); + client._execution.connection = completion; + + return ClientCancelation(ClientCancelation::Type::Connection, completion); +} + +err_t Client::close() { + complete(ERR_CLSD); + return _control.close(); +} + +void Client::complete(err_t err) { + const auto errc = from_lwip_error(err); + + decltype(_execution) execution; + std::swap(_execution, execution); + if (execution.connection) { + execution.connection->set_error(errc); + } + + for (auto& reader : execution.readers) { + reader.ptr->set_error(errc); + } + + for (auto& writer : execution.writers) { + writer.ptr->set_error(errc); + } +} + +err_t Client::abort(err_t err) { + detach(); + complete(err); + return _control.abort(err); +} + +void Client::detach() { + _control.detach(); +} + +err_t Client::on_recv(pbuf* pb, err_t err) { + // safe to assume on err, just abort + if (!pb || (err != ERR_OK)) { + return close(); + } + + _packet.append(pb); + try_readers(); + + return _control.error(); +} + +ClientCancelation write_async(Client& client, ConstDataSequence data, IoCompletion::Result&& result) { + ClientCancelation out; + if (!client.connected()) { + result(0, std::errc::not_connected); + return out; + } + + auto completion = make_io_completion( + data.size(), + std::move(result)); + client.start_writer(completion, std::move(data)); + + return ClientCancelation(ClientCancelation::Type::Writing, completion); +} + +WriteResult write(Client& client, ConstDataSequence data) { + WriteResult out; + out.size = 0; + out.err = std::errc::operation_in_progress; + + auto completion = write_async( + client, + std::move(data), + [&](size_t size, std::errc err) { + out.size = size; + out.err = err; + }); + + for (;;) { + if (out.err != std::errc::operation_in_progress) { + break; + } + + client.try_writers(); + time::delay(duration::Milliseconds{ 10 }); + } + + return out; +} + +ClientCancelation read_async(Client& client, size_t size, bool at_most, IoCompletion::Result&& result) { + ClientCancelation out; + if (!client.connected()) { + result(0, std::errc::not_connected); + return out; + } + + auto completion = make_io_completion(size, std::move(result)); + client.start_reader(completion, size, at_most); + + return ClientCancelation(ClientCancelation::Type::Reading, completion); +} + +ReadResult read(Client& client, size_t size) { + ReadResult out; + out.err = std::errc::operation_in_progress; + + size_t received = 0; + auto cancelation = read_async( + client, + size, + false, + [&](size_t size, std::errc err) { + out.err = err; + received = size; + }); + cancelation.wait(client); + + out.data = client.read(received); + + return out; +} + +Server::~Server() { + close(); +} + +Address Server::local() const { + return local_address(_control); +} + +bool Server::accept(AcceptCompletionPtr ptr, bool retry) { + if (_control) { + _accept.push_back( + Server::Accept{ + .completion = std::move(ptr), + .retry = retry, + }); + return true; + } + + return false; +} + +ServerCancelation accept(Server& server, bool retry, AcceptCompletion::Result&& result) { + auto completion = std::make_shared(std::move(result)); + if (!server.accept(completion, retry)) { + return ServerCancelation(); + } + + return ServerCancelation(completion); +} + +ServerCancelation accept_once(Server& server, AcceptCompletion::Result&& result) { + return accept(server, false, std::move(result)); +} + +ServerCancelation accept(Server& server, AcceptCompletion::Result&& result) { + return accept(server, true, std::move(result)); +} + +void AcceptCompletion::set_error(std::errc err) { + set_done(nullptr, err); +} + +void AcceptCompletion::set_done(ClientPtr ptr, std::errc err) { + if (!_done) { + _result(std::move(ptr), err); + _done = true; + } +} + +err_t Server::on_accept(tcp_pcb* pcb, err_t err) { + if (!_accept.size()) { + return ERR_VAL; + } + + const auto errc = from_lwip_error(err); + auto it = _accept.begin(); + + auto accept = (*it); + if (!pcb) { + accept.completion->set_error(errc); + } else if (err != ERR_OK) { + accept.completion->set_error(errc); + } else { + accept.completion->set_done( + std::make_unique(pcb), std::errc{}); + } + + const auto last = _control.error(); + if (last != ERR_OK) { + return err; + } + + if (!accept.retry) { + _accept.erase(it); + return ERR_OK; + } + + accept.completion->retry(); + + return ERR_OK; +} + +void Server::complete(err_t err) { + const auto errc = from_lwip_error(err); + + decltype(_accept) accept; + std::swap(_accept, accept); + for (auto& handler : accept) { + handler.completion->set_error(errc); + } +} + +err_t Server::close() { + complete(ERR_CLSD); + return _control.close(); +} + +err_t Server::listen(Address address) { + return _control.attach( + address, + Control::ServerHandler{ + .arg = this, + .on_accept = s_on_tcp_accept, + }); +} + +void Server::detach() { + _control.detach(); +} + +void Server::try_timeouts() { + try_timeouts_impl( + _timeouts, + [&](const Timeout&, CompletionPtr ptr) { + stop_completion(ptr); + }); +} + +void Server::stop_completion(CompletionPtr ptr) { + _accept.erase(std::remove_if( + _accept.begin(), + _accept.end(), + [ptr](const Accept& accept) { + return accept.completion == ptr; + })); +} + +void ServerCancelation::cancel(Server& server) { + auto ptr = _ptr.lock(); + if (!ptr) { + return; + } + + server.stop_completion(ptr); +} + +void ServerCancelation::wait_for(Server& server, TimeSource::duration timeout, TimeSource::duration interval) { + const auto result = poll( + timeout, interval, + [&]() { + return !_ptr.expired(); + }); + + if (result) { + return; + } + + const auto ptr = _ptr.lock(); + if (!ptr) { + return; + } + + ptr->set_error(std::errc::timed_out); + server.stop_completion(ptr); +} + +void ServerCancelation::wait_for(Server& server, TimeSource::duration timeout) { + wait_for(server, timeout, duration::Milliseconds{ 10 }); +} + +void ServerCancelation::wait(Server& server) { + for (;;) { + if (_ptr.expired()) { + return; + } + + time::delay(duration::Milliseconds{ 10 }); + } +} + +} // namespace tcp + namespace dns { namespace { @@ -198,6 +1307,60 @@ void netstat(::terminal::CommandContext&& ctx) { } } +PROGMEM_STRING(Ncat, "NCAT"); + +void ncat(::terminal::CommandContext&& ctx) { + if (ctx.argv.size() != 3) { + terminalError(ctx, F("NCAT ")); + return; + } + + const auto host = dns::gethostbyname(ctx.argv[1]); + if (!host.isSet()) { + terminalError(ctx, F("cannot resolve")); + return; + } + + const auto port = settings::internal::convert(ctx.argv[2]); + + auto* pcb = tcp_new(); + if (!pcb) { + terminalError(ctx, F("no pcb")); + return; + } + + const auto serialize_error = [&](std::errc err) { + String error; + error += "error "; + error += (int)err; + return error; + }; + + tcp::Client x(pcb); + const auto result = tcp::connect(x, Address{host, port}); + if (result != std::errc{}) { + terminalError(ctx, serialize_error(result)); + return; + } + + auto cancelation = tcp::read_async(x, 1024, true, + [&](size_t size, std::errc err) { + if (err != std::errc{}) { + terminalError(ctx, serialize_error(err)); + return; + } + + const auto data = x.read(size); + ctx.output.printf("notified %zu read %zu\n", + size, data.size()); + ctx.output.println( + hexEncode(data.data(), data.data() + data.size())); + }); + cancelation.wait_for(x, duration::Seconds{ 5 }); + + terminalOK(ctx); +} + #if SECURE_CLIENT == SECURE_CLIENT_BEARSSL PROGMEM_STRING(MflnProbe, "MFLN.PROBE"); @@ -227,6 +1390,7 @@ void mfln_probe(::terminal::CommandContext&& ctx) { static constexpr ::terminal::Command List[] PROGMEM { {Host, host}, {Netstat, netstat}, + {Ncat, ncat}, #if SECURE_CLIENT == SECURE_CLIENT_BEARSSL {MflnProbe, mfln_probe}, #endif diff --git a/code/espurna/network.h b/code/espurna/network.h index 003719fb..708a541a 100644 --- a/code/espurna/network.h +++ b/code/espurna/network.h @@ -12,14 +12,623 @@ Copyright (C) 2022 by Maxim Prokhorov #include #include +#include +#include +#include +#include +#include +#include #include +#include #include +#include "system.h" #include "types.h" namespace espurna { namespace network { + +using TimeSource = time::SystemClock; + +struct Address { + ip_addr_t ip; + uint16_t port; +}; + +using Data = Span; +using ConstData = Span; + +struct ConstDataSequence { + using value_type = ConstData; + using container_type = std::vector; + + using iterator = container_type::iterator; + using const_iterator = container_type::const_iterator; + + size_t size() const; + + iterator begin() { + return data.begin(); + } + + const_iterator begin() const { + return data.begin(); + } + + iterator end() { + return data.end(); + } + + const_iterator end() const { + return data.end(); + } + + container_type data; +}; + +struct Packet { + static constexpr auto RecvMax = size_t{ std::numeric_limits::max() }; + + Packet() = default; + ~Packet(); + + Packet(const Packet&) = delete; + Packet& operator=(const Packet&) = delete; + + Packet(Packet&&) noexcept; + Packet& operator=(Packet&&) noexcept; + + void append(pbuf*); + void consume(size_t); + + size_t size_chunk() const; + ConstData peek_chunk(size_t); + + size_t size() const; + + size_t peek(Data); + std::vector peek(size_t); + + size_t read(Data); + std::vector read(size_t); + +private: + pbuf* _pbuf { nullptr }; +}; + +namespace tcp { + +class Client; + +struct Completion { + virtual ~Completion() = default; + + virtual bool is_done() const = 0; + virtual void set_done() = 0; + virtual void set_error(std::errc) = 0; +}; + +using CompletionPtr = std::shared_ptr; + +struct IoCompletion : public Completion { + using Result = std::function; + + IoCompletion() = delete; + + template + IoCompletion(size_t size, T&& result) : + _result(std::forward(result)), + _total(size) + {} + + bool is_done() const override { + return _done; + } + + void set_done() override { + set_done(std::errc{}); + } + + void set_error(std::errc err) override { + set_done(err); + } + + void notify_progress(size_t size) { + if (_done) { + return; + } + + _size += size; + if (_size == _total) { + set_done(std::errc{}); + } + } + + void notify_total(size_t size) { + if (_done) { + return; + } + + _size = size; + if (_size == _total) { + set_done(std::errc{}); + } + } + + void notify_partial(size_t size) { + if (_done) { + return; + } + + _size = size; + set_done(std::errc{}); + retry(size); + } + + size_t size() const { + return _size; + } + + size_t total() const { + return _total; + } + + void retry(size_t total) { + _done = false; + _size = 0; + _total = total; + } + +private: + void set_done(std::errc err) { + if (!_done) { + _result(_size, err); + _done = true; + } + } + + Result _result; + + size_t _size { 0 }; + size_t _total { 0 }; + + bool _done { false }; +}; + +using IoCompletionPtr = std::shared_ptr; + +struct BasicCompletion : public Completion { + using Result = std::function; + + template + BasicCompletion(T&& result) : + _result(std::forward(result)) + {} + + bool is_done() const override { + return _done; + } + + void set_done() override { + set_done(std::errc{}); + } + + void set_error(std::errc err) override { + set_done(err); + } + + void reset() { + _done = false; + } + +private: + void set_done(std::errc err) { + if (!_done) { + _result(err); + _done = true; + } + } + + Result _result; + bool _done { false }; +}; + +struct ClientCancelation { + enum class Type { + Connection, + Reading, + Writing, + }; + + ClientCancelation() = default; + ClientCancelation(Type type, CompletionPtr ptr) : + _type(type), + _ptr(ptr) + {} + + explicit operator bool() const { + return _ptr.expired(); + } + + bool set_timeout(Client&, TimeSource::duration); + + void wait_for(Client&, TimeSource::duration); + void wait_for(Client&, TimeSource::duration, TimeSource::duration); + + void wait(Client&); + + void cancel(Client&); + +private: + bool try_completion_once(Client&); + void try_completion(Client&, TimeSource::duration); + + Type _type; + std::weak_ptr _ptr; +}; + +[[nodiscard]] +ClientCancelation connect_async(Client&, Address, BasicCompletion::Result&&); +std::errc connect(Client&, Address); + +struct ReadResult { + std::vector data; + std::errc err { std::errc::invalid_argument }; +}; + +[[nodiscard]] +ClientCancelation read_async(Client&, size_t, bool, IoCompletion::Result&&); +ReadResult read(Client&, size_t); + +struct WriteResult { + size_t size { 0 }; + std::errc err { std::errc::invalid_argument }; +}; + +[[nodiscard]] +ClientCancelation write_async(Client&, ConstDataSequence, IoCompletion::Result&&); +WriteResult write(Client&, ConstDataSequence); + +class Control { +public: + struct ClientHandler { + void* arg; + tcp_err_fn on_error; + tcp_poll_fn on_poll; + tcp_recv_fn on_recv; + tcp_sent_fn on_sent; + }; + + struct ServerHandler { + void* arg; + tcp_accept_fn on_accept; + }; + + Control() = default; + explicit Control(tcp_pcb* pcb) : + _pcb(pcb) + {} + + explicit operator bool() const { + return _pcb != nullptr; + } + + Control(const Control&) = delete; + Control& operator=(const Control&) = delete; + + Control(Control&&) noexcept; + Control& operator=(Control&&) noexcept; + + void set_nodelay(); + void set_nagle(); + + err_t connect(Address, tcp_connected_fn); + + err_t attach(Address, ServerHandler); + + err_t attach(tcp_pcb*, ClientHandler); + err_t attach(ClientHandler); + + tcp_state state() const { + if (_pcb) { + return _pcb->state; + } + + return CLOSED; + } + + err_t error() const { + return _last_err; + } + + tcp_pcb* get() const { + return _pcb; + } + + void recv_consume(size_t); + + size_t write_buffer() const; + err_t write(const uint8_t*, size_t, uint8_t); + err_t try_send(); + + err_t close() noexcept; + void detach() noexcept; + + err_t abort(err_t); + err_t abort() { + return abort(ERR_ABRT); + } + +private: + err_t _last_err { ERR_OK }; + tcp_pcb* _pcb { nullptr }; +}; + +class Client { +public: + friend Completion; + friend IoCompletion; + + friend ClientCancelation; + + friend ClientCancelation connect_async(Client&, Address, BasicCompletion::Result&&); + friend std::errc connect(Client&, Address); + + friend ClientCancelation write_async(Client&, ConstDataSequence, IoCompletion::Result&&); + friend WriteResult write(Client&, ConstDataSequence); + + friend ClientCancelation read_async(Client&, size_t, bool, IoCompletion::Result&&); + friend ReadResult read(Client&, size_t); + + static constexpr auto RecvBufferSize = size_t{ 536 }; + + Client() = delete; + explicit Client(tcp_pcb*) noexcept; + explicit Client(Control&&) noexcept; + + // tcp_pcb can be held by only one client + Client(const Client&) = delete; + Client& operator=(const Client&) = delete; + + // but it is allowed to transfer ownership + Client(Client&&) noexcept; + Client& operator=(Client&&) noexcept; + + ~Client() { + close(); + } + + std::errc connect(Address); + + err_t close(); + + err_t error() const { + return _control.error(); + } + + Address local() const; + Address remote() const { + return _remote; + } + + size_t available() const; + size_t available_chunk() const; + + bool connected() const; + bool closed() const; + void consume(size_t); + + ConstData peek_chunk(size_t); + ConstData read_chunk(size_t); + + size_t peek(Data); + std::vector peek(size_t); + + size_t read(Data); + std::vector read(size_t); + + size_t write(ConstData); + size_t write(ConstDataSequence); + +private: + err_t abort(err_t); + err_t abort() { + return abort(ERR_ABRT); + } + + void detach(); + void complete(err_t); + + err_t on_connect(err_t); + static err_t s_on_tcp_connect(void* arg, tcp_pcb*, err_t err) { + return reinterpret_cast(arg)->on_connect(err); + } + + static void s_on_tcp_error(void* arg, err_t err) { + reinterpret_cast(arg)->abort(err); + } + + err_t on_poll(); + static err_t s_on_tcp_poll(void* arg, tcp_pcb* pcb) { + return reinterpret_cast(arg)->on_poll(); + } + + err_t on_sent(uint16_t); + static err_t s_on_tcp_sent(void* arg, tcp_pcb*, uint16_t len) { + return reinterpret_cast(arg)->on_sent(len); + } + + err_t on_recv(pbuf*, err_t); + static err_t s_on_tcp_recv(void* arg, tcp_pcb*, pbuf* pb, err_t err) { + return reinterpret_cast(arg)->on_recv(pb, err); + } + + Address _remote; + Control _control; + Packet _packet; + + struct Timeout { + TimeSource::duration timeout; + TimeSource::time_point start; + ClientCancelation::Type type; + std::weak_ptr completion; + }; + + struct ReadRequest { + IoCompletionPtr ptr; + size_t size; + bool at_most; + }; + + struct WriteRequest { + IoCompletionPtr ptr; + ConstDataSequence data; + ConstDataSequence::iterator current; + std::vector copy_buffer; + TimeSource::time_point start; + }; + + struct Execution { + std::forward_list timeouts; + CompletionPtr connection; + std::vector readers; + std::vector writers; + }; + + Execution _execution; + + void try_timeouts(); + void stop_completion(ClientCancelation::Type, CompletionPtr); + + void stop_connection(CompletionPtr); + + template + void try_requests(Requests&, Handler&&); + + template + void stop_requests(Requests&, CompletionPtr); + + void start_reader(IoCompletionPtr, size_t, bool); + void stop_reader(CompletionPtr); + + std::errc read_some(ReadRequest&); + void try_readers(); + + void start_writer(IoCompletionPtr, ConstDataSequence data); + void stop_writer(CompletionPtr); + + std::errc write_some(WriteRequest&); + void try_writers(); +}; + +struct AcceptCompletion : public Completion { + using ClientPtr = std::unique_ptr; + using Result = std::function; + + template + explicit AcceptCompletion(T&& result) noexcept : + _result(std::move(result)) + {} + + void set_done() override { + _done = true; + } + + void set_error(std::errc) override; + + void set_done(ClientPtr, std::errc); + void retry() { + _done = false; + } + +private: + bool _done { false }; + Result _result; +}; + +class Server; + +struct ServerCancelation { + ServerCancelation() = default; + + explicit ServerCancelation(CompletionPtr ptr) : + _ptr(ptr) + {} + + explicit operator bool() const { + return _ptr.expired(); + } + + // server cancelations are always polling in current context, + // `set_timeout` not yet implemented to allow async wait + + void wait_for(Server&, TimeSource::duration); + void wait_for(Server&, TimeSource::duration, TimeSource::duration); + + void wait(Server&); + + void cancel(Server&); + +private: + std::weak_ptr _ptr; +}; + +ServerCancelation accept(Server&, bool, AcceptCompletion::Result&&); + +ServerCancelation accept_once(Server&, AcceptCompletion::Result&&); +ServerCancelation accept(Server&, AcceptCompletion::Result&&); + +class Server { +public: + friend ServerCancelation; + friend ServerCancelation accept(Server&, bool, AcceptCompletion::Result&&); + + Server() = default; + ~Server(); + + Address local() const; + + err_t close(); + err_t listen(Address); + +private: + using AcceptCompletionPtr = std::shared_ptr; + + err_t on_accept(tcp_pcb*, err_t); + static err_t s_on_tcp_accept(void* arg, tcp_pcb* pcb, err_t err) { + return reinterpret_cast(arg)->on_accept(pcb, err); + } + + void detach(); + + bool accept(AcceptCompletionPtr, bool retry); + + void complete(err_t); + void stop_completion(CompletionPtr); + + void try_timeouts(); + + struct Accept { + AcceptCompletionPtr completion; + bool retry; + }; + + std::vector _accept; + + struct Timeout { + TimeSource::duration timeout; + TimeSource::time_point start; + std::weak_ptr completion; + }; + + std::forward_list _timeouts; + + Control _control; +}; + +} // namespace tcp + namespace dns { struct Host { diff --git a/code/espurna/types.h b/code/espurna/types.h index 6922ab17..232eaba3 100644 --- a/code/espurna/types.h +++ b/code/espurna/types.h @@ -238,6 +238,14 @@ private: bool& _handle; }; +// common comparison would use >=0x40000000 +// instead, slightly reduce the footprint by +// checking *only* for numbers below it +inline bool pointerInFlash(const void* ptr) { + static constexpr uintptr_t Mask { 1 << 30 }; + return (reinterpret_cast(ptr) & Mask) > 0; +} + struct StringView { constexpr StringView() noexcept : _ptr(nullptr), @@ -347,11 +355,7 @@ private: } #else static bool inFlash(const char* ptr) { - // common comparison would use >=0x40000000 - // instead, slightly reduce the footprint by - // checking *only* for numbers below it - static constexpr uintptr_t Mask { 1 << 30 }; - return (reinterpret_cast(ptr) & Mask) > 0; + return pointerInFlash(ptr); } #endif @@ -409,4 +413,284 @@ inline String operator+(StringView lhs, const String& rhs) { #define STRING_VIEW_SETTING(X)\ ((__builtin_strlen(X) > 0) ? STRING_VIEW(X) : StringView()) +// ref. https://en.cppreference.com/w/cpp/types/type_identity + +template +struct TypeIdentity { + using type = T; +}; + +// ref. +// - https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2018/p0122r7.html +// - https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p1976r2.html +// - https://github.com/microsoft/STL/issues/4 +// - https://github.com/microsoft/GSL/blob/main/include/gsl/span + +template +struct SpanIterator { +#if __cplusplus > 201103L +#define SPAN_ITERATOR_CONSTEXPR constexpr +#else +#define SPAN_ITERATOR_CONSTEXPR +#endif + using iterator_category = std::random_access_iterator_tag; + + using difference_type = std::ptrdiff_t; + using pointer = T*; + using reference = T&; + using value_type = typename std::remove_cv::type; + + SpanIterator() = delete; + constexpr SpanIterator(pointer begin, pointer end, pointer current) : + _begin(begin), + _end(end), + _current(current) + {} + + constexpr reference operator*() const noexcept { + return *_current; + } + + constexpr pointer operator->() const noexcept { + return _current; + } + + constexpr SpanIterator& operator++() noexcept { + ++_current; + return *this; + } + + constexpr SpanIterator operator++(int) noexcept { + auto& self = *this; + SpanIterator tmp{self}; + ++self; + return self; + } + + constexpr SpanIterator& operator--() noexcept { + --_current; + return *this; + } + + constexpr SpanIterator operator--(int) noexcept { + auto& self = *this; + SpanIterator tmp{self}; + --self; + return self; + } + + constexpr SpanIterator& operator+=(const difference_type offset) noexcept { + _current += offset; + return *this; + } + + constexpr SpanIterator operator+(const difference_type offset) noexcept { + SpanIterator out{*this}; + out += offset; + return out; + } + + constexpr SpanIterator& operator-=(const difference_type offset) noexcept { + _current -= offset; + return *this; + } + + constexpr SpanIterator operator-(const difference_type offset) noexcept { + SpanIterator out{*this}; + out -= offset; + return out; + } + + constexpr difference_type operator-(const SpanIterator& other) const noexcept { + return _current - other._current; + } + + constexpr reference operator[](const difference_type offset) const noexcept { + return *(*this + offset); + } + + constexpr bool operator==(const SpanIterator& other) const noexcept { + return _current == other._current; + } + + constexpr bool operator!=(const SpanIterator& other) const noexcept { + return _current != other._current; + } + + constexpr bool operator<(const SpanIterator& other) const noexcept { + return _current < other._current; + } + + constexpr bool operator>(const SpanIterator& other) const noexcept { + return _current > other._current; + } + + constexpr bool operator<=(const SpanIterator& other) const noexcept { + return _current <= other._current; + } + + constexpr bool operator>=(const SpanIterator& other) const noexcept { + return _current <= other._current; + } + +private: + pointer _begin; + pointer _end; + pointer _current; + +#undef SPAN_ITERATOR_CONSTEXPR +}; + +// storage helper. either store size in type info, or as a member +// using the same magic trick as most implementations +// - limits::max() holds member inside of the struct +// - everything else is encoded in type + +auto constexpr SpanDynamicExtent = std::numeric_limits::max(); + +template +struct __SpanBase { + constexpr __SpanBase() noexcept = default; + constexpr explicit __SpanBase(size_t) noexcept { + } + + constexpr size_t size() const noexcept { + return Size; + } +}; + +template <> +struct __SpanBase { + constexpr __SpanBase() noexcept = default; + constexpr explicit __SpanBase(size_t size) noexcept : + _size(size) + {} + + constexpr size_t size() const noexcept { + return _size; + } +private: + size_t _size{}; +}; + +template <> +struct __SpanBase<0> { + constexpr __SpanBase() = delete; + constexpr explicit __SpanBase(size_t) noexcept = delete; +}; + +template +struct Span : public __SpanBase { + using element_type = T; + using value_type = typename std::remove_cv::type; + using size_type = size_t; + using difference_type = std::ptrdiff_t; + using pointer = T*; + using const_pointer = const T*; + using reference = T&; + using const_reference = const T&; + using iterator = SpanIterator; + + static constexpr size_t extent = Extent; + + constexpr Span() = default; + constexpr Span(const Span&) = default; + constexpr Span& operator=(const Span&) = default; + + constexpr Span(Span&&) = default; + constexpr Span& operator=(Span&&) = default; + + constexpr explicit Span(pointer data) noexcept : + __SpanBase{}, + _data(data) + {} + + constexpr Span(pointer data, size_t size) noexcept : + __SpanBase{size}, + _data(data) + {} + + constexpr Span(pointer first, pointer last) noexcept : + __SpanBase{last - first}, + _data(first) + {} + + template + constexpr Span(typename TypeIdentity::type (&data)[Size]) noexcept : + Span(&data[0], Size) + {} + + template + constexpr Span(typename std::array& data) noexcept : + __SpanBase{}, + _data(data.data()) + {} + + template + constexpr Span(const typename std::array& data) noexcept : + __SpanBase{}, + _data(data.data()) + {} + + constexpr reference operator[](size_t index) const { + return _data[index]; + } + + constexpr pointer data() const noexcept { + return _data; + } + + constexpr iterator begin() const noexcept { + return {_data, _data + size(), &_data[0]}; + } + + constexpr iterator end() const noexcept { + return {_data, _data + size(), &_data[size()]}; + } + + constexpr size_type size() const { + return __SpanBase::size(); + } + + constexpr Span subspan(size_type offset) const { + return {data() + offset, size() - offset}; + } + + constexpr reference front() const noexcept { + return _data[0]; + } + + constexpr reference back() const noexcept { + return _data[size() - 1]; + } + +private: + T* _data; +}; + +template +inline constexpr Span make_span(typename TypeIdentity::type (&data)[Size]) { + return Span(data); +} + +template +inline constexpr Span make_span(typename std::array& data) { + return Span(data); +} + +template +inline constexpr Span make_span(const typename std::array& data) { + return Span(data); +} + +template +inline constexpr Span make_span(std::vector& data) { + return Span(data.data(), data.size()); +} + +template +inline constexpr Span make_span(const std::vector& data) { + return Span(data.data(), data.size()); +} + } // namespace espurna