8000 Added ssl_key_password param by svaskov · Pull Request #1145 · PyMySQL/PyMySQL · GitHub
[go: up one dir, main page]

Skip to content

Added ssl_key_password param #1145

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions pymysql/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ class Connection:
:param ssl_disabled: A boolean value that disables usage of TLS.
:param ssl_key: Path to the file that contains a PEM-formatted private key for
the client certificate.
:param ssl_key_password: The password for the client certificate private key.
:param ssl_verify_cert: Set to true to check the server certificate's validity.
:param ssl_verify_identity: Set to true to check the server's identity.
:param read_default_group: Group to read from in the configuration file.
Expand Down Expand Up @@ -201,6 +202,7 @@ def __init__(
ssl_cert=None,
ssl_disabled=None,
ssl_key=None,
ssl_key_password=None,
ssl_verify_cert=None,
ssl_verify_identity=None,
compress=None, # not supported
Expand Down Expand Up @@ -262,7 +264,7 @@ def _config(key, arg):
if not ssl:
ssl = {}
if isinstance(ssl, dict):
for key in ["ca", "capath", "cert", "key", "cipher"]:
for key in ["ca", "capath", "cert", "key", "password", "cipher"]:
value = _config("ssl-" + key, ssl.get(key))
if value:
ssl[key] = value
Expand All @@ -281,6 +283,8 @@ def _config(key, arg):
ssl["cert"] = ssl_cert
if ssl_key is not None:
ssl["key"] = ssl_key
if ssl_key_password is not None:
ssl["password"] = ssl_key_password
if ssl:
if not SSL_ENABLED:
raise NotImplementedError("ssl module not found")
Expand Down Expand Up @@ -389,7 +393,9 @@ def _create_ssl_ctx(self, sslp):
else:
ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED
if "cert" in sslp:
ctx.load_cert_chain(sslp["cert"], keyfile=sslp.get("key"))
ctx.load_cert_chain(
sslp["cert"], keyfile=sslp.get("key"), password=sslp.get("password")
)
if "cipher" in sslp:
ctx.set_ciphers(sslp["cipher"])
ctx.options |= ssl.OP_NO_SSLv2
Expand Down
81 changes: 74 additions & 7 deletions pymysql/tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,11 @@ def test_ssl_connect(self):
assert create_default_context.called
assert dummy_ssl_context.check_hostname
assert dummy_ssl_context.verify_mode == ssl.CERT_REQUIRED
dummy_ssl_context.load_cert_chain.assert_called_with("cert", keyfile="key")
dummy_ssl_context.load_cert_chain.assert_called_with(
10000 "cert",
keyfile="key",
password=None,
)
dummy_ssl_context.set_ciphers.assert_called_with("cipher")

dummy_ssl_context = mock.Mock(options=0)
Expand All @@ -592,7 +596,34 @@ def test_ssl_connect(self):
assert create_default_context.called
assert dummy_ssl_context.check_hostname
assert dummy_ssl_context.verify_mode == ssl.CERT_REQUIRED
dummy_ssl_context.load_cert_chain.assert_called_with("cert", keyfile="key")
dummy_ssl_context.load_cert_chain.assert_called_with(
"cert",
keyfile="key",
password=None,
)
dummy_ssl_context.set_ciphers.assert_not_called

dummy_ssl_context = mock.Mock(options=0)
with mock.patch("pymysql.connections.Connection.connect"), mock.patch(
"pymysql.connections.ssl.create_default_context",
new=mock.Mock(return_value=dummy_ssl_context),
) as create_default_context:
pymysql.connect(
ssl={
"ca": "ca",
"cert": "cert",
"key": "key",
"password": "password",
},
)
assert create_default_context.called
assert dummy_ssl_context.check_hostname
assert dummy_ssl_context.verify_mode == ssl.CERT_REQUIRED
dummy_ssl_context.load_cert_chain.assert_called_with(
"cert",
keyfile="key",
password="password",
)
dummy_ssl_context.set_ciphers.assert_not_called

dummy_ssl_context = mock.Mock(options=0)
Expand Down Expand Up @@ -622,7 +653,11 @@ def test_ssl_connect(self):
assert create_default_context.called
assert not dummy_ssl_context.check_hostname
assert dummy_ssl_context.verify_mode == ssl.CERT_NONE
dummy_ssl_context.load_cert_chain.assert_called_with("cert", keyfile="key")
dummy_ssl_context.load_cert_chain.assert_called_with(
"cert",
keyfile="key",
password=None,
)
dummy_ssl_context.set_ciphers.assert_not_called

for ssl_verify_cert in (True, "1", "yes", "true"):
Expand All @@ -640,7 +675,9 @@ def test_ssl_connect(self):
assert not dummy_ssl_context.check_hostname
assert dummy_ssl_context.verify_mode == ssl.CERT_REQUIRED
dummy_ssl_context.load_cert_chain.assert_called_with(
"cert", keyfile="key"
"cert",
keyfile="key",
password=None,
)
dummy_ssl_context.set_ciphers.assert_not_called

Expand All @@ -659,7 +696,9 @@ def test_ssl_connect(self):
assert not dummy_ssl_context.check_hostname
assert dummy_ssl_context.verify_mode == ssl.CERT_NONE
dummy_ssl_context.load_cert_chain.assert_called_with(
"cert", keyfile="key"
"cert",
keyfile="key",
password=None,
)
dummy_ssl_context.set_ciphers.assert_not_called

Expand All @@ -682,7 +721,9 @@ def test_ssl_connect(self):
ssl.CERT_REQUIRED if ssl_ca is not None else ssl.CERT_NONE
), (ssl_ca, ssl_verify_cert)
dummy_ssl_context.load_cert_chain.assert_called_with(
"cert", keyfile="key"
"cert",
keyfile="key",
password=None,
)
dummy_ssl_context.set_ciphers.assert_not_called

Expand All @@ -700,7 +741,33 @@ def test_ssl_connect(self):
assert create_default_context.called
assert dummy_ssl_context.check_hostname
assert dummy_ssl_context.verify_mode == ssl.CERT_NONE
dummy_ssl_context.load_cert_chain.assert_called_with("cert", keyfile="key")
dummy_ssl_context.load_cert_chain.assert_called_with(
"cert",
keyfile="key",
password=None,
)
dummy_ssl_context.set_ciphers.assert_not_called

dummy_ssl_context = mock.Mock(options=0)
with mock.patch("pymysql.connections.Connection.connect"), mock.patch(
"pymysql.connections.ssl.create_default_context",
new=mock.Mock(return_value=dummy_ssl_context),
) as create_default_context:
pymysql.connect(
ssl_ca="ca",
ssl_cert="cert",
ssl_key="key",
ssl_key_password="password",
ssl_verify_identity=True,
)
assert create_default_context.called
assert dummy_ssl_context.check_hostname
assert dummy_ssl_context.verify_mode == ssl.CERT_NONE
dummy_ssl_context.load_cert_chain.assert_called_with(
"cert",
keyfile="key",
password="password",
)
dummy_ssl_context.set_ciphers.assert_not_called

dummy_ssl_context = mock.Mock(options=0)
Expand Down
0