diff --git a/pymysql/connections.py b/pymysql/connections.py index 967d0c59..6d4951ab 100644 --- a/pymysql/connections.py +++ b/pymysql/connections.py @@ -23,6 +23,7 @@ from .optionfile import Parser from .util import byte2int, int2byte from . import err +from . import protocol try: import ssl @@ -230,16 +231,8 @@ def get_all_data(self): def read(self, size): """Read the first 'size' bytes in packet and advance cursor past them.""" - result = self._data[self._position:(self._position+size)] - if len(result) != size: - error = ('Result length not requested length:\n' - 'Expected=%s. Actual=%s. Position: %s. Data Length: %s' - % (size, len(result), self._position, len(self._data))) - if DEBUG: - print(error) - self.dump() - raise AssertionError(error) - self._position += size + bytes_read, result = protocol.read_bytes(self._data, size, offset=self._position) + self._position += bytes_read return result def read_all(self): @@ -247,7 +240,7 @@ def read_all(self): (Subsequent read() will return errors.) """ - result = self._data[self._position:] + bytes_read, result = protocol.read_bytes(self._data, None, offset=self._position) self._position = None # ensure no subsequent read() return result @@ -276,43 +269,34 @@ def get_bytes(self, position, length=1): """ return self._data[position:(position+length)] - if PY2: - def read_uint8(self): - result = ord(self._data[self._position]) - self._position += 1 - return result - else: - def read_uint8(self): - result = self._data[self._position] - self._position += 1 - return result + def read_uint8(self): + result = protocol.read_uint8(self._data, offset=self._position) + self._position += 1 + return result def read_uint16(self): - result = struct.unpack_from('= 7 + return protocol.is_ok_packet(self._data) def is_eof_packet(self): - # http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-EOF_Packet - # Caution: \xFE may be LengthEncodedInteger. - # If \xFE is LengthEncodedInteger header, 8bytes followed. - return self._data[0:1] == b'\xfe' and len(self._data) < 9 + return protocol.is_eof_packet(self._data) def is_auth_switch_request(self): - # http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest - return self._data[0:1] == b'\xfe' + return protocol.is_auth_switch_request(self._data) def is_resultset_packet(self): - field_count = ord(self._data[0:1]) - return 1 <= field_count <= 250 + return protocol.is_resultset_packet(self._data) def is_load_local_packet(self): return self._data[0:1] == b'\xfb' - def is_error_packet(self): - return self._data[0:1] == b'\xff' - def check_error(self): - if self.is_error_packet(): - self.rewind() - self.advance(1) # field_count == error (we already know that) - errno = self.read_uint16() - if DEBUG: print("errno =", errno) - err.raise_mysql_exception(self._data) + protocol.check_error(self._data) def dump(self): dump_packet(self._data) diff --git a/pymysql/protocol.py b/pymysql/protocol.py new file mode 100644 index 00000000..fdf700af --- /dev/null +++ b/pymysql/protocol.py @@ -0,0 +1,130 @@ +from __future__ import print_function +from ._compat import PY2 + +from struct import unpack_from + +from . import err + +DEBUG = False + +if PY2: + def read_uint8(data, offset=0): + """Read 1 byte of data""" + return ord(data[offset]) +else: + def read_uint8(data, offset=0): + """Read 1 byte of data""" + return data[offset] + + +def read_uint16(data, offset=0): + """Read 2 bytes of data beginning at offset""" + return unpack_from('= 7 + + +def is_eof_packet(packet): + # http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-EOF_Packet + # Caution: \xFE may be LengthEncodedInteger. + # If \xFE is LengthEncodedInteger header, 8bytes followed. + return read_uint8(packet) == 254 and len(packet) < 9 + + +def is_auth_switch_request(packet): + # http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest + return read_uint8(packet) == 254 + + +def is_load_local(packet): + return read_uint8(packet) == 251 + + +def is_resultset_packet(packet): + return 1 <= read_uint8(packet) <= 250 + + +def check_error(packet): + if read_uint8(packet) == 255: + errno = read_uint16(packet, offset=1) + if DEBUG: print("errno = ", errno) + err.raise_mysql_exception(packet) + + +def read_string(data, offset=0): + end = data.find(b'\0', offset) + if end >= 0: + result = data[offset:end] + # Add one to length to account for the null byte + return len(result) + 1, result + else: + raise ValueError("Invalid read on non-null terminated string") + + +def read_length_encoded_integer(data, offset=0): + col = read_uint8(data, offset=offset) + bytes_read = 1 + + if col == 251: + return bytes_read, None + + # Unsigned char column + if col < 251: + return bytes_read, col + # Unsigned short column + elif col == 252: + return bytes_read + 2, read_uint16(data, offset=offset+bytes_read) + # Unsigned int24 column + elif col == 253: + return bytes_read + 3, read_uint24(data, offset=offset+bytes_read) + # Unsigned int64 column + elif col == 254: + return bytes_read + 8, read_uint64(data, offset=offset+bytes_read) + else: + raise ValueError + + +def read_bytes(data, nbytes, offset=0): + if nbytes is None: + result = data[offset:] + return len(result), result + else: + result = data[offset:offset+nbytes] + if len(result) == nbytes: + return nbytes, result + + error = ('Result length not requested length:\n' + 'Expected=%s Actual=%s Position: %s Data Length: %s' + % (nbytes, len(result), offset, len(data))) + if DEBUG: + print(error) + raise AssertionError(error) + + +def read_length_coded_string(data, offset=0): + bytes_read, length = read_length_encoded_integer(data, offset=offset) + if length is not None: + _br, result = read_bytes(data, length, offset=offset+bytes_read) + return bytes_read + _br, result + else: + # Null column + return bytes_read, None \ No newline at end of file