8000 fix auth_switch_request handling (#1200) · PyMySQL/PyMySQL@01af30f · GitHub
[go: up one dir, main page]

Skip to content

Commit 01af30f

Browse files
authored
fix auth_switch_request handling (#1200)
1 parent 53efd1e commit 01af30f

File tree

4 files changed

+38
-3
lines changed

4 files changed

+38
-3
lines changed

.coveragerc

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
branch = True
33
source =
44
pymysql
5+
tests
56
omit = pymysql/tests/*
67
pymysql/tests/thirdparty/test_MySQLdb/*
78

pymysql/_auth.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@ def sha256_password_auth(conn, pkt):
166166

167167
if pkt.is_auth_switch_request():
168168
conn.salt = pkt.read_all()
169+
if conn.salt.endswith(b"\0"):
170+
conn.salt = conn.salt[:-1]
169171
if not conn.server_public_key and conn.password:
170172
# Request server public key
171173
if DEBUG:
@@ -215,9 +217,11 @@ def caching_sha2_password_auth(conn, pkt):
215217

216218
if pkt.is_auth_switch_request():
217219
# Try from fast auth
218-
if DEBUG:
219-
print("caching sha2: Trying fast path")
220220
conn.salt = pkt.read_all()
221+
if conn.salt.endswith(b"\0"): # str.removesuffix is available in 3.9
222+
conn.salt = conn.salt[:-1]
223+
if DEBUG:
224+
print(f"caching sha2: Trying fast path. salt={conn.salt.hex()!r}")
221225
scrambled = scramble_caching_sha2(conn.password, conn.salt)
222226
pkt = _roundtrip(conn, scrambled)
223227
# else: fast auth is tried in initial handshake

pymysql/connections.py

+4
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
DEFAULT_USER = None
4848

4949
DEBUG = False
50+
_DEFAULT_AUTH_PLUGIN = None # if this is not None, use it instead of server's default.
5051

5152
TEXT_TYPES = {
5253
FIELD_TYPE.BIT,
@@ -1158,6 +1159,9 @@ def _get_server_information(self):
11581159
else:
11591160
self._auth_plugin_name = data[i:server_end].decode("utf-8")
11601161

1162+
if _DEFAULT_AUTH_PLUGIN is not None: # for tests
1163+
self._auth_plugin_name = _DEFAULT_AUTH_PLUGIN
1164+
11611165
def get_server_info(self):
11621166
return self.server_version
11631167

tests/test_auth.py

+27-1
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,19 @@ def test_caching_sha2_password():
7171
con.query("FLUSH PRIVILEGES")
7272
con.close()
7373

74+
# Fast path after auth_switch_request
75+
pymysql.connections._DEFAULT_AUTH_PLUGIN = "mysql_native_password"
76+
con = pymysql.connect(
77+
user="user_caching_sha2",
78+
password=pass_caching_sha2,
79+
host=host,
80+
port=port,
81+
ssl=ssl,
82+
)
83+
con.query("FLUSH PRIVILEGES")
84+
con.close()
85+
pymysql.connections._DEFAULT_AUTH_PLUGIN = None
86+
7487

7588
def test_caching_sha2_password_ssl():
7689
con = pymysql.connect(
@@ -88,7 +101,20 @@ def test_caching_sha2_password_ssl():
88101
password=pass_caching_sha2,
89102
host=host,
90103
port=port,
91-
ssl=None,
104+
ssl=ssl,
105+
)
106+
con.query("FLUSH PRIVILEGES")
107+
con.close()
108+
109+
# Fast path after auth_switch_request
110+
pymysql.connections._DEFAULT_AUTH_PLUGIN = "mysql_native_password"
111+
con = pymysql.connect(
112+
user="user_caching_sha2",
113+
password=pass_caching_sha2,
114+
host=host,
115+
port=port,
116+
ssl=ssl,
92117
)
93118
con.query("FLUSH PRIVILEGES")
94119
con.close()
120+
pymysql.connections._DEFAULT_AUTH_PLUGIN = None

0 commit comments

Comments
 (0)
0