diff --git a/cassandra/connection.py b/cassandra/connection.py index bfe38fc702..5c7d71b72e 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -245,9 +245,9 @@ def create(self, row): class SniEndPoint(EndPoint): """SNI Proxy EndPoint implementation.""" - def __init__(self, proxy_address, server_name, port=9042): + def __init__(self, proxy_address, server_name, port=9042, init_index=0): self._proxy_address = proxy_address - self._index = 0 + self._index = init_index self._resolved_address = None # resolved address self._port = port self._server_name = server_name @@ -267,8 +267,7 @@ def ssl_options(self): def resolve(self): try: - resolved_addresses = socket.getaddrinfo(self._proxy_address, self._port, - socket.AF_UNSPEC, socket.SOCK_STREAM) + resolved_addresses = self._resolve_proxy_addresses() except socket.gaierror: log.debug('Could not resolve sni proxy hostname "%s" ' 'with port %d' % (self._proxy_address, self._port)) @@ -280,6 +279,10 @@ def resolve(self): return self._resolved_address, self._port + def _resolve_proxy_addresses(self): + return socket.getaddrinfo(self._proxy_address, self._port, + socket.AF_UNSPEC, socket.SOCK_STREAM) + def __eq__(self, other): return (isinstance(other, SniEndPoint) and self.address == other.address and self.port == other.port and @@ -305,16 +308,24 @@ class SniEndPointFactory(EndPointFactory): def __init__(self, proxy_address, port): self._proxy_address = proxy_address self._port = port + # Initial lookup index to prevent all SNI endpoints to be resolved + # into the same starting IP address (which might not be available currently). + # If SNI resolves to 3 IPs, first endpoint will connect to first + # IP address, and subsequent resolutions to next IPs in round-robin + # fusion. + self._init_index = -1 def create(self, row): host_id = row.get("host_id") if host_id is None: raise ValueError("No host_id to create the SniEndPoint") - return SniEndPoint(self._proxy_address, str(host_id), self._port) + self._init_index += 1 + return SniEndPoint(self._proxy_address, str(host_id), self._port, self._init_index) def create_from_sni(self, sni): - return SniEndPoint(self._proxy_address, sni, self._port) + self._init_index += 1 + return SniEndPoint(self._proxy_address, sni, self._port, self._init_index) @total_ordering diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index 90bcfbdca8..517a709c9d 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -15,13 +15,16 @@ import logging import socket +import uuid -from unittest.mock import patch, Mock +from unittest.mock import patch, Mock, MagicMock from cassandra import ConsistencyLevel, DriverException, Timeout, Unavailable, RequestExecutionException, ReadTimeout, WriteTimeout, CoordinationFailure, ReadFailure, WriteFailure, FunctionFailure, AlreadyExists,\ InvalidRequest, Unauthorized, AuthenticationFailed, OperationTimedOut, UnsupportedOperation, RequestValidationException, ConfigurationException, ProtocolVersion from cassandra.cluster import _Scheduler, Session, Cluster, default_lbp_factory, \ - ExecutionProfile, _ConfigMode, EXEC_PROFILE_DEFAULT + ExecutionProfile, _ConfigMode, EXEC_PROFILE_DEFAULT, ControlConnection +from cassandra.connection import SniEndPoint, Connection, SniEndPointFactory +from cassandra.datastax.cloud import CloudConfig from cassandra.pool import Host from cassandra.policies import HostDistance, RetryPolicy, RoundRobinPolicy, DowngradingConsistencyRetryPolicy, SimpleConvictionPolicy from cassandra.query import SimpleStatement, named_tuple_factory, tuple_factory @@ -31,6 +34,7 @@ log = logging.getLogger(__name__) + class ExceptionTypeTest(unittest.TestCase): def test_exception_types(self): @@ -85,6 +89,12 @@ def test_exception_types(self): self.assertTrue(issubclass(UnsupportedOperation, DriverException)) +class MockOrderedPolicy(RoundRobinPolicy): + all_hosts = set() + + def make_query_plan(self, working_keyspace=None, query=None): + return sorted(self.all_hosts, key=lambda x: x.endpoint.ssl_options['server_hostname']) + class ClusterTest(unittest.TestCase): def test_tuple_for_contact_points(self): @@ -119,6 +129,66 @@ def test_requests_in_flight_threshold(self): for n in (0, mn, 128): self.assertRaises(ValueError, c.set_max_requests_per_connection, d, n) + # Tests verifies that driver can connect to SNI endpoint even when one IP + # returned by the DNS resolution of SNI raises error. Mocked SNI resolution method + # returns two IPs. Trying to connect to the first one always fails + # with socket exception. + def test_sni_round_robin_dns_resolution(self): + def _mocked_cloud_config(cloud_config, create_pyopenssl_context): + config = CloudConfig.from_dict({}) + config.sni_host = 'proxy.datastax.com' + config.sni_port = 9042 + # for 2e25021d-8d72-41a7-a247-3da85c5d92d2 we return IP 127.0.0.1 to which connection fails + config.host_ids = ['2e25021d-8d72-41a7-a247-3da85c5d92d2', '8c4b6ed7-f505-4226-b7a4-41f322520c1f'] + return config + + def _mocked_proxy_dns_resolution(self): + return [ + (socket.AF_UNIX, socket.SOCK_STREAM, 0, None, ('127.0.0.1', 9042)), + (socket.AF_UNIX, socket.SOCK_STREAM, 0, None, ('127.0.0.2', 9042)) + ] + + def _mocked_try_connect(self, host): + address, port = host.endpoint.resolve() + if address == '127.0.0.1': + raise socket.error + return MagicMock(spec=Connection) + + with patch('cassandra.datastax.cloud.get_cloud_config', _mocked_cloud_config): + with patch.object(SniEndPoint, '_resolve_proxy_addresses', _mocked_proxy_dns_resolution): + cloud_config = { + 'secure_connect_bundle': '/path/to/secure-connect-dbname.zip' + } + cluster = Cluster(cloud=cloud_config) + lbp = MockOrderedPolicy() + cluster.load_balancing_policy = lbp + with patch.object(ControlConnection, '_try_connect', _mocked_try_connect): + for endpoint in cluster.endpoints_resolved: + host, new = cluster.add_host(endpoint, signal=False) + lbp.all_hosts.add(host) + # No NoHostAvailable indicates that test passed. + cluster.control_connection.connect() + cluster.shutdown() + + # Validate that at least the default LBP can create a query plan with end points that resolve + # to different addresses initially. This may not be exactly how things play out in practice + # (the control connection will muck with this even if nothing else does) but it should be + # a pretty good approximation. + def test_query_plan_for_sni_contains_unique_addresses(self): + node_cnt = 5 + def _mocked_proxy_dns_resolution(self): + return [(socket.AF_UNIX, socket.SOCK_STREAM, 0, None, ('127.0.0.%s' % (i,), 9042)) for i in range(node_cnt)] + + c = Cluster() + lbp = c.load_balancing_policy + lbp.local_dc = "dc1" + factory = SniEndPointFactory("proxy.foo.bar", 9042) + for host in (Host(factory.create({"host_id": uuid.uuid4().hex, "dc": "dc1"}), SimpleConvictionPolicy) for _ in range(node_cnt)): + lbp.on_up(host) + with patch.object(SniEndPoint, '_resolve_proxy_addresses', _mocked_proxy_dns_resolution): + addrs = [host.endpoint.resolve() for host in lbp.make_query_plan()] + self.assertEqual(len(addrs), len(set(addrs))) + class SchedulerTest(unittest.TestCase): # TODO: this suite could be expanded; for now just adding a test covering a ticket diff --git a/tests/unit/test_endpoints.py b/tests/unit/test_endpoints.py index b0841962ca..4352afb9a5 100644 --- a/tests/unit/test_endpoints.py +++ b/tests/unit/test_endpoints.py @@ -65,3 +65,15 @@ def test_endpoint_resolve(self): for i in range(10): (address, _) = endpoint.resolve() self.assertEqual(address, next(it)) + + def test_sni_resolution_start_index(self): + factory = SniEndPointFactory("proxy.datastax.com", 9999) + initial_index = factory._init_index + + endpoint1 = factory.create_from_sni('sni1') + self.assertEqual(factory._init_index, initial_index + 1) + self.assertEqual(endpoint1._index, factory._init_index) + + endpoint2 = factory.create_from_sni('sni2') + self.assertEqual(factory._init_index, initial_index + 2) + self.assertEqual(endpoint2._index, factory._init_index)