8000 Add ssl_key_password param (#1145) · PyMySQL/PyMySQL@d206182 · GitHub
[go: up one dir, main page]

Skip to content

Commit d206182

Browse files
authored
Add ssl_key_password param (#1145)
Add support for SSL private key password in Connection class to handle encrypted keys. Co-authored-by: Sergei Vaskov <self@insightoutofspace.com>
1 parent 84d3f93 commit d206182

File tree

2 files changed

+82
-9
lines changed

2 files changed

+82
-9
lines changed

pymysql/connections.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ class Connection:
135135
:param ssl_disabled: A boolean value that disables usage of TLS.
136136
:param ssl_key: Path to the file that contains a PEM-formatted private key for
137137
the client certificate.
138+
:param ssl_key_password: The password for the client certificate private key.
138139
:param ssl_verify_cert: Set to true to check the server certificate's validity.
139140
:param ssl_verify_identity: Set to true to check the server's identity.
140141
:param read_default_group: Group to read from in the configuration file.
@@ -201,6 +202,7 @@ def __init__(
201202
ssl_cert=None,
202203
ssl_disabled=None,
203204
ssl_key=None,
205+
ssl_key_password=None,
204206
ssl_verify_cert=None,
205207
ssl_verify_identity=None,
206208
compress=None, # not supported
@@ -262,7 +264,7 @@ def _config(key, arg):
262264
if not ssl:
263265
ssl = {}
264266
if isinstance(ssl, dict):
265-
for key in ["ca", "capath", "cert", "key", "cipher"]:
267+
for key in ["ca", "capath", "cert", "key", "password", "cipher"]:
266268
value = _config("ssl-" + key, ssl.get(key))
267269
if value:
268270
ssl[key] = value
@@ -281,6 +283,8 @@ def _config(key, arg):
281283
ssl["cert"] = ssl_cert
282284
if ssl_key is not None:
283285
ssl["key"] = ssl_key
286+
if ssl_key_password is not None:
287+
ssl["password"] = ssl_key_password
284288
if ssl:
285289
if not SSL_ENABLED:
286290
raise NotImplementedError("ssl module not found")
@@ -389,7 +393,9 @@ def _create_ssl_ctx(self, sslp):
389393
else:
390394
ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED
391395
if "cert" in sslp:
392-
ctx.load_cert_chain(sslp["cert"], keyfile=sslp.get("key"))
396+
ctx.load_cert_chain(
397+
sslp["cert"], keyfile=sslp.get("key"), password=sslp.get("password")
398+
)
393399
if "cipher" in sslp:
394400
ctx.set_ciphers(sslp["cipher"])
395401
ctx.options |= ssl.OP_NO_SSLv2

pymysql/tests/test_connection.py

+74-7
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,11 @@ def test_ssl_connect(self):
574574
assert create_default_context.called
575575
assert dummy_ssl_context.check_hostname
576576
assert dummy_ssl_context.verify_mode == ssl.CERT_REQUIRED
577-
dummy_ssl_context.load_cert_chain.assert_called_with("cert", keyfile="key")
577+
dummy_ssl_context.load_cert_chain.assert_called_with(
578+
"cert",
579+
keyfile="key",
580+
password=None,
581+
)
578582
dummy_ssl_context.set_ciphers.assert_called_with("cipher")
579583

580584
dummy_ssl_context = mock.Mock(options=0)
@@ -592,7 +596,34 @@ def test_ssl_connect(self):
592596
assert create_default_context.called
593597
assert dummy_ssl_context.check_hostname
594598
assert dummy_ssl_context.verify_mode == ssl.CERT_REQUIRED
595-
dummy_ssl_context.load_cert_chain.assert_called_with("cert", keyfile="key")
599+
dummy_ssl_context.load_cert_chain.assert_called_with(
600+
"cert",
601+
keyfile="key",
602+
password=None,
603+
)
604+
dummy_ssl_context.set_ciphers.assert_not_called
605+
606+
dummy_ssl_context = mock.Mock(options=0)
607+
wi 6D40 th mock.patch("pymysql.connections.Connection.connect"), mock.patch(
608+
"pymysql.connections.ssl.create_default_context",
609+
new=mock.Mock(return_value=dummy_ssl_context),
610+
) as create_default_context:
611+
pymysql.connect(
612+
ssl={
613+
"ca": "ca",
614+
"cert": "cert",
615+
"key": "key",
616+
"password": "password",
617+
},
618+
)
619+
assert create_default_context.called
620+
assert dummy_ssl_context.check_hostname
621+
assert dummy_ssl_context.verify_mode == ssl.CERT_REQUIRED
622+
dummy_ssl_context.load_cert_chain.assert_called_with(
623+
"cert",
624+
keyfile="key",
625+
password="password",
626+
)
596627
dummy_ssl_context.set_ciphers.assert_not_called
597628

598629
dummy_ssl_context = mock.Mock(options=0)
@@ -622,7 +653,11 @@ def test_ssl_connect(self):
622653
assert create_default_context.called
623654
assert not dummy_ssl_context.check_hostname
624655
assert dummy_ssl_context.verify_mode == ssl.CERT_NONE
625-
dummy_ssl_context.load_cert_chain.assert_called_with("cert", keyfile="key")
656+
dummy_ssl_context.load_cert_chain.assert_called_with(
657+
"cert",
658+
keyfile="key",
659+
password=None,
660+
)
626661
dummy_ssl_context.set_ciphers.assert_not_called
627662

628663
for ssl_verify_cert in (True, "1", "yes", "true"):
@@ -640,7 +675,9 @@ def test_ssl_connect(self):
640675
assert not dummy_ssl_context.check_hostname
641676
assert dummy_ssl_context.verify_mode == ssl.CERT_REQUIRED
642677
dummy_ssl_context.load_cert_chain.assert_called_with(
643-
"cert", keyfile="key"
678+
"cert",
679+
keyfile="key",
680+
password=None,
644681
)
645682
dummy_ssl_context.set_ciphers.assert_not_called
646683

@@ -659,7 +696,9 @@ def test_ssl_connect(self):
659696
assert not dummy_ssl_context.check_hostname
660697
assert dummy_ssl_context.verify_mode == ssl.CERT_NONE
661698
dummy_ssl_context.load_cert_chain.assert_called_with(
662-
"cert", keyfile="key"
699+
"cert",
700+
keyfile="key",
701+
password=None,
663702
)
664703
dummy_ssl_context.set_ciphers.assert_not_called
665704

@@ -682,7 +721,9 @@ def test_ssl_connect(self):
682721
ssl.CERT_REQUIRED if ssl_ca is not None else ssl.CERT_NONE
683722
), (ssl_ca, ssl_verify_cert)
684723
dummy_ssl_context.load_cert_chain.assert_called_with(
685-
"cert", keyfile="key"
724+
"cert",
725+
keyfile="key",
726+
password=None,
686727
)
687728
dummy_ssl_context.set_ciphers.assert_not_called
688729

@@ -700,7 +741,33 @@ def test_ssl_connect(self):
700741
assert create_default_context.called
701742
assert dummy_ssl_context.check_hostname
702743
assert dummy_ssl_context.verify_mode == ssl.CERT_NONE
703-
dummy_ssl_context.load_cert_chain.assert_called_with("cert", keyfile="key")
744+
dummy_ssl_context.load_cert_chain.assert_called_with(
745+
"cert",
746+
keyfile="key",
747+
password=None,
748+
)
749+
dummy_ssl_context.set_ciphers.assert_not_called
750+
751+
dummy_ssl_context = mock.Mock(options=0)
752+
with mock.patch("pymysql.connections.Connection.connect"), mock.patch(
753+
"pymysql.connections.ssl.create_default_context",
754+
new=mock.Mock(return_value=dummy_ssl_context),
755+
) as create_default_context:
756+
pymysql.connect(
757+
ssl_ca="ca",
758+
ssl_cert="cert",
759+
ssl_key="key",
760+
ssl_key_password="password",
761+
ssl_verify_identity=True,
762+
)
763+
assert create_default_context.called
764+
assert dummy_ssl_context.check_hostname
765+
assert dummy_ssl_context.verify_mode == ssl.CERT_NONE
766+
dummy_ssl_context.load_cert_chain.assert_called_with(
767+
"cert",
768+
keyfile="key",
769+
password="password",
770+
)
704771
dummy_ssl_context.set_ciphers.assert_not_called
705772

706773
dummy_ssl_context = mock.Mock(options=0)

0 commit comments

Comments
 (0)
0