8000 custom http header validation implementation · robokoding/arduinoWebSockets@e589b40 · GitHub
[go: up one dir, main page]

Skip to content

Commit e589b40

Browse files
committed
custom http header validation implementation
1 parent eb2351a commit e589b40

File tree

5 files changed

+196
-16
lines changed

5 files changed

+196
-16
lines changed
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/*
2+
* WebSocketServer.ino
3+
*
4+
* Created on: 22.05.2015
5+
*
6+
*/
7+
8+
#include <Arduino.h>
9+
10+
#include <ESP8266WiFi.h>
11+
#include <ESP8266WiFiMulti.h>
12+
#include <WebSocketsServer.h>
13+
#include <Hash.h>
14+
15+
ESP8266WiFiMulti WiFiMulti;
16+
17+
WebSocketsServer webSocket = WebSocketsServer(81);
18+
19+
#define USE_SERIAL Serial1
20+
21+
const unsigned long int validSessionId = 12345; //some arbitrary value to act as a valid sessionId
22+
23+
/*
24+
* Returns a bool value as an indicator to describe whether a user is allowed to initiate a websocket upgrade
25+
* based on the value of a cookie. This function expects the rawCookieHeaderValue to look like this "sessionId=<someSessionIdNumberValue>|"
26+
*/
27+
bool isCookieValid(String rawCookieHeaderValue) {
28+
29+
if (rawCookieHeaderValue.indexOf("sessionId") != -1) {
30+
String sessionIdStr = rawCookieHeaderValue.substring(rawCookieHeaderValue.indexOf("sessionId=") + 10, rawCookieHeaderValue.indexOf("|"));
31+
unsigned long int sessionId = strtoul(sessionIdStr.c_str(), NULL, 10);
32+
return sessionId == validSessionId;
33+
}
34+
return false;
35+
}
36+
37+
/*
38+
* The WebSocketServerHttpHeaderValFunc delegate passed to webSocket.onValidateHttpHeader
39+
*/
40+
bool validateHttpHeader(String headerName, String headerValue) {
41+
42+
//assume a true response for any headers not handled by this validator
43+
bool valid = true;
44+
45+
if(headerName.equalsIgnoreCase("Cookie")) {
46+
//if the header passed is the Cookie header, validate it according to the rules in 'isCookieValid' function
47+
valid = isCookieValid(headerValue);
48+
}
49+
50+
return valid;
51+
}
52+
53+
void setup() {
54+
// USE_SERIAL.begin(921600);
55+
USE_SERIAL.begin(115200);
56+
57+
//Serial.setDebugOutput(true);
58+
USE_SERIAL.setDebugOutput(true);
59+
60+
USE_SERIAL.println();
61+
USE_SERIAL.println();
62+
USE_SERIAL.println();
63+
64+
for(uint8_t t = 4; t > 0; t--) {
65+
USE_SERIAL.printf("[SETUP] BOOT WAIT %d...\n", t);
66+
USE_SERIAL.flush();
67+
delay(1000);
68+
}
69+
70+
WiFiMulti.addAP("SSID", "passpasspass");
71+
72+
while(WiFiMulti.run() != WL_CONNECTED) {
73+
delay(100);
74+
}
75+
76+
//connecting clients must supply a valid session cookie at websocket upgrade handshake negotiation time
77+
const char * headerkeys[] = { "Cookie" };
78+
webSocket.onValidateHttpHeader(validateHttpHeader, headerkeys);
79+
webSocket.begin();
80+
}
81+
82+
void loop() {
83+
webSocket.loop();
84+
}
85+

