@@ -40,6 +40,9 @@ WebSocketsServer::WebSocketsServer(uint16_t port, String origin, String protocol
40
40
41
41
_cbEvent = NULL ;
42
42
43
+ _httpHeaderValidationFunc = NULL ;
44
+ _mandatoryHttpHeaders = NULL ;
45
+ _mandatoryHttpHeaderCount = 0 ;
43
46
}
44
47
45
48
@@ -53,10 +56,14 @@ WebSocketsServer::~WebSocketsServer() {
53
56
// TODO how to close server?
54
57
#endif
55
58
59
+ if (_mandatoryHttpHeaders)
60
+ delete[] _mandatoryHttpHeaders;
61
+
62
+ _mandatoryHttpHeaderCount = 0 ;
56
63
}
57
64
58
65
/* *
59
- * calles to init the Websockets server
66
+ * called to initialize the Websocket server
60
67
*/
61
68
void WebSocketsServer::begin (void ) {
62
69
WSclient_t * client;
@@ -83,6 +90,7 @@ void WebSocketsServer::begin(void) {
83
90
client->base64Authorization = " " ;
84
91
85
92
client->cWsRXsize = 0 ;
93
+
86
94
#if (WEBSOCKETS_NETWORK_TYPE == NETWORK_ESP8266_ASYNC)
87
95
client->cHttpLine = " " ;
88
96
#endif
@@ -118,7 +126,30 @@ void WebSocketsServer::onEvent(WebSocketServerEvent cbEvent) {
118
126
_cbEvent = cbEvent;
119
127
}
120
128
121
- /* *
129
+ /*
130
+ * Sets the custom http header validator function
131
+ * If this functionality is being used, call this function prior to calling WebSocketsServer::begin
132
+ * @param httpHeaderValidationFunc WebSocketServerHttpHeaderValFunc ///< pointer to the custom http header validation function
133
+ * @param mandatoryHttpHeaders const char* ///< the array of named http headers considered to be mandatory / must be present in order for websocket upgrade to succeed
134
+ */
135
+ void WebSocketsServer::onValidateHttpHeader (
136
+ WebSocketServerHttpHeaderValFunc validationFunc,
137
+ const char * mandatoryHttpHeaders[])
138
+ {
139
+ _httpHeaderValidationFunc = validationFunc;
140
+
141
+ if (_mandatoryHttpHeaders)
142
+ delete[] _mandatoryHttpHeaders;
143
+
144
+ _mandatoryHttpHeaderCount = (sizeof (mandatoryHttpHeaders) / sizeof (char *));
145
+ _mandatoryHttpHeaders = new String[_mandatoryHttpHeaderCount];
146
+
147
+ for (size_t i = 0 ; i < _mandatoryHttpHeaderCount; i++) {
148
+ _mandatoryHttpHeaders[i] = mandatoryHttpHeaders[i];
149
+ }
150
+ }
151
+
152
+ /*
122
153
* send text data to client
123
154
* @param num uint8_t client id
124
155
* @param payload uint8_t *
@@ -279,9 +310,8 @@ void WebSocketsServer::disconnect(uint8_t num) {
279
310
}
280
311
281
312
282
-
283
- /* *
284
- * set the Authorizatio for the http request
313
+ /*
314
+ * set the Authorization for the http request
285
315
* @param user const char *
286
316
* @param password const char *
287
317
*/
@@ -388,7 +418,7 @@ bool WebSocketsServer::newClient(WEBSOCKETS_NETWORK_CLASS * TCPclient) {
388
418
* @param payload uint8_t *
389
419
* @param lenght size_t
390
420
*/
391
- void WebSocketsServer::messageRecived (WSclient_t * client, WSopcode_t opcode, uint8_t * payload, size_t lenght) {
421
+ void WebSocketsServer::messageReceived (WSclient_t * client, WSopcode_t opcode, uint8_t * payload, size_t lenght) {
392
422
WStype_t type = WStype_ERROR;
393
423
394
424
switch (opcode) {
@@ -446,6 +476,7 @@ void WebSocketsServer::clientDisconnect(WSclient_t * client) {
446
476
client->cIsWebsocket = false ;
447
477
448
478
client->cWsRXsize = 0 ;
479
+
449
480
#if (WEBSOCKETS_NETWORK_TYPE == NETWORK_ESP8266_ASYNC)
450
481
client->cHttpLine = " " ;
451
482
#endif
@@ -461,7 +492,7 @@ void WebSocketsServer::clientDisconnect(WSclient_t * client) {
461
492
/* *
462
493
* get client state
463
494
* @param client WSclient_t * ptr to the client struct
464
- * @return true = conneted
495
+ * @return true = connected
465
496
*/
466
497
bool WebSocketsServer::clientIsConnected (WSclient_t * client) {
467
498
@@ -492,7 +523,7 @@ bool WebSocketsServer::clientIsConnected(WSclient_t * client) {
492
523
}
493
524
#if (WEBSOCKETS_NETWORK_TYPE != NETWORK_ESP8266_ASYNC)
494
525
/* *
495
- * Handle incomming Connection Request
526
+ * Handle incoming Connection Request
496
527
*/
497
528
void WebSocketsServer::handleNewClients (void ) {
498
529
@@ -569,10 +600,22 @@ void WebSocketsServer::handleClientData(void) {
569
600
}
570
601
#endif
571
602
603
+ /*
604
+ * returns an indicator whether the given named header exists in the configured _mandatoryHttpHeaders collection
605
+ * @param headerName String ///< the name of the header being checked
606
+ */
607
+ bool WebSocketsServer::hasMandatoryHeader (String headerName) {
608
+ for (size_t i = 0 ; i < _mandatoryHttpHeaderCount; i++) {
609
+ if (_mandatoryHttpHeaders[i].equalsIgnoreCase (headerName))
610
+ return true ;
611
+ }
612
+ return false ;
613
+ }
572
614
573
615
/* *
574
- * handle the WebSocket header reading
575
- * @param client WSclient_t * ptr to the client struct
616
+ * handles http header reading for WebSocket upgrade
617
+ * @param client WSclient_t * ///< pointer to the client struct
618
+ * @param headerLine String ///< the header being read / processed
576
619
*/
577
620
void WebSocketsServer::handleHeader (WSclient_t * client, String * headerLine) {
578
621
@@ -581,10 +624,16 @@ void WebSocketsServer::handleHeader(WSclient_t * client, String * headerLine) {
581
624
if (headerLine->length () > 0 ) {
582
625
DEBUG_WEBSOCKETS (" [WS-Server][%d][handleHeader] RX: %s\n " , client->num , headerLine->c_str ());
583
626
584
- // websocket request starts allways with GET see rfc6455
627
+ // websocket requests always start with GET see rfc6455
585
628
if (headerLine->startsWith (" GET " )) {
629
+
586
630
// cut URL out
587
631
client->cUrl = headerLine->substring (4 , headerLine->indexOf (' ' , 4 ));
632
+
633
+ // reset non-websocket http header validation state for this client
634
+ client->cHttpHeadersValid = true ;
635
+ client->cMandatoryHeadersCount = 0 ;
636
+
588
637
} else if (headerLine->indexOf (' :' )) {
589
638
String headerName = headerLine->substring (0 , headerLine->indexOf (' :' ));
590
639
String headerValue = headerLine->substring (headerLine->indexOf (' :' ) + 2 );
@@ -609,7 +658,13 @@ void WebSocketsServer::handleHeader(WSclient_t * client, String * headerLine) {
609
658
client->cExtensions = headerValue;
610
659
} else if (headerName.equalsIgnoreCase (" Authorization" )) {
611
660
client->base64Authorization = headerValue;
661
+ } else {
662
+ client->cHttpHeadersValid &= execHttpHeaderValidation (headerName, headerValue);
663
+ if (_mandatoryHttpHeaderCount > 0 && hasMandatoryHeader (headerName)) {
664
+ client->cMandatoryHeadersCount ++;
665
+ }
612
666
}
667
+
613
668
} else {
614
669
DEBUG_WEBSOCKETS (" [WS-Client][handleHeader] Header error (%s)\n " , headerLine->c_str ());
615
670
}
@@ -619,8 +674,8 @@ void WebSocketsServer::handleHeader(WSclient_t * client, String * headerLine) {
619
674
client->tcp ->readStringUntil (' \n ' , &(client->cHttpLine ), std::bind (&WebSocketsServer::handleHeader, this , client, &(client->cHttpLine )));
620
675
#endif
621
676
} else {
622
- DEBUG_WEBSOCKETS (" [WS-Server][%d][handleHeader] Header read fin.\n " , client->num );
623
677
678
+ DEBUG_WEBSOCKETS (" [WS-Server][%d][handleHeader] Header read fin.\n " , client->num );
624
679
DEBUG_WEBSOCKETS (" [WS-Server][%d][handleHeader] - cURL: %s\n " , client->num , client->cUrl .c_str ());
625
680
DEBUG_WEBSOCKETS (" [WS-Server][%d][handleHeader] - cIsUpgrade: %d\n " , client->num , client->cIsUpgrade );
626
681
DEBUG_WEBSOCKETS (" [WS-Server][%d][handleHeader] - cIsWebsocket: %d\n " , client->num , client->cIsWebsocket );
@@ -629,6 +684,8 @@ void WebSocketsServer::handleHeader(WSclient_t * client, String * headerLine) {
629
684
DEBUG_WEBSOCKETS (" [WS-Server][%d][handleHeader] - cExtensions: %s\n " , client->num , client->cExtensions .c_str ());
630
685
DEBUG_WEBSOCKETS (" [WS-Server][%d][handleHeader] - cVersion: %d\n " , client->num , client->cVersion );
631
686
DEBUG_WEBSOCKETS (" [WS-Server][%d][handleHeader] - base64Authorization: %s\n " , client->num , client->base64Authorization .c_str ());
687
+ DEBUG_WEBSOCKETS (" [WS-Server][%d][handleHeader] - cHttpHeadersValid: %d\n " , client->num , client->cHttpHeadersValid );
688
+ DEBUG_WEBSOCKETS (" [WS-Server][%d][handleHeader] - cMandatoryHeadersCount: %d\n " , client->num , client->cMandatoryHeadersCount );
632
689
633
690
bool ok = (client->cIsUpgrade && client->cIsWebsocket );
634
691
@@ -642,6 +699,12 @@ void WebSocketsServer::handleHeader(WSclient_t * client, String * headerLine) {
642
699
if (client->cVersion != 13 ) {
643
700
ok = false ;
644
701
}
702
+ if (!client->cHttpHeadersValid ) {
703
+ ok = false ;
704
+ }
705
+ if (client->cMandatoryHeadersCount != _mandatoryHttpHeaderCount) {
706
+ ok = false ;
707
+ }
645
708
}
646
709
647
710
if (_base64Authorization.length () > 0 ) {
0 commit comments