@@ -525,6 +525,7 @@ class Connection(object):
525
525
"""
526
526
527
527
socket = None
528
+ _auth_plugin_name = ''
528
529
529
530
def __init__ (self , host = None , user = None , password = "" ,
530
531
database = None , port = 3306 , unix_socket = None ,
@@ -535,7 +536,7 @@ def __init__(self, host=None, user=None, password="",
535
536
compress = None , named_pipe = None , no_delay = None ,
536
537
autocommit = False , db = None , passwd = None , local_infile = False ,
537
538
max_allowed_packet = 16 * 1024 * 1024 , defer_connect = False ,
538
- plugin_map = {}):
539
+ auth_plugin_map = {}):
539
540
"""
540
541
Establish a connection to the MySQL database. Accepts several
541
542
arguments:
@@ -571,12 +572,11 @@ def __init__(self, host=None, user=None, password="",
571
572
max_allowed_packet: Max size of packet sent to server in bytes. (default: 16MB)
572
573
defer_connect: Don't explicitly connect on contruction - wait for connect call.
573
574
(default: False)
574
- plugin_map: Map of plugin names to a class that processes that plugin. The class
575
- will take the Connection object as the argument to the constructor. The class
576
- needs an authenticate method taking an authentication packet as an argument.
577
- For the dialog plugin, a prompt(echo, prompt) method can be used (if no
578
- authenticate method) for returning a string from the user.
579
-
575
+ auth_plugin_map: A dict of plugin names to a class that processes that plugin.
576
+ The class will take the Connection object as the argument to the constructor.
577
+ The class needs an authenticate method taking an authentication packet as
578
+ an argument. For the dialog plugin, a prompt(echo, prompt) method can be used
579
+ (if no authenticate method) for returning a string from the user. (experimental)
580
580
db: Alias for database. (for compatibility to MySQLdb)
581
581
passwd: Alias for password. (for compatibility to MySQLdb)
582
582
"""
@@ -672,7 +672,7 @@ def _config(key, arg):
672
672
self .sql_mode = sql_mode
673
673
self .init_command = init_command
674
674
self .max_allowed_packet = max_allowed_packet
675
- self .plugin_map = plugin_map
675
+ self ._auth_plugin_map = auth_plugin_map
676
676
if defer_connect :
677
677
self .socket = None
678
678
else :
@@ -726,7 +726,6 @@ def __del__(self):
726
726
def autocommit (self , value ):
727
727
self .autocommit_mode = bool (value )
728
728
current = self .get_autocommit ()
729
- self .next_packet = 1
730
729
if value != current :
731
730
self ._send_autocommit_mode ()
732
731
@@ -816,15 +815,13 @@ def query(self, sql, unbuffered=False):
816
815
"You may not close previous cursor." )
817
816
# if DEBUG:
818
817
# print("DEBUG: sending query:", sql)
819
- self .next_packet = 1
820
818
if isinstance (sql , text_type ) and not (JYTHON or IRONPYTHON ):
821
819
if PY2 :
822
820
sql = sql .encode (self .encoding )
823
821
else :
824
822
sql = sql .encode (self .encoding , 'surrogateescape' )
825
823
self ._execute_command (COMMAND .COM_QUERY , sql )
826
824
self ._affected_rows = self ._read_query_result (unbuffered = unbuffered )
827
- self .next_packet = 1
828
825
return self ._affected_rows
829
826
830
827
def next_result (self , unbuffered = False ):
@@ -892,7 +889,7 @@ def connect(self, sock=None):
892
889
sock .setsockopt (socket .SOL_SOCKET , socket .SO_KEEPALIVE , 1 )
893
890
self .socket = sock
894
891
self ._rfile = _makefile (sock , 'rb' )
895
- self .next_packet = 0
892
+ self ._next_seq_id = 0
896
893
897
894
self ._get_server_information ()
898
895
self ._request_authentication ()
@@ -933,16 +930,16 @@ def connect(self, sock=None):
933
930
# So just reraise it.
934
931
raise
935
932
936
- def write_packet (self , data ):
933
+ def write_packet (self , payload ):
937
934
"""Writes an entire "mysql packet" in its entirety to the network
938
- addings its length and sequence number. Intended for use by plugins
939
- only.
935
+ addings its length and sequence number.
940
936
"""
941
- data = pack_int24 (len (data )) + int2byte (self .next_packet ) + data
937
+ # Internal note: when you build packet manualy and calls _write_bytes()
938
+ # directly, you should set self._next_seq_id properly.
939
+ data = pack_int24 (len (payload )) + int2byte (self ._next_seq_id ) + payload
942
940
if DEBUG : dump_packet (data )
943
-
944
941
self ._write_bytes (data )
945
- self .next_packet = (self .next_packet + 1 ) % 256
942
+ self ._next_seq_id = (self ._next_seq_id + 1 ) % 256
946
943
947
944
def _read_packet (self , packet_type = MysqlPacket ):
948
945
"""Read an entire "mysql packet" in its entirety from the network
@@ -952,8 +949,14 @@ def _read_packet(self, packet_type=MysqlPacket):
952
949
while True :
953
950
packet_header = self ._read_bytes (4 )
954
<
10000
/td>951
if DEBUG : dump_packet (packet_header )
952
+
955
953
btrl , btrh , packet_number = struct .unpack ('<HBB' , packet_header )
956
954
bytes_to_read = btrl + (btrh << 16 )
955
+ if packet_number != self ._next_seq_id :
956
+ raise err .InternalError ("Packet sequence number wrong - got %d expected %d" %
957
+ (packet_number , self ._next_seq_id ))
958
+ self ._next_seq_id = (self ._next_seq_id + 1 ) % 256
959
+
957
960
recv_data = self ._read_bytes (bytes_to_read )
958
961
if DEBUG : dump_packet (recv_data )
959
962
buff += recv_data
@@ -962,13 +965,7 @@ def _read_packet(self, packet_type=MysqlPacket):
962
965
continue
963
966
if bytes_to_read < MAX_PACKET_LEN :
964
967
break
965
- if packet_number != self .next_packet :
966
- pass
967
- #TODO: check sequence id
968
- #raise err.InternalError("Packet sequence number wrong - got %d expected %d" %
969
- # (packet_number, self.next_packet))
970
968
971
- self .next_packet = (packet_number + 1 ) % 256
972
969
packet = packet_type (buff , self .encoding )
973
970
packet .check_error ()
974
971
return packet
@@ -1027,33 +1024,32 @@ def _execute_command(self, command, sql):
1027
1024
if self ._result is not None and self ._result .unbuffered_active :
1028
1025
warnings .warn ("Previous unbuffered result was left incomplete" )
1029
1026
self ._result ._finish_unbuffered_query ()
1027
+ self ._result = None
1030
1028
1031
1029
if isinstance (sql , text_type ):
1032
1030
sql = sql .encode (self .encoding )
1033
1031
1034
- chunk_size = min (self .max_allowed_packet , len (sql ) + 1 ) # +1 is for command
1032
+ # +1 is for command
1033
+ chunk_size = min (self .max_allowed_packet , len (sql ) + 1 )
1035
1034
1035
+ # tiny optimization: build first packet manually instead of
1036
+ # calling self..write_packet()
1036
1037
prelude = struct .pack ('<iB' , chunk_size , command )
1037
- self ._write_bytes (prelude + sql [:chunk_size - 1 ])
1038
- if DEBUG : dump_packet (prelude + sql )
1038
+ packet = prelude + sql [:chunk_size - 1 ]
1039
+ self ._write_bytes (packet )
1040
+ if DEBUG : dump_packet (packet )
1041
+ self ._next_seq_id = 1
1039
1042
1040
- self .next_packet = 1
1041
1043
if chunk_size < self .max_allowed_packet :
1042
1044
return
1043
1045
1044
- seq_id = 1
1045
1046
sql = sql [chunk_size - 1 :]
1046
1047
while True :
1047
1048
chunk_size = min (self .max_allowed_packet , len (sql ))
1048
- prelude = struct .pack ('<i' , chunk_size )[:3 ]
1049
- data = prelude + int2byte (seq_id % 256 ) + sql [:chunk_size ]
1050
- self ._write_bytes (data )
1051
- if DEBUG : dump_packet (data )
1049
+ self .write_packet (sql [:chunk_size ])
1052
1050
sql = sql [chunk_size :]
1053
1051
if not sql and chunk_size < self .max_allowed_packet :
1054
1052
break
1055
- seq_id += 1
1056
- self .next_packet = seq_id % 256
1057
1053
1058
1054
def _request_authentication (self ):
1059
1055
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
@@ -1078,18 +1074,15 @@ def _request_authentication(self):
1078
1074
1079
1075
data = data_init + self .user + b'\0 '
1080
1076
1081
- authresp = ''
1082
- if self .plugin_name == 'mysql_native_password' :
1077
+ authresp = b ''
1078
+ if self ._auth_plugin_name == 'mysql_native_password' :
1083
1079
authresp = _scramble (self .password .encode ('latin1' ), self .salt )
1084
1080
1085
1081
if self .server_capabilities & CLIENT .PLUGIN_AUTH_LENENC_CLIENT_DATA :
1086
- data += lenenc_int (len (authresp ))
1087
- data += authresp
1082
+ data += lenenc_int (len (authresp )) + authresp
1088
1083
elif self .server_capabilities & CLIENT .SECURE_CONNECTION :
1089
- length = len (authresp )
1090
- data += struct .pack ('B' , length )
1091
- data += authresp
1092
- else : # pragma: no cover - not testing against servers without secure auth (>=5.0)
1084
+ data += struct .pack ('B' , len (authresp )) + authresp
1085
+ else : # pragma: no cover - not testing against servers without secure auth (>=5.0)
1093
1086
data += authresp + b'\0 '
1094
1087
1095
1088
if self .db and self .server_capabilities & CLIENT .CONNECT_WITH_DB :
@@ -1098,15 +1091,16 @@ def _request_authentication(self):
1098
1091
data += self .db + b'\0 '
1099
1092
1100
1093
if self .server_capabilities & CLIENT .PLUGIN_AUTH :
1101
- data += self .plugin_name .encode ('latin1' ) + b'\0 '
1094
+ name = self ._auth_plugin_name
1095
+ if isinstance (name , text_type ):
1096
+ name = name .encode ('ascii' )
1097
+ data += name + b'\0 '
1102
1098
1103
1099
self .write_packet (data )
1104
-
1105
1100
auth_packet = self ._read_packet ()
1106
1101
1107
1102
# if authentication method isn't accepted the first byte
1108
1103
# will have the octet 254
1109
-
1110
1104
if auth_packet .is_auth_switch_request ():
1111
1105
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
1112
1106
auth_packet .read_uint8 () # 0xfe packet identifier
@@ -1119,8 +1113,13 @@ def _request_authentication(self):
1119
1113
self .write_packet (data )
1120
1114
auth_packet = self ._read_packet ()
1121
1115
1116
+ #TODO: ok packet or error packet?
1117
+
1118
+
1122
1119
def _process_auth (self , plugin_name , auth_packet ):
1123
- plugin_class = self .plugin_map .get (plugin_name )
1120
+ plugin_class = self ._auth_plugin_map .get (plugin_name )
1121
+ if not plugin_class :
1122
+ plugin_class = self ._auth_plugin_map .get (plugin_name .decode ('ascii' ))
1124
1123
if plugin_class :
1125
1124
try :
1126
1125
handler = plugin_class (self )
@@ -1246,18 +1245,13 @@ def _get_server_information(self):
1246
1245
server_end = data .find (b'\0 ' , i )
1247
1246
if server_end < 0 : # pragma: no cover - very specific upstream bug
1248
1247
# not found \0 and last field so take it all
1249
- self .plugin_name = data [i :].decode ('latin1' )
1248
+ self ._auth_plugin_name = data [i :].decode ('latin1' )
1250
1249
else :
1251
- self .plugin_name = data [i :server_end ].decode ('latin1' )
1252
- else : # pragma: no cover - not testing against any plugin uncapable servers
1253
- self .plugin_name = ''
1250
+ self ._auth_plugin_name = data [i :server_end ].decode ('latin1' )
1254
1251
1255
1252
def get_server_info (self ):
1256
1253
return self .server_version
1257
1254
1258
- def get_plugin_name (self ):
1259
- return self .plugin_name
1260
-
1261
1255
Warning = err .Warning
1262
1256
Error = err .Error
1263
1257
InterfaceError = err .InterfaceError
@@ -1331,7 +1325,11 @@ def _read_ok_packet(self, first_packet):
1331
1325
def _read_load_local_packet (self , first_packet ):
1332
1326
load_packet = LoadLocalPacketWrapper (first_packet )
1333
1327
sender = LoadLocalFile (load_packet .filename , self .connection )
1334
- sender .send_data ()
1328
+ try :
1329
+ sender .send_data ()
1330
+ except :
1331
+ self .connection ._read_packet () # skip ok packet
1332
+ raise
1335
1333
1336
1334
ok_packet = self .connection ._read_packet ()
1337
1335
if not ok_packet .is_ok_packet (): # pragma: no cover - upstream induced protocol error
@@ -1448,27 +1446,20 @@ def send_data(self):
1448
1446
"""Send data packets from the local file to the server"""
1449
1447
if not self .connection .socket :
1450
1448
raise err .InterfaceError ("(0, '')" )
1449
+ conn = self .connection
1451
1450
1452
- # sequence id is 2 as we already sent a query packet
1453
- seq_id = 2
1454
1451
try :
1455
1452
with open (self .filename , 'rb' ) as open_file :
1456
- chunk_size = self . connection .max_allowed_packet
1453
+ chunk_size = conn .max_allowed_packet
1457
1454
packet = b""
1458
1455
1459
1456
while True :
1460
1457
chunk = open_file .read (chunk_size )
1461
1458
if not chunk :
1462
1459
break
1463
- packet = struct .pack ('<i' , len (chunk ))[:3 ] + int2byte (seq_id )
1464
- format_str = '!{0}s' .format (len (chunk ))
1465
- packet += struct .pack (format_str , chunk )
1466
- self .connection ._write_bytes (packet )
1467
- seq_id = (seq_id + 1 ) % 256
1460
+ conn .write_packet (chunk )
1468
1461
except IOError :
1469
1462
raise err .OperationalError (1017 , "Can't find file '{0}'" .format (self .filename ))
1470
1463
finally :
1471
1464
# send the empty packet to signify we are done sending data
1472
- packet = struct .pack ('<i' , 0 )[:3 ] + int2byte (seq_id )
1473
- self .connection ._write_bytes (packet )
1474
- self .next_packet = (seq_id + 1 ) % 256
1465
+ conn .write_packet (b'' )
0 commit comments