From a547c2ff0b3c6f43201e913c970fd5ad20e9795f Mon Sep 17 00:00:00 2001 From: Sergei Vaskov Date: Wed, 15 Nov 2023 13:57:35 +0200 Subject: [PATCH 1/2] Add support for SSL key password in Connection class. --- pymysql/connections.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pymysql/connections.py b/pymysql/connections.py index 843bea5e..c5bbdf2b 100644 --- a/pymysql/connections.py +++ b/pymysql/connections.py @@ -135,6 +135,7 @@ class Connection: :param ssl_disabled: A boolean value that disables usage of TLS. :param ssl_key: Path to the file that contains a PEM-formatted private key for the client certificate. + :param ssl_key_password: The password for the client certificate private key. :param ssl_verify_cert: Set to true to check the server certificate's validity. :param ssl_verify_identity: Set to true to check the server's identity. :param read_default_group: Group to read from in the configuration file. @@ -201,6 +202,7 @@ def __init__( ssl_cert=None, ssl_disabled=None, ssl_key=None, + ssl_key_password=None, ssl_verify_cert=None, ssl_verify_identity=None, compress=None, # not supported @@ -281,6 +283,8 @@ def _config(key, arg): ssl["cert"] = ssl_cert if ssl_key is not None: ssl["key"] = ssl_key + if ssl_key_password is not None: + ssl["password"] = ssl_key_password if ssl: if not SSL_ENABLED: raise NotImplementedError("ssl module not found") @@ -389,7 +393,7 @@ def _create_ssl_ctx(self, sslp): else: ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED if "cert" in sslp: - ctx.load_cert_chain(sslp["cert"], keyfile=sslp.get("key")) + ctx.load_cert_chain(sslp["cert"], keyfile=sslp.get("key"), password=sslp.get("password")) if "cipher" in sslp: ctx.set_ciphers(sslp["cipher"]) ctx.options |= ssl.OP_NO_SSLv2 From 3de1c6cf40410e6de048f3eec722f85e2ca85167 Mon Sep 17 00:00:00 2001 From: Sergei Vaskov Date: Wed, 15 Nov 2023 19:53:02 +0200 Subject: [PATCH 2/2] 1145 test and linter fixes (#1) Fixed tests and linter issues caused by adding SSL key password parameter. --- pymysql/connections.py | 6 ++- pymysql/tests/test_connection.py | 81 +++++++++++++++++++++++++++++--- 2 files changed, 78 insertions(+), 9 deletions(-) diff --git a/pymysql/connections.py b/pymysql/connections.py index c5bbdf2b..7e12e169 100644 --- a/pymysql/connections.py +++ b/pymysql/connections.py @@ -264,7 +264,7 @@ def _config(key, arg): if not ssl: ssl = {} if isinstance(ssl, dict): - for key in ["ca", "capath", "cert", "key", "cipher"]: + for key in ["ca", "capath", "cert", "key", "password", "cipher"]: value = _config("ssl-" + key, ssl.get(key)) if value: ssl[key] = value @@ -393,7 +393,9 @@ def _create_ssl_ctx(self, sslp): else: ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED if "cert" in sslp: - ctx.load_cert_chain(sslp["cert"], keyfile=sslp.get("key"), password=sslp.get("password")) + ctx.load_cert_chain( + sslp["cert"], keyfile=sslp.get("key"), password=sslp.get("password") + ) if "cipher" in sslp: ctx.set_ciphers(sslp["cipher"]) ctx.options |= ssl.OP_NO_SSLv2 diff --git a/pymysql/tests/test_connection.py b/pymysql/tests/test_connection.py index 0803efc9..ccfc4a32 100644 --- a/pymysql/tests/test_connection.py +++ b/pymysql/tests/test_connection.py @@ -574,7 +574,11 @@ def test_ssl_connect(self): assert create_default_context.called assert dummy_ssl_context.check_hostname assert dummy_ssl_context.verify_mode == ssl.CERT_REQUIRED - dummy_ssl_context.load_cert_chain.assert_called_with("cert", keyfile="key") + dummy_ssl_context.load_cert_chain.assert_called_with( + "cert", + keyfile="key", + password=None, + ) dummy_ssl_context.set_ciphers.assert_called_with("cipher") dummy_ssl_context = mock.Mock(options=0) @@ -592,7 +596,34 @@ def test_ssl_connect(self): assert create_default_context.called assert dummy_ssl_context.check_hostname assert dummy_ssl_context.verify_mode == ssl.CERT_REQUIRED - dummy_ssl_context.load_cert_chain.assert_called_with("cert", keyfile="key") + dummy_ssl_context.load_cert_chain.assert_called_with( + "cert", + keyfile="key", + password=None, + ) + dummy_ssl_context.set_ciphers.assert_not_called + + dummy_ssl_context = mock.Mock(options=0) + with mock.patch("pymysql.connections.Connection.connect"), mock.patch( + "pymysql.connections.ssl.create_default_context", + new=mock.Mock(return_value=dummy_ssl_context), + ) as create_default_context: + pymysql.connect( + ssl={ + "ca": "ca", + "cert": "cert", + "key": "key", + "password": "password", + }, + ) + assert create_default_context.called + assert dummy_ssl_context.check_hostname + assert dummy_ssl_context.verify_mode == ssl.CERT_REQUIRED + dummy_ssl_context.load_cert_chain.assert_called_with( + "cert", + keyfile="key", + password="password", + ) dummy_ssl_context.set_ciphers.assert_not_called dummy_ssl_context = mock.Mock(options=0) @@ -622,7 +653,11 @@ def test_ssl_connect(self): assert create_default_context.called assert not dummy_ssl_context.check_hostname assert dummy_ssl_context.verify_mode == ssl.CERT_NONE - dummy_ssl_context.load_cert_chain.assert_called_with("cert", keyfile="key") + dummy_ssl_context.load_cert_chain.assert_called_with( + "cert", + keyfile="key", + password=None, + ) dummy_ssl_context.set_ciphers.assert_not_called for ssl_verify_cert in (True, "1", "yes", "true"): @@ -640,7 +675,9 @@ def test_ssl_connect(self): assert not dummy_ssl_context.check_hostname assert dummy_ssl_context.verify_mode == ssl.CERT_REQUIRED dummy_ssl_context.load_cert_chain.assert_called_with( - "cert", keyfile="key" + "cert", + keyfile="key", + password=None, ) dummy_ssl_context.set_ciphers.assert_not_called @@ -659,7 +696,9 @@ def test_ssl_connect(self): assert not dummy_ssl_context.check_hostname assert dummy_ssl_context.verify_mode == ssl.CERT_NONE dummy_ssl_context.load_cert_chain.assert_called_with( - "cert", keyfile="key" + "cert", + keyfile="key", + password=None, ) dummy_ssl_context.set_ciphers.assert_not_called @@ -682,7 +721,9 @@ def test_ssl_connect(self): ssl.CERT_REQUIRED if ssl_ca is not None else ssl.CERT_NONE ), (ssl_ca, ssl_verify_cert) dummy_ssl_context.load_cert_chain.assert_called_with( - "cert", keyfile="key" + "cert", + keyfile="key", + password=None, ) dummy_ssl_context.set_ciphers.assert_not_called @@ -700,7 +741,33 @@ def test_ssl_connect(self): assert create_default_context.called assert dummy_ssl_context.check_hostname assert dummy_ssl_context.verify_mode == ssl.CERT_NONE - dummy_ssl_context.load_cert_chain.assert_called_with("cert", keyfile="key") + dummy_ssl_context.load_cert_chain.assert_called_with( + "cert", + keyfile="key", + password=None, + ) + dummy_ssl_context.set_ciphers.assert_not_called + + dummy_ssl_context = mock.Mock(options=0) + with mock.patch("pymysql.connections.Connection.connect"), mock.patch( + "pymysql.connections.ssl.create_default_context", + new=mock.Mock(return_value=dummy_ssl_context), + ) as create_default_context: + pymysql.connect( + ssl_ca="ca", + ssl_cert="cert", + ssl_key="key", + ssl_key_password="password", + ssl_verify_identity=True, + ) + assert create_default_context.called + assert dummy_ssl_context.check_hostname + assert dummy_ssl_context.verify_mode == ssl.CERT_NONE + dummy_ssl_context.load_cert_chain.assert_called_with( + "cert", + keyfile="key", + password="password", + ) dummy_ssl_context.set_ciphers.assert_not_called dummy_ssl_context = mock.Mock(options=0)