8000 Protocol module by groutr · Pull Request #669 · PyMySQL/PyMySQL · GitHub
[go: up one dir, main page]

Skip to content

Protocol module #669

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
91 changes: 26 additions & 65 deletions pymysql/connections.py
10000
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .optionfile import Parser
from .util import byte2int, int2byte
from . import err
from . import protocol

try:
import ssl
Expand Down Expand Up @@ -230,24 +231,16 @@ def get_all_data(self):

def read(self, size):
"""Read the first 'size' bytes in packet and advance cursor past them."""
result = self._data[self._position:(self._position+size)]
if len(result) != size:
error = ('Result length not requested length:\n'
'Expected=%s. Actual=%s. Position: %s. Data Length: %s'
% (size, len(result), self._position, len(self._data)))
if DEBUG:
print(error)
self.dump()
raise AssertionError(error)
self._position += size
bytes_read, result = protocol.read_bytes(self._data, size, offset=self._position)
self._position += bytes_read
return result

def read_all(self):
"""Read all remaining data in the packet.

(Subsequent read() will return errors.)
"""
result = self._data[self._position:]
bytes_read, result = protocol.read_bytes(self._data, None, offset=self._position)
self._position = None # ensure no subsequent read()
return result

Expand Down Expand Up @@ -276,43 +269,34 @@ def get_bytes(self, position, length=1):
"""
return self._data[position:(position+length)]

if PY2:
def read_uint8(self):
result = ord(self._data[self._position])
self._position += 1
return result
else:
def read_uint8(self):
result = self._data[self._position]
self._position += 1
return result
def read_uint8(self):
result = protocol.read_uint8(self._data, offset=self._position)
self._position += 1
return result
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't split methods. Just move.


def read_uint16(self):
result = struct.unpack_from('<H', self._data, self._position)[0]
result = protocol.read_uint16(self._data, offset=self._position)
self._position += 2
return result

def read_uint24(self):
low, high = struct.unpack_from('<HB', self._data, self._position)
result = protocol.read_uint24(self._data, offset=self._position)
self._position += 3
return low + (high << 16)
return result

def read_uint32(self):
result = struct.unpack_from('<I', self._data, self._position)[0]
result = protocol.read_uint32(self._data, offset=self._position)
self._position += 4
return result

def read_uint64(self):
result = struct.unpack_from('<Q', self._data, self._position)[0]
result = protocol.read_uint64(self._data, offset=self._position)
self._position += 8
return result

def read_string(self):
end_pos = self._data.find(b'\0', self._position)
if end_pos < 0:
return None
result = self._data[self._position:end_pos]
self._position = end_pos + 1
pos, result = protocol.read_string(self._data, offset=self._position)
self._position += pos
return result

def read_length_encoded_integer(self):
Expand All @@ -321,17 +305,9 @@ def read_length_encoded_integer(self):
Length coded numbers can be anywhere from 1 to 9 bytes depending
on the value of the first byte.
"""
c = self.read_uint8()
if c == NULL_COLUMN:
return None
if c < UNSIGNED_CHAR_COLUMN:
return c
elif c == UNSIGNED_SHORT_COLUMN:
return self.read_uint16()
elif c == UNSIGNED_INT24_COLUMN:
return self.read_uint24()
elif c == UNSIGNED_INT64_COLUMN:
return self.read_uint64()
bytes_read, result = protocol.read_length_encoded_integer(self._data, self._position)
self._position += bytes_read
return result

def read_length_coded_string(self):
"""Read a 'Length Coded String' from the data buffer.
Expand All @@ -340,10 +316,9 @@ def read_length_coded_string(self):
(unsigned, positive) integer represented in 1-9 bytes followed by
that many bytes of binary data. (For example "cat" would be "3cat".)
"""
length = self.read_length_encoded_integer()
if length is None:
return None
return self.read(length)
bytes_read, result = protocol.read_length_coded_string(self._data, offset=self._position)
self._position += bytes_read
return result

def read_struct(self, fmt):
s = struct.Struct(fmt)
Expand All @@ -352,36 +327,22 @@ def read_struct(self, fmt):
return result

def is_ok_packet(self):
# https://dev.mysql.com/doc/internals/en/packet-OK_Packet.html
return self._data[0:1] == b'\0' and len(self._data) >= 7
return protocol.is_ok_packet(self._data)

