10000 Merge pull request #405 from methane/feature/refactor · pkdevboxy/PyMySQL@78af0a9 · GitHub
[go: up one dir, main page]

Skip to content

Commit 78af0a9

Browse files
committed
Merge pull request PyMySQL#405 from methane/feature/refactor
refactoring auth plugin support
2 parents 1893637 + a126c4e commit 78af0a9

File tree

4 files changed

+79
-80
lines changed

4 files changed

+79
-80
lines changed

.travis.yml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
sudo: false
22
language: python
3-
python: "3.4"
3+
python: "3.5"
44
cache:
5-
pip: true
5+
directories:
6+
- $HOME/.cache/pip
67

78
env:
89
matrix:
910
- TOX_ENV=py26
1011
- TOX_ENV=py27
1112
- TOX_ENV=py33
1213
- TOX_ENV=py34
14+
- TOX_ENV=py35
1315
- TOX_ENV=pypy
1416
- TOX_ENV=pypy3
1517

@@ -36,7 +38,7 @@ matrix:
3638
sudo: required
3739
- env:
3840
- TOX_ENV=py34
39-
- DB=5.6.26
41+
- DB=5.6.28
4042
addons:
4143
apt:
4244
packages:

pymysql/connections.py

Lines changed: 57 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,7 @@ class Connection(object):
525525
"""
526526

527527
socket = None
528+
_auth_plugin_name = ''
528529

529530
def __init__(self, host=None, user=None, password="",
530531
database=None, port=3306, unix_socket=None,
@@ -535,7 +536,7 @@ def __init__(self, host=None, user=None, password="",
535536
compress=None, named_pipe=None, no_delay=None,
536537
autocommit=False, db=None, passwd=None, local_infile=False,
537538
max_allowed_packet=16*1024*1024, defer_connect=False,
538-
plugin_map={}):
539+
auth_plugin_map={}):
539540
"""
540541
Establish a connection to the MySQL database. Accepts several
541542
arguments:
@@ -571,12 +572,11 @@ def __init__(self, host=None, user=None, password="",
571572
max_allowed_packet: Max size of packet sent to server in bytes. (default: 16MB)
572573
defer_connect: Don't explicitly connect on contruction - wait for connect call.
573574
(default: False)
574-
plugin_map: Map of plugin names to a class that processes that plugin. The class
575-
will take the Connection object as the argument to the constructor. The class
576-
needs an authenticate method taking an authentication packet as an argument.
577-
For the dialog plugin, a prompt(echo, prompt) method can be used (if no
578-
authenticate method) for returning a string from the user.
579-
575+
auth_plugin_map: A dict of plugin names to a class that processes that plugin.
576+
The class will take the Connection object as the argument to the constructor.
577+
The class needs an authenticate method taking an authentication packet as
578+
an argument. For the dialog plugin, a prompt(echo, prompt) method can be used
579+
(if no authenticate method) for returning a string from the user. (experimental)
580580
db: Alias for database. (for compatibility to MySQLdb)
581581
passwd: Alias for password. (for compatibility to MySQLdb)
582582
"""
@@ -672,7 +672,7 @@ def _config(key, arg):
672672
self.sql_mode = sql_mode
673673
self.init_command = init_command
674674
self.max_allowed_packet = max_allowed_packet
675-
self.plugin_map = plugin_map
675+
self._auth_plugin_map = auth_plugin_map
676676
if defer_connect:
677677
self.socket = None
678678
else:
@@ -726,7 +726,6 @@ def __del__(self):
726726
def autocommit(self, value):
727727
self.autocommit_mode = bool(value)
728728
current = self.get_autocommit()
729-
self.next_packet = 1
730729
if value != current:
731730
self._send_autocommit_mode()
732731

@@ -816,15 +815,13 @@ def query(self, sql, unbuffered=False):
816815
"You may not close previous cursor.")
817816
# if DEBUG:
818817
# print("DEBUG: sending query:", sql)
819-
self.next_packet = 1
820818
if isinstance(sql, text_type) and not (JYTHON or IRONPYTHON):
821819
if PY2:
822820
sql = sql.encode(self.encoding)
823821
else:
824822
sql = sql.encode(self.encoding, 'surrogateescape')
825823
self._execute_command(COMMAND.COM_QUERY, sql)
826824
self._affected_rows = self._read_query_result(unbuffered=unbuffered)
827-
self.next_packet = 1
828825
return self._affected_rows
829826

830827
def next_result(self, unbuffered=False):
@@ -892,7 +889,7 @@ def connect(self, sock=None):
892889
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
893890
self.socket = sock
894891
self._rfile = _makefile(sock, 'rb')
895-
self.next_packet = 0
892+
self._next_seq_id = 0
896893

897894
self._get_server_information()
898895
self._request_authentication()
@@ -933,16 +930,16 @@ def connect(self, sock=None):
933930
# So just reraise it.
934931
raise
935932

936-
def write_packet(self, data):
933+
def write_packet(self, payload):
937934
"""Writes an entire "mysql packet" in its entirety to the network
938-
addings its length and sequence number. Intended for use by plugins
939-
only.
935+
addings its length and sequence number.
940936
"""
941-
data = pack_int24(len(data)) + int2byte(self.next_packet) + data
937+
# Internal note: when you build packet manualy and calls _write_bytes()
938+
# directly, you should set self._next_seq_id properly.
939+
data = pack_int24(len(payload)) + int2byte(self._next_seq_id) + payload
942940
if DEBUG: dump_packet(data)
943-
944941
self._write_bytes(data)
945-
self.next_packet = (self.next_packet + 1) % 256
942+
self._next_seq_id = (self._next_seq_id + 1) % 256
946943

947944
def _read_packet(self, packet_type=MysqlPacket):
948945
"""Read an entire "mysql packet" in its entirety from the network
@@ -952,8 +949,14 @@ def _read_packet(self, packet_type=MysqlPacket):
952949
while True:
953950
packet_header = self._read_bytes(4)
954< 10000 /td>951
if DEBUG: dump_packet(packet_header)
952+
955953
btrl, btrh, packet_number = struct.unpack('<HBB', packet_header)
956954
bytes_to_read = btrl + (btrh << 16)
955+
if packet_number != self._next_seq_id:
956+
raise err.InternalError("Packet sequence number wrong - got %d expected %d" %
957+
(packet_number, self._next_seq_id))
958+
self._next_seq_id = (self._next_seq_id + 1) % 256
959+
957960
recv_data = self._read_bytes(bytes_to_read)
958961
if DEBUG: dump_packet(recv_data)
959962
buff += recv_data
@@ -962,13 +965,7 @@ def _read_packet(self, packet_type=MysqlPacket):
962965
continue
963966
if bytes_to_read < MAX_PACKET_LEN:
964967
break
965-
if packet_number != self.next_packet:
966-
pass
967-
#TODO: check sequence id
968-
#raise err.InternalError("Packet sequence number wrong - got %d expected %d" %
969-
# (packet_number, self.next_packet))
970968

971-
self.next_packet = (packet_number + 1) % 256
972969
packet = packet_type(buff, self.encoding)
973970
packet.check_error()
974971
return packet
@@ -1027,33 +1024,32 @@ def _execute_command(self, command, sql):
10271024
if self._result is not None and self._result.unbuffered_active:
10281025
warnings.warn("Previous unbuffered result was left incomplete")
10291026
self._result._finish_unbuffered_query()
1027+
self._result = None
10301028

10311029
if isinstance(sql, text_type):
10321030
sql = sql.encode(self.encoding)
10331031

1034-
chunk_size = min(self.max_allowed_packet, len(sql) + 1) # +1 is for command
1032+
# +1 is for command
1033+
chunk_size = min(self.max_allowed_packet, len(sql) + 1)
10351034

1035+
# tiny optimization: build first packet manually instead of
1036+
# calling self..write_packet()
10361037
prelude = struct.pack('<iB', chunk_size, command)
1037-
self._write_bytes(prelude + sql[:chunk_size-1])
1038-
if DEBUG: dump_packet(prelude + sql)
1038+
packet = prelude + sql[:chunk_size-1]
1039+
self._write_bytes(packet)
1040+
if DEBUG: dump_packet(packet)
1041+
self._next_seq_id = 1
10391042

1040-
self.next_packet = 1
10411043
if chunk_size < self.max_allowed_packet:
10421044
return
10431045

1044-
seq_id = 1
10451046
sql = sql[chunk_size-1:]
10461047
while True:
10471048
chunk_size = min(self.max_allowed_packet, len(sql))
1048-
prelude = struct.pack('<i', chunk_size)[:3]
1049-
data = prelude + int2byte(seq_id%256) + sql[:chunk_size]
1050-
self._write_bytes(data)
1051-
if DEBUG: dump_packet(data)
1049+
self.write_packet(sql[:chunk_size])
10521050
sql = sql[chunk_size:]
10531051
if not sql and chunk_size < self.max_allowed_packet:
10541052
break
1055-
seq_id += 1
1056-
self.next_packet = seq_id%256
10571053

10581054
def _request_authentication(self):
10591055
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
@@ -1078,18 +1074,15 @@ def _request_authentication(self):
10781074

10791075
data = data_init + self.user + b'\0'
10801076

1081-
authresp = ''
1082-
if self.plugin_name == 'mysql_native_password':
1077+
authresp = b''
1078+
if self._auth_plugin_name == 'mysql_native_password':
10831079
authresp = _scramble(self.password.encode('latin1'), self.salt)
10841080

10851081
if self.server_capabilities & CLIENT.PLUGIN_AUTH_LENENC_CLIENT_DATA:
1086-
data += lenenc_int(len(authresp))
1087-
data += authresp
1082+
data += lenenc_int(len(authresp)) + authresp
10881083
elif self.server_capabilities & CLIENT.SECURE_CONNECTION:
1089-
length = len(authresp)
1090-
data += struct.pack('B', length)
1091-
data += authresp
1092-
else: # pragma: no cover - not testing against servers without secure auth (>=5.0)
1084+
data += struct.pack('B', len(authresp)) + authresp
1085+
else: # pragma: no cover - not testing against servers without secure auth (>=5.0)
10931086
data += authresp + b'\0'
10941087

10951088
if self.db and self.server_capabilities & CLIENT.CONNECT_WITH_DB:
@@ -1098,15 +1091,16 @@ def _request_authentication(self):
10981091
data += self.db + b'\0'
10991092

11001093
if self.server_capabilities & CLIENT.PLUGIN_AUTH:
1101-
data += self.plugin_name.encode('latin1') + b'\0'
1094+
name = self._auth_plugin_name
1095+
if isinstance(name, text_type):
1096+
name = name.encode('ascii')
1097+
data += name + b'\0'
11021098

11031099
self.write_packet(data)
1104-
11051100
auth_packet = self._read_packet()
11061101

11071102
# if authentication method isn't accepted the first byte
11081103
# will have the octet 254
1109-
11101104
if auth_packet.is_auth_switch_request():
11111105
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
11121106
auth_packet.read_uint8() # 0xfe packet identifier
@@ -1119,8 +1113,13 @@ def _request_authentication(self):
11191113
self.write_packet(data)
11201114
auth_packet = self._read_packet()
11211115

1116+
#TODO: ok packet or error packet?
1117+
1118+
11221119
def _process_auth(self, plugin_name, auth_packet):
1123-
plugin_class = self.plugin_map.get(plugin_name)
1120+
plugin_class = self._auth_plugin_map.get(plugin_name)
1121+
if not plugin_class:
1122+
plugin_class = self._auth_plugin_map.get(plugin_name.decode('ascii'))
11241123
if plugin_class:
11251124
try:
11261125
handler = plugin_class(self)
@@ -1246,18 +1245,13 @@ def _get_server_information(self):
12461245
server_end = data.find(b'\0', i)
12471246
if server_end < 0: # pragma: no cover - very specific upstream bug
12481247
# not found \0 and last field so take it all
1249-
self.plugin_name = data[i:].decode('latin1')
1248+
self._auth_plugin_name = data[i:].decode('latin1')
12501249
else:
1251-
self.plugin_name = data[i:server_end].decode('latin1')
1252-
else: # pragma: no cover - not testing against any plugin uncapable servers
1253-
self.plugin_name = ''
1250+
self._auth_plugin_name = data[i:server_end].decode('latin1')
12541251

12551252
def get_server_info(self):
12561253
return self.server_version
12571254

1258-
def get_plugin_name(self):
1259-
return self.plugin_name
1260-
12611255
Warning = err.Warning
12621256
Error = err.Error
12631257
InterfaceError = err.InterfaceError
@@ -1331,7 +1325,11 @@ def _read_ok_packet(self, first_packet):
13311325
def _read_load_local_packet(self, first_packet):
13321326
load_packet = LoadLocalPacketWrapper(first_packet)
13331327
sender = LoadLocalFile(load_packet.filename, self.connection)
1334-
sender.send_data()
1328+
try:
1329+
sender.send_data()
1330+
except:
1331+
self.connection._read_packet() # skip ok packet
1332+
raise
13351333

13361334
ok_packet = self.connection._read_packet()
13371335
if not ok_packet.is_ok_packet(): # pragma: no cover - upstream induced protocol error
@@ -1448,27 +1446,20 @@ def send_data(self):
14481446
"""Send data packets from the local file to the server"""
14491447
if not self.connection.socket:
14501448
raise err.InterfaceError("(0, '')")
1449+
conn = self.connection
14511450

1452-
# sequence id is 2 as we already sent a query packet
1453-
seq_id = 2
14541451
try:
14551452
with open(self.filename, 'rb') as open_file:
1456-
chunk_size = self.connection.max_allowed_packet
1453+
chunk_size = conn.max_allowed_packet
14571454
packet = b""
14581455

14591456
while True:
14601457
chunk = open_file.read(chunk_size)
14611458
if not chunk:
14621459
break
1463-
packet = struct.pack('<i', len(chunk))[:3] + int2byte(seq_id)
1464-
format_str = '!{0}s'.format(len(chunk))
1465-
packet += struct.pack(format_str, chunk)
1466-
self.connection._write_bytes(packet)
1467-
seq_id = (seq_id + 1) % 256
1460+
conn.write_packet(chunk)
14681461
except IOError:
14691462
raise err.OperationalError(1017, "Can't find file '{0}'".format(self.filename))
14701463
finally:
14711464
# send the empty packet to signify we are done sending data
1472-
packet = struct.pack('<i', 0)[:3] + int2byte(seq_id)
1473-
self.connection._write_bytes(packet)
1474-
self.next_packet = (seq_id + 1) % 256
1465+
conn.write_packet(b'')

0 commit comments

Comments
 (0)
0