From 2a5f30bdb69b0e515d61b3841b897b7e7135b027 Mon Sep 17 00:00:00 2001 From: Ryan Grout Date: Mon, 7 May 2018 10:50:35 -0500 Subject: [PATCH 01/10] Add protocol.py --- pymysql/protocol.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 pymysql/protocol.py diff --git a/pymysql/protocol.py b/pymysql/protocol.py new file mode 100644 index 00000000..09278c08 --- /dev/null +++ b/pymysql/protocol.py @@ -0,0 +1,30 @@ +from ._compat import PY2 + +from struct import unpack_from + +DEBUG = False + +if PY2: + def read_uint8(data, offset=0): + return ord(data[offset]) +else: + def read_uint8(data, offset=0): + return data[offset] + + +def read_uint16(data, offset=0): + return unpack_from(' Date: Mon, 7 May 2018 10:51:35 -0500 Subject: [PATCH 02/10] Use protocol module to decode ints. --- pymysql/connections.py | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/pymysql/connections.py b/pymysql/connections.py index 967d0c59..f57a24c7 100644 --- a/pymysql/connections.py +++ b/pymysql/connections.py @@ -23,6 +23,7 @@ from .optionfile import Parser from .util import byte2int, int2byte from . import err +from . import protocol try: import ssl @@ -276,34 +277,28 @@ def get_bytes(self, position, length=1): """ return self._data[position:(position+length)] - if PY2: - def read_uint8(self): - result = ord(self._data[self._position]) - self._position += 1 - return result - else: - def read_uint8(self): - result = self._data[self._position] - self._position += 1 - return result + def read_uint8(self): + result = protocol.read_uint8(self._data, offset=self._position) + self._position += 1 + return result def read_uint16(self): - result = struct.unpack_from(' Date: Mon, 7 May 2018 11:28:50 -0500 Subject: [PATCH 03/10] Add packet checking functions. --- pymysql/protocol.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/pymysql/protocol.py b/pymysql/protocol.py index 09278c08..8811d2e3 100644 --- a/pymysql/protocol.py +++ b/pymysql/protocol.py @@ -1,7 +1,10 @@ +from __future__ import print_function from ._compat import PY2 from struct import unpack_from +from . import err + DEBUG = False if PY2: @@ -28,3 +31,34 @@ def read_uint32(data, offset=0): def read_uint64(data, offset=0): return unpack_from('= 7 + + +def is_eof_packet(packet): + # http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-EOF_Packet + # Caution: \xFE may be LengthEncodedInteger. + # If \xFE is LengthEncodedInteger header, 8bytes followed. + return read_uint8(packet) == 254 and len(packet) < 9 + + +def is_auth_switch_request(packet): + # http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest + return read_uint8(packet) == 254 + + +def is_load_local(packet): + return read_uint8(packet) == 251 + + +def is_resultset_packet(packet): + return 1 <= read_uint8(packet) <= 250 + + +def check_error(packet): + if read_uint8(packet) == 255: + errno = read_uint16(packet, offset=1) + if DEBUG: print("errno = ", errno) + err.raise_mysql_exception(packet) From db275f338ee642167bb9bc8ea9ef76434f59b071 Mon Sep 17 00:00:00 2001 From: Ryan Grout Date: Mon, 7 May 2018 11:29:13 -0500 Subject: [PATCH 04/10] Use protocol packet checking functions. --- pymysql/connections.py | 24 +++++------------------- 1 file changed, 5 insertions(+), 19 deletions(-) diff --git a/pymysql/connections.py b/pymysql/connections.py index f57a24c7..471b6d8f 100644 --- a/pymysql/connections.py +++ b/pymysql/connections.py @@ -347,36 +347,22 @@ def read_struct(self, fmt): return result def is_ok_packet(self): - # https://dev.mysql.com/doc/internals/en/packet-OK_Packet.html - return self._data[0:1] == b'\0' and len(self._data) >= 7 + return protocol.is_ok_packet(self._data) def is_eof_packet(self): - # http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-EOF_Packet - # Caution: \xFE may be LengthEncodedInteger. - # If \xFE is LengthEncodedInteger header, 8bytes followed. - return self._data[0:1] == b'\xfe' and len(self._data) < 9 + return protocol.is_eof_packet(self._data) def is_auth_switch_request(self): - # http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest - return self._data[0:1] == b'\xfe' + return protocol.is_auth_switch_request(self._data) def is_resultset_packet(self): - field_count = ord(self._data[0:1]) - return 1 <= field_count <= 250 + return protocol.is_resultset_packet(self._data) def is_load_local_packet(self): return self._data[0:1] == b'\xfb' - def is_error_packet(self): - return self._data[0:1] == b'\xff' - def check_error(self): - if self.is_error_packet(): - self.rewind() - self.advance(1) # field_count == error (we already know that) - errno = self.read_uint16() - if DEBUG: print("errno =", errno) - err.raise_mysql_exception(self._data) + protocol.check_error(self._data) def dump(self): dump_packet(self._data) From fe72db968774343d8dbccbedc14dc7dfc125a666 Mon Sep 17 00:00:00 2001 From: Ryan Grout Date: Mon, 7 May 2018 23:25:38 -0500 Subject: [PATCH 05/10] Move read_string to protocol.py --- pymysql/connections.py | 8 +++----- pymysql/protocol.py | 8 ++++++++ 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/pymysql/connections.py b/pymysql/connections.py index 471b6d8f..e3217b6c 100644 --- a/pymysql/connections.py +++ b/pymysql/connections.py @@ -303,11 +303,9 @@ def read_uint64(self): return result def read_string(self): - end_pos = self._data.find(b'\0', self._position) - if end_pos < 0: - return None - result = self._data[self._position:end_pos] - self._position = end_pos + 1 + pos, result = protocol.read_string(self._data, offset=self._position) + # We need to add one to account for the null terminating character. + self._position += pos return result def read_length_encoded_integer(self): diff --git a/pymysql/protocol.py b/pymysql/protocol.py index 8811d2e3..f7574fa8 100644 --- a/pymysql/protocol.py +++ b/pymysql/protocol.py @@ -62,3 +62,11 @@ def check_error(packet): errno = read_uint16(packet, offset=1) if DEBUG: print("errno = ", errno) err.raise_mysql_exception(packet) + + +def read_string(data, offset=0): + end = data.find(b'\0', offset) + if end >= 0: + result = data[offset:end] + # Add one to length to account for the null byte + return len(result) + 1, result From 61928a7b1e63563e9231ffbf98d59c10a4c56ccf Mon Sep 17 00:00:00 2001 From: Ryan Grout Date: Tue, 8 May 2018 00:22:18 -0500 Subject: [PATCH 06/10] Move length coded integer to protocol.py --- pymysql/connections.py | 14 +++----------- pymysql/protocol.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 11 deletions(-) diff --git a/pymysql/connections.py b/pymysql/connections.py index e3217b6c..83a24a9f 100644 --- a/pymysql/connections.py +++ b/pymysql/connections.py @@ -314,17 +314,9 @@ def read_length_encoded_integer(self): Length coded numbers can be anywhere from 1 to 9 bytes depending on the value of the first byte. """ - c = self.read_uint8() - if c == NULL_COLUMN: - return None - if c < UNSIGNED_CHAR_COLUMN: - return c - elif c == UNSIGNED_SHORT_COLUMN: - return self.read_uint16() - elif c == UNSIGNED_INT24_COLUMN: - return self.read_uint24() - elif c == UNSIGNED_INT64_COLUMN: - return self.read_uint64() + bytes_read, result = protocol.read_length_encoded_integer(self._data, self._position) + self._position += bytes_read + return result def read_length_coded_string(self): """Read a 'Length Coded String' from the data buffer. diff --git a/pymysql/protocol.py b/pymysql/protocol.py index f7574fa8..4a36b952 100644 --- a/pymysql/protocol.py +++ b/pymysql/protocol.py @@ -9,26 +9,32 @@ if PY2: def read_uint8(data, offset=0): + """Read 1 byte of data""" return ord(data[offset]) else: def read_uint8(data, offset=0): + """Read 1 byte of data""" return data[offset] def read_uint16(data, offset=0): + """Read 2 bytes of data beginning at offset""" return unpack_from(' Date: Tue, 8 May 2018 00:46:20 -0500 Subject: [PATCH 07/10] Move reading byte strings to protocol.py --- pymysql/connections.py | 14 +++----------- pymysql/protocol.py | 16 ++++++++++++++++ 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/pymysql/connections.py b/pymysql/connections.py index 83a24a9f..b651f5be 100644 --- a/pymysql/connections.py +++ b/pymysql/connections.py @@ -231,16 +231,8 @@ def get_all_data(self): def read(self, size): """Read the first 'size' bytes in packet and advance cursor past them.""" - result = self._data[self._position:(self._position+size)] - if len(result) != size: - error = ('Result length not requested length:\n' - 'Expected=%s. Actual=%s. Position: %s. Data Length: %s' - % (size, len(result), self._position, len(self._data))) - if DEBUG: - print(error) - self.dump() - raise AssertionError(error) - self._position += size + bytes_read, result = protocol.read_bytes(self._data, size, offset=self._position) + self._position += bytes_read return result def read_all(self): @@ -248,7 +240,7 @@ def read_all(self): (Subsequent read() will return errors.) """ - result = self._data[self._position:] + bytes_read, result = protocol.read_bytes(self._data, None, offset=self._position) self._position = None # ensure no subsequent read() return result diff --git a/pymysql/protocol.py b/pymysql/protocol.py index 4a36b952..59dbf6c6 100644 --- a/pymysql/protocol.py +++ b/pymysql/protocol.py @@ -100,3 +100,19 @@ def read_length_encoded_integer(data, offset=0): else: raise ValueError + +def read_bytes(data, nbytes, offset=0): + if nbytes is None: + result = data[offset:] + return len(result), result + else: + result = data[offset:offset+nbytes] + if len(result) == nbytes: + return nbytes, result + + error = ('Result length not requested length:\n' + 'Expected=%s Actual=%s Position: %s Data Length: %s' + % (nbytes, len(result), offset, len(data))) + if DEBUG: + print(error) + raise AssertionError(error) \ No newline at end of file From dd6ad876a303a764a011fc8cf092dea05de43056 Mon Sep 17 00:00:00 2001 From: Ryan Grout Date: Tue, 8 May 2018 00:49:21 -0500 Subject: [PATCH 08/10] Raise an error if reading a null terminated string fails. --- pymysql/protocol.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pymysql/protocol.py b/pymysql/protocol.py index 59dbf6c6..14f60e44 100644 --- a/pymysql/protocol.py +++ b/pymysql/protocol.py @@ -76,6 +76,8 @@ def read_string(data, offset=0): result = data[offset:end] # Add one to length to account for the null byte return len(result) + 1, result + else: + raise ValueError("Invalid read on non-null terminated string") def read_length_encoded_integer(data, offset=0): From 3be686e06eeec5e0530314be560b8c120595746c Mon Sep 17 00:00:00 2001 From: Ryan Grout Date: Tue, 8 May 2018 12:44:35 -0500 Subject: [PATCH 09/10] Move reading length encoded strings to protocol.py --- pymysql/connections.py | 7 +++---- pymysql/protocol.py | 12 +++++++++++- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/pymysql/connections.py b/pymysql/connections.py index b651f5be..f046983f 100644 --- a/pymysql/connections.py +++ b/pymysql/connections.py @@ -317,10 +317,9 @@ def read_length_coded_string(self): (unsigned, positive) integer represented in 1-9 bytes followed by that many bytes of binary data. (For example "cat" would be "3cat".) """ - length = self.read_length_encoded_integer() - if length is None: - return None - return self.read(length) + bytes_read, result = protocol.read_length_coded_string(self._data, offset=self._position) + self._position += bytes_read + return result def read_struct(self, fmt): s = struct.Struct(fmt) diff --git a/pymysql/protocol.py b/pymysql/protocol.py index 14f60e44..fdf700af 100644 --- a/pymysql/protocol.py +++ b/pymysql/protocol.py @@ -117,4 +117,14 @@ def read_bytes(data, nbytes, offset=0): % (nbytes, len(result), offset, len(data))) if DEBUG: print(error) - raise AssertionError(error) \ No newline at end of file + raise AssertionError(error) + + +def read_length_coded_string(data, offset=0): + bytes_read, length = read_length_encoded_integer(data, offset=offset) + if length is not None: + _br, result = read_bytes(data, length, offset=offset+bytes_read) + return bytes_read + _br, result + else: + # Null column + return bytes_read, None \ No newline at end of file From 6ae4ce9255593e2a9c462e51a03b426a19cc9c06 Mon Sep 17 00:00:00 2001 From: Ryan Grout Date: Tue, 8 May 2018 12:48:59 -0500 Subject: [PATCH 10/10] Remove old comment. --- pymysql/connections.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pymysql/connections.py b/pymysql/connections.py index f046983f..6d4951ab 100644 --- a/pymysql/connections.py +++ b/pymysql/connections.py @@ -296,7 +296,6 @@ def read_uint64(self): def read_string(self): pos, result = protocol.read_string(self._data, offset=self._position) - # We need to add one to account for the null terminating character. self._position += pos return result