8000 Add support for SHA256 auth plugin · PyMySQL/PyMySQL@5915482 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5915482

Browse files
elemountmethane
authored andcommitted
Add support for SHA256 auth plugin
1 parent 14e4c25 commit 5915482

File tree

Expand file tree

5 files changed

+107
-18
lines changed

5 files changed

+107
-18
lines changed

pymysql/auth/__init__.py

Whitespace-only changes.
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from ..constants import CLIENT
2+
from ..err import OperationalError
3+
4+
# Import cryptography for RSA_PKCS1_OAEP_PADDING algorithm
5+
# which is needed when use sha256_password_plugin with no SSL
6+
try:
7+
from cryptography.hazmat.backends import default_backend
8+
from cryptography.hazmat.primitives import serialization, hashes
9+
from cryptography.hazmat.primitives.asymmetric import padding
10+
HAVE_CRYPTOGRAPHY = True
11+
except ImportError:
12+
HAVE_CRYPTOGRAPHY = False
13+
14+
15+
def _xor_password(password, salt):
16+
password_bytes = bytearray(password, 'ascii')
17+
salt_len = len(salt)
18+
for i in range(len(password_bytes)):
19+
password_bytes[i] ^= ord(salt[i % salt_len])
20+
return password_bytes
21+
22+
23+
def _sha256_rsa_crypt(password, salt, public_key):
24+
if not HAVE_CRYPTOGRAPHY:
25+
raise OperationalError("cryptography module not found for sha256_password_plugin")
26+
message = _xor_password(password + b'\0', salt)
27+
rsa_key = serialization.load_pem_public_key(public_key, default_backend())
28+
return rsa_key.encrypt(
29+
message.decode('latin1').encode('latin1'), padding.OAEP(
30+
mgf=padding.MGF1(algorithm=hashes.SHA1()),
31+
algorithm=hashes.SHA1(),
32+
label=None))
33+
34+
35+
class SHA256PasswordPlugin(object):
36+
def __init__(self, con):
37+
self.con = con
38+
39+
def authenticate(self, pkt):
40+
if self.con.ssl and self.con.server_capabilities & CLIENT.SSL:
41+
data = self.con.password.encode('latin1') + b'\0'
42+
else:
43+
if pkt.is_auth_switch_request():
44+
self.con.salt = pkt.read_all()
45+
if self.con.server_public_key == '':
46+
self.con.write_packet(b'\1')
47+
pkt = self.con._read_packet()
48+
if pkt.is_extra_auth_data() and self.con.server_public_key == '':
49+
pkt.read_uint8()
50+
self.con.server_public_key = pkt.read_all()
51+
data = _sha256_rsa_crypt(
52+
self.con.password,
53+
self.con.salt,
54+
self.con.server_public_key)
55+
self.con.write_packet(data)
56+
pkt = self.con._read_packet()
57+
pkt.check_error()
58+
return pkt

pymysql/connections.py

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@
1616
import traceback
1717
import warnings
1818

19+
<<<<<<< HEAD
1920
from .charset import charset_by_name, charset_by_id
21+
=======
22+
from .auth import sha256_password_plugin as _auth
23+
from .charset import MBLENGTH, charset_by_name, charset_by_id
24+
>>>>>>> 695df63... Add support for SHA256 auth plugin
2025
from .constants import CLIENT, COMMAND, CR, FIELD_TYPE, SERVER_STATUS
2126
from . import converters
2227
from .cursors import Cursor
@@ -43,7 +48,6 @@
4348
# KeyError occurs when there's no entry in OS database for a current user.
4449
DEFAULT_USER = None
4550

46-
4751
DEBUG = False
4852

4953
_py_version = sys.version_info[:2]
@@ -107,7 +111,6 @@ def _scramble(password, message):
107111
result = s.digest()
108112
return _my_crypt(result, stage1)
109113

