8000 refactoring before next release by methane · Pull Request #405 · PyMySQL/PyMySQL · GitHub
[go: up one dir, main page]

Skip to content

refactoring before next release #405

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 4, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
sudo: false
language: python
python: "3.4"
python: "3.5"
cache:
pip: true
directories:
- $HOME/.cache/pip

env:
matrix:
- TOX_ENV=py26
- TOX_ENV=py27
- TOX_ENV=py33
- TOX_ENV=py34
- TOX_ENV=py35
- TOX_ENV=pypy
- TOX_ENV=pypy3

Expand All @@ -36,7 +38,7 @@ matrix:
sudo: required
- env:
- TOX_ENV=py34
- DB=5.6.26
- DB=5.6.28
addons:
apt:
packages:
Expand Down
123 changes: 57 additions & 66 deletions pymysql/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,7 @@ class Connection(object):
"""

socket = None
_auth_plugin_name = ''

def __init__(self, host=None, user=None, password="",
database=None, port=3306, unix_socket=None,
Expand All @@ -535,7 +536,7 @@ def __init__(self, host=None, user=None, password="",
compress=None, named_pipe=None, no_delay=None,
autocommit=False, db=None, passwd=None, local_infile=False,
max_allowed_packet=16*1024*1024, defer_connect=False,
plugin_map={}):
auth_plugin_map={}):
"""
Establish a connection to the MySQL database. Accepts several
arguments:
Expand Down Expand Up @@ -571,12 +572,11 @@ def __init__(self, host=None, user=None, password="",
max_allowed_packet: Max size of packet sent to server in bytes. (default: 16MB)
defer_connect: Don't explicitly connect on contruction - wait for connect call.
(default: False)
plugin_map: Map of plugin names to a class that processes that plugin. The class
will take the Connection object as the argument to the constructor. The class
needs an authenticate method taking an authentication packet as an argument.
For the dialog plugin, a prompt(echo, prompt) method can be used (if no
authenticate method) for returning a string from the user.

auth_plugin_map: A dict of plugin names to a class that processes that plugin.
The class will take the Connection object as the argument to the constructor.
The class needs an authenticate method taking an authentication packet as
an argument. For the dialog plugin, a prompt(echo, prompt) method can be used
(if no authenticate method) for returning a string from the user. (experimental)
db: Alias for database. (for compatibility to MySQLdb)
passwd: Alias for password. (for compatibility to MySQLdb)
"""
Expand Down Expand Up @@ -672,7 +672,7 @@ def _config(key, arg):
self.sql_mode = sql_mode
self.init_command = init_command
self.max_allowed_packet = max_allowed_packet
self.plugin_map = plugin_map
self._auth_plugin_map = auth_plugin_map
if defer_connect:
self.socket = None
else:
Expand Down Expand Up @@ -726,7 +726,6 @@ def __del__(self):
def autocommit(self, value):
self.autocommit_mode = bool(value)
current = self.get_autocommit()
self.next_packet = 1
if value != current:
self._send_autocommit_mode()

Expand Down Expand Up @@ -816,15 +815,13 @@ def query(self, sql, unbuffered=False):
"You may not close previous cursor.")
# if DEBUG:
# print("DEBUG: sending query:", sql)
self.next_packet = 1
if isinstance(sql, text_type) and not (JYTHON or IRONPYTHON):
if PY2:
sql = sql.encode(self.encoding)
else:
sql = sql.encode(self.encoding, 'surrogateescape')
self._execute_command(COMMAND.COM_QUERY, sql)
self._affected_rows = self._read_query_result(unbuffered=unbuffered)
self.next_packet = 1
return self._affected_rows

def next_result(self, unbuffered=False):
Expand Down Expand Up @@ -892,7 +889,7 @@ def connect(self, sock=None):
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
self.socket = sock
self._rfile = _makefile(sock, 'rb')
self.next_packet = 0
self._next_seq_id = 0

self._get_server_information()
self._request_authentication()
Expand Down Expand Up @@ -933,16 +930,16 @@ def connect(self, sock=None):
# So just reraise it.
raise

def write_packet(self, data):
def write_packet(self, payload):
"""Writes an entire "mysql packet" in its entirety to the network
addings its length and sequence number. Intended for use by plugins
only.
addings its length and sequence number.
"""
data = pack_int24(len(data)) + int2byte(self.next_packet) + data
# Internal note: when you build packet manualy and calls _write_bytes()
# directly, you should set self._next_seq_id properly.
data = pack_int24(len(payload)) + int2byte(self._next_seq_id) + payload
if DEBUG: dump_packet(data)

self._write_bytes(data)
self.next_packet = (self.next_packet + 1) % 256
self._next_seq_id = (self._next_seq_id + 1) % 256

def _read_packet(self, packet_type=MysqlPacket):
"""Read an entire "mysql packet" in its entirety from the network
Expand All @@ -952,8 +949,14 @@ def _read_packet(self, packet_type=MysqlPacket):
while True:
packet_header = self._read_bytes(4)
if DEBUG: dump_packet(packet_header)

btrl, btrh, packet_number = struct.unpack('<HBB', packet_header)
bytes_to_read = btrl + (btrh << 16)
if packet_number != self._next_seq_id:
raise err.InternalError("Packet sequence number wrong - got %d expected %d" %
(packet_number, self._next_seq_id))
self._next_seq_id = (self._next_seq_id + 1) % 256

recv_data = self._read_bytes(bytes_to_read)
if DEBUG: dump_packet(recv_data)
buff += recv_data
Expand All @@ -962,13 +965,7 @@ def _read_packet(self, packet_type=MysqlPacket):
continue
if bytes_to_read < MAX_PACKET_LEN:
break
if packet_number != self.next_packet:
pass
#TODO: check sequence id
#raise err.InternalError("Packet sequence number wrong - got %d expected %d" %
# (packet_number, self.next_packet))

self.next_packet = (packet_number + 1) % 256
packet = packet_type(buff, self.encoding)
packet.check_error()
return packet
Expand Down Expand Up @@ -1027,33 +1024,32 @@ def _execute_command(self, command, sql):
if self._result is not None and self._result.unbuffered_active:
warnings.warn("Previous unbuffered result was left incomplete")
self._result._finish_unbuffered_query()
self._result = None

if isinstance(sql, text_type):
sql = sql.encode(self.encoding)

chunk_size = min(self.max_allowed_packet, len(sql) + 1) # +1 is for command
# +1 is for command
chunk_size = min(self.max_allowed_packet, len(sql) + 1)

# tiny optimization: build first packet manually instead of
# calling self..write_packet()
prelude = struct.pack('<iB', chunk_size, command)
self._write_bytes(prelude + sql[:chunk_size-1])
if DEBUG: dump_packet(prelude + sql)
packet = prelude + sql[:chunk_size-1]
self._write_bytes(packet)
if DEBUG: dump_packet(packet)
self._next_seq_id = 1

self.next_packet = 1
if chunk_size < self.max_allowed_packet:
return

seq_id = 1
sql = sql[chunk_size-1:]
while True:
chunk_size = min(self.max_allowed_packet, len(sql))
prelude = struct.pack('<i', chunk_size)[:3]
data = prelude + int2byte(seq_id%256) + sql[:chunk_size]
self._write_bytes(data)
if DEBUG: dump_packet(data)
self.write_packet(sql[:chunk_size])
sql = sql[chunk_size:]
if not sql and chunk_size < self.max_allowed_packet:
break
seq_id += 1
self.next_packet = seq_id%256

def _request_authentication(self):
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
Expand All @@ -1078,18 +1074,15 @@ def _request_authentication(self):

data = data_init + self.user + b'\0'

authresp = ''
if self.plugin_name == 'mysql_native_password':
authresp = b''
if self._auth_plugin_name == 'mysql_native_password':
authresp = _scramble(self.password.encode('latin1'), self.salt)

if self.server_capabilities & CLIENT.PLUGIN_AUTH_LENENC_CLIENT_DATA:
data += lenenc_int(len(authresp))
data += authresp
data += lenenc_int(len(authresp)) + authresp
elif self.server_capabilities & CLIENT.SECURE_CONNECTION:
length = len(authresp)
data += struct.pack('B', length)
data += authresp
else: # pragma: no cover - not testing against servers without secure auth (>=5.0)
data += struct.pack('B', len(authresp)) + authresp
else: # pragma: no cover - not testing against servers without secure auth (>=5.0)
data += authresp + b'\0'

if self.db and self.server_capabilities & CLIENT.CONNECT_WITH_DB:
Expand All @@ -1098,15 +1091,16 @@ def _request_authentication(self):
data += self.db + b'\0'

if self.server_capabilities & CLIENT.PLUGIN_AUTH:
data += self.plugin_name.encode('latin1') + b'\0'
name = self._auth_plugin_name
if isinstance(name, text_type):
name = name.encode('ascii')
data += name + b'\0'

self.write_packet(data)

auth_packet = self._read_packet()

# if authentication method isn't accepted the first byte
# will have the octet 254

if auth_packet.is_auth_switch_request():
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
auth_packet.read_uint8() # 0xfe packet identifier
Expand All @@ -1119,8 +1113,13 @@ def _request_authentication(self):
self.write_packet(data)
auth_packet = self._read_packet()

#TODO: ok packet or error packet?


def _process_auth(self, plugin_name, auth_packet):
plugin_class = self.plugin_map.get(plugin_name)
plugin_class = self._auth_plugin_map.get(plugin_name)
if not plugin_class:
plugin_class = self._auth_plugin_map.get(plugin_name.decode('ascii'))
if plugin_class:
try:
handler = plugin_class(self)
Expand Down Expand Up @@ -1246,18 +1245,13 @@ def _get_server_information(self):
server_end = data.find(b'\0', i)
if server_end < 0: # pragma: no cover - very specific upstream bug
# not found \0 and last field so take it all
self.plugin_name = data[i:].decode('latin1')
self._auth_plugin_name = data[i:].decode('latin1')
else:
self.plugin_name = data[i:server_end].decode('latin1')
else: # pragma: no cover - not testing against any plugin uncapable servers
self.plugin_name = ''
self._auth_plugin_name = data[i:server_end].decode('latin1')

def get_server_info(self):
return self.server_version

def get_plugin_name(self):
return self.plugin_name

Warning = err.Warning
Error = err.Error
InterfaceError = err.InterfaceError
Expand Down Expand Up @@ -1331,7 +1325,11 @@ def _read_ok_packet(self, first_packet):
def _read_load_local_packet(self, first_packet):
load_packet = LoadLocalPacketWrapper(first_packet)
sender = LoadLocalFile(load_packet.filename, self.connection)
sender.send_data()
try:
sender.send_data()
except:
self.connection._read_packet() # skip ok packet
raise

ok_packet = self.connection._read_packet()
if not ok_packet.is_ok_packet(): # pragma: no cover - upstream induced protocol error
Expand Down Expand Up @@ -1448,27 +1446,20 @@ def send_data(self):
"""Send data packets from the local file to the server"""
if not self.connection.socket:
raise err.InterfaceError("(0, '')")
conn = self.connection

# sequence id is 2 as we already sent a query packet
seq_id = 2
try:
with open(self.filename, 'rb') as open_file:
chunk_size = self.connection.max_allowed_packet
chunk_size = conn.max_allowed_packet
packet = b""

while True:
chunk = open_file.read(chunk_size)
if not chunk:
break
packet = struct.pack('<i', len(chunk))[:3] + int2byte(seq_id)
format_str = '!{0}s'.format(len(chunk))
packet += struct.pack(format_str, chunk)
self.connection._write_bytes(packet)
seq_id = (seq_id + 1) % 256
conn.write_packet(chunk)
except IOError:
raise err.OperationalError(1017, "Can't find file '{0}'".format(self.filename))
finally:
# send the empty packet to signify we are done sending data
packet = struct.pack('<i', 0)[:3] + int2byte(seq_id)
self.connection._write_bytes(packet)
self.next_packet = (seq_id + 1) % 256
conn.write_packet(b'')
Loading
0