8000 Add sha256 and chaching_sha2 auth support (#682) · jiangsanyin/PyMySQL@83a8c92 · GitHub
[go: up one dir, main page]

Skip to content

Commit 83a8c92

Browse files
authored
Add sha256 and chaching_sha2 auth support (PyMySQL#682)
1 parent 935afc8 commit 83a8c92

File tree

11 files changed

+449
-139
lines changed

11 files changed

+449
-139
lines changed

.gitignore

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
*.pyc
22
*.pyo
3-
__pycache__
4-
.coverage
5-
/dist
6-
/PyMySQL.egg-info
3+
/.cache
4+
/.coverage
5+
/.idea
76
/.tox
7+
/.venv
8+
/.vscode
9+
/PyMySQL.egg-info
810
/build
11+
/dist
12+
/docs/build
913
/pymysql/tests/databases.json
10-
11-
/.idea
12-
docs/build
14+
__pycache__

.travis.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,14 @@ matrix:
3535
python: "3.4"
3636
- env:
3737
- DB=mysql:8.0
38+
- TEST_AUTH=yes
3839
python: "3.7-dev"
3940

4041
# different py version from 5.6 and 5.7 as cache seems to be based on py version
4142
# http://dev.mysql.com/downloads/mysql/5.7.html has latest development release version
4243
# really only need libaio1 for DB builds however libaio-dev is whitelisted for container builds and liaio1 isn't
4344
install:
44-
- pip install -U coveralls unittest2 coverage
45+
- pip install -U coveralls unittest2 coverage cryptography pytest
4546

4647
before_script:
4748
- ./.travis/initializedb.sh
@@ -51,6 +52,9 @@ before_script:
5152

5253
script:
5354
- coverage run ./runtests.py
55+
- if [ "${TEST_AUTH}" = "yes" ];
56+
then pytest -v tests;
57+
fi
5458
- if [ ! -z "${DB}" ];
5559
then docker logs mysqld;
5660
fi

.travis/initializedb.sh

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,16 @@ if [ ! -z "${DB}" ]; then
3737
docker cp mysqld:/var/lib/mysql/server-cert.pem "${HOME}"
3838
docker cp mysqld:/var/lib/mysql/client-key.pem "${HOME}"
3939
docker cp mysqld:/var/lib/mysql/client-cert.pem "${HOME}"
40+
41+
# Test user for auth test
42+
mysql -e '
43+
CREATE USER
44+
user_sha256 IDENTIFIED WITH "sha256_password" BY "pass_sha256",
45+
nopass_sha256 IDENTIFIED WITH "sha256_password",
46+
user_caching_sha2 IDENTIFIED WITH "caching_sha2_password" BY "pass_caching_sha2",
47+
nopass_caching_sha2 IDENTIFIED WITH "caching_sha2_password"
48+
PASSWORD EXPIRE NEVER;'
49+
mysql -e 'GRANT RELOAD ON *.* TO user_caching_sha2;'
4050
else
4151
WITH_PLUGIN=''
4252
fi

pymysql/_auth.py

Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
"""
2+
Implements auth methods
3+
"""
4+
from ._compat import text_type
5+
from .constants import CLIENT
6+
from .err import OperationalError
7+
8+
from cryptography.hazmat.backends import default_backend
9+
from cryptography.hazmat.primitives import serialization, hashes
10+
from cryptography.hazmat.primitives.asymmetric import padding
11+
12+
from functools import partial
13+
import hashlib
14+
import struct
15+
16+
17+
DEBUG = True
18+
SCRAMBLE_LENGTH = 20
19+
sha1_new = partial(hashlib.new, 'sha1')
20+
21+
22+
# mysql_native_password
23+
# https://dev.mysql.com/doc/internals/en/secure-password-authentication.html#packet-Authentication::Native41
24+
25+
26+
def scramble_native_password(password, message):
27+
"""Scramble used for mysql_native_password"""
28+
if not password:
29+
return b''
30+
31+
stage1 = sha1_new(password).digest()
32+
stage2 = sha1_new(stage1).digest()
33+
s = sha1_new()
34+
s.update(message[:SCRAMBLE_LENGTH])
35+
s.update(stage2)
36+
result = s.digest()
37+
return _my_crypt(result, stage1)
38+
39+
40+
def _my_crypt(message1, message2):
41+
length = len(message1)
42+
result = b''
43+
for i in range(length):
44+
x = (
45+
struct.unpack('B', message1[i:i + 1])[0] ^
46+
struct.unpack('B', message2[i:i + 1])[0]
47+
)
48+
result += struct.pack('B', x)
49+
return result
50+
51+
52+
# old_passwords support ported from libmysql/password.c
53+
# https://dev.mysql.com/doc/internals/en/old-password-authentication.html
54+
55+
SCRAMBLE_LENGTH_323 = 8
56+
57+
58+
class RandStruct_323(object):
59+
60+
def __init__(self, seed1, seed2):
61+
self.max_value = 0x3FFFFFFF
62+
self.seed1 = seed1 % self.max_value
63+
self.seed2 = seed2 % self.max_value
64+
65+
def my_rnd(self):
66+
self.seed1 = (self.seed1 * 3 + self.seed2) % self.max_value
67+
self.seed2 = (self.seed1 + self.seed2 + 33) % self.max_value
68+
return float(self.seed1) / float(self.max_value)
69+
70+
71+
def scramble_old_password(password, message):
72+
"""Scramble for old_password"""
73+
hash_pass = _hash_password_323(password)
74+
hash_message = _hash_password_323(message[:SCRAMBLE_LENGTH_323])
75+
hash_pass_n = struct.unpack(">LL", hash_pass)
76+
hash_message_n = struct.unpack(">LL", hash_message)
77+
78+
rand_st = RandStruct_323(
79+
hash_pass_n[0] ^ hash_message_n[0], hash_pass_n[1] ^ hash_message_n[1]
80+
)
81+
outbuf = io.BytesIO()
82+
for _ in range(min(SCRAMBLE_LENGTH_323, len(message))):
83+
outbuf.write(int2byte(int(rand_st.my_rnd() * 31) + 64))
84+
extra = int2byte(int(rand_st.my_rnd() * 31))
85+
out = outbuf.getvalue()
86+
outbuf = io.BytesIO()
87+
for c in out:
88+
outbuf.write(int2byte(byte2int(c) ^ byte2int(extra)))
89+
return outbuf.getvalue()
90+
91+
92+
def _hash_password_323(password):
93+
nr = 1345345333
94+
add = 7
95+
nr2 = 0x12345671
96+
97+
# x in py3 is numbers, p27 is chars
98+
for c in [byte2int(x) for x in password if x not in (' ', '\t', 32, 9)]:
99+
nr ^= (((nr & 63) + add) * c) + (nr << 8) & 0xFFFFFFFF
100+
nr2 = (nr2 + ((nr2 << 8) ^ nr)) & 0xFFFFFFFF
101+
add = (add + c) & 0xFFFFFFFF
102+
103+
r1 = nr & ((1 << 31) - 1) # kill sign bits
104+
r2 = nr2 & ((1 << 31) - 1)
105+
return struct.pack(">LL", r1, r2)
106+
107+
108+
# sha256_password
109+
110+
111+
def _roundtrip(conn, send_data):
112+
conn.write_packet(send_data)
113+
pkt = conn._read_packet()
114+
pkt.check_error()
115+
return pkt
116+
117+
118+
def _xor_password(password, salt):
119+
password_bytes = bytearray(password)
120+
salt = bytearray(salt) # for PY2 compat.
121+
salt_len = len(salt)
122+
for i in range(len(password_bytes)):
123+
password_bytes[i] ^= salt[i % salt_len]
124+
return bytes(password_bytes)
125+
126+
127+
def sha2_rsa_encrypt(password, salt, public_key):
128+
"""Encrypt password with salt and public_key.
129+
130+
Used for sha256_password and caching_sha2_password.
131+
"""
132+
message = _xor_password(password + b'\0', salt)
133+
rsa_key = serialization.load_pem_public_key(public_key, default_backend())
134+
return rsa_key.encrypt(
135+
message,
136+
padding.OAEP(
137+
mgf=padding.MGF1(algorithm=hashes.SHA1()),
138+
algorithm=hashes.SHA1(),
139+
label=None,
140+
),
141+
)
142+
143+
144+
def sha256_password_auth(conn, pkt):
145+
if conn.ssl and conn.server_capabilities & CLIENT.SSL:
146+
if DEBUG:
147+
print("sha256: Sending plain password")
148+
data = conn.password + b'\0'
149+
return _roundtrip(conn, data)
150+
151+
if pkt.is_auth_switch_request():
152+
conn.salt = pkt.read_all()
153+
if not conn.server_public_key and conn.password:
154+
# Request server public key
155+
if DEBUG:
156+
print("sha256: Requesting server public key")
157+
pkt = _roundtrip(conn, b'\1')
158+
159+
if pkt.is_extra_auth_data():
160+
conn.server_public_key = pkt._data[1:]
161+
if DEBUG:
162+
print("Received public key:\n", conn.server_public_key.decode('ascii'))
163+
164+
if conn.password:
165+
if not conn.server_public_key:
166+
raise OperationalError("Couldn't receive server's public key")
167+
168+
data = sha2_rsa_encrypt(conn.password, conn.salt, conn.server_public_key)
169+
else:
170+
data = b''
171+
172+
return _roundtrip(conn, data)
173+
174+
175+
def scramble_caching_sha2(password, nonce):
176+
# (bytes, bytes) -> bytes
177+
"""Scramble algorithm used in cached_sha2_password fast path.
178+
179+
XOR(SHA256(password), SHA256(SHA256(SHA256(password)), nonce))
180+
"""
181+
if not password:
182+
return b''
183+
184+
p1 = hashlib.sha256(password).digest()
185+
p2 = hashlib.sha256(p1).digest()
186+
p3 = hashlib.sha256(p2 + nonce).digest()
187+
188+
res = bytearray(p1)
189+
for i in range(len(p3)):
190+
res[i] ^= p3[i]
191+
192+
return bytes(res)
193+
194+
195+
def caching_sha2_password_auth(conn, pkt):
196+
# No password fast path
197+
if not conn.password:
198+
return _roundtrip(conn, b'')
199+
200+
if pkt.is_auth_switch_request():
201+
# Try from fast auth
202+
if DEBUG:
203+
print("caching sha2: Trying fast path")
204+
conn.salt = pkt.read_all()
205+
scrambled = scramble_caching_sha2(conn.password, conn.salt)
206+
pkt = _roundtrip(conn, scrambled)
207+
# else: fast auth is tried in initial handshake
208+
209+
if not pkt.is_extra_auth_data():
210+
raise OperationalError(
211+
"caching sha2: Unknown packet for fast auth: %s" % pkt._data[:1]
212+
)
213+
214+
# magic numbers:
215+
# 2 - request public key
216+
# 3 - fast auth succeeded
217+
# 4 - need full auth
218+
219+
pkt.advance(1)
220+
n = pkt.read_uint8()
221+
222+
if n == 3:
223+
if DEBUG:
224+
print("caching sha2: succeeded by fast path.")
225+
pkt = conn._read_packet()
226+
pkt.check_error() # pkt must be OK packet
227+
return pkt
228+
229+
if n != 4:
230+
raise OperationalError("caching sha2: Unknwon result for fast auth: %s" % n)
231+
232+
if DEBUG:
233+
print("caching sha2: Trying full auth...")
234+
235+
if conn.ssl and conn.server_capabilities & CLIENT.SSL:
236+
if DEBUG:
237+
print("caching sha2: Sending plain password via SSL")
238+
return _roundtrip(conn, conn.password + b'\0')
239+
240+
if not conn.server_public_key:
241+
pkt = _roundtrip(conn, b'\x02') # Request public key
242+
if not pkt.is_extra_auth_data():
243+
raise OperationalError(
244+
"caching sha2: Unknown packet for public key: %s" % pkt._data[:1]
245+
)
246+
247+
conn.server_public_key = pkt._data[1:]
248+
if DEBUG:
249+
print(conn.server_public_key.decode('ascii'))
250+
251+
data = sha2_rsa_encrypt(conn.password, conn.salt, conn.server_public_key)
252+
pkt = _roundtrip(conn, data)

0 commit comments

Comments
 (0)
0