110-
111114
def _my_crypt(message1, message2):
112115
length = len(message1)
113116
result = b''
@@ -186,6 +189,7 @@ def lenenc_int(i):
186189
else:
187190
raise ValueError("Encoding %x is larger than %x - no representation in LengthEncodedInteger" % (i, (1 << 64)))
188191

192+
189193
class Connection(object):
190194
"""
191195
Representation of a socket with a mysql server.
@@ -240,6 +244,7 @@ class Connection(object):
240244
The class needs an authenticate method taking an authentication packet as
241245
an argument. For the dialog plugin, a prompt(echo, prompt) method can be used
242246
(if no authenticate method) for returning a string from the user. (experimental)
247+
:param server_public_key: SHA256 authenticaiton plugin public key value. (default: '')
243248
:param db: Alias for database. (for compatibility to MySQLdb)
244249
:param passwd: Alias for password. (for compatibility to MySQLdb)
245250
:param binary_prefix: Add _binary prefix on bytes and bytearray. (default: False)
@@ -262,7 +267,7 @@ def __init__(self, host=None, user=None, password="",
262267
autocommit=False, db=None, passwd=None, local_infile=False,
263268
max_allowed_packet=16*1024*1024, defer_connect=False,
264269
auth_plugin_map={}, read_timeout=None, write_timeout=None,
265-
bind_address=None, binary_prefix=False):
270+
bind_address=None, binary_prefix=False, server_public_key=''):
266271
if no_delay is not None:
267272
warnings.warn("no_delay option is deprecated", DeprecationWarning)
268273

@@ -379,6 +384,9 @@ def _config(key, arg):
379384
self.max_allowed_packet = max_allowed_packet
380385
self._auth_plugin_map = auth_plugin_map
381386
self._binary_prefix = binary_prefix
387+
if b"sha256_password" not in self._auth_plugin_map:
388+
self._auth_plugin_map[b"sha256_password"] = _auth.SHA256PasswordPlugin
389+
self.server_public_key = server_public_key
382390
if defer_connect:
383391
self._sock = None
384392
else:
@@ -507,7 +515,7 @@ def select_db(self, db):
507515

508516
def escape(self, obj, mapping=None):
509517
"""Escape whatever value you pass to it.
510-
518+
511519
Non-standard, for internal use; do not use this in your applications.
512520
"""
513521
if isinstance(obj, str_type):
@@ -521,7 +529,7 @@ def escape(self, obj, mapping=None):
521529

522530
def literal(self, obj):
523531
"""Alias for escape()
524-
532+
525533
Non-standard, for internal use; do not use this in your applications.
526534
"""
527535
return self.escape(obj, self.encoders)
@@ -861,6 +869,14 @@ def _request_authentication(self):
861869
authresp = b''
862870
if self._auth_plugin_name in ('', 'mysql_native_password'):
863871
authresp = _scramble(self.password.encode('latin1'), self.salt)
872+
elif self._auth_plugin_name == 'sha256_password':
873+
if self.ssl and self.server_capabilities & CLIENT.SSL:
874+
authresp = self.password.encode('latin1') + b'\0'
875+
else:
876+
if self.password is not None:
877+
authresp = b'\1'
878+
else:
879+
authresp = b'\0'
864880

865881
if self.server_capabilities & CLIENT.PLUGIN_AUTH_LENENC_CLIENT_DATA:
866882
data += lenenc_int(len(authresp)) + authresp
@@ -896,24 +912,20 @@ def _request_authentication(self):
896912
data = _scramble_323(self.password.encode('latin1'), self.salt) + b'\0'
897913
self.write_packet(data)
898914
auth_packet = self._read_packet()
915+
elif auth_packet.is_extra_auth_data():
916+
# https://dev.mysql.com/doc/internals/en/successful-authentication.html
917+
handler = self._get_auth_plugin_handler(self._auth_plugin_name)
918+
handler.authenticate(auth_packet)
899919

900920
def _process_auth(self, plugin_name, auth_packet):
901-
plugin_class = self._auth_plugin_map.get(plugin_name)
902-
if not plugin_class:
903-
plugin_class = self._auth_plugin_map.get(plugin_name.decode('ascii'))
904-
if plugin_class:
921+
handler = self._get_auth_plugin_handler(plugin_name)
922+
if handler != None:
905923
try:
906-
handler = plugin_class(self)
907924
return handler.authenticate(auth_packet)
908925
except AttributeError:
909926
if plugin_name != b'dialog':
910927
raise err.OperationalError(2059, "Authentication plugin '%s'" \
911928
" not loaded: - %r missing authenticate method" % (plugin_name, plugin_class))
912-
except TypeError:
913-
raise err.OperationalError(2059, "Authentication plugin '%s'" \
914-
" not loaded: - %r cannot be constructed with connection object" % (plugin_name, plugin_class))
915-
else:
916-
handler = None
917929
if plugin_name == b"mysql_native_password":
918930
# https://dev.mysql.com/doc/internals/en/secure-password-authentication.html#packet-Authentication::Native41
919931
data = _scramble(self.password.encode('latin1'), auth_packet.read_all())
@@ -958,6 +970,20 @@ def _process_auth(self, plugin_name, auth_packet):
958970
pkt = self._read_packet()
959971
pkt.check_error()
960972
return pkt
973+
974+
def _get_auth_plugin_handler(self, plugin_name):
975+
plugin_class = self._auth_plugin_map.get(plugin_name)
976+
if not plugin_class:
977+
plugin_class = self._auth_plugin_map.get(plugin_name.decode('ascii'))
978+
if plugin_class:
979+
try:
980+
handler = plugin_class(self)
981+
except TypeError:
982+
raise err.OperationalError(2059, "Authentication plugin '%s'" \
983+
" not loaded: - %r cannot be constructed with connection object" % (plugin_name, plugin_class))
984+
else:
985+
handler = None
986+
return handler
961987

962988
# _mysql support
963989
def thread_id(self):
@@ -1232,7 +1258,7 @@ def _get_descriptions(self):
12321258
# This behavior is different from TEXT / BLOB.
12331259
# We should decode result by connection encoding regardless charsetnr.
12341260
# See https://github.com/PyMySQL/PyMySQL/issues/488
1235-
encoding = conn_encoding # SELECT CAST(... AS JSON)
1261+
encoding = conn_encoding # SELECT CAST(... AS JSON)
12361262
elif field_type in TEXT_TYPES:
12371263
if field.charsetnr == 63: # binary
12381264
# TEXTs with charset=binary means BINARY types.

pymysql/protocol.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,10 @@ def is_auth_switch_request(self):
196196
# http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
197197
return self._data[0:1] == b'\xfe'
198198

199+
def is_extra_auth_data(self):
200+
# https://dev.mysql.com/doc/internals/en/successful-authentication.html
201+
return self._data[0:1] == b'\x01'
202+
199203
def is_resultset_packet(self):
200204
field_count = ord(self._data[0:1])
201205
return 1 <= field_count <= 250

pymysql/tests/test_connection.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -360,11 +360,12 @@ def testAuthSHA256(self):
360360
else:
361361
c.execute('SET old_passwords = 2')
362362
c.execute("SET PASSWORD FOR 'pymysql_sha256'@'localhost' = PASSWORD('Sh@256Pa33')")
363+
c.execute("FLUSH PRIVILEGES")
363364
db = self.db.copy()
364365
db['password'] = "Sh@256Pa33"
365-
# not implemented yet so thows error
366+
# Although SHA256 is supported, need the configuration of public key of the mysql server. Currently w 597D ill get error by this test.
366367
with self.assertRaises(pymysql.err.OperationalError):
367-
pymysql.connect(user='pymysql_256', **db)
368+
pymysql.connect(user='pymysql_sha256', **db)
368369

369370
class TestConnection(base.PyMySQLTestCase):
370371

0 commit comments

Comments
 (0)
0