From 4b63d1011fb9c3d758d0274f2d077619067c86a3 Mon Sep 17 00:00:00 2001 From: Patrick Balsiger Date: Thu, 28 Aug 2025 20:46:27 +0200 Subject: [PATCH] feat: improve task handling, refactoring --- include/ApiServer.h | 8 +-- include/NodeContext.h | 3 +- include/TaskManager.h | 13 +--- platformio.ini | 1 - src/ApiServer.cpp | 93 +++++++++------------------ src/NodeContext.cpp | 2 - src/NodeInfo.cpp | 2 +- src/TaskManager.cpp | 144 +++++++++++++----------------------------- 8 files changed, 83 insertions(+), 183 deletions(-) diff --git a/include/ApiServer.h b/include/ApiServer.h index 7bf88ae..4856c39 100644 --- a/include/ApiServer.h +++ b/include/ApiServer.h @@ -11,9 +11,6 @@ #include "NodeInfo.h" #include "TaskManager.h" -using namespace std; -using namespace std::placeholders; - class ApiServer { public: ApiServer(NodeContext& ctx, TaskManager& taskMgr, uint16_t port = 80); @@ -47,7 +44,7 @@ private: std::vector> serviceRegistry; std::vector capabilityRegistry; void onClusterMembersRequest(AsyncWebServerRequest *request); - void methodToStr(const std::tuple &endpoint, ArduinoJson::V742PB22::JsonObject &apiObj); + void methodToStr(const std::tuple &endpoint, JsonObject &apiObj); void onSystemStatusRequest(AsyncWebServerRequest *request); void onFirmwareUpdateRequest(AsyncWebServerRequest *request); void onFirmwareUpload(AsyncWebServerRequest *request, const String &filename, size_t index, uint8_t *data, size_t len, bool final); @@ -59,4 +56,7 @@ private: // Capabilities endpoint void onCapabilitiesRequest(AsyncWebServerRequest *request); + + // Internal helpers + void registerServiceForLocalNode(const String& uri, int method); }; diff --git a/include/NodeContext.h b/include/NodeContext.h index f5f4bef..85e67ac 100644 --- a/include/NodeContext.h +++ b/include/NodeContext.h @@ -1,5 +1,5 @@ #pragma once -#include + #include #include #include "NodeInfo.h" @@ -11,7 +11,6 @@ class NodeContext { public: NodeContext(); ~NodeContext(); - Scheduler* scheduler; WiFiUDP* udp; String hostname; IPAddress localIP; diff --git a/include/TaskManager.h b/include/TaskManager.h index c08136f..ffe1488 100644 --- a/include/TaskManager.h +++ b/include/TaskManager.h @@ -7,10 +7,6 @@ #include "NodeContext.h" #include -// Forward declarations to avoid multiple definition errors -class Task; -class Scheduler; - // Define our own callback type to avoid conflict with TaskScheduler using TaskFunction = std::function; @@ -60,13 +56,8 @@ public: private: NodeContext& ctx; - std::vector tasks; std::vector taskDefinitions; + std::vector lastExecutionTimes; - Task* findTask(const std::string& name) const; - void createTask(const TaskDefinition& taskDef); - - // Static callback registry for all TaskManager instances - static std::map> callbackRegistry; - static void executeCallback(const std::string& taskName); + int findTaskIndex(const std::string& name) const; }; \ No newline at end of file diff --git a/platformio.ini b/platformio.ini index aba648c..c764471 100644 --- a/platformio.ini +++ b/platformio.ini @@ -17,7 +17,6 @@ monitor_speed = 115200 lib_deps = esp32async/ESPAsyncWebServer@^3.8.0 bblanchon/ArduinoJson@^7.4.2 - arkhipenko/TaskScheduler@^3.8.5 [env:esp01_1m] platform = platformio/espressif8266@^4.2.1 diff --git a/src/ApiServer.cpp b/src/ApiServer.cpp index e7059f8..535cb9e 100644 --- a/src/ApiServer.cpp +++ b/src/ApiServer.cpp @@ -1,29 +1,38 @@ #include "ApiServer.h" #include +// Shared helper for HTTP method to string +static const char* methodStrFromInt(int method) { + switch (method) { + case HTTP_GET: return "GET"; + case HTTP_POST: return "POST"; + case HTTP_PUT: return "PUT"; + case HTTP_DELETE: return "DELETE"; + case HTTP_PATCH: return "PATCH"; + default: return "UNKNOWN"; + } +} + ApiServer::ApiServer(NodeContext& ctx, TaskManager& taskMgr, uint16_t port) : server(port), ctx(ctx), taskManager(taskMgr) {} -void ApiServer::addEndpoint(const String& uri, int method, std::function requestHandler) { +void ApiServer::registerServiceForLocalNode(const String& uri, int method) { serviceRegistry.push_back(std::make_tuple(uri, method)); - // Store in NodeInfo for local node if (ctx.memberList && !ctx.memberList->empty()) { auto it = ctx.memberList->find(ctx.hostname); if (it != ctx.memberList->end()) { it->second.apiEndpoints.push_back(std::make_tuple(uri, method)); } } +} + +void ApiServer::addEndpoint(const String& uri, int method, std::function requestHandler) { + registerServiceForLocalNode(uri, method); server.on(uri.c_str(), method, requestHandler); } void ApiServer::addEndpoint(const String& uri, int method, std::function requestHandler, std::function uploadHandler) { - serviceRegistry.push_back(std::make_tuple(uri, method)); - if (ctx.memberList && !ctx.memberList->empty()) { - auto it = ctx.memberList->find(ctx.hostname); - if (it != ctx.memberList->end()) { - it->second.apiEndpoints.push_back(std::make_tuple(uri, method)); - } - } + registerServiceForLocalNode(uri, method); server.on(uri.c_str(), method, requestHandler, uploadHandler); } @@ -31,13 +40,7 @@ void ApiServer::addEndpoint(const String& uri, int method, std::function requestHandler, const std::vector& params) { capabilityRegistry.push_back(EndpointCapability{uri, method, params}); - serviceRegistry.push_back(std::make_tuple(uri, method)); - if (ctx.memberList && !ctx.memberList->empty()) { - auto it = ctx.memberList->find(ctx.hostname); - if (it != ctx.memberList->end()) { - it->second.apiEndpoints.push_back(std::make_tuple(uri, method)); - } - } + registerServiceForLocalNode(uri, method); server.on(uri.c_str(), method, requestHandler); } @@ -45,13 +48,7 @@ void ApiServer::addEndpoint(const String& uri, int method, std::function uploadHandler, const std::vector& params) { capabilityRegistry.push_back(EndpointCapability{uri, method, params}); - serviceRegistry.push_back(std::make_tuple(uri, method)); - if (ctx.memberList && !ctx.memberList->empty()) { - auto it = ctx.memberList->find(ctx.hostname); - if (it != ctx.memberList->end()) { - it->second.apiEndpoints.push_back(std::make_tuple(uri, method)); - } - } + registerServiceForLocalNode(uri, method); server.on(uri.c_str(), method, requestHandler, uploadHandler); } @@ -131,37 +128,15 @@ void ApiServer::onClusterMembersRequest(AsyncWebServerRequest *request) { request->send(200, "application/json", json); } -void ApiServer::methodToStr(const std::tuple &endpoint, ArduinoJson::V742PB22::JsonObject &apiObj) +void ApiServer::methodToStr(const std::tuple &endpoint, JsonObject &apiObj) { int method = std::get<1>(endpoint); - const char *methodStr = nullptr; - switch (method) - { - case HTTP_GET: - methodStr = "GET"; - break; - case HTTP_POST: - methodStr = "POST"; - break; - case HTTP_PUT: - methodStr = "PUT"; - break; - case HTTP_DELETE: - methodStr = "DELETE"; - break; - case HTTP_PATCH: - methodStr = "PATCH"; - break; - default: - methodStr = "UNKNOWN"; - break; - } - apiObj["method"] = methodStr; + apiObj["method"] = methodStrFromInt(method); } void ApiServer::onFirmwareUpdateRequest(AsyncWebServerRequest *request) { - bool hasError = !Update.hasError(); - AsyncWebServerResponse *response = request->beginResponse(200, "application/json", hasError ? "{\"status\": \"OK\"}" : "{\"status\": \"FAIL\"}"); + bool success = !Update.hasError(); + AsyncWebServerResponse *response = request->beginResponse(200, "application/json", success ? "{\"status\": \"OK\"}" : "{\"status\": \"FAIL\"}"); response->addHeader("Connection", "close"); request->send(response); request->onDisconnect([]() { @@ -183,8 +158,6 @@ void ApiServer::onFirmwareUpload(AsyncWebServerRequest *request, const String &f response->addHeader("Connection", "close"); request->send(response); return; - } else { - Update.printError(Serial); } } if (!Update.hasError()){ @@ -221,11 +194,15 @@ void ApiServer::onRestartRequest(AsyncWebServerRequest *request) { } void ApiServer::onTaskStatusRequest(AsyncWebServerRequest *request) { - JsonDocument doc; + // Use a separate document as scratch space for task statuses to avoid interfering with the response root + JsonDocument scratch; // Get comprehensive task status from TaskManager - auto taskStatuses = taskManager.getAllTaskStatuses(doc); + auto taskStatuses = taskManager.getAllTaskStatuses(scratch); + // Build response document + JsonDocument doc; + // Add summary information JsonObject summaryObj = doc["summary"].to(); summaryObj["totalTasks"] = taskStatuses.size(); @@ -344,16 +321,6 @@ void ApiServer::onCapabilitiesRequest(AsyncWebServerRequest *request) { auto makeKey = [](const String& uri, int method) { String k = uri; k += "|"; k += method; return k; }; - auto methodStrFromInt = [](int method) -> const char* { - switch (method) { - case HTTP_GET: return "GET"; - case HTTP_POST: return "POST"; - case HTTP_PUT: return "PUT"; - case HTTP_DELETE: return "DELETE"; - case HTTP_PATCH: return "PATCH"; - default: return "UNKNOWN"; - } - }; // Rich entries first for (const auto& cap : capabilityRegistry) { diff --git a/src/NodeContext.cpp b/src/NodeContext.cpp index 4d2f14b..77d4a3c 100644 --- a/src/NodeContext.cpp +++ b/src/NodeContext.cpp @@ -1,14 +1,12 @@ #include "NodeContext.h" NodeContext::NodeContext() { - scheduler = new Scheduler(); udp = new WiFiUDP(); memberList = new std::map(); hostname = ""; } NodeContext::~NodeContext() { - delete scheduler; delete udp; delete memberList; } diff --git a/src/NodeInfo.cpp b/src/NodeInfo.cpp index 3bfaab5..2aa8f88 100644 --- a/src/NodeInfo.cpp +++ b/src/NodeInfo.cpp @@ -14,7 +14,7 @@ void updateNodeStatus(NodeInfo &node, unsigned long now, unsigned long inactive_ if (diff < inactive_threshold) { node.status = NodeInfo::ACTIVE; } else if (diff < dead_threshold) { - node.status = NodeInfo::DEAD; + node.status = NodeInfo::INACTIVE; } else { node.status = NodeInfo::DEAD; } diff --git a/src/TaskManager.cpp b/src/TaskManager.cpp index fc5d64f..c211fd5 100644 --- a/src/TaskManager.cpp +++ b/src/TaskManager.cpp @@ -1,18 +1,9 @@ #include "TaskManager.h" #include -#include - -// Define static members -std::map> TaskManager::callbackRegistry; TaskManager::TaskManager(NodeContext& ctx) : ctx(ctx) {} TaskManager::~TaskManager() { - // Clean up tasks - for (auto task : tasks) { - delete task; - } - tasks.clear(); } void TaskManager::registerTask(const std::string& name, unsigned long interval, TaskFunction callback, bool enabled, bool autoStart) { @@ -25,44 +16,21 @@ void TaskManager::registerTask(const TaskDefinition& taskDef) { } void TaskManager::initialize() { - // Initialize the scheduler - ctx.scheduler->init(); - - // Create all registered tasks - for (const auto& taskDef : taskDefinitions) { - createTask(taskDef); - } + // Ensure timing vector matches number of tasks + lastExecutionTimes.assign(taskDefinitions.size(), 0UL); // Enable tasks that should auto-start - for (const auto& taskDef : taskDefinitions) { + for (auto& taskDef : taskDefinitions) { if (taskDef.autoStart && taskDef.enabled) { - enableTask(taskDef.name); + taskDef.enabled = true; } } } -void TaskManager::createTask(const TaskDefinition& taskDef) { - // Store the callback in the static registry - callbackRegistry[taskDef.name] = taskDef.callback; - - // Create a dummy task - we'll handle execution ourselves - Task* task = new Task(0, TASK_FOREVER, []() { /* Dummy callback */ }); - task->setInterval(taskDef.interval); - - // Add to scheduler - ctx.scheduler->addTask(*task); - - // Store task pointer - tasks.push_back(task); - - Serial.printf("[TaskManager] Created task: %s (interval: %lu ms)\n", - taskDef.name.c_str(), taskDef.interval); -} - void TaskManager::enableTask(const std::string& name) { - Task* task = findTask(name); - if (task) { - task->enable(); + int idx = findTaskIndex(name); + if (idx >= 0) { + taskDefinitions[idx].enabled = true; Serial.printf("[TaskManager] Enabled task: %s\n", name.c_str()); } else { Serial.printf("[TaskManager] Warning: Task not found: %s\n", name.c_str()); @@ -70,9 +38,9 @@ void TaskManager::enableTask(const std::string& name) { } void TaskManager::disableTask(const std::string& name) { - Task* task = findTask(name); - if (task) { - task->disable(); + int idx = findTaskIndex(name); + if (idx >= 0) { + taskDefinitions[idx].enabled = false; Serial.printf("[TaskManager] Disabled task: %s\n", name.c_str()); } else { Serial.printf("[TaskManager] Warning: Task not found: %s\n", name.c_str()); @@ -80,9 +48,9 @@ void TaskManager::disableTask(const std::string& name) { } void TaskManager::setTaskInterval(const std::string& name, unsigned long interval) { - Task* task = findTask(name); - if (task) { - task->setInterval(interval); + int idx = findTaskIndex(name); + if (idx >= 0) { + taskDefinitions[idx].interval = interval; Serial.printf("[TaskManager] Set interval for task %s: %lu ms\n", name.c_str(), interval); } else { Serial.printf("[TaskManager] Warning: Task not found: %s\n", name.c_str()); @@ -90,50 +58,38 @@ void TaskManager::setTaskInterval(const std::string& name, unsigned long interva } void TaskManager::startTask(const std::string& name) { - Task* task = findTask(name); - if (task) { - task->enable(); - Serial.printf("[TaskManager] Started task: %s\n", name.c_str()); - } else { - Serial.printf("[TaskManager] Warning: Task not found: %s\n", name.c_str()); - } + enableTask(name); } void TaskManager::stopTask(const std::string& name) { - Task* task = findTask(name); - if (task) { - task->disable(); - Serial.printf("[TaskManager] Stopped task: %s\n", name.c_str()); - } else { - Serial.printf("[TaskManager] Warning: Task not found: %s\n", name.c_str()); - } + disableTask(name); } bool TaskManager::isTaskEnabled(const std::string& name) const { - Task* task = findTask(name); - return task ? task->isEnabled() : false; + int idx = findTaskIndex(name); + return idx >= 0 ? taskDefinitions[idx].enabled : false; } bool TaskManager::isTaskRunning(const std::string& name) const { - Task* task = findTask(name); - return task ? task->isEnabled() : false; + int idx = findTaskIndex(name); + return idx >= 0 ? taskDefinitions[idx].enabled : false; } unsigned long TaskManager::getTaskInterval(const std::string& name) const { - Task* task = findTask(name); - return task ? task->getInterval() : 0; + int idx = findTaskIndex(name); + return idx >= 0 ? taskDefinitions[idx].interval : 0; } void TaskManager::enableAllTasks() { - for (auto task : tasks) { - task->enable(); + for (auto& taskDef : taskDefinitions) { + taskDef.enabled = true; } Serial.println("[TaskManager] Enabled all tasks"); } void TaskManager::disableAllTasks() { - for (auto task : tasks) { - task->disable(); + for (auto& taskDef : taskDefinitions) { + taskDef.enabled = false; } Serial.println("[TaskManager] Disabled all tasks"); } @@ -142,41 +98,32 @@ void TaskManager::printTaskStatus() const { Serial.println("\n[TaskManager] Task Status:"); Serial.println("=========================="); - for (size_t i = 0; i < tasks.size() && i < taskDefinitions.size(); ++i) { - const auto& taskDef = taskDefinitions[i]; - const auto& task = tasks[i]; - + for (const auto& taskDef : taskDefinitions) { Serial.printf(" %s: %s (interval: %lu ms)\n", taskDef.name.c_str(), - task->isEnabled() ? "ENABLED" : "DISABLED", - task->getInterval()); + taskDef.enabled ? "ENABLED" : "DISABLED", + taskDef.interval); } Serial.println("==========================\n"); } void TaskManager::execute() { - // Execute all enabled tasks by calling their stored callbacks - static unsigned long lastExecutionTimes[100] = {0}; // Simple array for timing - static int taskCount = 0; - - if (taskCount == 0) { - taskCount = tasks.size(); + // Ensure timing vector matches number of tasks + if (lastExecutionTimes.size() != taskDefinitions.size()) { + lastExecutionTimes.assign(taskDefinitions.size(), 0UL); } unsigned long currentTime = millis(); - for (size_t i = 0; i < tasks.size() && i < taskDefinitions.size(); ++i) { - Task* task = tasks[i]; - const std::string& taskName = taskDefinitions[i].name; + for (size_t i = 0; i < taskDefinitions.size(); ++i) { + auto& taskDef = taskDefinitions[i]; - if (task->isEnabled()) { + if (taskDef.enabled) { // Check if it's time to run this task - if (currentTime - lastExecutionTimes[i] >= task->getInterval()) { - // Execute the stored callback - if (callbackRegistry.find(taskName) != callbackRegistry.end()) { - callbackRegistry[taskName](); + if (currentTime - lastExecutionTimes[i] >= taskDef.interval) { + if (taskDef.callback) { + taskDef.callback(); } - // Update the last execution time lastExecutionTimes[i] = currentTime; } @@ -184,27 +131,26 @@ void TaskManager::execute() { } } -Task* TaskManager::findTask(const std::string& name) const { - for (size_t i = 0; i < tasks.size() && i < taskDefinitions.size(); ++i) { +int TaskManager::findTaskIndex(const std::string& name) const { + for (size_t i = 0; i < taskDefinitions.size(); ++i) { if (taskDefinitions[i].name == name) { - return tasks[i]; + return static_cast(i); } } - return nullptr; + return -1; } std::vector> TaskManager::getAllTaskStatuses(JsonDocument& doc) const { std::vector> taskStatuses; - for (size_t i = 0; i < tasks.size() && i < taskDefinitions.size(); ++i) { + for (size_t i = 0; i < taskDefinitions.size(); ++i) { const auto& taskDef = taskDefinitions[i]; - const auto& task = tasks[i]; JsonObject taskStatus = doc.add(); taskStatus["name"] = taskDef.name; - taskStatus["interval"] = task->getInterval(); - taskStatus["enabled"] = task->isEnabled(); - taskStatus["running"] = task->isEnabled(); // For now, enabled = running + taskStatus["interval"] = taskDef.interval; + taskStatus["enabled"] = taskDef.enabled; + taskStatus["running"] = taskDef.enabled; // For now, enabled = running taskStatus["autoStart"] = taskDef.autoStart; taskStatuses.push_back(std::make_pair(taskDef.name, taskStatus));