diff --git a/cpp_utils/GeneralUtils.cpp b/cpp_utils/GeneralUtils.cpp index 019c81bd..2f31900b 100644 --- a/cpp_utils/GeneralUtils.cpp +++ b/cpp_utils/GeneralUtils.cpp @@ -126,7 +126,7 @@ void GeneralUtils::dumpInfo() { /** * @brief Does the string end with a specific character? * @param [in] str The string to examine. - * @param [in] c The character to look form. + * @param [in] c The character to look for. * @return True if the string ends with the given character. */ bool GeneralUtils::endsWith(std::string str, char c) { @@ -139,6 +139,19 @@ bool GeneralUtils::endsWith(std::string str, char c) { return false; } // endsWidth +/** + * @brief Does the string end with a specific string? + * @param [in] str The string to examine. + * @param [in] suffix the string to look for. + * @return True if the string ends with the given suffix. + */ +bool GeneralUtils::endsWith(std::string str, std::string suffix) { + if (str.empty() || suffix.empty() || (suffix.size() > str.size())) { + return false; + } + return std::equal(str.begin() + str.size() - suffix.size(), str.end(), suffix.begin()); +} // endsWidth + static int DecodedLength(const std::string& in) { int numEq = 0; @@ -542,3 +555,14 @@ std::string GeneralUtils::trim(const std::string& str) { size_t last = str.find_last_not_of(' '); return str.substr(first, (last - first + 1)); } // trim + +/** + * @brief combine the base and path, include only one "/" */ +std::string GeneralUtils::makePath(std::string base, std::string path) { + if ( path[0] == '/') + { + return (base + path); + } + + return (base + "/" + path); +} // makePath diff --git a/cpp_utils/GeneralUtils.h b/cpp_utils/GeneralUtils.h index 8eecbd4d..81c927ec 100644 --- a/cpp_utils/GeneralUtils.h +++ b/cpp_utils/GeneralUtils.h @@ -22,6 +22,7 @@ class GeneralUtils { static bool base64Encode(const std::string& in, std::string* out); static void dumpInfo(); static bool endsWith(std::string str, char c); + static bool endsWith(std::string str, std::string); static const char* errorToString(esp_err_t errCode); static const char* wifiErrorToString(uint8_t value); static void hexDump(const uint8_t* pData, uint32_t length); @@ -29,7 +30,7 @@ class GeneralUtils { static std::vector split(std::string source, char delimiter); static std::string toLower(std::string& value); static std::string trim(const std::string& str); - + static std::string makePath(std::string base, std::string path); }; #endif /* COMPONENTS_CPP_UTILS_GENERALUTILS_H_ */ diff --git a/cpp_utils/HttpRequest.cpp b/cpp_utils/HttpRequest.cpp index ff9336fb..c352331d 100644 --- a/cpp_utils/HttpRequest.cpp +++ b/cpp_utils/HttpRequest.cpp @@ -39,7 +39,10 @@ #include "GeneralUtils.h" #include + +// for SDK 3.3 #include +//#include #define STATE_NAME 0 #define STATE_VALUE 1 @@ -53,6 +56,7 @@ const char HttpRequest::HTTP_HEADER_ALLOW[] = "Allow"; const char HttpRequest::HTTP_HEADER_CONNECTION[] = "Connection"; const char HttpRequest::HTTP_HEADER_CONTENT_LENGTH[] = "Content-Length"; const char HttpRequest::HTTP_HEADER_CONTENT_TYPE[] = "Content-Type"; +const char HttpRequest::HTTP_HEADER_CONTENT_ENCODING[] = "Content-Encoding"; const char HttpRequest::HTTP_HEADER_COOKIE[] = "Cookie"; const char HttpRequest::HTTP_HEADER_HOST[] = "Host"; const char HttpRequest::HTTP_HEADER_LAST_MODIFIED[] = "Last-Modified"; diff --git a/cpp_utils/HttpRequest.h b/cpp_utils/HttpRequest.h index 5984ed4f..9719e812 100644 --- a/cpp_utils/HttpRequest.h +++ b/cpp_utils/HttpRequest.h @@ -25,6 +25,7 @@ class HttpRequest { static const char HTTP_HEADER_CONNECTION[]; static const char HTTP_HEADER_CONTENT_LENGTH[]; static const char HTTP_HEADER_CONTENT_TYPE[]; + static const char HTTP_HEADER_CONTENT_ENCODING[]; static const char HTTP_HEADER_COOKIE[]; static const char HTTP_HEADER_HOST[]; static const char HTTP_HEADER_LAST_MODIFIED[]; diff --git a/cpp_utils/HttpResponse.cpp b/cpp_utils/HttpResponse.cpp index 46f3e9bd..193302f9 100644 --- a/cpp_utils/HttpResponse.cpp +++ b/cpp_utils/HttpResponse.cpp @@ -8,6 +8,7 @@ #include #include "HttpRequest.h" #include "HttpResponse.h" +#include "GeneralUtils.h" #include static const char* LOG_TAG = "HttpResponse"; @@ -121,19 +122,84 @@ void HttpResponse::sendData(uint8_t* pData, size_t size) { ESP_LOGD(LOG_TAG, "<< sendData"); } // sendData +class String : std::string +{ +public: + String(std::string initStr) : std::string(initStr) + { + } + + bool endsWith(const std::string& suffix) { + if (suffix.size() > size()) return false; + return std::equal(begin() + size() - suffix.size(), end(), suffix.begin()); + }; +}; + +std::string getContentType(std::string fileStr) { + // if (server.hasArg("download")) + // return "application/octet-stream"; + // else + + if (GeneralUtils::endsWith(fileStr, ".htm")) + return "text/html"; + else if (GeneralUtils::endsWith(fileStr, ".html")) + return "text/html"; + else if (GeneralUtils::endsWith(fileStr, ".css")) + return "text/css"; + else if (GeneralUtils::endsWith(fileStr, ".js")) + return "application/javascript"; + else if (GeneralUtils::endsWith(fileStr, ".png")) + return "image/png"; + else if (GeneralUtils::endsWith(fileStr, ".gif")) + return "image/gif"; + else if (GeneralUtils::endsWith(fileStr, ".jpg")) + return "image/jpeg"; + else if (GeneralUtils::endsWith(fileStr, ".ico")) + return "image/x-icon"; + else if (GeneralUtils::endsWith(fileStr, ".xml")) + return "text/xml"; + else if (GeneralUtils::endsWith(fileStr, ".pdf")) + return "application/x-pdf"; + else if (GeneralUtils::endsWith(fileStr, ".zip")) + return "application/x-zip"; + else if (GeneralUtils::endsWith(fileStr, ".gz")) + return "application/x-gzip"; + return "text/plain"; +} + void HttpResponse::sendFile(std::string fileName, size_t bufSize) { ESP_LOGI(LOG_TAG, "Opening file: %s", fileName.c_str()); + std::string contentType = getContentType(fileName); + ESP_LOGD(LOG_TAG, "content type: %s", contentType.c_str()); + std::string encodingStr = m_request->getHeader("Accept-Encoding"); + bool canGz = encodingStr.find("gzip") != encodingStr.npos; + ESP_LOGD(LOG_TAG, "canGz: %d",canGz); + std::ifstream ifStream; ifStream.open(fileName, std::ifstream::in | std::ifstream::binary); // Attempt to open the file for reading. // If we failed to open the requested file, then it probably didn't exist so return a not found. if (!ifStream.is_open()) { + if (canGz) + { + // try to find a compressed version of the file + std::string fileNameGz = fileName + ".gz"; + ifStream.open(fileNameGz, std::ifstream::in | std::ifstream::binary); // Attempt to open the file for reading. + } + + if (!ifStream.is_open()) { ESP_LOGE(LOG_TAG, "Unable to open file %s for reading", fileName.c_str()); setStatus(HttpResponse::HTTP_STATUS_NOT_FOUND, "Not Found"); addHeader(HttpRequest::HTTP_HEADER_CONTENT_TYPE, "text/plain"); sendData("Not Found"); close(); return; // Since we failed to open the file, no further work to be done. + } + + // the type should now be compressed + + // contentType = std::string("application/x-gzip"); + addHeader(HttpRequest::HTTP_HEADER_CONTENT_ENCODING, "gzip"); } // We now have an open file and want to push the content of that file through to the browser. @@ -141,6 +207,8 @@ void HttpResponse::sendFile(std::string fileName, size_t bufSize) { // RAM at one time. Instead what we have to do is ensure that we only have enough data in RAM to be sent. setStatus(HttpResponse::HTTP_STATUS_OK, "OK"); + addHeader(HttpRequest::HTTP_HEADER_CONTENT_TYPE, contentType); + uint8_t *pData = new uint8_t[bufSize]; while (!ifStream.eof()) { ifStream.read((char*) pData, bufSize); diff --git a/cpp_utils/HttpServer.cpp b/cpp_utils/HttpServer.cpp index fb90fcc9..d55a496d 100644 --- a/cpp_utils/HttpServer.cpp +++ b/cpp_utils/HttpServer.cpp @@ -472,14 +472,15 @@ PathHandler::PathHandler(std::string method, std::string matchPath, bool PathHandler::match(std::string method, std::string path) { if (method != m_method) return false; if (m_isRegex) { - ESP_LOGD("PathHandler", "regex matching: %s with %s", m_textPattern.c_str(), path.c_str()); + ESP_LOGI("PathHandler", "regex matching: %s with %s", m_textPattern.c_str(), path.c_str()); return std::regex_search(path, *m_pRegex); } - ESP_LOGD("PathHandler", "plain matching: %s with %s", m_textPattern.c_str(), path.c_str()); + ESP_LOGI("PathHandler", "plain matching: %s with %s", m_textPattern.c_str(), path.c_str()); return m_textPattern.compare(0, m_textPattern.length(), path) ==0; } // match + /** * @brief Invoke the handler. * @param [in] request An object representing the request. diff --git a/cpp_utils/SPI.cpp b/cpp_utils/SPI.cpp index 1312e706..45128316 100644 --- a/cpp_utils/SPI.cpp +++ b/cpp_utils/SPI.cpp @@ -54,7 +54,7 @@ void SPI::init(int mosiPin, int misoPin, int clkPin, int csPin) { bus_config.quadwp_io_num = -1; // Not used bus_config.quadhd_io_num = -1; // Not used bus_config.max_transfer_sz = 0; // 0 means use default. - bus_config.flags = (SPICOMMON_BUSFLAG_SCLK | SPICOMMON_BUSFLAG_MOSI | SPICOMMON_BUSFLAG_MISO); + // bus_config.flags = (SPICOMMON_BUSFLAG_SCLK | SPICOMMON_BUSFLAG_MOSI | SPICOMMON_BUSFLAG_MISO); ESP_LOGI(LOG_TAG, "... Initializing bus; host=%d", m_host); @@ -79,7 +79,7 @@ void SPI::init(int mosiPin, int misoPin, int clkPin, int csPin) { dev_config.cs_ena_pretrans = 0; dev_config.clock_speed_hz = 100000; dev_config.spics_io_num = csPin; - dev_config.flags = SPI_DEVICE_NO_DUMMY; + // dev_config.flags = SPI_DEVICE_NO_DUMMY; dev_config.queue_size = 1; dev_config.pre_cb = NULL; dev_config.post_cb = NULL; diff --git a/cpp_utils/SockServ.cpp b/cpp_utils/SockServ.cpp index 020ed096..756f2e92 100644 --- a/cpp_utils/SockServ.cpp +++ b/cpp_utils/SockServ.cpp @@ -70,7 +70,7 @@ SockServ::~SockServ() { pSockServ->m_clientSet.insert(tempSock); xQueueSendToBack(pSockServ->m_acceptQueue, &tempSock, portMAX_DELAY); pSockServ->m_clientSemaphore.give(); - } catch (std::exception e) { + } catch (const std::exception &e) { ESP_LOGD(LOG_TAG, "acceptTask ending"); pSockServ->m_clientSemaphore.give(); // Wake up any waiting clients. FreeRTOS::deleteTask(); diff --git a/cpp_utils/WebServer.cpp b/cpp_utils/WebServer.cpp index 0fc666af..7ab27ae6 100644 --- a/cpp_utils/WebServer.cpp +++ b/cpp_utils/WebServer.cpp @@ -15,8 +15,51 @@ #define MG_ENABLE_FILESYSTEM 1 #include "WebServer.h" #include -#include +#include "mongoose.h" #include +#include +//#include + +#include + +const char WebServer::HTTPRequest::HTTP_HEADER_ACCEPT[] = "Accept"; +const char WebServer::HTTPRequest::HTTP_HEADER_ALLOW[] = "Allow"; +const char WebServer::HTTPRequest::HTTP_HEADER_CONNECTION[] = "Connection"; +const char WebServer::HTTPRequest::HTTP_HEADER_CONTENT_LENGTH[] = "Content-Length"; +const char WebServer::HTTPRequest::HTTP_HEADER_CONTENT_TYPE[] = "Content-Type"; +const char WebServer::HTTPRequest::HTTP_HEADER_CONTENT_ENCODING[] = "Content-Encoding"; +const char WebServer::HTTPRequest::HTTP_HEADER_COOKIE[] = "Cookie"; +const char WebServer::HTTPRequest::HTTP_HEADER_HOST[] = "Host"; +const char WebServer::HTTPRequest::HTTP_HEADER_LAST_MODIFIED[] = "Last-Modified"; +const char WebServer::HTTPRequest::HTTP_HEADER_ORIGIN[] = "Origin"; +const char WebServer::HTTPRequest::HTTP_HEADER_SEC_WEBSOCKET_ACCEPT[] = "Sec-WebSocket-Accept"; +const char WebServer::HTTPRequest::HTTP_HEADER_SEC_WEBSOCKET_PROTOCOL[] = "Sec-WebSocket-Protocol"; +const char WebServer::HTTPRequest::HTTP_HEADER_SEC_WEBSOCKET_KEY[] = "Sec-WebSocket-Key"; +const char WebServer::HTTPRequest::HTTP_HEADER_SEC_WEBSOCKET_VERSION[] = "Sec-WebSocket-Version"; +const char WebServer::HTTPRequest::HTTP_HEADER_UPGRADE[] = "Upgrade"; +const char WebServer::HTTPRequest::HTTP_HEADER_USER_AGENT[] = "User-Agent"; + +const char WebServer::HTTPRequest::HTTP_METHOD_CONNECT[] = "CONNECT"; +const char WebServer::HTTPRequest::HTTP_METHOD_DELETE[] = "DELETE"; +const char WebServer::HTTPRequest::HTTP_METHOD_GET[] = "GET"; +const char WebServer::HTTPRequest::HTTP_METHOD_HEAD[] = "HEAD"; +const char WebServer::HTTPRequest::HTTP_METHOD_OPTIONS[] = "OPTIONS"; +const char WebServer::HTTPRequest::HTTP_METHOD_PATCH[] = "PATCH"; +const char WebServer::HTTPRequest::HTTP_METHOD_POST[] = "POST"; +const char WebServer::HTTPRequest::HTTP_METHOD_PUT[] = "PUT"; + +const int WebServer::HTTPResponse::HTTP_STATUS_CONTINUE = 100; +const int WebServer::HTTPResponse::HTTP_STATUS_SWITCHING_PROTOCOL = 101; +const int WebServer::HTTPResponse::HTTP_STATUS_OK = 200; +const int WebServer::HTTPResponse::HTTP_STATUS_MOVED_PERMANENTLY = 301; +const int WebServer::HTTPResponse::HTTP_STATUS_BAD_REQUEST = 400; +const int WebServer::HTTPResponse::HTTP_STATUS_UNAUTHORIZED = 401; +const int WebServer::HTTPResponse::HTTP_STATUS_FORBIDDEN = 403; +const int WebServer::HTTPResponse::HTTP_STATUS_NOT_FOUND = 404; +const int WebServer::HTTPResponse::HTTP_STATUS_METHOD_NOT_ALLOWED = 405; +const int WebServer::HTTPResponse::HTTP_STATUS_INTERNAL_SERVER_ERROR = 500; +const int WebServer::HTTPResponse::HTTP_STATUS_NOT_IMPLEMENTED = 501; +const int WebServer::HTTPResponse::HTTP_STATUS_SERVICE_UNAVAILABLE = 503; #define STATE_NAME 0 #define STATE_VALUE 1 @@ -31,6 +74,35 @@ struct WebServerUserData { }; +static std::string getContentType(std::string filename) +{ + if (GeneralUtils::endsWith(filename,".htm")) + return "text/html; charset=UTF-8"; + else if (GeneralUtils::endsWith(filename,".html")) + return "text/html; charset=UTF-8"; + else if (GeneralUtils::endsWith(filename,".css")) + return "text/css; charset=UTF-8"; + else if (GeneralUtils::endsWith(filename,".js")) + return "application/javascript; charset=UTF-8"; + else if (GeneralUtils::endsWith(filename,".png")) + return "image/png"; + else if (GeneralUtils::endsWith(filename,".gif")) + return "image/gif"; + else if (GeneralUtils::endsWith(filename,".jpg")) + return "image/jpeg"; + else if (GeneralUtils::endsWith(filename,".ico")) + return "image/x-icon"; + else if (GeneralUtils::endsWith(filename,".xml")) + return "text/xml; charset=UTF-8"; + else if (GeneralUtils::endsWith(filename,".pdf")) + return "application/x-pdf"; + else if (GeneralUtils::endsWith(filename,".zip")) + return "application/x-zip"; + else if (GeneralUtils::endsWith(filename,".gz")) + return "application/x-gzip"; + return "text/plain; charset=UTF-8"; +} + /** * @brief Convert a Mongoose event type to a string. * @param [in] event The received event type. @@ -157,34 +229,38 @@ static void mongoose_event_handler_web_server(struct mg_connection* mgConnection } // MG_EV_HTTP_REQUEST case MG_EV_HTTP_MULTIPART_REQUEST: { - struct WebServerUserData *pWebServerUserData = (struct WebServerUserData*) mgConnection->user_data; - ESP_LOGD(LOG_TAG, "User_data address 0x%d", (uint32_t) pWebServerUserData); + struct http_message* message = (struct http_message*) eventData; + ESP_LOGD(LOG_TAG, "Multipart Request user data %p", mgConnection->user_data); +// dumpHttpMessage(message); + struct WebServerUserData* pWebServerUserData = (struct WebServerUserData*) mgConnection->user_data; WebServer* pWebServer = pWebServerUserData->pWebServer; - if (pWebServer->m_pMultiPartFactory == nullptr) return; - WebServer::HTTPMultiPart* pMultiPart = pWebServer->m_pMultiPartFactory->newInstance(); - struct WebServerUserData* p2 = new WebServerUserData(); - ESP_LOGD(LOG_TAG, "New User_data address 0x%d", (uint32_t) p2); - p2->originalUserData = pWebServerUserData; - p2->pWebServer = pWebServerUserData->pWebServer; - p2->pMultiPart = pMultiPart; - p2->pWebSocketHandler = nullptr; - mgConnection->user_data = p2; - //struct http_message* message = (struct http_message*) eventData; - //dumpHttpMessage(message); + pWebServer->processMultiRequest(mgConnection, message); break; } // MG_EV_HTTP_MULTIPART_REQUEST case MG_EV_HTTP_MULTIPART_REQUEST_END: { struct WebServerUserData* pWebServerUserData = (struct WebServerUserData*) mgConnection->user_data; + ESP_LOGD(LOG_TAG, "Multipart Request End orig user data %p user data %p", pWebServerUserData->originalUserData, + mgConnection->user_data); if (pWebServerUserData->pMultiPart != nullptr) { + // get the status of this upload + int status = pWebServerUserData->pMultiPart->getStatus(); delete pWebServerUserData->pMultiPart; pWebServerUserData->pMultiPart = nullptr; + mgConnection->user_data = pWebServerUserData->originalUserData; + delete pWebServerUserData; + WebServer::HTTPResponse httpResponse = WebServer::HTTPResponse(mgConnection); + httpResponse.setStatus(status); + ESP_LOGD(LOG_TAG, "MultiPart status %d", status); + if (status == 200) + { + httpResponse.sendData("File uploaded OK"); + } + else + { + httpResponse.sendData("File upload FAILED"); + } } - mgConnection->user_data = pWebServerUserData->originalUserData; - delete pWebServerUserData; - WebServer::HTTPResponse httpResponse = WebServer::HTTPResponse(mgConnection); - httpResponse.setStatus(200); - httpResponse.sendData(""); break; } // MG_EV_HTTP_MULTIPART_REQUEST_END @@ -231,8 +307,10 @@ static void mongoose_event_handler_web_server(struct mg_connection* mgConnection struct WebServerUserData* p2 = new WebServerUserData(); ESP_LOGD(LOG_TAG, "New User_data address 0x%d", (uint32_t) p2); p2->originalUserData = pWebServerUserData; - p2->pWebServer = pWebServerUserData->pWebServer; + p2->pWebServer = pWebServerUserData->pWebServer; p2->pWebSocketHandler = pWebServer->m_pWebSocketHandlerFactory->newInstance(); + // the WebSocketHandler needs the connection + p2->pWebSocketHandler->setConnection(mgConnection); mgConnection->user_data = p2; } else { ESP_LOGD(LOG_TAG, "We received a WebSocket request but we have no handler factory!"); @@ -280,11 +358,6 @@ WebServer::WebServer() { m_pWebSocketHandlerFactory = nullptr; } // WebServer - -WebServer::~WebServer() { -} - - /** * @brief Get the current root path. * @return The current root path. @@ -336,16 +409,16 @@ void WebServer::addPathHandler(std::string&& method, const std::string& pathExpr */ void WebServer::start(uint16_t port) { ESP_LOGD(LOG_TAG, "WebServer task starting"); - struct mg_mgr mgr; - mg_mgr_init(&mgr, NULL); + mg_mgr_init(&m_mgr, NULL); + TaskHandle_t self = 0; std::stringstream stringStream; stringStream << ':' << port; - struct mg_connection *mgConnection = mg_bind(&mgr, stringStream.str().c_str(), mongoose_event_handler_web_server); + struct mg_connection *mgConnection = mg_bind(&m_mgr, stringStream.str().c_str(), mongoose_event_handler_web_server); if (mgConnection == NULL) { ESP_LOGE(LOG_TAG, "No connection from the mg_bind()"); - vTaskDelete(NULL); + vTaskDelete(self); return; } @@ -358,7 +431,7 @@ void WebServer::start(uint16_t port) { ESP_LOGD(LOG_TAG, "WebServer listening on port %d", port); while (true) { - mg_mgr_poll(&mgr, 2000); + mg_mgr_poll(&m_mgr, 2000); } } // run @@ -409,6 +482,44 @@ void WebServer::setWebSocketHandlerFactory(WebSocketHandlerFactory* pWebSocketHa m_pWebSocketHandlerFactory = pWebSocketHandlerFactory; } // setWebSocketHandlerFactory +/** + * @brief Send data down all the WebSocket(s) + * @param [in] message The message to send down the socket. + * @return N/A. + */ +void WebServer::broadcastData(const std::string& message) { + ESP_LOGD(LOG_TAG, "broadcastData(length=%d)", message.length()); + struct mg_connection *conn; + for (conn = mg_next(&m_mgr, NULL); conn != NULL; conn = mg_next(&m_mgr, conn)) + { + if (conn->flags & MG_F_IS_WEBSOCKET) + { + mg_send_websocket_frame(conn, + WEBSOCKET_OP_TEXT | WEBSOCKET_OP_CONTINUE, + message.data(), message.length()); + } + } +} // broadcastData + + +/** + * @brief Send data down all the WebSocket(s) + * @param [in] data The message to send down the socket. + * @param [in] size The size of the message + * @return N/A. + */ +void WebServer::broadcastData(const uint8_t* data, uint32_t size) { + struct mg_connection *conn; + for (conn = mg_next(&m_mgr, NULL); conn != NULL; conn = mg_next(&m_mgr, conn)) + { + if (conn->flags & MG_F_IS_WEBSOCKET) + { + mg_send_websocket_frame(conn, + WEBSOCKET_OP_BINARY | WEBSOCKET_OP_CONTINUE, + data, size); + } + } +} // broadcastData /** * @brief Constructor. @@ -447,21 +558,24 @@ std::string WebServer::HTTPResponse::buildHeaders() { if (iter != m_headers.begin()) { headers_len += 2; } + headers_len += iter->first.length(); headers_len += 2; headers_len += iter->second.length(); } headers_len += 1; - headers.resize(headers_len); // Will not have to resize and recopy during the next loop, we have 2 loops but it still ends up being faster +// headers.resize(headers_len); // Will not have to resize and recopy during the next loop, we have 2 loops but it still ends up being faster for (auto iter = m_headers.begin(); iter != m_headers.end(); iter++) { if (iter != m_headers.begin()) { headers += "\r\n"; } + headers += iter->first; headers += ": "; headers += iter->second; } + return headers; } // buildHeaders @@ -598,12 +712,17 @@ void WebServer::processRequest(struct mg_connection* mgConnection, struct http_m ESP_LOGD(LOG_TAG, "WebServer::processRequest: Matching: %.*s", (int) message->uri.len, message->uri.p); HTTPResponse httpResponse = HTTPResponse(mgConnection); + // make all rootPaths match + httpResponse.setRootPath(getRootPath()); + + ESP_LOGD(LOG_TAG, "Root = %s", httpResponse.getRootPath().c_str()); /* * Iterate through each of the path handlers looking for a match with the method and specified path. */ std::vector::iterator it; for (it = m_pathHandlers.begin(); it != m_pathHandlers.end(); ++it) { - if ((*it).match(message->method.p, message->method.len, message->uri.p)) { + if ((*it).match(message->method.p, message->method.len, message->uri.p, message->uri.len)) { + ESP_LOGD(LOG_TAG, "matched? method %.*s uri %.*s", (int) message->method.len, message->method.p, (int) message->uri.len, message->uri.p); HTTPRequest httpRequest(message); (*it).invoke(&httpRequest, &httpResponse); ESP_LOGD(LOG_TAG, "Found a match!!"); @@ -617,13 +736,52 @@ void WebServer::processRequest(struct mg_connection* mgConnection, struct http_m filePath.reserve(httpResponse.getRootPath().length() + message->uri.len + 1); filePath += httpResponse.getRootPath(); filePath.append(message->uri.p, message->uri.len); + ESP_LOGD(LOG_TAG, "Opening file: %s", filePath.c_str()); FILE* file = nullptr; if (strcmp(filePath.c_str(), "/") != 0) { - file = fopen(filePath.c_str(), "rb"); + // only try this if it is NOT a .gz file request + + if (!GeneralUtils::endsWith(filePath, "gz")) { + // if they are willing to handle a gzipped file we should serve one, assuming we + // have it + + // test to see if they are willing + HTTPRequest myHttpRequest(message); + std::string encoding = myHttpRequest.getHeader("Accept-Encoding"); + + if (encoding.find("gzip") != std::string::npos) { + // they seem willing to accept gzipped files, see if we have one + std::string gzFilePath(filePath + ".gz"); + file = fopen(gzFilePath.c_str(), "rb"); + + if (file != nullptr) { + ESP_LOGD(LOG_TAG, "Found %s file", gzFilePath.c_str()); + // add some needed headers + httpResponse.addHeader("Content-Encoding", "gzip"); + } + else + { + ESP_LOGD(LOG_TAG, "We did not open a compressed file %s (%s)", gzFilePath.c_str(), strerror(errno)); + } + } + else + { + ESP_LOGI(LOG_TAG, "They are not willing to accept gz files"); + } + } + + if ( file == nullptr ){ + file = fopen(filePath.c_str(), "rb"); + } } + if (file != nullptr) { + // we should be setting the type in the response + + httpResponse.addHeader("Content-Type", getContentType(filePath)); + auto pData = (uint8_t*)malloc(MAX_CHUNK_LENGTH); size_t read = fread(pData, 1, MAX_CHUNK_LENGTH, file); @@ -638,12 +796,80 @@ void WebServer::processRequest(struct mg_connection* mgConnection, struct http_m } free(pData); } else { + ESP_LOGD(LOG_TAG, "We did not open a file %s (%s)", filePath.c_str(), strerror(errno)); // Handle unable to open file httpResponse.setStatus(404); // Not found httpResponse.sendData(""); } } // processRequest +/** + * @brief Process an incoming HTTP multipart request. + * + * We look at the path of the request and see if it has a matching path handler. If it does, + * we invoke the proper multipart factory. If it does not, we fail with 404 + * + * @param [in] mgConnection The network connection on which the request was received. + * @param [in] message The message representing the request. + */ +void WebServer::processMultiRequest(struct mg_connection* mgConnection, struct http_message* message) { + ESP_LOGD(LOG_TAG, "WebServer::processMultiRequest: Matching: %.*s", (int) message->uri.len, message->uri.p); + HTTPResponse httpResponse = HTTPResponse(mgConnection); + + // make all rootPaths match + httpResponse.setRootPath(getRootPath()); + + ESP_LOGD(LOG_TAG, "Root = %s", httpResponse.getRootPath().c_str()); + + // are they looking for a POST method? + + std::string method; + method.append(message->method.p, message->method.len); + + if (method == "POST" || method == "DELETE") + { + std::string matchMethod = method + "_MULTI"; + + /* + * Iterate through each of the path handlers looking for a match with the method and specified path. + */ + std::vector::iterator it; + for (it = m_pathHandlers.begin(); it != m_pathHandlers.end(); ++it) { + if ((*it).match(matchMethod.c_str(), matchMethod.length(), message->uri.p, message->uri.len)) { + ESP_LOGD(LOG_TAG, "Found a match!!"); + + struct WebServerUserData *pWebServerUserData = (struct WebServerUserData*) mgConnection->user_data; + ESP_LOGD(LOG_TAG, "User_data address 0x%d", (uint32_t) pWebServerUserData); + WebServer* pWebServer = pWebServerUserData->pWebServer; + + if (pWebServer->m_pMultiPartFactory == nullptr) + break; + + WebServer::HTTPMultiPart* pMultiPart = pWebServer->m_pMultiPartFactory->newInstance(); + std::string uri; + uri.append(message->uri.p, message->uri.len); + pMultiPart->setUri(uri); + pMultiPart->setMethod(method); + struct WebServerUserData* p2 = new WebServerUserData(); + + ESP_LOGD(LOG_TAG, "New User_data address 0x%d", (uint32_t) p2); + + p2->originalUserData = pWebServerUserData; + p2->pWebServer = pWebServerUserData->pWebServer; + p2->pMultiPart = pMultiPart; + p2->pWebSocketHandler = nullptr; + mgConnection->user_data = p2; + return; + } + } // End of examine path handlers. + } + + // Handle unable to handle multipart post + httpResponse.setStatus(404); // Not found + httpResponse.sendData("No Multipart handler"); + +} // processMultiRequest + void WebServer::continueConnection(struct mg_connection* mgConnection) { if (unfinishedConnection.count(mgConnection->sock) == 0) return; @@ -692,12 +918,28 @@ WebServer::PathHandler::PathHandler(std::string&& method, const std::string& pat * @param [in] path The path to be matched. * @return True if the path matches. */ -bool WebServer::PathHandler::match(const char* method, size_t method_len, const char* path) { +bool WebServer::PathHandler::match(const char* method, size_t method_len, const char* path, size_t path_len) { //ESP_LOGD(LOG_TAG, "match: %s with %s", m_pattern.c_str(), path.c_str()); if (method_len != m_method.length() || strncmp(method, m_method.c_str(), method_len) != 0) { return false; } - return std::regex_search(path, m_pattern); + + char *temp = (char *) calloc(1, path_len); + + strncpy(temp, path, path_len); + + bool ret; + + ret = std::regex_search(temp, m_pattern); + + free(temp); + + if (ret) + { + ESP_LOGD(LOG_TAG, "path [%.*s]", (int) path_len, path); + } + + return ret; } // match @@ -720,6 +962,31 @@ void WebServer::PathHandler::invoke(WebServer::HTTPRequest* request, WebServer:: */ WebServer::HTTPRequest::HTTPRequest(struct http_message* message) { m_message = message; + + // loop through the possible headers (there are only 20) + + for (int x = 0; x < MG_MAX_HTTP_HEADERS; x++) + { + std::string headerName; + std::string headerValue; + + if (m_message->header_names[x].len <= 0) + { + break; // we are done + } + + // get a string of the name + headerName.clear(); + headerName.append(m_message->header_names[x].p, m_message->header_names[x].len); + + // get a string of the value + + headerValue.clear(); + headerValue.append(m_message->header_values[x].p, m_message->header_values[x].len); + + // now add the entry to the map + m_headers.insert(HeaderPair(headerName, headerValue)); + } } // HTTPRequest @@ -876,6 +1143,89 @@ std::vector WebServer::HTTPRequest::pathSplit() const { return ret; } // pathSplit +/** + * @brief Return a std::map of the headers. + * + * @return A map of header strings and their values. + */ + +HeaderMap WebServer::HTTPRequest::getHeaders(void) const { + return m_headers; +} //getHeaders + +/** + * @brief Return a std::map of the headers. + * + * @return A map of header strings and their values. + */ +std::string WebServer::HTTPRequest::getHeader(std::string headerName) const { + // get all the headers in a map + + HeaderMap::const_iterator headerIt = m_headers.find(headerName); + + if (headerIt != m_headers.end()) + { + return headerIt->second; + } + + return std::string(); +} + +/** + * @brief Parse the request message as a form. + * @return A map containing the names/values of the form elements that were found. + */ +std::map WebServer::HTTPRequest::parseForm() { + ESP_LOGD(LOG_TAG, ">> parseForm"); + std::map map; + // A form is composed of name=value pairs where each pair is separated with an "&" character. + // Our algorithm is to split all the pairs by "&" and then split all the name/value pairs by "=". + std::istringstream ss(getBody()); // Get the body of the request. + std::string currentEntry; + while (std::getline(ss, currentEntry, '&')) { // For each form entry. + ESP_LOGD(LOG_TAG, "Processing: %s", currentEntry.c_str()); // Debug + std::istringstream currentPair(currentEntry); // Prepare to parse the name=value string. + std::string name; // Declare the name variable. + std::string value; // Declare the value variable. + + std::getline(currentPair, name, '='); // Parse the current form entry into name/value. + currentPair >> value; // The value is what remains. + map[name] = urlDecode(value); // Decode the field which may have been encoded. + ESP_LOGD(LOG_TAG, " %s = \"%s\"", name.c_str(), map[name].c_str()); // Debug + // Add the form entry into the map. + } // Processed all form entries. + ESP_LOGD(LOG_TAG, "<< parseForm"); // Debug + return map; // Return the map of form entries. +} // parseForm + +/** + * @brief Decode a URL/form + * @param [in] str + * @return The decoded string. + */ +std::string WebServer::HTTPRequest::urlDecode(std::string str) { + // https://stackoverflow.com/questions/154536/encode-decode-urls-in-c + std::string ret; + char ch; + int ii, len = str.length(); + + for (int i = 0; i < len; i++) { + if (str[i] != '%'){ + if (str[i] == '+') { + ret += ' '; + } else { + ret += str[i]; + } + } else { + sscanf(str.substr(i + 1, 2).c_str(), "%x", &ii); + ch = static_cast(ii); + ret += ch; + i = i + 2; + } + } + return ret; +} // urlDecode + /** * @brief Indicate the beginning of a multipart part. * An HTTP Multipart form is where each of the fields in the form are broken out into distinct @@ -929,6 +1279,113 @@ void WebServer::HTTPMultiPart::multipartStart() { ESP_LOGD(LOG_TAG, "WebServer::HTTPMultiPart::multipartStart()"); } // WebServer::HTTPMultiPart::multipartStart +void WebServer::HTTPMultiPart::setUri(std::string uri) { + ESP_LOGD(LOG_TAG, "WebServer::HTTPMultiPart::setUri(%s)", uri.c_str()); +} // WebServer::HTTPMultiPart::setUri + +std::string WebServer::HTTPMultiPart::getUri() { + ESP_LOGD(LOG_TAG, "WebServer::HTTPMultiPart::getUri()"); + return "UNKNOWN"; +} // WebServer::HTTPMultiPart::geturi + +void WebServer::HTTPMultiPart::setMethod(std::string method) { + ESP_LOGD(LOG_TAG, "WebServer::HTTPMultiPart::setMethod(%s0)", method.c_str()); +} // WebServer::HTTPMultiPart::setMethod + +std::string WebServer::HTTPMultiPart::getMethod() { + ESP_LOGD(LOG_TAG, "WebServer::HTTPMultiPart::getMethod()"); + return "UNKNOWN"; +} // WebServer::HTTPMultiPart::getMethod + +int WebServer::HTTPMultiPart::getStatus() { + ESP_LOGD(LOG_TAG, "WebServer::HTTPMultiPart::getStatus()"); + return 404; +} // WebServer::HTTPMultiPart::getStatus + +class MyMultiPart : public WebServer::HTTPMultiPart { +public: + void begin(std::string varName, std::string fileName) { + ESP_LOGD(LOG_TAG, "MyMultiPart begin(): varName=%s, fileName=%s", + varName.c_str(), fileName.c_str()); + m_currentVar = varName; + if (varName == "path") { + m_path = ""; + } else if (varName == "myfile") { + m_fileData = ""; + m_fileName = fileName; + } + } // begin + + void end() { + ESP_LOGD(LOG_TAG, "MyMultiPart end()"); + if (m_currentVar == "myfile") { + std::string fileName = m_path + "/" + m_fileName; + ESP_LOGD(LOG_TAG, "Write to file: %s ... data: %s", fileName.c_str(), m_fileData.c_str()); + + //std::ofstream myfile; + //myfile.open(fileName, std::ios::out | std::ios::binary | std::ios::trunc); + //myfile << m_fileData; + //myfile.close(); + + FILE *ffile = fopen(fileName.c_str(), "w"); + fwrite(m_fileData.data(), m_fileData.length(), 1, ffile); + fclose(ffile); + } + } // end + + void data(std::string data) { + ESP_LOGD(LOG_TAG, "MyMultiPart data(): length=%d", data.length()); + if (m_currentVar == "path") { + m_path += data; + } + else if (m_currentVar == "myfile") { + m_fileData += data; + } + } // data + + void multipartEnd() { + ESP_LOGD(LOG_TAG, "MyMultiPart multipartEnd()"); + } // multipartEnd + + void multipartStart() { + ESP_LOGD(LOG_TAG, "MyMultiPart multipartStart()"); + } // multipartStart + + void setUri(std::string uri) { + m_uri = uri; + } + + std::string getUri(void) { + return m_uri; + } + + void setMethod(std::string method) { + m_method = method; + } + + std::string getMethod(void) { + return m_method; + } + + int getStatus(void) { + return m_status; + } + +private: + std::string m_fileName; + std::string m_path; + std::string m_currentVar; + std::string m_fileData; + std::string m_method; + std::string m_uri; + int m_status; +}; + +class MyMultiPartFactory : public WebServer::HTTPMultiPartFactory { + WebServer::HTTPMultiPart* newInstance() { + return new MyMultiPart(); + } +}; /** * @brief Indicate that a new WebSocket instance has been created. @@ -962,7 +1419,7 @@ void WebServer::WebSocketHandler::onClosed() { void WebServer::WebSocketHandler::sendData(const std::string& message) { ESP_LOGD(LOG_TAG, "WebSocketHandler::sendData(length=%d)", message.length()); mg_send_websocket_frame(m_mgConnection, - WEBSOCKET_OP_BINARY | WEBSOCKET_OP_CONTINUE, + WEBSOCKET_OP_TEXT | WEBSOCKET_OP_CONTINUE, message.data(), message.length()); } // sendData @@ -979,6 +1436,44 @@ void WebServer::WebSocketHandler::sendData(const uint8_t* data, uint32_t size) { data, size); } // sendData +/** + * @brief Send data down all the WebSocket(s) + * @param [in] message The message to send down the socket. + * @return N/A. + */ +void WebServer::WebSocketHandler::broadcastData(const std::string& message) { + ESP_LOGD(LOG_TAG, "WebSocketHandler::sendData(length=%d)", message.length()); + struct mg_connection *conn; + for (conn = mg_next(m_mgConnection->mgr, NULL); conn != NULL; conn = mg_next(m_mgConnection->mgr, conn)) + { + if (conn->flags & MG_F_IS_WEBSOCKET) + { + mg_send_websocket_frame(conn, + WEBSOCKET_OP_TEXT | WEBSOCKET_OP_CONTINUE, + message.data(), message.length()); + } + } +} // broadcastData + + +/** + * @brief Send data down all the WebSocket(s) + * @param [in] data The message to send down the socket. + * @param [in] size The size of the message + * @return N/A. + */ +void WebServer::WebSocketHandler::broadcastData(const uint8_t* data, uint32_t size) { + struct mg_connection *conn; + for (conn = mg_next(m_mgConnection->mgr, NULL); conn != NULL; conn = mg_next(m_mgConnection->mgr, conn)) + { + if (conn->flags & MG_F_IS_WEBSOCKET) + { + mg_send_websocket_frame(conn, + WEBSOCKET_OP_BINARY | WEBSOCKET_OP_CONTINUE, + data, size); + } + } +} // broadcastData /** * @brief Close the WebSocket from the web server end. diff --git a/cpp_utils/WebServer.h b/cpp_utils/WebServer.h index 961b6662..9c6ee537 100644 --- a/cpp_utils/WebServer.h +++ b/cpp_utils/WebServer.h @@ -11,13 +11,17 @@ #include #include #include +#include "GeneralUtils.h" #include "sdkconfig.h" #ifdef CONFIG_MONGOOSE_PRESENT -#include +#include "mongoose.h" #define MAX_CHUNK_LENGTH 4090 // 4 kilobytes +typedef std::map HeaderMap; +typedef std::pair HeaderPair; + class WebServer; /** @@ -25,6 +29,7 @@ class WebServer; * * A web server. */ + class WebServer { public: /** @@ -33,6 +38,32 @@ class WebServer { class HTTPRequest { public: HTTPRequest(struct http_message* message); + static const char HTTP_HEADER_ACCEPT[]; + static const char HTTP_HEADER_ALLOW[]; + static const char HTTP_HEADER_CONNECTION[]; + static const char HTTP_HEADER_CONTENT_LENGTH[]; + static const char HTTP_HEADER_CONTENT_TYPE[]; + static const char HTTP_HEADER_CONTENT_ENCODING[]; + static const char HTTP_HEADER_COOKIE[]; + static const char HTTP_HEADER_HOST[]; + static const char HTTP_HEADER_LAST_MODIFIED[]; + static const char HTTP_HEADER_ORIGIN[]; + static const char HTTP_HEADER_SEC_WEBSOCKET_ACCEPT[]; + static const char HTTP_HEADER_SEC_WEBSOCKET_PROTOCOL[]; + static const char HTTP_HEADER_SEC_WEBSOCKET_KEY[]; + static const char HTTP_HEADER_SEC_WEBSOCKET_VERSION[]; + static const char HTTP_HEADER_UPGRADE[]; + static const char HTTP_HEADER_USER_AGENT[]; + + static const char HTTP_METHOD_CONNECT[]; + static const char HTTP_METHOD_DELETE[]; + static const char HTTP_METHOD_GET[]; + static const char HTTP_METHOD_HEAD[]; + static const char HTTP_METHOD_OPTIONS[]; + static const char HTTP_METHOD_PATCH[]; + static const char HTTP_METHOD_POST[]; + static const char HTTP_METHOD_PUT[]; + const char* getMethod() const; const char* getPath() const; const char* getBody() const; @@ -41,9 +72,14 @@ class WebServer { size_t getBodyLen() const; std::map getQuery() const; std::vector pathSplit() const; + HeaderMap getHeaders() const; + std::string getHeader(std::string) const; + std::map parseForm(); + std::string urlDecode(std::string str); private: struct http_message* m_message; + HeaderMap m_headers; }; // HTTPRequest @@ -53,6 +89,20 @@ class WebServer { class HTTPResponse { public: HTTPResponse(struct mg_connection* nc); + + static const int HTTP_STATUS_CONTINUE; + static const int HTTP_STATUS_SWITCHING_PROTOCOL; + static const int HTTP_STATUS_OK; + static const int HTTP_STATUS_MOVED_PERMANENTLY; + static const int HTTP_STATUS_BAD_REQUEST; + static const int HTTP_STATUS_UNAUTHORIZED; + static const int HTTP_STATUS_FORBIDDEN; + static const int HTTP_STATUS_NOT_FOUND; + static const int HTTP_STATUS_METHOD_NOT_ALLOWED; + static const int HTTP_STATUS_INTERNAL_SERVER_ERROR; + static const int HTTP_STATUS_NOT_IMPLEMENTED; + static const int HTTP_STATUS_SERVICE_UNAVAILABLE; + void addHeader(const std::string& name, const std::string& value); void addHeader(std::string&& name, std::string&& value); void setStatus(int status); @@ -107,7 +157,11 @@ class WebServer { virtual void data(const std::string& data); virtual void multipartEnd(); virtual void multipartStart(); - + virtual void setUri(std::string); + virtual std::string getUri(void); + virtual void setMethod(std::string); + virtual std::string getMethod(void); + virtual int getStatus(void); }; // HTTPMultiPart /** @@ -148,6 +202,8 @@ class WebServer { */ class HTTPMultiPartFactory { public: + virtual ~HTTPMultiPartFactory() = default; + /** * @brief Create a new HTTPMultiPart instance. * @return A new HTTPMultiPart instance. @@ -164,7 +220,7 @@ class WebServer { public: PathHandler(const std::string& method, const std::string& pathPattern, void (*webServerRequestHandler) (WebServer::HTTPRequest* pHttpRequest, WebServer::HTTPResponse* pHttpResponse)); PathHandler(std::string&& method, const std::string& pathPattern, void (*webServerRequestHandler) (WebServer::HTTPRequest* pHttpRequest, WebServer::HTTPResponse* pHttpResponse)); - bool match(const char* method, size_t method_len, const char* path); + bool match(const char* method, size_t method_len, const char* path, size_t path_len); void invoke(HTTPRequest* request, HTTPResponse* response); private: @@ -179,12 +235,19 @@ class WebServer { */ class WebSocketHandler { public: + virtual ~WebSocketHandler() = default; + void onCreated(); virtual void onMessage(const std::string& message); void onClosed(); void sendData(const std::string& message); void sendData(const uint8_t* data, uint32_t size); + void broadcastData(const std::string& message); + void broadcastData(const uint8_t* data, uint32_t size); void close(); + void setConnection(struct mg_connection* mgConnection) { + m_mgConnection = mgConnection; + } private: struct mg_connection* m_mgConnection; @@ -193,12 +256,13 @@ class WebServer { class WebSocketHandlerFactory { public: + virtual ~WebSocketHandlerFactory() = default; virtual WebSocketHandler* newInstance() = 0; }; WebServer(); - virtual ~WebServer(); + virtual ~WebServer() = default; void addPathHandler(const std::string& method, const std::string& pathExpr, void (*webServerRequestHandler) (WebServer::HTTPRequest* pHttpRequest, WebServer::HTTPResponse* pHttpResponse)); void addPathHandler(std::string&& method, const std::string& pathExpr, void (*webServerRequestHandler) (WebServer::HTTPRequest* pHttpRequest, WebServer::HTTPResponse* pHttpResponse)); const std::string& getRootPath(); @@ -208,11 +272,15 @@ class WebServer { void setWebSocketHandlerFactory(WebSocketHandlerFactory* pWebSocketHandlerFactory); void start(unsigned short port = 80); void processRequest(struct mg_connection* mgConnection, struct http_message* message); + void processMultiRequest(struct mg_connection* mgConnection, struct http_message* message); void continueConnection(struct mg_connection* mgConnection); + void broadcastData(const std::string& message); + void broadcastData(const uint8_t* data, uint32_t size); HTTPMultiPartFactory* m_pMultiPartFactory; WebSocketHandlerFactory* m_pWebSocketHandlerFactory; private: + struct mg_mgr m_mgr; std::string m_rootPath; std::vector m_pathHandlers; std::map unfinishedConnection; diff --git a/cpp_utils/WebSocket.cpp b/cpp_utils/WebSocket.cpp index 6930f3b6..0a6b4ef2 100644 --- a/cpp_utils/WebSocket.cpp +++ b/cpp_utils/WebSocket.cpp @@ -5,6 +5,7 @@ * Author: kolban */ +#include "sdkconfig.h" #include #include "WebSocket.h" #include "Task.h" diff --git a/cpp_utils/WiFi.cpp b/cpp_utils/WiFi.cpp index 3c6d4112..20edcb77 100644 --- a/cpp_utils/WiFi.cpp +++ b/cpp_utils/WiFi.cpp @@ -178,6 +178,11 @@ uint8_t WiFi::connectAP(const std::string& ssid, const std::string& password, bo ESP_LOGE(LOG_TAG, "esp_wifi_set_mode: rc=%d %s", errRc, GeneralUtils::errorToString(errRc)); abort(); } + errRc = ::esp_wifi_set_ps(WIFI_PS_NONE); + if (errRc != ESP_OK) { + ESP_LOGE(LOG_TAG, "esp_wifi_set_ps: rc=%d %s", errRc, GeneralUtils::errorToString(errRc)); + //abort(); + } wifi_config_t sta_config; ::memset(&sta_config, 0, sizeof(sta_config)); ::memcpy(sta_config.sta.ssid, ssid.data(), ssid.size()); @@ -508,15 +513,26 @@ std::string WiFi::getStaSSID() { std::vector WiFi::scan() { ESP_LOGD(LOG_TAG, ">> scan"); std::vector apRecords; + wifi_mode_t wifiMode; init(); - esp_err_t errRc = ::esp_wifi_set_mode(WIFI_MODE_STA); + esp_err_t errRc = ::esp_wifi_get_mode(&wifiMode); + if (errRc != ESP_OK) { - ESP_LOGE(LOG_TAG, "esp_wifi_set_mode: rc=%d %s", errRc, GeneralUtils::errorToString(errRc)); + ESP_LOGE(LOG_TAG, "esp_wifi_get_mode: rc=%d %s", errRc, GeneralUtils::errorToString(errRc)); abort(); } + if (wifiMode != WIFI_MODE_APSTA) + { + esp_err_t errRc = ::esp_wifi_set_mode(WIFI_MODE_APSTA); + if (errRc != ESP_OK) { + ESP_LOGE(LOG_TAG, "esp_wifi_set_mode: rc=%d %s", errRc, GeneralUtils::errorToString(errRc)); + abort(); + } + } + errRc = ::esp_wifi_start(); if (errRc != ESP_OK) { ESP_LOGE(LOG_TAG, "esp_wifi_start: rc=%d %s", errRc, GeneralUtils::errorToString(errRc)); @@ -611,7 +627,7 @@ void WiFi::startAP(const std::string& ssid, const std::string& password, wifi_au init(); - esp_err_t errRc = ::esp_wifi_set_mode(WIFI_MODE_AP); + esp_err_t errRc = ::esp_wifi_set_mode(WIFI_MODE_APSTA); // change to APSTA so we can scan if (errRc != ESP_OK) { ESP_LOGE(LOG_TAG, "esp_wifi_set_mode: rc=%d %s", errRc, GeneralUtils::errorToString(errRc)); abort(); diff --git a/networking/bootwifi/BootWiFi.cpp b/networking/bootwifi/BootWiFi.cpp index cc07c898..ef02d6ed 100644 --- a/networking/bootwifi/BootWiFi.cpp +++ b/networking/bootwifi/BootWiFi.cpp @@ -6,6 +6,8 @@ * See the README.md for full information. * */ +#include "sdkconfig.h" +#include #include #include #include @@ -13,8 +15,9 @@ #include #include #include -//#include + //#include + #include #include #include @@ -22,7 +25,14 @@ #include #include #include "BootWiFi.h" -#include "sdkconfig.h" +#include "JSON.h" + +// include in case we do a restart timer, no harm. + +# include + +#include "esp_wifi.h" +#include "esp_err.h" #include "selectAP.h" // If the structure of a record saved for a subsequent reboot changes @@ -31,6 +41,10 @@ #define KEY_VERSION "version" uint32_t g_version=0x0100; +#if (CONFIG_RESTART_COUNTER > 0) +# define KEY_RESTART "restartCount" +#endif + #define KEY_CONNECTION_INFO "connectionInfo" // Key used in NVS for connection info #define BOOTWIFI_NAMESPACE "bootwifi" // Namespace in NVS for bootwifi #define SSID_SIZE (32) // Maximum SSID size @@ -51,14 +65,18 @@ typedef struct { // Forward declarations static void saveConnectionInfo(connection_info_t *pConnectionInfo); +#if (CONFIG_RESTART_COUNTER > 0) +static int restart_count_get(void); +#endif -static const char LOG_TAG[] = "bootwifi"; +static const char TAG[] = "bootwifi"; +static std::vector apList; static void dumpConnectionInfo(connection_info_t *pConnectionInfo) { - ESP_LOGD(LOG_TAG, "connection_info.ssid = %.*s", SSID_SIZE, pConnectionInfo->ssid); - ESP_LOGD(LOG_TAG, "connection_info.password = %.*s", PASSWORD_SIZE, pConnectionInfo->password); - ESP_LOGD(LOG_TAG, "ip: %s, gw: %s, netmask: %s", + ESP_LOGD(TAG, "connection_info.ssid = %.*s", SSID_SIZE, pConnectionInfo->ssid); + ESP_LOGD(TAG, "connection_info.password = %.*s", PASSWORD_SIZE, pConnectionInfo->password); + ESP_LOGD(TAG, "ip: %s, gw: %s, netmask: %s", GeneralUtils::ipToString((uint8_t*)&pConnectionInfo->ipInfo.ip).c_str(), GeneralUtils::ipToString((uint8_t*)&pConnectionInfo->ipInfo.gw).c_str(), GeneralUtils::ipToString((uint8_t*)&pConnectionInfo->ipInfo.netmask).c_str()); @@ -69,7 +87,7 @@ static void dumpConnectionInfo(connection_info_t *pConnectionInfo) { * Retrieve the connection info. A rc==0 means ok. */ static int getConnectionInfo(connection_info_t *pConnectionInfo) { - ESP_LOGD(LOG_TAG, ">> getConnectionInfo"); + ESP_LOGD(TAG, ">> getConnectionInfo"); size_t size; uint32_t version; @@ -78,7 +96,7 @@ static int getConnectionInfo(connection_info_t *pConnectionInfo) { // Check the versions match if ((version & 0xff00) != (g_version & 0xff00)) { - ESP_LOGD(LOG_TAG, "Incompatible versions ... current is %x, found is %x", version, g_version); + ESP_LOGD(TAG, "Incompatible versions ... current is %x, found is %x", version, g_version); return -1; } @@ -87,11 +105,11 @@ static int getConnectionInfo(connection_info_t *pConnectionInfo) { // Do a sanity check on the SSID if (strlen(pConnectionInfo->ssid) == 0) { - ESP_LOGD(LOG_TAG, "NULL ssid detected"); + ESP_LOGD(TAG, "NULL ssid detected"); return -1; } dumpConnectionInfo(pConnectionInfo); - ESP_LOGD(LOG_TAG, "<< getConnectionInfo"); + ESP_LOGD(TAG, "<< getConnectionInfo"); return 0; } // getConnectionInfo @@ -126,11 +144,43 @@ static int checkOverrideGpio() { #endif } // checkOverrideGpio + static void sendForm(HttpRequest* pRequest, HttpResponse* pResponse) { - pResponse->setStatus(HttpResponse::HTTP_STATUS_OK, "OK"); - pResponse->addHeader(HttpRequest::HTTP_HEADER_CONTENT_TYPE, "text/html"); - pResponse->sendData(std::string((char*)selectAP_html, selectAP_html_len)); + pResponse->setStatus(HttpResponse::HTTP_STATUS_OK, "OK"); + pResponse->addHeader(HttpRequest::HTTP_HEADER_CONTENT_TYPE, "text/html"); + + // send the prefix + + pResponse->sendData(formPrefix); + + // now create the SSID input + + std::string ssidInput = + "SSID:\n" + "\n" + "\n" + "\n" + "\n\n\n"; + + pResponse->sendData(ssidInput); + + // then send the postfix + + pResponse->sendData(formPostfix); + + // now close + pResponse->close(); + } // sendForm @@ -148,7 +198,7 @@ static void copyData(uint8_t* pTarget, size_t targetLength, std::string source) * @brief Process the form response. */ static void processForm(HttpRequest* pRequest, HttpResponse* pResponse) { - ESP_LOGD(LOG_TAG, ">> processForm"); + ESP_LOGD(TAG, ">> processForm"); std::map formMap = pRequest->parseForm(); connection_info_t connectionInfo; copyData((uint8_t*)connectionInfo.ssid, SSID_SIZE, formMap["ssid"]); @@ -157,55 +207,71 @@ static void processForm(HttpRequest* pRequest, HttpResponse* pResponse) { try { std::string ipStr = formMap.at("ip"); if (ipStr.empty()) { - ESP_LOGD(LOG_TAG, "No IP address using default 0.0.0.0"); + ESP_LOGD(TAG, "No IP address using default 0.0.0.0"); connectionInfo.ipInfo.ip.addr = 0; } else { inet_pton(AF_INET, ipStr.c_str(), &connectionInfo.ipInfo.ip); } } catch(std::out_of_range& e) { - ESP_LOGD(LOG_TAG, "No IP address using default 0.0.0.0"); + ESP_LOGD(TAG, "No IP address using default 0.0.0.0"); connectionInfo.ipInfo.ip.addr = 0; } try { - std::string gwStr = formMap.at("gw"); - if (gwStr.empty()) { - ESP_LOGD(LOG_TAG, "No GW address using default 0.0.0.0"); - connectionInfo.ipInfo.gw.addr = 0; - } else { - inet_pton(AF_INET, gwStr.c_str(), &connectionInfo.ipInfo.gw); - } - } catch(std::out_of_range& e) { - ESP_LOGD(LOG_TAG, "No GW address using default 0.0.0.0"); + std::string gwStr = formMap.at("gw"); + if (gwStr.empty()) { + ESP_LOGD(TAG, "No GW address using default 0.0.0.0"); connectionInfo.ipInfo.gw.addr = 0; + } else { + inet_pton(AF_INET, gwStr.c_str(), &connectionInfo.ipInfo.gw); + } + } catch(std::out_of_range& e) { + ESP_LOGD(TAG, "No GW address using default 0.0.0.0"); + connectionInfo.ipInfo.gw.addr = 0; } try { - std::string netmaskStr = formMap.at("netmask"); - if (netmaskStr.empty()) { - ESP_LOGD(LOG_TAG, "No Netmask address using default 0.0.0.0"); - connectionInfo.ipInfo.netmask.addr = 0; - } else { - inet_pton(AF_INET, netmaskStr.c_str(), &connectionInfo.ipInfo.netmask); - } - } catch(std::out_of_range& e) { - ESP_LOGD(LOG_TAG, "No Netmask address using default 0.0.0.0"); + std::string netmaskStr = formMap.at("netmask"); + if (netmaskStr.empty()) { + ESP_LOGD(TAG, "No Netmask address using default 0.0.0.0"); connectionInfo.ipInfo.netmask.addr = 0; + } else { + inet_pton(AF_INET, netmaskStr.c_str(), &connectionInfo.ipInfo.netmask); + } + } catch(std::out_of_range& e) { + ESP_LOGD(TAG, "No Netmask address using default 0.0.0.0"); + connectionInfo.ipInfo.netmask.addr = 0; } - ESP_LOGD(LOG_TAG, "ssid: %s, password: %s", connectionInfo.ssid, connectionInfo.password); + ESP_LOGD(TAG, "ssid: %s, password: %s", connectionInfo.ssid, connectionInfo.password); saveConnectionInfo(&connectionInfo); - pResponse->setStatus(HttpResponse::HTTP_STATUS_OK, "OK"); - pResponse->addHeader(HttpRequest::HTTP_HEADER_CONTENT_TYPE, "text/html"); + pResponse->setStatus(HttpResponse::HTTP_STATUS_OK, "OK"); + pResponse->addHeader(HttpRequest::HTTP_HEADER_CONTENT_TYPE, "text/html"); //pResponse->sendData(std::string((char*)selectAP_html, selectAP_html_len)); pResponse->close(); FreeRTOS::sleep(500); System::restart(); - ESP_LOGD(LOG_TAG, "<< processForm"); + ESP_LOGD(TAG, "<< processForm"); } // processForm +static void handleScan(HttpRequest* pRequest, HttpResponse* pResponse) { + // send back the apList + // print all the aps + pResponse->setStatus(HttpResponse::HTTP_STATUS_OK, "OK"); + pResponse->addHeader(HttpRequest::HTTP_HEADER_CONTENT_TYPE, "text/html"); + pResponse->sendData(std::string((char*)selectAP_html, selectAP_html_len)); + pResponse->close(); + + std::vector::iterator it; + for (it = apList.begin(); it != apList.end(); it++) + { + ESP_LOGI(TAG, "%s, %d rssi", it->getSSID().c_str() , it->getRSSI()); + } + +} // sendForm + class BootWifiEventHandler: public WiFiEventHandler { public: @@ -223,6 +289,7 @@ class BootWifiEventHandler: public WiFiEventHandler { m_pBootWiFi->m_httpServerStarted = true; m_pBootWiFi->m_httpServer.addPathHandler("GET", "/", sendForm); m_pBootWiFi->m_httpServer.addPathHandler("POST", "/ssidSelected", processForm); + m_pBootWiFi->m_httpServer.addPathHandler("GET", "/scan", handleScan); m_pBootWiFi->m_httpServer.start(80); } ESP_LOGD("BootWifiEventHandler", "<< apStart"); @@ -243,9 +310,9 @@ class BootWifiEventHandler: public WiFiEventHandler { esp_err_t staGotIp(system_event_sta_got_ip_t event_sta_got_ip) { ESP_LOGD("BootWifiEventHandler", ">> staGotIP"); - m_pBootWiFi->m_apConnectionStatus = ESP_OK; // Set the status to ESP_OK + m_pBootWiFi->m_apConnectionStatus = ESP_OK; // Set the status to ESP_OK m_pBootWiFi->m_completeSemaphore.give(); // If we got an IP address, then we can end the boot process. - ESP_LOGD("BootWifiEventHandler", "<< staGotIP"); + printf("IP = %s", m_pBootWiFi->m_wifi.getStaIp().c_str()); return ESP_OK; } // staGotIp @@ -263,17 +330,24 @@ class BootWifiEventHandler: public WiFiEventHandler { * */ void BootWiFi::bootWiFi2() { - ESP_LOGD(LOG_TAG, ">> bootWiFi2"); + ESP_LOGD(TAG, ">> bootWiFi2"); // Check for a GPIO override which occurs when a physical Pin is high // during the test. This can force the ability to check for new configuration // even if the existing configured access point is available. m_wifi.setWifiEventHandler(new BootWifiEventHandler(this)); - if (checkOverrideGpio()) { - ESP_LOGD(LOG_TAG, "- GPIO override detected"); + if (checkOverrideGpio() +#if (CONFIG_RESTART_COUNTER > 0) + || (restart_count_get() >= CONFIG_RESTART_COUNTER) +#endif + ) + { + ESP_LOGD(TAG, "- restart count exceeded starting AP mode"); m_wifi.startAP(m_ssid, m_password); + // do a scan to see what APs are around + apList = m_wifi.scan(); } else { - // There was NO GPIO override, proceed as normal. This means we retrieve + // There was no restart count override, proceed as normal. This means we retrieve // our stored access point information of the access point we should connect // against. If that information doesn't exist, then again we become an // access point ourselves in order to allow a client to connect and bring @@ -283,17 +357,17 @@ void BootWiFi::bootWiFi2() { if (rc == 0) { // We have received connection information, let us now become a station // and attempt to connect to the access point. - ESP_LOGD(LOG_TAG, "- Connecting to access point \"%s\" ...", connectionInfo.ssid); + ESP_LOGD(TAG, "- Connecting to access point \"%s\" ...", connectionInfo.ssid); assert(strlen(connectionInfo.ssid) > 0); m_wifi.setIPInfo( - connectionInfo.ipInfo.ip.addr, - connectionInfo.ipInfo.gw.addr, - connectionInfo.ipInfo.netmask.addr + (uint32_t) connectionInfo.ipInfo.ip.addr, + (uint32_t) connectionInfo.ipInfo.gw.addr, + (uint32_t) connectionInfo.ipInfo.netmask.addr ); - m_apConnectionStatus = m_wifi.connectAP(connectionInfo.ssid, connectionInfo.password); // Try to connect to the access point. - m_completeSemaphore.give(); // end the boot process so we don't hang... + m_apConnectionStatus = m_wifi.connectAP(connectionInfo.ssid, connectionInfo.password); // Try to connect to the access point. + m_completeSemaphore.give(); // end the boot process so we don't hang... } else { // We do NOT have connection information. Let us now become an access @@ -301,9 +375,10 @@ void BootWiFi::bootWiFi2() { // the details that will be eventually used to allow us to connect // as a station. m_wifi.startAP(m_ssid, m_password); + apList = m_wifi.scan(); } // We do NOT have connection info } - ESP_LOGD(LOG_TAG, "<< bootWiFi2"); + ESP_LOGD(TAG, "<< bootWiFi2"); } // bootWiFi2 @@ -327,21 +402,95 @@ void BootWiFi::setAccessPointCredentials(std::string ssid, std::string password) * @returns ESP_OK if successfully connected to an access point. Otherwise returns wifi_err_reason_t - to print use GeneralUtils::wifiErrorToString */ uint8_t BootWiFi::boot() { - ESP_LOGD(LOG_TAG, ">> boot"); - ESP_LOGD(LOG_TAG, " +----------+"); - ESP_LOGD(LOG_TAG, " | BootWiFi |"); - ESP_LOGD(LOG_TAG, " +----------+"); - ESP_LOGD(LOG_TAG, " Access point credentials: %s/%s", m_ssid.c_str(), m_password.c_str()); + ESP_LOGD(TAG, ">> boot"); + ESP_LOGD(TAG, " +----------+"); + ESP_LOGD(TAG, " | BootWiFi |"); + ESP_LOGD(TAG, " +----------+"); + ESP_LOGD(TAG, " Access point credentials: %s/%s", m_ssid.c_str(), m_password.c_str()); m_completeSemaphore.take("boot"); // Take the semaphore which will be unlocked when we complete booting. bootWiFi2(); m_completeSemaphore.wait("boot"); // Wait for the semaphore that indicated we have completed booting. m_wifi.setWifiEventHandler(nullptr); // Remove the WiFi boot handler when we have completed booting. - ESP_LOGD(LOG_TAG, "<< boot"); + ESP_LOGD(TAG, "<< boot"); return m_apConnectionStatus; } // boot BootWiFi::BootWiFi() { m_httpServerStarted = false; - m_apConnectionStatus = UINT8_MAX; - setAccessPointCredentials("esp32", "password"); // Default access point credentials + m_apConnectionStatus = UINT8_MAX; + uint8_t myApMac[6]; + esp_read_mac(myApMac, ESP_MAC_WIFI_SOFTAP); + char ssid[32] = ""; + sprintf(ssid, "esp32-%02x%02x", myApMac[4], myApMac[5]); + ESP_LOGI(TAG, "SoftAp = [%s]\n", ssid); + setAccessPointCredentials(ssid, "password"); // Default access point credentials +} + +BootWiFi::BootWiFi(char *ssidBase) { + m_httpServerStarted = false; + m_apConnectionStatus = UINT8_MAX; + uint8_t myApMac[6]; + esp_read_mac(myApMac, ESP_MAC_WIFI_SOFTAP); + char ssid[32] = ""; + sprintf(ssid, "%s%02x%02x", ssidBase, myApMac[4], myApMac[5]); + ESP_LOGI(TAG, "SoftAp = [%s]\n", ssid); + setAccessPointCredentials(ssid, "password"); // Default access point credentials +} + +std::string BootWiFi::getIp(void) +{ + return (m_wifi.getStaIp()); +} + +#if (CONFIG_RESTART_COUNTER > 0) +static void restart_count_erase_timercb(void *timer) +{ + ESP_LOGI(TAG, "Erase restart callback"); + if (!xTimerStop(timer, portMAX_DELAY)) { + ESP_LOGE(TAG, "xTimerStop timer: %p", timer); + } + + if (!xTimerDelete(timer, portMAX_DELAY)) { + ESP_LOGE(TAG, "xTimerDelete timer: %p", timer); + } + + NVS myNamespace(BOOTWIFI_NAMESPACE); + myNamespace.set(KEY_RESTART, 0); + ESP_LOGD(TAG, "Erase restart count"); } + +static int restart_count_get() +{ + TimerHandle_t timer = NULL; + uint32_t restart_count = 0; + + /**< If the device restarts within the instruction time, + the restart_count value will be incremented by one */ + NVS myNamespace(BOOTWIFI_NAMESPACE); + myNamespace.get(KEY_RESTART, restart_count); + restart_count++; + myNamespace.set(KEY_RESTART, restart_count); + + /* create a timer that in N seconds will clear the restart counter, so if you restart within N secs */ + /* the restart counter will increment. So to get it to increment remove power and add power then remove */ + /* power again, if you continue to reboot within N seconds each time then the restart counter will continue */ + /* to incrment, the caller can decide when to do something because of this */ + + timer = xTimerCreate("restart_count_erase", CONFIG_RESTART_TIMEOUT / portTICK_RATE_MS, + false, NULL, restart_count_erase_timercb); + if (!timer) + { + ESP_LOGE(TAG, "xTaskCreate, timer: %p", timer); + return 0; + } + + if (xTimerStart(timer, 0) != pdPASS) + { + ESP_LOGE(TAG, "failed to start"); + } + +// vTaskStartScheduler(); + ESP_LOGD(TAG, "restart count %d", restart_count); + return restart_count; +} +#endif diff --git a/networking/bootwifi/BootWiFi.h b/networking/bootwifi/BootWiFi.h index 9c4a17b6..6957678e 100644 --- a/networking/bootwifi/BootWiFi.h +++ b/networking/bootwifi/BootWiFi.h @@ -24,13 +24,16 @@ class BootWiFi { bool m_httpServerStarted; std::string m_ssid; std::string m_password; - uint8_t m_apConnectionStatus; // receives the connection status. ESP_OK = received SYSTEM_EVENT_STA_GOT_IP event. + uint8_t m_apConnectionStatus; // receives the connection status. ESP_OK = received SYSTEM_EVENT_STA_GOT_IP event. FreeRTOS::Semaphore m_completeSemaphore = FreeRTOS::Semaphore("completeSemaphore"); public: BootWiFi(); + BootWiFi(char *); void setAccessPointCredentials(std::string ssid, std::string password); uint8_t boot(); + uint8_t boot(std::string); + std::string getIp(void); }; #endif /* MAIN_BOOTWIFI_H_ */ diff --git a/networking/bootwifi/Kconfig b/networking/bootwifi/Kconfig new file mode 100644 index 00000000..52061da0 --- /dev/null +++ b/networking/bootwifi/Kconfig @@ -0,0 +1,15 @@ +menu "BootWifi Settings" + +config RESTART_COUNTER + int "# times to restart quickly to enter bootwifi" + default 0 + help + Set to a positive value to be able to restart quickly and enter config from wifi + +config RESTART_TIMEOUT + int "#ms you have to restart and be considered quickly" + default 3000 + help + Increase or decrease. + +endmenu \ No newline at end of file diff --git a/networking/bootwifi/selectAP.h b/networking/bootwifi/selectAP.h index c96f623a..70350cba 100644 --- a/networking/bootwifi/selectAP.h +++ b/networking/bootwifi/selectAP.h @@ -1,3 +1,56 @@ +std::string formPrefix = + "\n" + "\n" + "\n" + "\n" + "Select WiFi\n" + "\n" + "\n" + "\n" + "
\n" + "

Select WiFi

\n" + "
\n" + "\n" + "\n" + "\n"; +#if 0 + "\n"; + "\n" +#endif + +std::string formPostfix = + "\n" + "\n" + "\n" + "\n" + "\n" + "\n" + "\n" + "\n" + "\n" + "\n" + "\n" + "\n" + "\n" + "\n" + "\n" + "\n" + "\n" + "\n" + "
SSID:
Password:
IP address:
Gateway address:
Netmask:
\n" + "

\n" + "\n" + "

\n" + "
\n" + "
\n" + "The IP address, gateway address and netmask are optional. If not supplied " + "these values will be issued by the WiFi access point." + "
\n" + "
\n" + "\n" + "\n" + "\n" + ; unsigned char selectAP_html[] = { 0x3c, 0x21, 0x44, 0x4f, 0x43, 0x54, 0x59, 0x50, 0x45, 0x20, 0x68, 0x74, 0x6d, 0x6c, 0x3e, 0x0a, 0x3c, 0x68, 0x74, 0x6d, 0x6c, 0x3e, 0x0a, 0x3c,