From 175c04acc1afeacdd17f7a7721ebaa21ce5ef347 Mon Sep 17 00:00:00 2001 From: Maxim Prokhorov Date: Wed, 15 Mar 2023 17:01:54 +0300 Subject: [PATCH] system: task'ify every pending dns request allow more than one request, until lwip api rejects the call --- code/espurna/network.cpp | 203 +++++++++++++++++++++------------------ code/espurna/network.h | 33 ++++++- code/espurna/telnet.cpp | 27 +++--- 3 files changed, 157 insertions(+), 106 deletions(-) diff --git a/code/espurna/network.cpp b/code/espurna/network.cpp index 16f05f48..aa16cbc0 100644 --- a/code/espurna/network.cpp +++ b/code/espurna/network.cpp @@ -8,9 +8,10 @@ Copyright (C) 2022 by Maxim Prokhorov #include "espurna.h" +#include #include -#include #include +#include #include #include @@ -26,79 +27,133 @@ extern "C" struct tcp_pcb *tcp_tw_pcbs; namespace espurna { namespace network { +namespace dns { namespace { -namespace dns { +struct PendingHost { + HostPtr ptr; + HostCallback callback; +}; + namespace internal { -struct Task { - Task() = delete; - explicit Task(String hostname, IpFoundCallback callback) : - _hostname(std::move(hostname)), - _callback(std::move(callback)) - {} +using Pending = std::forward_list; +Pending pending; - IPAddress addr() const { - return _addr; - } +} // namespace internal - const String& hostname() const { - return _hostname; - } +void dns_found_callback_impl(const char* name, const ip_addr_t* addr, void* arg) { + auto* pending = reinterpret_cast(arg); - void found_callback(const char* name, const ip_addr_t* addr, void*) { - _callback(name, addr); + if (addr) { + pending->addr = addr; + pending->err = ERR_OK; + } else { + pending->err = ERR_ABRT; } - void found_callback() { - _callback(_hostname, _addr); - } + internal::pending.remove_if( + [&](const PendingHost& lhs) { + if (lhs.ptr.get() == pending) { + if (lhs.callback) { + lhs.callback(lhs.ptr); + } -private: - IPAddress _addr { IPADDR_NONE }; - String _hostname; + return true; + } - IpFoundCallback _callback; -}; + return false; + }); +} -using TaskPtr = std::unique_ptr; -TaskPtr task; +HostPtr resolve_impl(String hostname, HostCallback callback) { + auto host = std::make_shared( + Host{ + .name = std::move(hostname), + .addr = IPAddress{}, + .err = ERR_INPROGRESS, + }); + + const auto err = dns_gethostbyname( + host->name.c_str(), + host->addr, + dns_found_callback_impl, + host.get()); + + host->err = err; + + switch (err) { + case ERR_OK: + case ERR_MEM: + if (callback) { + callback(host); + } + break; -void found_callback(const char* name, const ip_addr_t* addr, void* arg) { - if (task) { - task->found_callback(name, addr, arg); - task.reset(); + case ERR_INPROGRESS: + internal::pending.push_front( + PendingHost{ + .ptr = host, + .callback = std::move(callback) + }); + break; } + + return host; } -} // namespace internal +} // namespace -bool started() { - return static_cast(internal::task); +HostPtr resolve(String hostname) { + return resolve_impl(hostname, nullptr); } -void start(String hostname, IpFoundCallback callback) { - auto task = std::make_unique( - std::move(hostname), std::move(callback)); +void resolve(String hostname, HostCallback callback) { + if (!callback) { + return; + } - const auto result = dns_gethostbyname( - task->hostname().c_str(), task->addr(), - internal::found_callback, nullptr); + resolve_impl(std::move(hostname), std::move(callback)); +} - switch (result) { - // No need to wait, return result immediately - case ERR_OK: - task->found_callback(); - break; - // Task needs to linger for a bit - case ERR_INPROGRESS: - internal::task = std::move(task); - break; +bool wait_for(HostPtr ptr, duration::Milliseconds timeout) { + if (ptr->err == ERR_OK) { + return true; } + + if (ptr->err != ERR_INPROGRESS) { + return false; + } + + time::blockingDelay( + timeout, + duration::Milliseconds{ 10 }, + [&]() { + return ptr->err == ERR_INPROGRESS; + }); + + return ptr->err == ERR_OK; +} + +IPAddress gethostbyname(String hostname, duration::Milliseconds timeout) { + IPAddress out; + + auto result = resolve(hostname); + if (wait_for(result, timeout)) { + out = result->addr; + } + + return out; +} + +IPAddress gethostbyname(String hostname) { + return gethostbyname(hostname, duration::Seconds{ 3 }); } } // namespace dns +namespace { + #if TERMINAL_SUPPORT namespace terminal { namespace commands { @@ -111,20 +166,15 @@ void host(::terminal::CommandContext&& ctx) { return; } - dns::start(std::move(ctx.argv[1]), - [&](const String& name, IPAddress addr) { - if (!addr) { - ctx.output.printf_P(PSTR("%s not found\n"), name.c_str()); - return; - } - - ctx.output.printf_P(PSTR("%s has address %s\n"), - name.c_str(), addr.toString().c_str()); - }); - - while (dns::started()) { - delay(10); + const auto result = dns::gethostbyname(ctx.argv[1]); + if (result.isSet()) { + ctx.output.printf_P(PSTR("%s has address %s\n"), + ctx.argv[1].c_str(), result.toString().c_str()); + terminalOK(ctx); + return; } + + ctx.output.printf_P(PSTR("%s not found\n"), ctx.argv[1].c_str()); } PROGMEM_STRING(Netstat, "NETSTAT"); @@ -191,29 +241,6 @@ void setup() { } // namespace terminal #endif -void gethostbyname(String hostname, IpFoundCallback callback) { - dns::start(std::move(hostname), std::move(callback)); -} - -IPAddress gethostbyname(String hostname) { - IPAddress out; - - dns::start(std::move(hostname), - [&](const String& name, IPAddress ip) { - if (!ip.isSet()) { - return; - } - - out = ip; - }); - - while (dns::started()) { - delay(10); - } - - return out; -} - void setup() { #if TERMINAL_SUPPORT terminal::setup(); @@ -224,14 +251,6 @@ void setup() { } // namespace network } // namespace espurna -void networkGetHostByName(String hostname, espurna::network::IpFoundCallback callback) { - return espurna::network::gethostbyname(std::move(hostname), std::move(callback)); -} - -IPAddress networkGetHostByName(String hostname) { - return espurna::network::gethostbyname(std::move(hostname)); -} - void networkSetup() { espurna::network::setup(); } diff --git a/code/espurna/network.h b/code/espurna/network.h index 36e79e59..003719fb 100644 --- a/code/espurna/network.h +++ b/code/espurna/network.h @@ -11,14 +11,41 @@ Copyright (C) 2022 by Maxim Prokhorov #include #include +#include + +#include +#include + +#include "types.h" + namespace espurna { namespace network { +namespace dns { + +struct Host { + String name; + IPAddress addr; + err_t err; +}; + +using HostPtr = std::shared_ptr; +using HostCallback = std::function; + +// DNS request is lauched in the background, HostPtr should be waited upon +HostPtr resolve(String); + +// ...or, user callback is executed when DNS client is ready to return something +void resolve(String, HostCallback); + +// Block until the HostPtr becomes available for reading, or when timeout occurs +bool wait_for(HostPtr, duration::Milliseconds); -using IpFoundCallback = std::function; +// Arduino style result +IPAddress gethostbyname(String, duration::Milliseconds); +IPAddress gethostbyname(String); +} // namespace dns } // namespace network } // namespace espurna -IPAddress networkGetHostByName(String); -void networkGetHostByName(String, espurna::network::IpFoundCallback); void networkSetup(); diff --git a/code/espurna/telnet.cpp b/code/espurna/telnet.cpp index e8e90915..60a30ba2 100644 --- a/code/espurna/telnet.cpp +++ b/code/espurna/telnet.cpp @@ -934,7 +934,7 @@ void reverse(::terminal::CommandContext&& ctx) { return; } - const auto ip = networkGetHostByName(ctx.argv[1]); + const auto ip = network::dns::gethostbyname(ctx.argv[1]); if (!ip.isSet()) { terminalError(ctx, F("Host not found")); return; @@ -984,17 +984,22 @@ void connect_url(String url) { } const auto port = parsed.port; - networkGetHostByName(std::move(parsed.host), - [port](const String& host, IPAddress ip) { - const auto addr = Address{ - .ip = ip, - .port = port, - }; - - if (!connect(addr)) { - DEBUG_MSG_P(PSTR("[TELNET] Cannot connect to %s:%hu\n"), - host.c_str(), port); + network::dns::resolve( + std::move(parsed.host), + [port](network::dns::HostPtr host) { + if (host->err == ERR_OK) { + const auto addr = Address{ + .ip = host->addr, + .port = port, + }; + + if (connect(addr)) { + return; + } } + + DEBUG_MSG_P(PSTR("[TELNET] Cannot connect to %s:%hu\n"), + host->name.c_str(), port); }); }