@@ -40,6 +40,9 @@ WebSocketsServer::WebSocketsServer(uint16_t port, String origin, String protocol
40
40
41
41
_cbEvent = NULL ;
42
42
43
+ _httpHeaderValidationFunc = NULL ;
44
+ _mandatoryHttpHeaders = NULL ;
45
+ _mandatoryHttpHeaderCount = 0 ;
43
46
}
44
47
45
48
@@ -53,10 +56,14 @@ WebSocketsServer::~WebSocketsServer() {
53
56
// TODO how to close server?
54
57
#endif
55
58
59
+ if (_mandatoryHttpHeaders)
60
+ delete[] _mandatoryHttpHeaders;
61
+
62
+ _mandatoryHttpHeaderCount = 0 ;
56
63
}
57
64
58
65
/* *
59
- * calles to init the Websockets server
66
+ * called to initialize the Websocket server
60
67
*/
61
68
void WebSocketsServer::begin (void ) {
62
69
WSclient_t * client;
@@ -83,6 +90,7 @@ void WebSocketsServer::begin(void) {
83
90
client->base64Authorization = " " ;
84
91
85
92
client->cWsRXsize = 0 ;
93
+
86
94
#if (WEBSOCKETS_NETWORK_TYPE == NETWORK_ESP8266_ASYNC)
87
95
client->cHttpLine = " " ;
88
96
#endif
@@ -118,7 +126,31 @@ void WebSocketsServer::onEvent(WebSocketServerEvent cbEvent) {
118
126
_cbEvent = cbEvent;
119
127
}
120
128
121
- /* *
129
+ /*
130
+ * Sets the custom http header validator function
131
+ * @param httpHeaderValidationFunc WebSocketServerHttpHeaderValFunc ///< pointer to the custom http header validation function
132
+ * @param mandatoryHttpHeaders[] const char* ///< the array of named http headers considered to be mandatory / must be present in order for websocket upgrade to succeed
133
+ * @param mandatoryHttpHeaderCount size_t ///< the number of items in the mandatoryHttpHeaders array
134
+ */
135
+ void WebSocketsServer::onValidateHttpHeader (
136
+ WebSocketServerHttpHeaderValFunc validationFunc,
137
+ const char * mandatoryHttpHeaders[],
138
+ size_t mandatoryHttpHeaderCount)
139
+ {
140
+ _httpHeaderValidationFunc = validationFunc;
141
+
142
+ if (_mandatoryHttpHeaders)
143
+ delete[] _mandatoryHttpHeaders;
144
+
145
+ _mandatoryHttpHeaderCount = mandatoryHttpHeaderCount;
146
+ _mandatoryHttpHeaders = new String[_mandatoryHttpHeaderCount];
147
+
148
+ for (size_t i = 0 ; i < _mandatoryHttpHeaderCount; i++) {
149
+ _mandatoryHttpHeaders[i] = mandatoryHttpHeaders[i];
150
+ }
151
+ }
152
+
153
+ /*
122
154
* send text data to client
123
155
* @param num uint8_t client id
124
156
* @param payload uint8_t *
@@ -279,9 +311,8 @@ void WebSocketsServer::disconnect(uint8_t num) {
279
311
}
280
312
281
313
282
-
283
- /* *
284
- * set the Authorizatio for the http request
314
+ /*
315
+ * set the Authorization for the http request
285
316
* @param user const char *
286
317
* @param password const char *
287
318
*/
@@ -388,7 +419,7 @@ bool WebSocketsServer::newClient(WEBSOCKETS_NETWORK_CLASS * TCPclient) {
388
419
* @param payload uint8_t *
389
420
* @param lenght size_t
390
421
*/
391
- void WebSocketsServer::messageRecived (WSclient_t * client, WSopcode_t opcode, uint8_t * payload, size_t lenght) {
422
+ void WebSocketsServer::messageReceived (WSclient_t * client, WSopcode_t opcode, uint8_t * payload, size_t lenght) {
392
423
WStype_t type = WStype_ERROR;
393
424
394
425
switch (opcode) {
@@ -446,6 +477,7 @@ void WebSocketsServer::clientDisconnect(WSclient_t * client) {
446
477
client->cIsWebsocket = false ;
447
478
448
479
client->cWsRXsize = 0 ;
480
+
449
481
#if (WEBSOCKETS_NETWORK_TYPE == NETWORK_ESP8266_ASYNC)
450
482
client->cHttpLine = " " ;
451
483
#endif
@@ -461,7 +493,7 @@ void WebSocketsServer::clientDisconnect(WSclient_t * client) {
461
493
/* *
462
494
* get client state
463
495
* @param client WSclient_t * ptr to the client struct
464
- * @return true = conneted
496
+ * @return true = connected
465
497
*/
466
498
bool WebSocketsServer::clientIsConnected (WSclient_t * client) {
467
499
@@ -492,7 +524,7 @@ bool WebSocketsServer::clientIsConnected(WSclient_t * client) {
492
524
}
493
525
#if (WEBSOCKETS_NETWORK_TYPE != NETWORK_ESP8266_ASYNC)
494
526
/* *
495
- * Handle incomming Connection Request
527
+ * Handle incoming Connection Request
496
528
*/
497
529
void WebSocketsServer::handleNewClients (void ) {
498
530
@@ -569,10 +601,22 @@ void WebSocketsServer::handleClientData(void) {
569
601
}
570
602
#endif
571
603
604
+ /*
605
+ * returns an indicator whether the given named header exists in the configured _mandatoryHttpHeaders collection
606
+ * @param headerName String ///< the name of the header being checked
607
+ */
608
+ bool WebSocketsServer::hasMandatoryHeader (String headerName) {
609
+ for (size_t i = 0 ; i < _mandatoryHttpHeaderCount; i++) {
610
+ if (_mandatoryHttpHeaders[i].equalsIgnoreCase (headerName))
611
+ return true ;
612
+ }
613
+ return false ;
614
+ }
572
615
573
616
/* *
574
- * handle the WebSocket header reading
575
- * @param client WSclient_t * ptr to the client struct
617
+ * handles http header reading for WebSocket upgrade
618
+ * @param client WSclient_t * ///< pointer to the client struct
619
+ * @param headerLine String ///< the header being read / processed
576
620
*/
577
621
void WebSocketsServer::handleHeader (WSclient_t * client, String * headerLine) {
578
622
@@ -581,10 +625,16 @@ void WebSocketsServer::handleHeader(WSclient_t * client, String * headerLine) {
581
625
if (headerLine->length () > 0 ) {
582
626
DEBUG_WEBSOCKETS (" [WS-Server][%d][handleHeader] RX: %s\n " , client->num , headerLine->c_str ());
583
627
584
- // websocket request starts allways with GET see rfc6455
628
+ // websocket requests always start with GET see rfc6455
585
629
if (headerLine->startsWith (" GET " )) {
630
+
586
631
// cut URL out
587
632
client->cUrl = headerLine->substring (4 , headerLine->indexOf (' ' , 4 ));
633
+
634
+ // reset non-websocket http header validation state for this client
635
+ client->cHttpHeadersValid = true ;
636
+ client->cMandatoryHeadersCount = 0 ;
637
+
588
638
} else if (headerLine->indexOf (' :' )) {
589
639
String headerName = headerLine->substring (0 , headerLine->indexOf (' :' ));
590
640
String headerValue = headerLine->substring (headerLine->indexOf (' :' ) + 2 );
@@ -609,7 +659,13 @@ void WebSocketsServer::handleHeader(WSclient_t * client, String * headerLine) {
609
659
client->cExtensions = headerValue;
610
660
} else if (headerName.equalsIgnoreCase (" Authorization" )) {
611
661
client->base64Authorization = headerValue;
662
+ } else {
663
+ client->cHttpHeadersValid &= execHttpHeaderValidation (headerName, headerValue);
664
+ if (_mandatoryHttpHeaderCount > 0 && hasMandatoryHeader (headerName)) {
665
+ client->cMandatoryHeadersCount ++;
666
+ }
612
667
}
668
+
613
669
} else {
614
670
DEBUG_WEBSOCKETS (" [WS-Client][handleHeader] Header error (%s)\n " , headerLine->c_str ());
615
671
}
@@ -619,8 +675,8 @@ void WebSocketsServer::handleHeader(WSclient_t * client, String * headerLine) {
619
675
client->tcp ->readStringUntil (' \n ' , &(client->cHttpLine ), std::bind (&WebSocketsServer::handleHeader, this , client, &(client->cHttpLine )));
620
676
#endif
621
677
} else {
622
- DEBUG_WEBSOCKETS (" [WS-Server][%d][handleHeader] Header read fin.\n " , client->num );
623
678
679
+ DEBUG_WEBSOCKETS (" [WS-Server][%d][handleHeader] Header read fin.\n " , client->num );
624
680
DEBUG_WEBSOCKETS (" [WS-Server][%d][handleHeader] - cURL: %s\n " , client->num , client->cUrl .c_str ());
625
681
DEBUG_WEBSOCKETS (" [WS-Server][%d][handleHeader] - cIsUpgrade: %d\n " , client->num , client->cIsUpgrade );
626
682
DEBUG_WEBSOCKETS (" [WS-Server][%d][handleHeader] - cIsWebsocket: %d\n " , client->num , client->cIsWebsocket );
@@ -629,6 +685,8 @@ void WebSocketsServer::handleHeader(WSclient_t * client, String * headerLine) {
629
685
DEBUG_WEBSOCKETS (" [WS-Server][%d][handleHeader] - cExtensions: %s\n " , client->num , client->cExtensions .c_str ());
630
686
DEBUG_WEBSOCKETS (" [WS-Server][%d][handleHeader] - cVersion: %d\n " , client->num , client->cVersion );
631
687
DEBUG_WEBSOCKETS (" [WS-Server][%d][handleHeader] - base64Authorization: %s\n " , client->num , client->base64Authorization .c_str ());
688
+ DEBUG_WEBSOCKETS (" [WS-Server][%d][handleHeader] - cHttpHeadersValid: %d\n " , client->num , client->cHttpHeadersValid );
689
+ DEBUG_WEBSOCKETS (" [WS-Server][%d][handleHeader] - cMandatoryHeadersCount: %d\n " , client->num , client->cMandatoryHeadersCount );
632
690
633
691
bool ok = (client->cIsUpgrade && client->cIsWebsocket );
634
692
@@ -642,6 +700,12 @@ void WebSocketsServer::handleHeader(WSclient_t * client, String * headerLine) {
642
700
if (client->cVersion != 13 ) {
643
701
ok = false ;
644
702
}
703
+ if (!client->cHttpHeadersValid ) {
704
+ ok = false ;
705
+ }
706
+ if (client->cMandatoryHeadersCount != _mandatoryHttpHeaderCount) {
707
+ ok = false ;
708
+ }
645
709
}
646
710
647
711
if (_base64Authorization.length () > 0 ) {
0 commit comments