def is_eof_packet(self):
# http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-EOF_Packet
# Caution: \xFE may be LengthEncodedInteger.
# If \xFE is LengthEncodedInteger header, 8bytes followed.
return self._data[0:1] == b'\xfe' and len(self._data) < 9
return protocol.is_eof_packet(self._data)

def is_auth_switch_request(self):
# http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
return self._data[0:1] == b'\xfe'
return protocol.is_auth_switch_request(self._data)

def is_resultset_packet(self):
field_count = ord(self._data[0:1])
return 1 <= field_count <= 250
return protocol.is_resultset_packet(self._data)

def is_load_local_packet(self):
return self._data[0:1] == b'\xfb'

def is_error_packet(self):
return self._data[0:1] == b'\xff'

def check_error(self):
if self.is_error_packet():
self.rewind()
self.advance(1) # field_count == error (we already know that)
errno = self.read_uint16()
if DEBUG: print("errno =", errno)
err.raise_mysql_exception(self._data)
protocol.check_error(self._data)

def dump(self):
dump_packet(self._data)
Expand Down
130 changes: 130 additions & 0 deletions pymysql/protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from __future__ import print_function
from ._compat import PY2

from struct import unpack_from

from . import err

DEBUG = False

if PY2:
def read_uint8(data, offset=0):
"""Read 1 byte of data"""
return ord(data[offset])
else:
def read_uint8(data, offset=0):
"""Read 1 byte of data"""
return data[offset]


def read_uint16(data, offset=0):
"""Read 2 bytes of data beginning at offset"""
return unpack_from('<H', data, offset=offset)[0]


def read_uint24(data, offset=0):
"""Read 3 bytes of data beginning at offset"""
low, high = unpack_from('<HB', data, offset=offset)
return low + (high << 16)


def read_uint32(data, offset=0):
"""Read 4 bytes of data beginning at offset"""
return unpack_from('<I', data, offset=offset)[0]


def read_uint64(data, offset=0):
"""Read 8 bytes of data beginning at offset"""
return unpack_from('<Q', data, offset=offset)[0]


def is_ok_packet(packet):
# https://dev.mysql.com/doc/internals/en/packet-OK_Packet.html
return read_uint8(packet) == 0 and len(packet) >= 7


def is_eof_packet(packet):
# http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-EOF_Packet
# Caution: \xFE may be LengthEncodedInteger.
# If \xFE is LengthEncodedInteger header, 8bytes followed.
return read_uint8(packet) == 254 and len(packet) < 9


def is_auth_switch_request(packet):
# http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
return read_uint8(packet) == 254


def is_load_local(packet):
return read_uint8(packet) == 251


def is_resultset_packet(packet):
return 1 <= read_uint8(packet) <= 250


def check_error(packet):
if read_uint8(packet) == 255:
errno = read_uint16(packet, offset=1)
if DEBUG: print("errno = ", errno)
err.raise_mysql_exception(packet)


def read_string(data, offset=0):
end = data.find(b'\0', offset)
if end >= 0:
result = data[offset:end]
# Add one to length to account for the null byte
return len(result) + 1, result
else:
raise ValueError("Invalid read on non-null terminated string")


def read_length_encoded_integer(data, offset=0):
col = read_uint8(data, offset=offset)
bytes_read = 1

if col == 251:
return bytes_read, None

# Unsigned char column
if col < 251:
return bytes_read, col
# Unsigned short column
elif col == 252:
return bytes_read + 2, read_uint16(data, offset=offset+bytes_read)
# Unsigned int24 column
elif col == 253:
return bytes_read + 3, read_uint24(data, offset=offset+bytes_read)
# Unsigned int64 column
elif col == 254:
return bytes_read + 8, read_uint64(data, offset=offset+bytes_read)
else:
raise ValueError


def read_bytes(data, nbytes, offset=0):
if nbytes is None:
result = data[offset:]
return len(result), result
else:
result = data[offset:offset+nbytes]
if len(result) == nbytes:
return nbytes, result

error = ('Result length not requested length:\n'
'Expected=%s Actual=%s Position: %s Data Length: %s'
% (nbytes, len(result), offset, len(data)))
if DEBUG:
print(error)
raise AssertionError(error)


def read_length_coded_string(data, offset=0):
bytes_read, length = read_length_encoded_integer(data, offset=offset)
if length is not None:
_br, result = read_bytes(data, length, offset=offset+bytes_read)
return bytes_read + _br, result
else:
# Null column
return bytes_read, None
0