8000 Merge pull request #405 from methane/feature/refactor · PyMySQL/PyMySQL@78af0a9 · GitHub
[go: up one dir, main page]

Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 78af0a9

Browse files
committed
Merge pull request #405 from methane/feature/refactor
refactoring auth plugin support
2 parents 1893637 + a126c4e commit 78af0a9

File tree

4 files changed

+79
-80
lines changed

4 files changed

+79
-80
lines changed

.travis.yml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
sudo: false
22
language: python
3-
python: "3.4"
3+
python: "3.5"
44
cache:
5-
pip: true
5+
directories:
6+
- $HOME/.cache/pip
67

78
env:
89
matrix:
910
- TOX_ENV=py26
1011
- TOX_ENV=py27
1112
- TOX_ENV=py33
1213
- TOX_ENV=py34
14+
- TOX_ENV=py35
1315
- TOX_ENV=pypy
1416
- TOX_ENV=pypy3
1517

@@ -36,7 +38,7 @@ matrix:
3638
sudo: required
3739
- env:
3840
- TOX_ENV=py34
39-
- DB=5.6.26
41+
- DB=5.6.28
4042
addons:
4143
apt:
4244
packages:

pymysql/connections.py

Lines changed: 57 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,7 @@ class Connection(object):
525525
"""
526526

527527
socket = None
528+
_auth_plugin_name = ''
528529

529530
def __init__(self, host=None, user=None, password="",
530531
database=None, port=3306, unix_socket=None,
@@ -535,7 +536,7 @@ def __init__(self, host=None, user=None, password="",
535536
compress=None, named_pipe=None, no_delay=None,
536537
autocommit=False, db=None, passwd=None, local_infile=False,
537538
max_allowed_packet=16*1024*1024, defer_connect=False,
538-
plugin_map={}):
539+
auth_plugin_map={}):
539540
"""
540541
Establish a connection to the MySQL database. Accepts several
541542
arguments:
@@ -571,12 +572,11 @@ def __init__(self, host=None, user=None, password="",
571572
max_allowed_packet: Max size of packet sent to server in bytes. (default: 16MB)
572573
defer_connect: Don't explicitly connect on contruction - wait for connect call.
573574
(default: False)
574-
plugin_map: Map of plugin names to a class that processes that plugin. The class
575-
will take the Connection object as the argument to the constructor. The class
576-
needs an authenticate method taking an authentication packet as an argument.
577-
For the dialog plugin, a prompt(echo, prompt) method can be used (if no
578-
authenticate method) for returning a string from the user.
579-
575+
auth_plugin_map: A dict of plugin names to a class that processes that plugin.
576+
The class will take the Connection object as the argument to the constructor.
577+
The class needs an authenticate method taking an authentication packet as
578+
an argument. For the dialog plugin, a prompt(echo, prompt) method can be used
579+
(if no authenticate method) for returning a string from the user. (experimental)
580580
db: Alias for database. (for compatibility to MySQLdb)
581581
passwd: Alias for password. (for compatibility to MySQLdb)
582582
"""
@@ -672,7 +672,7 @@ def _config(key, arg):
672672
self.sql_mode = sql_mode
673673
self.init_command = init_command
674674
self.max_allowed_packet = max_allowed_packet
675-
self.plugin_map = plugin_map
675+
self._auth_plugin_map = auth_plugin_map
676676
if defer_connect:
677677
self.socket = None
678678
else:
@@ -726,7 +726,6 @@ def __del__(self):
726726
def autocommit(self, value):
727727
self.autocommit_mode = bool(value)
728728
current = self.get_autocommit()
729-
self.next_packet = 1
730729
if value != current:
731730
self._send_autocommit_mode()
732731

@@ -816,15 +815,13 @@ def query(self, sql, unbuffered=False):
816815
"You may not close previous cursor.")
817816
# if DEBUG:
818817
# print("DEBUG: sending query:", sql)
819-
self.next_packet = 1
820818
if isinstance(sql, text_type) and not (JYTHON or IRONPYTHON):
821819
if PY2:
822820
sql = sql.encode(self.encoding)
823821
else:
824822
sql = sql.encode(self.encoding, 'surrogateescape')
825823
self._execute_command(COMMAND.COM_QUERY, sql)
826824
self._affected_rows = self._read_query_result(unbuffered=unbuffered)
827-
self.next_packet = 1
828825
return self._affected_rows
829826

830827
def next_result(self, unbuffered=False):
@@ -892,7 +889,7 @@ def connect(self, sock=None):
892889
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
893890
self.socket = sock
894891
self._rfile = _makefile(sock, 'rb')
895-
self.next_packet = 0
892+
self._next_seq_id = 0
896893

897894
self._get_server_information()
898895
self._request_authentication()
@@ -933,16 +930,16 @@ def connect(self, sock=None):
933930
# So just reraise it.
934931
raise
935932

936-
def write_packet(self, data):
933+
def write_packet(self, payload):
937934
"""Writes an entire "mysql packet" in its entirety to the network
938-
addings its length and sequence number. Intended for use by plugins
939-
only.
935+
addings its length and sequence number.
940936
"""
941-
data = pack_int24(len(data)) + int2byte(self.next_packet) + data
937+
# Internal note: when you build packet manualy and calls _write_bytes()
938+
# directly, you should set self._next_seq_id properly.
939+
data = pack_int24(len(payload)) + int2byte(self._next_seq_id) + payload
942940
if DEBUG: dump_packet(data)
943-
944941
self._write_bytes(data)
945-
self.next_packet = (self.next_packet + 1) % 256
942+
self._next_seq_id = (self._next_seq_id + 1) % 256
946943

947944
def _read_packet(self, packet_type=MysqlPacket):
948945
"""Read an entire "mysql packet" in its entirety from the network
@@ -952,8 +949,14 @@ def _read_packet(self, packet_type=MysqlPacket):
952949
while True:
953950
packet_header = self._read_bytes(4)
954951
if DEBUG: dump_packet(packet_header)
952+
955953
btrl, btrh, packet_number = struct.unpack('<HBB', packet_header)
956954
bytes_to_read = btrl + (btrh << 16)
955+
if packet_number != self._next_seq_id:
956+
raise err.InternalError("Packet sequence number wrong - got %d expected %d" %
957+
(packet_number, self._next_seq_id))
958+
self._next_seq_id = (self._next_seq_id + 1) % 256
959+
957960
recv_data = self._read_bytes(bytes_to_read)
958961
if DEBUG: dump_packet(recv_data)
959962
buff += recv_data
@@ -962,13 +965,7 @@ def _read_packet(self, packet_type=MysqlPacket):
962965
continue
963966
if bytes_to_read < MAX_PACKET_LEN:
964967
break
965-
if packet_number != self.next_packet:
966-
pass
967-
#TODO: check sequence id
968-
#raise err.InternalError("Packet sequence number wrong - got %d expected %d" %
969-
# (packet_number, self.next_packet))
970968

971-
self.next_packet = (packet_number + 1) % 256
972969
packet = packet_type(buff, self.encoding)
973970
packet.check_error()
974971
return packet
@@ -1027,33 +1024,32 @@ def _execute_command(self, command, sql):
10271024
if self._result is not None and self._result.unbuffered_active:
10281025
warnings.warn("Previous unbuffered result was left incomplete")
10291026
self._result._finish_unbuffered_query()
1027+
self._result = None
10301028

10311029
if isinstance(sql, text_type):
10321030
sql = sql.encode(self.encoding)
10331031

1034-
chunk_size = min(self.max_allowed_packet, len(sql) + 1) # +1 is for command
1032+
# +1 is for command
1033+
chunk_size = min(self.max_allowed_packet, len(sql) + 1)
10351034

1035+
# tiny optimization: build first packet manually instead of
1036+
# calling self..write_packet()
10361037
prelude = struct.pack('<iB', chunk_size, command)
1037-
self._write_bytes(prelude + sql[:chunk_size-1])
1038-
if DEBUG: dump_packet(prelude + sql)
1038+
packet = prelude + sql[:chunk_size-1]
1039+
self._write_bytes(packet)
1040+
if DEBUG: dump_packet(packet)
1041+
self._next_seq_id = 1
10391042

1040-
self.next_packet = 1
10411043
if chunk_size < self.max_allowed_packet:
10421044
return
10431045

1044-
seq_id = 1
10451046
sql = sql[chunk_size-1:]
10461047
while True:
10471048
chunk_size = min(self.max_allowed_packet, len(sql))
1048-
prelude = struct.pack('<i', chunk_size)[:3]
1049-
data = prelude + int2byte(seq_id%256) + sql[:chunk_size]
1050-
self._write_bytes(data)
1051-
if DEBUG: dump_packet(data)
1049+
self.write_packet(sql[:chunk_size])
10521050
sql = sql[chunk_size:]
10531051
if not sql and chunk_size < self.max_allowed_packet:
10541052
break
1055-
seq_id += 1
1056-
self.next_packet = seq_id%256
10571053

10581054
def _request_authentication(self):
10591055
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
@@ -1078,18 +1074,15 @@ def _request_authentication(self):
10781074

10791075
data = data_init + self.user + b'\0'
10801076

1081-
authresp = ''
1082-
if self.plugin_name == 'mysql_native_password':
1077+
authresp = b''
1078+
if self._auth_plugin_name == 'mysql_native_password':
10831079
authresp = _scramble(self.password.encode('latin1'), self.salt)
10841080

10851081
if self.server_capabilities & CLIENT.PLUGIN_AUTH_LENENC_CLIENT_DATA:
1086-
data += lenenc_int(len(authresp))
1087-
data += authresp
1082+
data += lenenc_int(len(authresp)) + authresp
10881083
elif self.server_capabilities & CLIENT.SECURE_CONNECTION:
1089-
length = len(authresp)
1090-
data += struct.pack('B', length)
1091-
data += authresp
1092-
else: # pragma: no cover - not testing against servers without secure auth (>=5.0)
1084+
data += struct.pack('B', len(authresp)) + authresp
1085+
else: # pragma: no cover - not testing against servers without secure auth (>=5.0)
10931086
data += authresp + b'\0'
10941087

10951088
if self.db and self.server_capabilities & CLIENT.CONNECT_WITH_DB:
@@ -1098,15 +1091,16 @@ def _request_authentication(self):
10981091
data += self.db + b'\0'
10991092

11001093
if self.server_capabilities & CLIENT.PLUGIN_AUTH:
1101-
data += self.plugin_name.encode('latin1') + b'\0'
1094+
name = self._auth_plugin_name
1095+
if isinstance(name, text_type):
1096+
name = name.encode('ascii')
1097+
data += name + b'\0'
11021098

11031099
self.write_packet(data)
1104-
11051100
auth_packet = self._read_packet()
11061101

11071102
# if authentication method isn't accepted the first byte
11081103
# will have the octet 254
1109-
11101104
if auth_packet.is_auth_switch_request():
11111105
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
11121106
auth_packet.read_uint8() # 0xfe packet identifier
@@ -1119,8 +1113,13 @@ def _request_authentication(self):
11191113
self.write_packet(data)
11201114
auth_packet = self._read_packet()
11211115

1116+
#TODO: ok packet or error packet?
1117+
1118+
11221119
def _process_auth(self, plugin_name, auth_packet):
1123-
plugin_class = self.plugin_map.get(plugin_name)
1120+
plugin_class = self._auth_plugin_map.get(plugin_name)
1121+
if not plugin_class:
1122+
plugin_class = self._auth_plugin_map.get(plugin_name.decode('ascii'))
11241123
if plugin_class:
11251124
try:
11261125
handler = plugin_class(self)
@@ -1246,18 +1245,13 @@ def _get_server_information(self):
12461245
server_end = data.find(b'\0', i)
12471246
if server_end < 0: # pragma: no cover - very specific upstream bug
12481247
# not found \0 and last field so take it all
1249-
self.plugin_name = data[i:].decode('latin1')
1248+
self._auth_plugin_name = data[i:].decode('latin1')
12501249
else:
1251-
self.plugin_name = data[i:server_end].decode('latin1')
1252-
else: # pragma: no cover - not testing against any plugin uncapable servers
1253-
self.plugin_name = ''
1250+
self._auth_plugin_name = data[i:server_end].decode('latin1')
12541251

12551252
def get_server_info(self):
12561253
return self.server_version
12571254

1258-
def get_plugin_name(self):
1259-
return self.plugin_name
1260-
12611255
Warning = err.Warning
12621256
Error = err.Error
12631257
InterfaceError = err.InterfaceError
@@ -1331,7 +1325,11 @@ def _read_ok_packet(self, first_packet):
13311325
def _read_load_local_packet(self, first_packet):
13321326
load_packet = LoadLocalPacketWrapper(first_packet)
13331327
sender = LoadLocalFile(load_packet.filename, self.connection)
1334-
sender.send_data()
1328+
try:
1329+
sender.send_data()
1330+
except:
1331+
self.connection._read_packet() # skip ok packet
1332+
raise
13351333

13361334
ok_packet = self.connection._read_packet()
13371335
if not ok_packet.is_ok_packet(): # pragma: no cover - upstream induced protocol error
@@ -1448,27 +1446,20 @@ def send_data(self):
14481446
"""Send data packets from the local file to the server"""
14491447
if not self.connection.socket:
14501448
raise err.InterfaceError("(0, '')")
1449+
conn = self.connection
14511450

1452-
# sequence id is 2 as we already sent a query packet
1453-
seq_id = 2
14541451
try:
14551452
with open(self.filename, 'rb') as open_file:
1456-
chunk_size = self.connection.max_allowed_packet
1453+
chunk_size = conn.max_allowed_packet
14571454
packet = b""
14581455

14591456
while True:
14601457
chunk = open_file.read(chunk_size)
14611458
if not chunk:
14621459
break
1463-
packet = struct.pack('<i', len(chunk))[:3] + int2byte(seq_id)
1464-
format_str = '!{0}s'.format(len(chunk))
1465-
packet += struct.pack(format_str, chunk)
1466-
self.connection._write_bytes(packet)
1467-
seq_id = (seq_id + 1) % 256
1460+
conn.write_packet(chunk)
14681461
except IOError:
14691462
raise err.OperationalError(1017, "Can't find file '{0}'".format(self.filename))
14701463
finally:
14711464
# send the empty packet to signify we are done sending data
1472-
packet = struct.pack('<i', 0)[:3] + int2byte(seq_id)
1473-
self.connection._write_bytes(packet)
1474-
self.next_packet = (seq_id + 1) % 256
1465+
conn.write_packet(b'')

0 commit comments

Comments
 (0)
0