8000 refactoring before next release by methane · Pull Request #405 · PyMySQL/PyMySQL · GitHub
[go: up one dir, main page]

Skip to content

refactoring before next release #405

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 4, 2016
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refactor connection's sequence id
  • Loading branch information
methane committed Jan 3, 2016
commit 344b933385d1c45e2777940ff76ff260af3113ee
62 changes: 26 additions & 36 deletions pymysql/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,6 @@ def __del__(self):
def autocommit(self, value):
self.autocommit_mode = bool(value)
current = self.get_autocommit()
self.next_packet = 1
if value != current:
self._send_autocommit_mode()

Expand Down Expand Up @@ -816,15 +815,13 @@ def query(self, sql, unbuffered=False):
"You may not close previous cursor.")
# if DEBUG:
# print("DEBUG: sending query:", sql)
self.next_packet = 1
if isinstance(sql, text_type) and not (JYTHON or IRONPYTHON):
if PY2:
sql = sql.encode(self.encoding)
else:
sql = sql.encode(self.encoding, 'surrogateescape')
self._execute_command(COMMAND.COM_QUERY, sql)
self._affected_rows = self._read_query_result(unbuffered=unbuffered)
self.next_packet = 1
return self._affected_rows

def next_result(self, unbuffered=False):
Expand Down Expand Up @@ -892,7 +889,7 @@ def connect(self, sock=None):
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
self.socket = sock
self._rfile = _makefile(sock, 'rb')
self.next_packet = 0
self._next_seq_id = 0

self._get_server_information()
self._request_authentication()
Expand Down Expand Up @@ -933,16 +930,16 @@ def connect(self, sock=None):
# So just reraise it.
raise

def write_packet(self, data):
def write_packet(self, payload):
"""Writes an entire "mysql packet" in its entirety to the network
addings its length and sequence number. Intended for use by plugins
only.
addings its length and sequence number.
"""
data = pack_int24(len(data)) + int2byte(self.next_packet) + data
# Internal note: when you build packet manualy and calls _write_bytes()
# directly, you should set self._next_seq_id properly.
data = pack_int24(len(payload)) + int2byte(self._next_seq_id) + payload
if DEBUG: dump_packet(data)

self._write_bytes(data)
self.next_packet = (self.next_packet + 1) % 256
self._next_seq_id = (self._next_seq_id + 1) % 256

def _read_packet(self, packet_type=MysqlPacket):
"""Read an entire "mysql packet" in its entirety from the network
Expand All @@ -952,8 +949,14 @@ def _read_packet(self, packet_type=MysqlPacket):
while True:
packet_header = self._read_bytes(4)
if DEBUG: dump_packet(packet_header)

btrl, btrh, packet_number = struct.unpack('<HBB', packet_header)
bytes_to_read = btrl + (btrh << 16)
if packet_number != self._next_seq_id:
raise err.InternalError("Packet sequence number wrong - got %d expected %d" %
(packet_number, self._next_seq_id))
self._next_seq_id = (self._next_seq_id + 1) % 256

recv_data = self._read_bytes(bytes_to_read)
if DEBUG: dump_packet(recv_data)
buff += recv_data
Expand All @@ -962,13 +965,7 @@ def _read_packet(self, packet_type=MysqlPacket):
continue
if bytes_to_read < MAX_PACKET_LEN:
break
if packet_number != self.next_packet:
pass
#TODO: check sequence id
#raise err.InternalError("Packet sequence number wrong - got %d expected %d" %
# (packet_number, self.next_packet))

self.next_packet = (packet_number + 1) % 256
packet = packet_type(buff, self.encoding)
packet.check_error()
return packet
Expand Down Expand Up @@ -1027,33 +1024,32 @@ def _execute_command(self, command, sql):
if self._result is not None and self._result.unbuffered_active:
warnings.warn("Previous unbuffered result was left incomplete")
self._result._finish_unbuffered_query()
self._result = None

if isinstance(sql, text_type):
sql = sql.encode(self.encoding)

chunk_size = min(self.max_allowed_packet, len(sql) + 1) # +1 is for command
# +1 is for command
chunk_size = min(self.max_allowed_packet, len(sql) + 1)

# tiny optimization: build first packet manually instead of
# calling self..write_packet()
prelude = struct.pack('<iB', chunk_size, command)
self._write_bytes(prelude + sql[:chunk_size-1])
if DEBUG: dump_packet(prelude + sql)
packet = prelude + sql[:chunk_size-1]
self._write_bytes(packet)
if DEBUG: dump_packet(packet)
self._next_seq_id = 1

self.next_packet = 1
if chunk_size < self.max_allowed_packet:
return

seq_id = 1
sql = sql[chunk_size-1:]
while True:
chunk_size = min(self.max_allowed_packet, len(sql))
prelude = struct.pack('<i', chunk_size)[:3]
data = prelude + int2byte(seq_id%256) + sql[:chunk_size]
self._write_bytes(data)
if DEBUG: dump_packet(data)
self.write_packet(sql[:chunk_size])
sql = sql[chunk_size:]
if not sql and chunk_size < self.max_allowed_packet:
break
seq_id += 1
self.next_packet = seq_id%256

def _request_authentication(self):
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
Expand Down Expand Up @@ -1448,9 +1444,8 @@ def send_data(self):
"""Send data packets from the local file to the server"""
if not self.connection.socket:
raise err.InterfaceError("(0, '')")
conn = self.connection

# sequence id is 2 as we already sent a query packet
seq_id = 2
try:
with open(self.filename, 'rb') as open_file:
chunk_size = self.connection.max_allowed_packet
Expand All @@ -1460,14 +1455,9 @@ def send_data(self):
chunk = open_file.read(chunk_size)
if not chunk:
break
packet = struct.pack('<i', len(chunk))[:3] + int2byte(seq_id)
format_str = '!{0}s'.format(len(chunk))
packet += struct.pack(format_str, chunk)
self.connection._write_bytes(packet)
seq_id = (seq_id + 1) % 256
conn.write_packet(chunk)
except IOError:
raise err.OperationalError(1017, "Can't find file '{0}'".format(self.filename))
finally:
# send the empty packet to signify we are done sending data
packet = struct.pack('<i', 0)[:3] + int2byte(seq_id)
self.connection._write_bytes(packet)
conn.write_packet(b'')
0