src/WebSockets.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,7 @@ void WebSockets::handleWebsocketPayloadCb(WSclient_t * client, bool ok, uint8_t
426426
DEBUG_WEBSOCKETS("[WS][%d][handleWebsocket] text: %s\n", client->num, payload);
427427
// no break here!
428428
case WSop_binary:
429-
messageRecived(client, header->opCode, payload, header->payloadLen);
429+
messageReceived(client, header->opCode, payload, header->payloadLen);
430430
break;
431431
case WSop_ping:
432432
// send pong back

src/WebSockets.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,9 @@ typedef struct {
183183
String base64Authorization; ///< Base64 encoded Auth request
184184
String plainAuthorization; ///< Base64 encoded Auth request
185185

186+
bool cHttpHeadersValid; ///< non-websocket http header validity indicator
187+
size_t cMandatoryHeadersCount; ///< non-websocket mandatory http headers present count
188+
186189
#if (WEBSOCKETS_NETWORK_TYPE == NETWORK_ESP8266_ASYNC)
187190
String cHttpLine; ///< HTTP header lines
188191
#endif
@@ -202,7 +205,7 @@ class WebSockets {
202205
virtual void clientDisconnect(WSclient_t * client);
203206
virtual bool clientIsConnected(WSclient_t * client);
204207

205-
virtual void messageRecived(WSclient_t * client, WSopcode_t opcode, uint8_t * payload, size_t length);
208+
virtual void messageReceived(WSclient_t * client, WSopcode_t opcode, uint8_t * payload, size_t length);
206209

207210
void clientDisconnect(WSclient_t * client, uint16_t code, char * reason = NULL, size_t reasonLen = 0);
208211
bool sendFrame(WSclient_t * client, WSopcode_t opcode, uint8_t * payload = NULL, size_t length = 0, bool mask = false, bool fin = true, bool headerToPayload = false);

src/WebSocketsServer.cpp

Lines changed: 75 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ WebSocketsServer::WebSocketsServer(uint16_t port, String origin, String protocol
4040

4141
_cbEvent = NULL;
4242

43+
_httpHeaderValidationFunc = NULL;
44+
_mandatoryHttpHeaders = NULL;
45+
_mandatoryHttpHeaderCount = 0;
4346
}
4447

4548

@@ -53,10 +56,14 @@ WebSocketsServer::~WebSocketsServer() {
5356
// TODO how to close server?
5457
#endif
5558

59+
if (_mandatoryHttpHeaders)
60+
delete[] _mandatoryHttpHeaders;
61+
62+
_mandatoryHttpHeaderCount = 0;
5663
}
5764

5865
/**
59-
* calles to init the Websockets server
66+
* called to initialize the Websocket server
6067
*/
6168
void WebSocketsServer::begin(void) {
6269
WSclient_t * client;
@@ -83,6 +90,7 @@ void WebSocketsServer::begin(void) {
8390
client->base64Authorization = "";
8491

8592
client->cWsRXsize = 0;
93+
8694
#if (WEBSOCKETS_NETWORK_TYPE == NETWORK_ESP8266_ASYNC)
8795
client->cHttpLine = "";
8896
#endif
@@ -118,7 +126,30 @@ void WebSocketsServer::onEvent(WebSocketServerEvent cbEvent) {
118126
_cbEvent = cbEvent;
119127
}
120128

121-
/**
129+
/*
130+
* Sets the custom http header validator function
131+
* If this functionality is being used, call this function prior to calling WebSocketsServer::begin
132+
* @param httpHeaderValidationFunc WebSocketServerHttpHeaderValFunc ///< pointer to the custom http header validation function
133+
* @param mandatoryHttpHeaders const char* ///< the array of named http headers considered to be mandatory / must be present in order for websocket upgrade to succeed
134+
*/
135+
void WebSocketsServer::onValidateHttpHeader(
136+
WebSocketServerHttpHeaderValFunc validationFunc,
137+
const char* mandatoryHttpHeaders[])
138+
{
139+
_httpHeaderValidationFunc = validationFunc;
140+
141+
if (_mandatoryHttpHeaders)
142+
delete[] _mandatoryHttpHeaders;
143+
144+
_mandatoryHttpHeaderCount = (sizeof(mandatoryHttpHeaders) / sizeof(char*));
145+
_mandatoryHttpHeaders = new String[_mandatoryHttpHeaderCount];
146+
147+
for (size_t i = 0; i < _mandatoryHttpHeaderCount; i++) {
148+
_mandatoryHttpHeaders[i] = mandatoryHttpHeaders[i];
149+
}
150+
}
151+
152+
/*
122153
* send text data to client
123154
* @param num uint8_t client id
124155
* @param payload uint8_t *
@@ -279,9 +310,8 @@ void WebSocketsServer::disconnect(uint8_t num) {
279310
}
280311

281312

282-
283-
/**
284-
* set the Authorizatio for the http request
313+
/*
314+
* set the Authorization for the http request
285315
* @param user const char *
286316
* @param password const char *
287317
*/
@@ -388,7 +418,7 @@ bool WebSocketsServer::newClient(WEBSOCKETS_NETWORK_CLASS * TCPclient) {
388418
* @param payload uint8_t *
389419
* @param lenght size_t
390420
*/
391-
void WebSocketsServer::messageRecived(WSclient_t * client, WSopcode_t opcode, uint8_t * payload, size_t lenght) {
421+
void WebSocketsServer::messageReceived(WSclient_t * client, WSopcode_t opcode, uint8_t * payload, size_t lenght) {
392422
WStype_t type = WStype_ERROR;
393423

394424
switch(opcode) {
@@ -446,6 +476,7 @@ void WebSocketsServer::clientDisconnect(WSclient_t * client) {
446476
client->cIsWebsocket = false;
447477

448478
client->cWsRXsize = 0;
479+
449480
#if (WEBSOCKETS_NETWORK_TYPE == NETWORK_ESP8266_ASYNC)
450481
client->cHttpLine = "";
451482
#endif
@@ -461,7 +492,7 @@ void WebSocketsServer::clientDisconnect(WSclient_t * client) {
461492
/**
462493
* get client state
463494
* @param client WSclient_t * ptr to the client struct
464-
* @return true = conneted
495+
* @return true = connected
465496
*/
466497
bool WebSocketsServer::clientIsConnected(WSclient_t * client) {
467498

@@ -492,7 +523,7 @@ bool WebSocketsServer::clientIsConnected(WSclient_t * client) {
492523
}
493524
#if (WEBSOCKETS_NETWORK_TYPE != NETWORK_ESP8266_ASYNC)
494525
/**
495-
* Handle incomming Connection Request
526+
* Handle incoming Connection Request
496527
*/
497528
void WebSocketsServer::handleNewClients(void) {
498529

@@ -569,10 +600,22 @@ void WebSocketsServer::handleClientData(void) {
569600
}
570601
#endif
571602

603+
/*
604+
* returns an indicator whether the given named header exists in the configured _mandatoryHttpHeaders collection
605+
* @param headerName String ///< the name of the header being checked
606+
*/
607+
bool WebSocketsServer::hasMandatoryHeader(String headerName) {
608+
for (size_t i = 0; i < _mandatoryHttpHeaderCount; i++) {
609+
if (_mandatoryHttpHeaders[i].equalsIgnoreCase(headerName))
610+
return true;
611+
}
612+
return false;
613+
}
572614

573615
/**
574-
* handle the WebSocket header reading
575-
* @param client WSclient_t * ptr to the client struct
616+
* handles http header reading for WebSocket upgrade
617+
* @param client WSclient_t * ///< pointer to the client struct
618+
* @param headerLine String ///< the header being read / processed
576619
*/
577620
void WebSocketsServer::handleHeader(WSclient_t * client, String * headerLine) {
578621

@@ -581,10 +624,16 @@ void WebSocketsServer::handleHeader(WSclient_t * client, String * headerLine) {
581624
if(headerLine->length() > 0) {
582625
DEBUG_WEBSOCKETS("[WS-Server][%d][handleHeader] RX: %s\n", client->num, headerLine->c_str());
583626

584-
// websocket request starts allways with GET see rfc6455
627+
// websocket requests always start with GET see rfc6455
585628
if(headerLine->startsWith("GET ")) {
629+
586630
// cut URL out
587631
client->cUrl = headerLine->substring(4, headerLine->indexOf(' ', 4));
632+
633+
//reset non-websocket http header validation state for this client
634+
client->cHttpHeadersValid = true;
635+
client->cMandatoryHeadersCount = 0;
636+
588637
} else if(headerLine->indexOf(':')) {
589638
String headerName = headerLine->substring(0, headerLine->indexOf(':'));
590639
String headerValue = headerLine->substring(headerLine->indexOf(':') + 2);
@@ -609,7 +658,13 @@ void WebSocketsServer::handleHeader(WSclient_t * client, String * headerLine) {
609658
client->cExtensions = headerValue;
610659
} else if(headerName.equalsIgnoreCase("Authorization")) {
611660
client->base64Authorization = headerValue;
661+
} else {
662+
client->cHttpHeadersValid &= execHttpHeaderValidation(headerName, headerValue);
663+
if (_mandatoryHttpHeaderCount > 0 && hasMandatoryHeader(headerName)) {
664+
client->cMandatoryHeadersCount++;
665+
}
612666
}
667+
613668
} else {
614669
DEBUG_WEBSOCKETS("[WS-Client][handleHeader] Header error (%s)\n", headerLine->c_str());
615670
}
@@ -619,8 +674,8 @@ void WebSocketsServer::handleHeader(WSclient_t * client, String * headerLine) {
619674
client->tcp->readStringUntil('\n', &(client->cHttpLine), std::bind(&WebSocketsServer::handleHeader, this, client, &(client->cHttpLine)));
620675
#endif
621676
} else {
622-
DEBUG_WEBSOCKETS("[WS-Server][%d][handleHeader] Header read fin.\n", client->num);
623677

678+
DEBUG_WEBSOCKETS("[WS-Server][%d][handleHeader] Header read fin.\n", client->num);
624679
DEBUG_WEBSOCKETS("[WS-Server][%d][handleHeader] - cURL: %s\n", client->num, client->cUrl.c_str());
625680
DEBUG_WEBSOCKETS("[WS-Server][%d][handleHeader] - cIsUpgrade: %d\n", client->num, client->cIsUpgrade);
626681
DEBUG_WEBSOCKETS("[WS-Server][%d][handleHeader] - cIsWebsocket: %d\n", client->num, client->cIsWebsocket);
@@ -629,6 +684,8 @@ void WebSocketsServer::handleHeader(WSclient_t * client, String * headerLine) {
629684
DEBUG_WEBSOCKETS("[WS-Server][%d][handleHeader] - cExtensions: %s\n", client->num, client->cExtensions.c_str());
630685
DEBUG_WEBSOCKETS("[WS-Server][%d][handleHeader] - cVersion: %d\n", client->num, client->cVersion);
631686
DEBUG_WEBSOCKETS("[WS-Server][%d][handleHeader] - base64Authorization: %s\n", client->num, client->base64Authorization.c_str());
687+
DEBUG_WEBSOCKETS("[WS-Server][%d][handleHeader] - cHttpHeadersValid: %d\n", client->num, client->cHttpHeadersValid);
688+
DEBUG_WEBSOCKETS("[WS-Server][%d][handleHeader] - cMandatoryHeadersCount: %d\n", client->num, client->cMandatoryHeadersCount);
632689

633690
bool ok = (client->cIsUpgrade && client->cIsWebsocket);
634691

@@ -642,6 +699,12 @@ void WebSocketsServer::handleHeader(WSclient_t * client, String * headerLine) {
642699
if(client->cVersion != 13) {
643700
ok = false;
644701
}
702+
if(!client->cHttpHeadersValid) {
703+
ok = false;
704+
}
705+
if (client->cMandatoryHeadersCount != _mandatoryHttpHeaderCount) {
706+
ok = false;
707+
}
645708
}
646709

647710
if(_base64Authorization.length() > 0) {

src/WebSocketsServer.h

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,10 @@ class WebSocketsServer: private WebSockets {
3838

3939
#ifdef __AVR__
4040
typedef void (*WebSocketServerEvent)(uint8_t num, WStype_t type, uint8_t * payload, size_t length);
41+
typedef bool (*WebSocketServerHttpHeaderValFunc)(String headerName, String headerValue);
4142
#else
4243
typedef std::function<void (uint8_t num, WStype_t type, uint8_t * payload, size_t length)> WebSocketServerEvent;
44+
typedef std::function<bool (String headerName, String headerValue)> WebSocketServerHttpHeaderValFunc;
4345
#endif
4446

4547
WebSocketsServer(uint16_t port, String origin = "", String protocol = "arduino");
@@ -55,6 +57,7 @@ class WebSocketsServer: private WebSockets {
5557
#endif
5658

5759
void onEvent(WebSocketServerEvent cbEvent);
60+
void onValidateHttpHeader(WebSocketServerHttpHeaderValFunc validationFunc, const char* mandatoryHttpHeaders[]);
5861

5962

6063
bool sendTXT(uint8_t num, uint8_t * payload, size_t length = 0, bool headerToPayload = false);
@@ -90,16 +93,19 @@ class WebSocketsServer: private WebSockets {
9093
String _origin;
9194
String _protocol;
9295
String _base64Authorization; ///< Base64 encoded Auth request
96+
String * _mandatoryHttpHeaders;
97+
size_t _mandatoryHttpHeaderCount;
9398

9499
WEBSOCKETS_NETWORK_SERVER_CLASS * _server;
95100

96101
WSclient_t _clients[WEBSOCKETS_SERVER_CLIENT_MAX];
97102

98103
WebSocketServerEvent _cbEvent;
104+
WebSocketServerHttpHeaderValFunc _httpHeaderValidationFunc;
99105

100106
bool newClient(WEBSOCKETS_NETWORK_CLASS * TCPclient);
101107

102-
void messageRecived(WSclient_t * client, WSopcode_t opcode, uint8_t * payload, size_t length);
108+
void messageReceived(WSclient_t * client, WSopcode_t opcode, uint8_t * payload, size_t length);
103109

104110
void clientDisconnect(WSclient_t * client);
105111
bool clientIsConnected(WSclient_t * client);
@@ -111,7 +117,6 @@ class WebSocketsServer: private WebSockets {
111117

112118
void handleHeader(WSclient_t * client, String * headerLine);
113119

114-
115120
/**
116121
* called if a non Websocket connection is coming in.
117122
* Note: can be override
@@ -162,6 +167,30 @@ class WebSocketsServer: private WebSockets {
162167
}
163168
}
164169

170+
/*
171+
* Called at client socket connect handshake negotiation time for each http header that is not
172+
* a websocket specific http header (not Connection, Upgrade, Sec-WebSocket-*)
173+
* If the custom httpHeaderValidationFunc returns false for any headerName / headerValue passed, the
174+
* socket negotiation is considered invalid and the upgrade to websockets request is denied / rejected
175+
* This mechanism can be used to enable custom authentication schemes e.g. test the value
176+
* of a session cookie to determine if a user is logged on / authenticated
177+
*/
178+
virtual bool execHttpHeaderValidation(String headerName, String headerValue) {
179+
if(_httpHeaderValidationFunc) {
180+
//return the value of the custom http header validation function
181+
return _httpHeaderValidationFunc(headerName, headerValue);
182+
}
183+
//no custom http header validation so just assume all is good
184+
return true;
185+
}
186+
187+
private:
188+
/*
189+
* returns an indicator whether the given named header exists in the configured _mandatoryHttpHeaders collection
190+
* @param headerName String ///< the name of the header being checked
191+
*/
192+
bool hasMandatoryHeader(String headerName);
193+
165194
};
166195

167196

0 commit comments

Comments
 (0)
0