38
38
break
39
39
40
40
ACCEPTABLE_CLIENT_ERRORS = set ((errno .ECONNRESET , errno .EPIPE ))
41
+ DEFAULT_MAX_FRAME_LENGTH = 8 << 20
41
42
42
43
__all__ = ["WebSocketWSGI" , "WebSocket" ]
43
44
PROTOCOL_GUID = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
@@ -75,14 +76,20 @@ def my_handler(ws):
75
76
:class:`WebSocket`. To close the socket, simply return from the
76
77
function. Note that the server will log the websocket request at
77
78
the time of closure.
79
+
80
+ An optional argument max_frame_length can be given, which will set the
81
+ maximum incoming *uncompressed* payload length of a frame. By default, this
82
+ is set to 8MiB. Note that excessive values here might create a DOS attack
83
+ vector.
78
84
"""
79
85
80
- def __init__ (self , handler ):
86
+ def __init__ (self , handler , max_frame_length = DEFAULT_MAX_FRAME_LENGTH ):
81
87
self .handler = handler
82
88
self .protocol_version = None
83
89
self .support_legacy_versions = True
84
90
self .supported_protocols = []
85
91
self .origin_checker = None
92
+ self .max_frame_length = max_frame_length
86
93
87
94
@classmethod
88
95
def configured (cls ,
@@ -324,7 +331,8 @@ def _handle_hybi_request(self, environ):
324
331
sock .sendall (b'\r \n ' .join (handshake_reply ) + b'\r \n \r \n ' )
325
332
return RFC6455WebSocket (sock , environ , self .protocol_version ,
326
333
protocol = negotiated_protocol ,
327
- extensions = parsed_extensions )
334
+ extensions = parsed_extensions ,
335
+ max_frame_length = self .max_frame_length )
328
336
329
337
def _extract_number (self , value ):
330
338
"""
@@ -503,7 +511,8 @@ class ProtocolError(ValueError):
503
511
504
512
505
513
class RFC6455WebSocket (WebSocket ):
506
- def __init__ (self , sock , environ , version = 13 , protocol = None , client = False , extensions = None ):
514
+ def __init__ (self , sock , environ , version = 13 , protocol = None , client = False , extensions = None ,
515
+ max_frame_length = DEFAULT_MAX_FRAME_LENGTH ):
507
516
super (RFC6455WebSocket , self ).__init__ (sock , environ , version )
508
517
self .iterator = self ._iter_frames ()
509
518
self .client = client
@@ -512,6 +521,8 @@ def __init__(self, sock, environ, version=13, protocol=None, client=False, exten
512
521
513
522
self ._deflate_enc = None
514
523
self ._deflate_dec = None
524
+ self .max_frame_length = max_frame_length
525
+ self ._remote_close_data = None
515
526
516
527
class UTF8Decoder (object ):
517
528
def __init__ (self ):
@@ -583,12 +594,13 @@ def _get_bytes(self, numbytes):
583
594
return data
584
595
585
596
class Message (object ):
586
- def __init__ (self , opcode , decoder = None , decompressor = None ):
597
+ def __init__ (self , opcode , max_frame_length , decoder = None , decompressor = None ):
587
598
self .decoder = decoder
588
599
self .data = []
589
600
self .finished = False
590
601
self .opcode = opcode
591
602
self .decompressor = decompressor
603
+ self .max_frame_length = max_frame_length
592
604
593
605
def push (self , data , final = False ):
594
606
self .finished = final
@@ -597,7 +609,12 @@ def push(self, data, final=False):
597
609
def getvalue (self ):
598
610
data = b"" .join (self .data )
599
611
if not self .opcode & 8 and self .decompressor :
600
- data = self .decompressor .decompress (data + b'\x00 \x00 \xff \xff ' )
612
+ data = self .decompressor .decompress (data + b"\x00 \x00 \xff \xff " , self .max_frame_length )
613
+ if self .decompressor .unconsumed_tail :
614
+ raise FailedConnectionError (
615
+ 1009 ,
616
+ "Incoming compressed frame exceeds length limit of {} bytes." .format (self .max_frame_length ))
617
+
601
618
if self .decoder :
602
619
8000
data = self .decoder .decode (data , self .finished )
603
620
return data
@@ -611,6 +628,7 @@ def _apply_mask(data, mask, length=None, offset=0):
611
628
612
629
def _handle_control_frame (self , opcode , data ):
613
630
if opcode == 8 : # connection close
631
+ self ._remote_close_data = data
614
632
if not data :
615
633
status = 1000
616
634
elif len (data ) > 1 :
@@ -710,13 +728,17 @@ def _recv_frame(self, message=None):
710
728
length = struct .unpack ('!H' , recv (2 ))[0 ]
711
729
elif length == 127 :
712
730
length = struct .unpack ('!Q' , recv (8 ))[0 ]
731
+
732
+ if length > self .max_frame_length :
733
+ raise FailedConnectionError (1009 , "Incoming frame of {} bytes is above length limit of {} bytes." .format (
734
+ length , self .max_frame_length ))
713
735
if masked :
714
736
mask = struct .unpack ('!BBBB' , recv (4 ))
715
737
received = 0
716
738
if not message or opcode & 8 :
717
739
decoder = self .UTF8Decoder () if opcode == 1 else None
718
740
decompressor = self ._get_permessage_deflate_dec (rsv1 )
719
- message = self .Message (opcode , decoder = decoder , decompressor = decompressor )
741
+ message = self .Message (opcode , self . max_frame_length , decoder = decoder , decompressor = decompressor )
720
742
if not length :
721
743
message .push (b'' , final = finished )
722
744
else :
0 commit comments