Skip to content

Commit b07eb17

Browse files
mathieucarbouSuGliderpre-commit-ci-lite[bot]
authoredJan 7, 2025··
feat(webserver): Middleware with default middleware for cors, authc, curl-like logging (#10750)
* feat(webserver): Middleware with default middleware for cors, authc, curl-like logging * ci(pre-commit): Apply automatic fixes --------- Co-authored-by: Rodrigo Garcia <rodrigo.garcia@espressif.com> Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
1 parent 089cbab commit b07eb17

File tree

14 files changed

+895
-101
lines changed

14 files changed

+895
-101
lines changed
 

‎CMakeLists.txt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,11 @@ set(ARDUINO_LIBRARY_USB_SRCS
242242
set(ARDUINO_LIBRARY_WebServer_SRCS
243243
libraries/WebServer/src/WebServer.cpp
244244
libraries/WebServer/src/Parsing.cpp
245-
libraries/WebServer/src/detail/mimetable.cpp)
245+
libraries/WebServer/src/detail/mimetable.cpp
246+
libraries/WebServer/src/middleware/MiddlewareChain.cpp
247+
libraries/WebServer/src/middleware/AuthenticationMiddleware.cpp
248+
libraries/WebServer/src/middleware/CorsMiddleware.cpp
249+
libraries/WebServer/src/middleware/LoggingMiddleware.cpp)
246250

247251
set(ARDUINO_LIBRARY_NetworkClientSecure_SRCS
248252
libraries/NetworkClientSecure/src/ssl_client.cpp
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
/**
2+
* Basic example of using Middlewares with WebServer
3+
*
4+
* Middleware are common request/response processing functions that can be applied globally to all incoming requests or to specific handlers.
5+
* They allow for a common processing thus saving memory and space to avoid duplicating code or states on multiple handlers.
6+
*
7+
* Once the example is flashed (with the correct WiFi credentials), you can test the following scenarios with the listed curl commands:
8+
* - CORS Middleware: answers to OPTIONS requests with the specified CORS headers and also add CORS headers to the response when the request has the Origin header
9+
* - Logging Middleware: logs the request and response to an output in a curl-like format
10+
* - Authentication Middleware: test the authentication with Digest Auth
11+
*
12+
* You can also add your own Middleware by extending the Middleware class and implementing the run method.
13+
* When implementing a Middleware, you can decide when to call the next Middleware in the chain by calling next().
14+
*
15+
* Middleware are execute in order of addition, the ones attached to the server will be executed first.
16+
*/
17+
#include <WiFi.h>
18+
#include <WebServer.h>
19+
#include <Middlewares.h>
20+
21+
// Your AP WiFi Credentials
22+
// ( This is the AP your ESP will broadcast )
23+
const char *ap_ssid = "ESP32_Demo";
24+
const char *ap_password = "";
25+
26+
WebServer server(80);
27+
28+
LoggingMiddleware logger;
29+
CorsMiddleware cors;
30+
AuthenticationMiddleware auth;
31+
32+
void setup(void) {
33+
Serial.begin(115200);
34+
WiFi.softAP(ap_ssid, ap_password);
35+
36+
Serial.print("IP address: ");
37+
Serial.println(WiFi.AP.localIP());
38+
39+
// curl-like output example:
40+
//
41+
// > curl -v -X OPTIONS -H "origin: http://192.168.4.1" http://192.168.4.1/
42+
//
43+
// Connection from 192.168.4.2:51683
44+
// > OPTIONS / HTTP/1.1
45+
// > Host: 192.168.4.1
46+
// > User-Agent: curl/8.10.0
47+
// > Accept: */*
48+
// > origin: http://192.168.4.1
49+
// >
50+
// * Processed in 5 ms
51+
// < HTTP/1.HTTP/1.1 200 OK
52+
// < Content-Type: text/html
53+
// < Access-Control-Allow-Origin: http://192.168.4.1
54+
// < Access-Control-Allow-Methods: POST,GET,OPTIONS,DELETE
55+
// < Access-Control-Allow-Headers: X-Custom-Header
56+
// < Access-Control-Allow-Credentials: false
57+
// < Access-Control-Max-Age: 600
58+
// < Content-Length: 0
59+
// < Connection: close
60+
// <
61+
logger.setOutput(Serial);
62+
63+
cors.setOrigin("http://192.168.4.1");
64+
cors.setMethods("POST,GET,OPTIONS,DELETE");
65+
cors.setHeaders("X-Custom-Header");
66+
cors.setAllowCredentials(false);
67+
cors.setMaxAge(600);
68+
69+
auth.setUsername("admin");
70+
auth.setPassword("admin");
71+
auth.setRealm("My Super App");
72+
auth.setAuthMethod(DIGEST_AUTH);
73+
auth.setAuthFailureMessage("Authentication Failed");
74+
75+
server.addMiddleware(&logger);
76+
server.addMiddleware(&cors);
77+
78+
// Not authenticated
79+
//
80+
// Test CORS preflight request with:
81+
// > curl -v -X OPTIONS -H "origin: http://192.168.4.1" http://192.168.4.1/
82+
//
83+
// Test cross-domain request with:
84+
// > curl -v -X GET -H "origin: http://192.168.4.1" http://192.168.4.1/
85+
//
86+
server.on("/", []() {
87+
server.send(200, "text/plain", "Home");
88+
});
89+
90+
// Authenticated
91+
//
92+
// > curl -v -X GET -H "origin: http://192.168.4.1" http://192.168.4.1/protected
93+
//
94+
// Outputs:
95+
//
96+
// * Connection from 192.168.4.2:51750
97+
// > GET /protected HTTP/1.1
98+
// > Host: 192.168.4.1
99+
// > User-Agent: curl/8.10.0
100+
// > Accept: */*
101+
// > origin: http://192.168.4.1
102+
// >
103+
// * Processed in 7 ms
104+
// < HTTP/1.HTTP/1.1 401 Unauthorized
105+
// < Content-Type: text/html
106+
// < Access-Control-Allow-Origin: http://192.168.4.1
107+
// < Access-Control-Allow-Methods: POST,GET,OPTIONS,DELETE
108+
// < Access-Control-Allow-Headers: X-Custom-Header
109+
// < Access-Control-Allow-Credentials: false
110+
// < Access-Control-Max-Age: 600
111+
// < WWW-Authenticate: Digest realm="My Super App", qop="auth", nonce="ac388a64184e3e102aae6fff1c9e8d76", opaque="e7d158f2b54d25328142d118ff0f932d"
112+
// < Content-Length: 21
113+
// < Connection: close
114+
// <
115+
//
116+
// > curl -v -X GET -H "origin: http://192.168.4.1" --digest -u admin:admin http://192.168.4.1/protected
117+
//
118+
// Outputs:
119+
//
120+
// * Connection from 192.168.4.2:53662
121+
// > GET /protected HTTP/1.1
122+
// > Authorization: Digest username="admin", realm="My Super App", nonce="db9e6824eb2a13bc7b2bf8f3c43db896", uri="/protected", cnonce="NTliZDZiNTcwODM2MzAyY2JjMDBmZGJmNzFiY2ZmNzk=", nc=00000001, qop=auth, response="6ebd145ba0d3496a4a73f5ae79ff5264", opaque="23d739c22810282ff820538cba98bda4"
123+
// > Host: 192.168.4.1
124+
// > User-Agent: curl/8.10.0
125+
// > Accept: */*
126+
// > origin: http://192.168.4.1
127+
// >
128+
// Request handling...
129+
// * Processed in 7 ms
130+
// < HTTP/1.HTTP/1.1 200 OK
131+
// < Content-Type: text/plain
132+
// < Access-Control-Allow-Origin: http://192.168.4.1
133+
// < Access-Control-Allow-Methods: POST,GET,OPTIONS,DELETE
134+
// < Access-Control-Allow-Headers: X-Custom-Header
135+
// < Access-Control-Allow-Credentials: false
136+
// < Access-Control-Max-Age: 600
137+
// < Content-Length: 9
138+
// < Connection: close
139+
// <
140+
server
141+
.on(
142+
"/protected",
143+
[]() {
144+
Serial.println("Request handling...");
145+
server.send(200, "text/plain", "Protected");
146+
}
147+
)
148+
.addMiddleware(&auth);
149+
150+
// Not found is also handled by global middleware
151+
//
152+
// curl -v -X GET -H "origin: http://192.168.4.1" http://192.168.4.1/inexsting
153+
//
154+
// Outputs:
155+
//
156+
// * Connection from 192.168.4.2:53683
157+
// > GET /inexsting HTTP/1.1
158+
// > Host: 192.168.4.1
159+
// > User-Agent: curl/8.10.0
160+
// > Accept: */*
161+
// > origin: http://192.168.4.1
162+
// >
163+
// * Processed in 16 ms
164+
// < HTTP/1.HTTP/1.1 404 Not Found
165+
// < Content-Type: text/plain
166+
// < Access-Control-Allow-Origin: http://192.168.4.1
167+
// < Access-Control-Allow-Methods: POST,GET,OPTIONS,DELETE
168+
// < Access-Control-Allow-Headers: X-Custom-Header
169+
// < Access-Control-Allow-Credentials: false
170+
// < Access-Control-Max-Age: 600
171+
// < Content-Length: 14
172+
// < Connection: close
173+
// <
174+
server.onNotFound([]() {
175+
server.send(404, "text/plain", "Page not found");
176+
});
177+
178+
server.collectAllHeaders();
179+
server.begin();
180+
Serial.println("HTTP server started");
181+
}
182+
183+
void loop(void) {
184+
server.handleClient();
185+
delay(2); //allow the cpu to switch to other tasks
186+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{
2+
"requires": [
3+
"CONFIG_SOC_WIFI_SUPPORTED=y"
4+
]
5+
}

‎libraries/WebServer/src/Middlewares.h

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
#ifndef MIDDLEWARES_H
2+
#define MIDDLEWARES_H
3+
4+
#include <WebServer.h>
5+
#include <Stream.h>
6+
7+
#include <assert.h>
8+
9+
// curl-like logging middleware
10+
class LoggingMiddleware : public Middleware {
11+
public:
12+
void setOutput(Print &output);
13+
14+
bool run(WebServer &server, Middleware::Callback next) override;
15+
16+
private:
17+
Print *_out = nullptr;
18+
};
19+
20+
class CorsMiddleware : public Middleware {
21+
public:
22+
CorsMiddleware &setOrigin(const char *origin);
23+
CorsMiddleware &setMethods(const char *methods);
24+
CorsMiddleware &setHeaders(const char *headers);
25+
CorsMiddleware &setAllowCredentials(bool credentials);
26+
CorsMiddleware &setMaxAge(uint32_t seconds);
27+
28+
void addCORSHeaders(WebServer &server);
29+
30+
bool run(WebServer &server, Middleware::Callback next) override;
31+
32+
private:
33+
String _origin = F("*");
34+
String _methods = F("*");
35+
String _headers = F("*");
36+
bool _credentials = true;
37+
uint32_t _maxAge = 86400;
38+
};
39+
40+
class AuthenticationMiddleware : public Middleware {
41+
public:
42+
AuthenticationMiddleware &setUsername(const char *username);
43+
AuthenticationMiddleware &setPassword(const char *password);
44+
AuthenticationMiddleware &setPasswordHash(const char *sha1AsBase64orHex);
45+
AuthenticationMiddleware &setCallback(WebServer::THandlerFunctionAuthCheck fn);
46+
47+
AuthenticationMiddleware &setRealm(const char *realm);
48+
AuthenticationMiddleware &setAuthMethod(HTTPAuthMethod method);
49+
AuthenticationMiddleware &setAuthFailureMessage(const char *message);
50+
51+
bool isAllowed(WebServer &server) const;
52+
53+
bool run(WebServer &server, Middleware::Callback next) override;
54+
55+
private:
56+
String _username;
57+
String _password;
58+
bool _hash = false;
59+
WebServer::THandlerFunctionAuthCheck _callback;
60+
61+
const char *_realm = nullptr;
62+
HTTPAuthMethod _method = BASIC_AUTH;
63+
String _authFailMsg;
64+
};
65+
66+
#endif

‎libraries/WebServer/src/Parsing.cpp

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,14 @@ bool WebServer::_parseRequest(NetworkClient &client) {
7878
String req = client.readStringUntil('\r');
7979
client.readStringUntil('\n');
8080
//reset header value
81-
for (int i = 0; i < _headerKeysCount; ++i) {
82-
_currentHeaders[i].value = String();
81+
if (_collectAllHeaders) {
82+
// clear previous headers
83+
collectAllHeaders();
84+
} else {
85+
// clear previous headers
86+
for (RequestArgument *header = _currentHeaders; header; header = header->next) {
87+
header->value = String();
88+
}
8389
}
8490

8591
// First line of HTTP request looks like "GET /path HTTP/1.1"
@@ -154,9 +160,6 @@ bool WebServer::_parseRequest(NetworkClient &client) {
154160
headerValue.trim();
155161
_collectHeader(headerName.c_str(), headerValue.c_str());
156162

157-
log_v("headerName: %s", headerName.c_str());
158-
log_v("headerValue: %s", headerValue.c_str());
159-
160163
if (headerName.equalsIgnoreCase(FPSTR(Content_Type))) {
161164
using namespace mime;
162165
if (headerValue.startsWith(FPSTR(mimeTable[txt].mimeType))) {
@@ -254,9 +257,6 @@ bool WebServer::_parseRequest(NetworkClient &client) {
254257
headerValue = req.substring(headerDiv + 2);
255258
_collectHeader(headerName.c_str(), headerValue.c_str());
256259

257-
log_v("headerName: %s", headerName.c_str());
258-
log_v("headerValue: %s", headerValue.c_str());
259-
260260
if (headerName.equalsIgnoreCase("Host")) {
261261
_hostHeader = headerValue;
262262
}
@@ -272,12 +272,29 @@ bool WebServer::_parseRequest(NetworkClient &client) {
272272
}
273273

274274
bool WebServer::_collectHeader(const char *headerName, const char *headerValue) {
275-
for (int i = 0; i < _headerKeysCount; i++) {
276-
if (_currentHeaders[i].key.equalsIgnoreCase(headerName)) {
277-
_currentHeaders[i].value = headerValue;
275+
RequestArgument *last = nullptr;
276+
for (RequestArgument *header = _currentHeaders; header; header = header->next) {
277+
if (header->next == nullptr) {
278+
last = header;
279+
}
280+
if (header->key.equalsIgnoreCase(headerName)) {
281+
header->value = headerValue;
282+
log_v("header collected: %s: %s", headerName, headerValue);
278283
return true;
279284
}
280285
}
286+
assert(last);
287+
if (_collectAllHeaders) {
288+
last->next = new RequestArgument();
289+
last->next->key = headerName;
290+
last->next->value = headerValue;
291+
_headerKeysCount++;
292+
log_v("header collected: %s: %s", headerName, headerValue);
293+
return true;
294+
}
295+
296+
log_v("header skipped: %s: %s", headerName, headerValue);
297+
281298
return false;
282299
}
283300

‎libraries/WebServer/src/WebServer.cpp

Lines changed: 183 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -41,31 +41,28 @@ static const char WWW_Authenticate[] = "WWW-Authenticate";
4141
static const char Content_Length[] = "Content-Length";
4242
static const char ETAG_HEADER[] = "If-None-Match";
4343

44-
WebServer::WebServer(IPAddress addr, int port)
45-
: _corsEnabled(false), _server(addr, port), _currentMethod(HTTP_ANY), _currentVersion(0), _currentStatus(HC_NONE), _statusChange(0), _nullDelay(true),
46-
_currentHandler(nullptr), _firstHandler(nullptr), _lastHandler(nullptr), _currentArgCount(0), _currentArgs(nullptr), _postArgsLen(0), _postArgs(nullptr),
47-
_headerKeysCount(0), _currentHeaders(nullptr), _contentLength(0), _clientContentLength(0), _chunked(false) {
44+
WebServer::WebServer(IPAddress addr, int port) : _server(addr, port) {
4845
log_v("WebServer::Webserver(addr=%s, port=%d)", addr.toString().c_str(), port);
4946
}
5047

51-
WebServer::WebServer(int port)
52-
: _corsEnabled(false), _server(port), _currentMethod(HTTP_ANY), _currentVersion(0), _currentStatus(HC_NONE), _statusChange(0), _nullDelay(true),
53-
_currentHandler(nullptr), _firstHandler(nullptr), _lastHandler(nullptr), _currentArgCount(0), _currentArgs(nullptr), _postArgsLen(0), _postArgs(nullptr),
54-
_headerKeysCount(0), _currentHeaders(nullptr), _contentLength(0), _clientContentLength(0), _chunked(false) {
48+
WebServer::WebServer(int port) : _server(port) {
5549
log_v("WebServer::Webserver(port=%d)", port);
5650
}
5751

5852
WebServer::~WebServer() {
5953
_server.close();
60-
if (_currentHeaders) {
61-
delete[] _currentHeaders;
62-
}
54+
55+
_clearRequestHeaders();
56+
_clearResponseHeaders();
57+
delete _chain;
58+
6359
RequestHandler *handler = _firstHandler;
6460
while (handler) {
6561
RequestHandler *next = handler->next();
6662
delete handler;
6763
handler = next;
6864
}
65+
_firstHandler = nullptr;
6966
}
7067

7168
void WebServer::begin() {
@@ -436,7 +433,17 @@ void WebServer::handleClient() {
436433
_currentClient.setTimeout(HTTP_MAX_SEND_WAIT); /* / 1000 removed, WifiClient setTimeout changed to ms */
437434
if (_parseRequest(_currentClient)) {
438435
_contentLength = CONTENT_LENGTH_NOT_SET;
439-
_handleRequest();
436+
_responseCode = 0;
437+
_clearResponseHeaders();
438+
439+
// Run server-level middlewares
440+
if (_chain) {
441+
_chain->runChain(*this, [this]() {
442+
return _handleRequest();
443+
});
444+
} else {
445+
_handleRequest();
446+
}
440447

441448
if (_currentClient.isSSE()) {
442449
_currentStatus = HC_WAIT_CLOSE;
@@ -495,16 +502,22 @@ void WebServer::stop() {
495502
}
496503

497504
void WebServer::sendHeader(const String &name, const String &value, bool first) {
498-
String headerLine = name;
499-
headerLine += F(": ");
500-
headerLine += value;
501-
headerLine += "\r\n";
505+
RequestArgument *header = new RequestArgument();
506+
header->key = name;
507+
header->value = value;
502508

503-
if (first) {
504-
_responseHeaders = headerLine + _responseHeaders;
509+
if (!_responseHeaders || first) {
510+
header->next = _responseHeaders;
511+
_responseHeaders = header;
505512
} else {
506-
_responseHeaders += headerLine;
513+
RequestArgument *last = _responseHeaders;
514+
while (last->next) {
515+
last = last->next;
516+
}
517+
last->next = header;
507518
}
519+
520+
_responseHeaderCount++;
508521
}
509522

510523
void WebServer::setContentLength(const size_t contentLength) {
@@ -529,11 +542,14 @@ void WebServer::enableETag(bool enable, ETagFunction fn) {
529542
}
530543

531544
void WebServer::_prepareHeader(String &response, int code, const char *content_type, size_t contentLength) {
532-
response = String(F("HTTP/1.")) + String(_currentVersion) + ' ';
533-
response += String(code);
534-
response += ' ';
535-
response += _responseCodeToString(code);
536-
response += "\r\n";
545+
_responseCode = code;
546+
547+
response.concat(version());
548+
response.concat(' ');
549+
response.concat(String(code));
550+
response.concat(' ');
551+
response.concat(responseCodeToString(code));
552+
response.concat(F("\r\n"));
537553

538554
using namespace mime;
539555
if (!content_type) {
@@ -558,19 +574,21 @@ void WebServer::_prepareHeader(String &response, int code, const char *content_t
558574
}
559575
sendHeader(String(F("Connection")), String(F("close")));
560576

561-
response += _responseHeaders;
562-
response += "\r\n";
563-
_responseHeaders = "";
577+
for (RequestArgument *header = _responseHeaders; header; header = header->next) {
578+
response.concat(header->key);
579+
response.concat(F(": "));
580+
response.concat(header->value);
581+
response.concat(F("\r\n"));
582+
}
583+
584+
response.concat(F("\r\n"));
564585
}
565586

566587
void WebServer::send(int code, const char *content_type, const String &content) {
567588
String header;
568589
// Can we assume the following?
569590
//if(code == 200 && content.length() == 0 && _contentLength == CONTENT_LENGTH_NOT_SET)
570591
// _contentLength = CONTENT_LENGTH_UNKNOWN;
571-
if (content.length() == 0) {
572-
log_w("content length is zero");
573-
}
574592
_prepareHeader(header, code, content_type, content.length());
575593
_currentClientWrite(header.c_str(), header.length());
576594
if (content.length()) {
@@ -728,52 +746,51 @@ bool WebServer::hasArg(const String &name) const {
728746
}
729747

730748
String WebServer::header(const String &name) const {
731-
for (int i = 0; i < _headerKeysCount; ++i) {
732-
if (_currentHeaders[i].key.equalsIgnoreCase(name)) {
733-
return _currentHeaders[i].value;
749+
for (RequestArgument *current = _currentHeaders; current; current = current->next) {
750+
if (current->key.equalsIgnoreCase(name)) {
751+
return current->value;
734752
}
735753
}
736754
return "";
737755
}
738756

739757
void WebServer::collectHeaders(const char *headerKeys[], const size_t headerKeysCount) {
740-
_headerKeysCount = headerKeysCount + 2;
741-
if (_currentHeaders) {
742-
delete[] _currentHeaders;
743-
}
744-
_currentHeaders = new RequestArgument[_headerKeysCount];
745-
_currentHeaders[0].key = FPSTR(AUTHORIZATION_HEADER);
746-
_currentHeaders[1].key = FPSTR(ETAG_HEADER);
758+
collectAllHeaders();
759+
_collectAllHeaders = false;
760+
761+
_headerKeysCount += headerKeysCount;
762+
763+
RequestArgument *last = _currentHeaders->next;
764+
747765
for (int i = 2; i < _headerKeysCount; i++) {
748-
_currentHeaders[i].key = headerKeys[i - 2];
766+
last->next = new RequestArgument();
767+
last->next->key = headerKeys[i - 2];
768+
last = last->next;
749769
}
750770
}
751771

752772
String WebServer::header(int i) const {
753-
if (i < _headerKeysCount) {
754-
return _currentHeaders[i].value;
773+
RequestArgument *current = _currentHeaders;
774+
while (current && i--) {
775+
current = current->next;
755776
}
756-
return "";
777+
return current ? current->value : emptyString;
757778
}
758779

759780
String WebServer::headerName(int i) const {
760-
if (i < _headerKeysCount) {
761-
return _currentHeaders[i].key;
781+
RequestArgument *current = _currentHeaders;
782+
while (current && i--) {
783+
current = current->next;
762784
}
763-
return "";
785+
return current ? current->key : emptyString;
764786
}
765787

766788
int WebServer::headers() const {
767789
return _headerKeysCount;
768790
}
769791

770792
bool WebServer::hasHeader(const String &name) const {
771-
for (int i = 0; i < _headerKeysCount; ++i) {
772-
if ((_currentHeaders[i].key.equalsIgnoreCase(name)) && (_currentHeaders[i].value.length() > 0)) {
773-
return true;
774-
}
775-
}
776-
return false;
793+
return header(name).length() > 0;
777794
}
778795

779796
String WebServer::hostHeader() const {
@@ -788,16 +805,17 @@ void WebServer::onNotFound(THandlerFunction fn) {
788805
_notFoundHandler = fn;
789806
}
790807

791-
void WebServer::_handleRequest() {
808+
bool WebServer::_handleRequest() {
792809
bool handled = false;
793-
if (!_currentHandler) {
794-
log_e("request handler not found");
795-
} else {
796-
handled = _currentHandler->handle(*this, _currentMethod, _currentUri);
810+
if (_currentHandler) {
811+
handled = _currentHandler->process(*this, _currentMethod, _currentUri);
797812
if (!handled) {
798813
log_e("request handler failed to handle request");
799814
}
800815
}
816+
// DO NOT LOG if _currentHandler == null !!
817+
// This is is valid use case to handle any other requests
818+
// Also, this is just causing log flooding
801819
if (!handled && _notFoundHandler) {
802820
_notFoundHandler();
803821
handled = true;
@@ -811,6 +829,7 @@ void WebServer::_handleRequest() {
811829
_finalizeResponse();
812830
}
813831
_currentUri = "";
832+
return handled;
814833
}
815834

816835
void WebServer::_finalizeResponse() {
@@ -819,7 +838,7 @@ void WebServer::_finalizeResponse() {
819838
}
820839
}
821840

822-
String WebServer::_responseCodeToString(int code) {
841+
String WebServer::responseCodeToString(int code) {
823842
switch (code) {
824843
case 100: return F("Continue");
825844
case 101: return F("Switching Protocols");
@@ -864,3 +883,108 @@ String WebServer::_responseCodeToString(int code) {
864883
default: return F("");
865884
}
866885
}
886+
887+
void WebServer::_clearResponseHeaders() {
888+
_responseHeaderCount = 0;
889+
RequestArgument *current = _responseHeaders;
890+
while (current) {
891+
RequestArgument *next = current->next;
892+
delete current;
893+
current = next;
894+
}
895+
_responseHeaders = nullptr;
896+
}
897+
898+
void WebServer::_clearRequestHeaders() {
899+
_headerKeysCount = 0;
900+
RequestArgument *current = _currentHeaders;
901+
while (current) {
902+
RequestArgument *next = current->next;
903+
delete current;
904+
current = next;
905+
}
906+
_currentHeaders = nullptr;
907+
}
908+
909+
void WebServer::collectAllHeaders() {
910+
_clearRequestHeaders();
911+
912+
_currentHeaders = new RequestArgument();
913+
_currentHeaders->key = FPSTR(AUTHORIZATION_HEADER);
914+
915+
_currentHeaders->next = new RequestArgument();
916+
_currentHeaders->next->key = FPSTR(ETAG_HEADER);
917+
918+
_headerKeysCount = 2;
919+
_collectAllHeaders = true;
920+
}
921+
922+
const String &WebServer::responseHeader(String name) const {
923+
for (RequestArgument *current = _responseHeaders; current; current = current->next) {
924+
if (current->key.equalsIgnoreCase(name)) {
925+
return current->value;
926+
}
927+
}
928+
return emptyString;
929+
}
930+
931+
const String &WebServer::responseHeader(int i) const {
932+
RequestArgument *current = _responseHeaders;
933+
while (current && i--) {
934+
current = current->next;
935+
}
936+
return current ? current->value : emptyString;
937+
}
938+
939+
const String &WebServer::responseHeaderName(int i) const {
940+
RequestArgument *current = _responseHeaders;
941+
while (current && i--) {
942+
current = current->next;
943+
}
944+
return current ? current->key : emptyString;
945+
}
946+
947+
bool WebServer::hasResponseHeader(const String &name) const {
948+
return header(name).length() > 0;
949+
}
950+
951+
int WebServer::clientContentLength() const {
952+
return _clientContentLength;
953+
}
954+
955+
const String WebServer::version() const {
956+
String v;
957+
v.reserve(8);
958+
v.concat(F("HTTP/1."));
959+
v.concat(_currentVersion);
960+
return v;
961+
}
962+
int WebServer::responseCode() const {
963+
return _responseCode;
964+
}
965+
int WebServer::responseHeaders() const {
966+
return _responseHeaderCount;
967+
}
968+
969+
WebServer &WebServer::addMiddleware(Middleware *middleware) {
970+
if (!_chain) {
971+
_chain = new MiddlewareChain();
972+
}
973+
_chain->addMiddleware(middleware);
974+
return *this;
975+
}
976+
977+
WebServer &WebServer::addMiddleware(Middleware::Function fn) {
978+
if (!_chain) {
979+
_chain = new MiddlewareChain();
980+
}
981+
_chain->addMiddleware(fn);
982+
return *this;
983+
}
984+
985+
WebServer &WebServer::removeMiddleware(Middleware *middleware) {
986+
if (_chain) {
987+
_chain->removeMiddleware(middleware);
988+
}
989+
return *this;
990+
}

‎libraries/WebServer/src/WebServer.h

Lines changed: 50 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ typedef struct {
9292
void *data; // additional data
9393
} HTTPRaw;
9494

95+
#include "middleware/Middleware.h"
9596
#include "detail/RequestHandler.h"
9697

9798
namespace fs {
@@ -158,6 +159,10 @@ class WebServer {
158159
void onNotFound(THandlerFunction fn); //called when handler is not assigned
159160
void onFileUpload(THandlerFunction ufn); //handle file uploads
160161

162+
WebServer &addMiddleware(Middleware *middleware);
163+
WebServer &addMiddleware(Middleware::Function fn);
164+
WebServer &removeMiddleware(Middleware *middleware);
165+
161166
String uri() const {
162167
return _currentUri;
163168
}
@@ -181,17 +186,23 @@ class WebServer {
181186
int args() const; // get arguments count
182187
bool hasArg(const String &name) const; // check if argument exists
183188
void collectHeaders(const char *headerKeys[], const size_t headerKeysCount); // set the request headers to collect
189+
void collectAllHeaders(); // collect all request headers
184190
String header(const String &name) const; // get request header value by name
185191
String header(int i) const; // get request header value by number
186192
String headerName(int i) const; // get request header name by number
187193
int headers() const; // get header count
188194
bool hasHeader(const String &name) const; // check if header exists
189195

190-
int clientContentLength() const {
191-
return _clientContentLength;
192-
} // return "content-length" of incoming HTTP header from "_currentClient"
196+
int clientContentLength() const; // return "content-length" of incoming HTTP header from "_currentClient"
197+
const String version() const; // get the HTTP version string
198+
String hostHeader() const; // get request host header if available or empty String if not
193199

194-
String hostHeader() const; // get request host header if available or empty String if not
200+
int responseCode() const; // get the HTTP response code set
201+
int responseHeaders() const; // get the HTTP response headers count
202+
const String &responseHeader(String name) const; // get the HTTP response header value by name
203+
const String &responseHeader(int i) const; // get the HTTP response header value by number
204+
const String &responseHeaderName(int i) const; // get the HTTP response header name by number
205+
bool hasResponseHeader(const String &name) const; // check if response header exists
195206

196207
// send response to the client
197208
// code - HTTP response code, can be 200 or 404
@@ -228,6 +239,8 @@ class WebServer {
228239
bool _eTagEnabled = false;
229240
ETagFunction _eTagFunction = nullptr;
230241

242+
static String responseCodeToString(int code);
243+
231244
protected:
232245
virtual size_t _currentClientWrite(const char *b, size_t l) {
233246
return _currentClient.write(b, l);
@@ -237,11 +250,10 @@ class WebServer {
237250
}
238251
void _addRequestHandler(RequestHandler *handler);
239252
bool _removeRequestHandler(RequestHandler *handler);
240-
void _handleRequest();
253+
bool _handleRequest();
241254
void _finalizeResponse();
242255
bool _parseRequest(NetworkClient &client);
243256
void _parseArguments(const String &data);
244-
static String _responseCodeToString(int code);
245257
bool _parseForm(NetworkClient &client, const String &boundary, uint32_t len);
246258
bool _parseFormUploadAborted();
247259
void _uploadWriteByte(uint8_t b);
@@ -255,48 +267,57 @@ class WebServer {
255267
// for extracting Auth parameters
256268
String _extractParam(String &authReq, const String &param, const char delimit = '"');
257269

270+
void _clearResponseHeaders();
271+
void _clearRequestHeaders();
272+
258273
struct RequestArgument {
259274
String key;
260275
String value;
276+
RequestArgument *next;
261277
};
262278

263-
boolean _corsEnabled;
279+
boolean _corsEnabled = false;
264280
NetworkServer _server;
265281

266282
NetworkClient _currentClient;
267-
HTTPMethod _currentMethod;
283+
HTTPMethod _currentMethod = HTTP_ANY;
268284
String _currentUri;
269-
uint8_t _currentVersion;
270-
HTTPClientStatus _currentStatus;
271-
unsigned long _statusChange;
272-
boolean _nullDelay;
273-
274-
RequestHandler *_currentHandler;
275-
RequestHandler *_firstHandler;
276-
RequestHandler *_lastHandler;
277-
THandlerFunction _notFoundHandler;
278-
THandlerFunction _fileUploadHandler;
279-
280-
int _currentArgCount;
281-
RequestArgument *_currentArgs;
282-
int _postArgsLen;
283-
RequestArgument *_postArgs;
285+
uint8_t _currentVersion = 0;
286+
HTTPClientStatus _currentStatus = HC_NONE;
287+
unsigned long _statusChange = 0;
288+
boolean _nullDelay = true;
289+
290+
RequestHandler *_currentHandler = nullptr;
291+
RequestHandler *_firstHandler = nullptr;
292+
RequestHandler *_lastHandler = nullptr;
293+
THandlerFunction _notFoundHandler = nullptr;
294+
THandlerFunction _fileUploadHandler = nullptr;
295+
296+
int _currentArgCount = 0;
297+
RequestArgument *_currentArgs = nullptr;
298+
int _postArgsLen = 0;
299+
RequestArgument *_postArgs = nullptr;
284300

285301
std::unique_ptr<HTTPUpload> _currentUpload;
286302
std::unique_ptr<HTTPRaw> _currentRaw;
287303

288-
int _headerKeysCount;
289-
RequestArgument *_currentHeaders;
290-
size_t _contentLength;
291-
int _clientContentLength; // "Content-Length" from header of incoming POST or GET request
292-
String _responseHeaders;
304+
int _headerKeysCount = 0;
305+
RequestArgument *_currentHeaders = nullptr;
306+
size_t _contentLength = 0;
307+
int _clientContentLength = 0; // "Content-Length" from header of incoming POST or GET request
308+
RequestArgument *_responseHeaders = nullptr;
293309

294310
String _hostHeader;
295-
bool _chunked;
311+
bool _chunked = false;
296312

297313
String _snonce; // Store noance and opaque for future comparison
298314
String _sopaque;
299315
String _srealm; // Store the Auth realm between Calls
316+
317+
int _responseHeaderCount = 0;
318+
int _responseCode = 0;
319+
bool _collectAllHeaders = false;
320+
MiddlewareChain *_chain = nullptr;
300321
};
301322

302323
#endif //ESP8266WEBSERVER_H

‎libraries/WebServer/src/detail/RequestHandler.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66

77
class RequestHandler {
88
public:
9-
virtual ~RequestHandler() {}
9+
virtual ~RequestHandler() {
10+
delete _chain;
11+
}
1012

1113
/*
1214
note: old handler API for backward compatibility
@@ -75,8 +77,14 @@ class RequestHandler {
7577
_next = r;
7678
}
7779

80+
RequestHandler &addMiddleware(Middleware *middleware);
81+
RequestHandler &addMiddleware(Middleware::Function fn);
82+
RequestHandler &removeMiddleware(Middleware *middleware);
83+
bool process(WebServer &server, HTTPMethod requestMethod, String requestUri);
84+
7885
private:
7986
RequestHandler *_next = nullptr;
87+
MiddlewareChain *_chain = nullptr;
8088

8189
protected:
8290
std::vector<String> pathArgs;

‎libraries/WebServer/src/detail/RequestHandlersImpl.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,39 @@
1010

1111
using namespace mime;
1212

13+
RequestHandler &RequestHandler::addMiddleware(Middleware *middleware) {
14+
if (!_chain) {
15+
_chain = new MiddlewareChain();
16+
}
17+
_chain->addMiddleware(middleware);
18+
return *this;
19+
}
20+
21+
RequestHandler &RequestHandler::addMiddleware(Middleware::Function fn) {
22+
if (!_chain) {
23+
_chain = new MiddlewareChain();
24+
}
25+
_chain->addMiddleware(fn);
26+
return *this;
27+
}
28+
29+
RequestHandler &RequestHandler::removeMiddleware(Middleware *middleware) {
30+
if (_chain) {
31+
_chain->removeMiddleware(middleware);
32+
}
33+
return *this;
34+
}
35+
36+
bool RequestHandler::process(WebServer &server, HTTPMethod requestMethod, String requestUri) {
37+
if (_chain) {
38+
return _chain->runChain(server, [this, &server, &requestMethod, &requestUri]() {
39+
return handle(server, requestMethod, requestUri);
40+
});
41+
} else {
42+
return handle(server, requestMethod, requestUri);
43+
}
44+
}
45+
1346
class FunctionRequestHandler : public RequestHandler {
1447
public:
1548
FunctionRequestHandler(WebServer::THandlerFunction fn, WebServer::THandlerFunction ufn, const Uri &uri, HTTPMethod method)
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
#include "Middlewares.h"
2+
3+
AuthenticationMiddleware &AuthenticationMiddleware::setUsername(const char *username) {
4+
_username = username;
5+
_callback = nullptr;
6+
return *this;
7+
}
8+
9+
AuthenticationMiddleware &AuthenticationMiddleware::setPassword(const char *password) {
10+
_password = password;
11+
_hash = false;
12+
_callback = nullptr;
13+
return *this;
14+
}
15+
16+
AuthenticationMiddleware &AuthenticationMiddleware::setPasswordHash(const char *sha1AsBase64orHex) {
17+
_password = sha1AsBase64orHex;
18+
_hash = true;
19+
_callback = nullptr;
20+
return *this;
21+
}
22+
23+
AuthenticationMiddleware &AuthenticationMiddleware::setCallback(WebServer::THandlerFunctionAuthCheck fn) {
24+
assert(fn);
25+
_callback = fn;
26+
_hash = false;
27+
_username = emptyString;
28+
_password = emptyString;
29+
return *this;
30+
}
31+
32+
AuthenticationMiddleware &AuthenticationMiddleware::setRealm(const char *realm) {
33+
_realm = realm;
34+
return *this;
35+
}
36+
37+
AuthenticationMiddleware &AuthenticationMiddleware::setAuthMethod(HTTPAuthMethod method) {
38+
_method = method;
39+
return *this;
40+
}
41+
42+
AuthenticationMiddleware &AuthenticationMiddleware::setAuthFailureMessage(const char *message) {
43+
_authFailMsg = message;
44+
return *this;
45+
}
46+
47+
bool AuthenticationMiddleware::isAllowed(WebServer &server) const {
48+
if (_callback) {
49+
return server.authenticate(_callback);
50+
}
51+
52+
if (!_username.isEmpty() && !_password.isEmpty()) {
53+
if (_hash) {
54+
return server.authenticateBasicSHA1(_username.c_str(), _password.c_str());
55+
} else {
56+
return server.authenticate(_username.c_str(), _password.c_str());
57+
}
58+
}
59+
60+
return true;
61+
}
62+
63+
bool AuthenticationMiddleware::run(WebServer &server, Middleware::Callback next) {
64+
bool authenticationRequired = false;
65+
66+
if (_callback) {
67+
authenticationRequired = !server.authenticate(_callback);
68+
} else if (!_username.isEmpty() && !_password.isEmpty()) {
69+
if (_hash) {
70+
authenticationRequired = !server.authenticateBasicSHA1(_username.c_str(), _password.c_str());
71+
} else {
72+
authenticationRequired = !server.authenticate(_username.c_str(), _password.c_str());
73+
}
74+
}
75+
76+
if (authenticationRequired) {
77+
server.requestAuthentication(_method, _realm, _authFailMsg);
78+
return true;
79+
} else {
80+
return next();
81+
}
82+
}
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#include "Middlewares.h"
2+
3+
CorsMiddleware &CorsMiddleware::setOrigin(const char *origin) {
4+
_origin = origin;
5+
return *this;
6+
}
7+
8+
CorsMiddleware &CorsMiddleware::setMethods(const char *methods) {
9+
_methods = methods;
10+
return *this;
11+
}
12+
13+
CorsMiddleware &CorsMiddleware::setHeaders(const char *headers) {
14+
_headers = headers;
15+
return *this;
16+
}
17+
18+
CorsMiddleware &CorsMiddleware::setAllowCredentials(bool credentials) {
19+
_credentials = credentials;
20+
return *this;
21+
}
22+
23+
CorsMiddleware &CorsMiddleware::setMaxAge(uint32_t seconds) {
24+
_maxAge = seconds;
25+
return *this;
26+
}
27+
28+
void CorsMiddleware::addCORSHeaders(WebServer &server) {
29+
server.sendHeader(F("Access-Control-Allow-Origin"), _origin.c_str());
30+
server.sendHeader(F("Access-Control-Allow-Methods"), _methods.c_str());
31+
server.sendHeader(F("Access-Control-Allow-Headers"), _headers.c_str());
32+
server.sendHeader(F("Access-Control-Allow-Credentials"), _credentials ? F("true") : F("false"));
33+
server.sendHeader(F("Access-Control-Max-Age"), String(_maxAge).c_str());
34+
}
35+
36+
bool CorsMiddleware::run(WebServer &server, Middleware::Callback next) {
37+
// Origin header ? => CORS handling
38+
if (server.hasHeader(F("Origin"))) {
39+
addCORSHeaders(server);
40+
// check if this is a preflight request => handle it and return
41+
if (server.method() == HTTP_OPTIONS) {
42+
server.send(200);
43+
return true;
44+
}
45+
}
46+
return next();
47+
}
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
#include "Middlewares.h"
2+
3+
void LoggingMiddleware::setOutput(Print &output) {
4+
_out = &output;
5+
}
6+
7+
bool LoggingMiddleware::run(WebServer &server, Middleware::Callback next) {
8+
if (_out == nullptr) {
9+
return next();
10+
}
11+
12+
_out->print(F("* Connection from "));
13+
_out->print(server.client().remoteIP().toString());
14+
_out->print(F(":"));
15+
_out->println(server.client().remotePort());
16+
17+
_out->print(F("> "));
18+
const HTTPMethod method = server.method();
19+
if (method == HTTP_ANY) {
20+
_out->print(F("HTTP_ANY"));
21+
} else {
22+
_out->print(http_method_str(method));
23+
}
24+
_out->print(F(" "));
25+
_out->print(server.uri());
26+
_out->print(F(" "));
27+
_out->println(server.version());
28+
29+
int n = server.headers();
30+
for (int i = 0; i < n; i++) {
31+
String v = server.header(i);
32+
if (!v.isEmpty()) {
33+
// because these 2 are always there, eventually empty: "Authorization", "If-None-Match"
34+
_out->print(F("> "));
35+
_out->print(server.headerName(i));
36+
_out->print(F(": "));
37+
_out->println(server.header(i));
38+
}
39+
}
40+
41+
_out->println(F(">"));
42+
43+
uint32_t elapsed = millis();
44+
const bool ret = next();
45+
elapsed = millis() - elapsed;
46+
47+
if (ret) {
48+
_out->print(F("* Processed in "));
49+
_out->print(elapsed);
50+
_out->println(F(" ms"));
51+
_out->print(F("< "));
52+
_out->print(F("HTTP/1."));
53+
_out->print(server.version());
54+
_out->print(F(" "));
55+
_out->print(server.responseCode());
56+
_out->print(F(" "));
57+
_out->println(WebServer::responseCodeToString(server.responseCode()));
58+
59+
n = server.responseHeaders();
60+
for (int i = 0; i < n; i++) {
61+
_out->print(F("< "));
62+
_out->print(server.responseHeaderName(i));
63+
_out->print(F(": "));
64+
_out->println(server.responseHeader(i));
65+
}
66+
67+
_out->println(F("<"));
68+
69+
} else {
70+
_out->println(F("* Not processed!"));
71+
}
72+
73+
return ret;
74+
}
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
#ifndef MIDDLEWARE_H
2+
#define MIDDLEWARE_H
3+
4+
#include <assert.h>
5+
#include <functional>
6+
7+
class MiddlewareChain;
8+
class WebServer;
9+
10+
class Middleware {
11+
public:
12+
typedef std::function<bool(void)> Callback;
13+
typedef std::function<bool(WebServer &server, Callback next)> Function;
14+
15+
virtual ~Middleware() {}
16+
17+
virtual bool run(WebServer &server, Callback next) {
18+
return next();
19+
};
20+
21+
private:
22+
friend MiddlewareChain;
23+
Middleware *_next = nullptr;
24+
bool _freeOnRemoval = false;
25+
};
26+
27+
class MiddlewareFunction : public Middleware {
28+
public:
29+
MiddlewareFunction(Middleware::Function fn) : _fn(fn) {}
30+
31+
bool run(WebServer &server, Middleware::Callback next) override {
32+
return _fn(server, next);
33+
}
34+
35+
private:
36+
Middleware::Function _fn;
37+
};
38+
39+
class MiddlewareChain {
40+
public:
41+
~MiddlewareChain();
42+
43+
void addMiddleware(Middleware::Function fn);
44+
void addMiddleware(Middleware *middleware);
45+
bool removeMiddleware(Middleware *middleware);
46+
47+
bool runChain(WebServer &server, Middleware::Callback finalizer);
48+
49+
private:
50+
Middleware *_root = nullptr;
51+
Middleware *_current = nullptr;
52+
};
53+
54+
#endif
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
#include "Middleware.h"
2+
3+
MiddlewareChain::~MiddlewareChain() {
4+
Middleware *current = _root;
5+
while (current) {
6+
Middleware *next = current->_next;
7+
if (current->_freeOnRemoval) {
8+
delete current;
9+
}
10+
current = next;
11+
}
12+
_root = nullptr;
13+
}
14+
15+
void MiddlewareChain::addMiddleware(Middleware::Function fn) {
16+
MiddlewareFunction *middleware = new MiddlewareFunction(fn);
17+
middleware->_freeOnRemoval = true;
18+
addMiddleware(middleware);
19+
}
20+
21+
void MiddlewareChain::addMiddleware(Middleware *middleware) {
22+
if (!_root) {
23+
_root = middleware;
24+
return;
25+
}
26+
Middleware *current = _root;
27+
while (current->_next) {
28+
current = current->_next;
29+
}
30+
current->_next = middleware;
31+
}
32+
33+
bool MiddlewareChain::removeMiddleware(Middleware *middleware) {
34+
if (!_root) {
35+
return false;
36+
}
37+
if (_root == middleware) {
38+
_root = _root->_next;
39+
if (middleware->_freeOnRemoval) {
40+
delete middleware;
41+
}
42+
return true;
43+
}
44+
Middleware *current = _root;
45+
while (current->_next) {
46+
if (current->_next == middleware) {
47+
current->_next = current->_next->_next;
48+
if (middleware->_freeOnRemoval) {
49+
delete middleware;
50+
}
51+
return true;
52+
}
53+
current = current->_next;
54+
}
55+
return false;
56+
}
57+
58+
bool MiddlewareChain::runChain(WebServer &server, Middleware::Callback finalizer) {
59+
if (!_root) {
60+
return finalizer();
61+
}
62+
_current = _root;
63+
Middleware::Callback next;
64+
next = [this, &server, &next, finalizer]() {
65+
if (!_current) {
66+
return finalizer();
67+
}
68+
Middleware *that = _current;
69+
_current = _current->_next;
70+
return that->run(server, next);
71+
};
72+
return next();
73+
}

0 commit comments

Comments
 (0)
Please sign in to comment.