diff --git a/cassandra/cluster.py b/cassandra/cluster.py index a1d098a035..b97e70541c 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -5482,6 +5482,7 @@ class ResponseFuture(object): _errbacks = None _current_host = None _connection = None + _connection_pool = None _query_retries = 0 _start_time = None _metrics = None @@ -5578,7 +5579,7 @@ def _on_timeout(self, _attempts=0): # Capture connection stats before pool.return_connection() can alter state conn_in_flight = self._connection.in_flight - pool = self.session._get_pool_by_host_identity(self._current_host) + pool = self._connection_pool if pool and not pool.is_shutdown: # Do not return the stream ID to the pool yet. We cannot reuse it # because the node might still be processing the query and will @@ -5661,7 +5662,24 @@ def _query(self, host, message=None, cb=None): if message is None: message = self.message - pool = self.session._get_pool_by_host_identity(host) + expected_endpoint = None + if isinstance(host, Host): + with host.lock: + expected_endpoint = host.endpoint + pool = self.session._get_pool_by_host_identity( + host, expected_endpoint=expected_endpoint) + else: + pool = self.session._get_pool_by_host_identity(host) + + if pool and expected_endpoint is not None: + with host.lock: + endpoint_changed = not self.session._endpoints_match( + host.endpoint, expected_endpoint) + if endpoint_changed: + self._errors[host] = ConnectionException( + "Host endpoint changed while borrowing connection") + return None + if not pool: self._errors[host] = ConnectionException("Host has been marked down or removed") return None @@ -5678,7 +5696,25 @@ def _query(self, host, message=None, cb=None): connection, request_id = pool.borrow_connection(timeout=2.0, routing_key=self.query.routing_key, keyspace=self.query.keyspace, table=self.query.table) else: connection, request_id = pool.borrow_connection(timeout=2.0) + + if expected_endpoint is not None: + with host.lock: + endpoint_changed = not self.session._endpoints_match( + host.endpoint, expected_endpoint) + if endpoint_changed: + try: + with connection.lock: + connection.request_ids.append(request_id) + pool.return_connection(connection) + finally: + connection = None + self._errors[host] = ConnectionException( + "Host endpoint changed while borrowing connection") + return None + self._connection = connection + self._connection_pool = pool + result_meta = self.prepared_statement.result_metadata if self.prepared_statement else [] if cb is None: diff --git a/tests/unit/test_response_future.py b/tests/unit/test_response_future.py index dd7fa75045..298751cb9c 100644 --- a/tests/unit/test_response_future.py +++ b/tests/unit/test_response_future.py @@ -13,6 +13,7 @@ # limitations under the License. import unittest +import uuid from collections import deque from threading import RLock @@ -20,7 +21,7 @@ from cassandra import ConsistencyLevel, Unavailable, SchemaTargetType, SchemaChangeType, OperationTimedOut from cassandra.cluster import Session, ResponseFuture, NoHostAvailable, ProtocolVersion -from cassandra.connection import Connection, ConnectionException +from cassandra.connection import Connection, ConnectionException, DefaultEndPoint from cassandra.protocol import (ReadTimeoutErrorMessage, WriteTimeoutErrorMessage, UnavailableErrorMessage, ResultMessage, QueryMessage, OverloadedErrorMessage, IsBootstrappingErrorMessage, @@ -28,8 +29,8 @@ RESULT_KIND_ROWS, RESULT_KIND_SET_KEYSPACE, RESULT_KIND_SCHEMA_CHANGE, RESULT_KIND_PREPARED, ProtocolHandler) -from cassandra.policies import RetryPolicy, ExponentialBackoffRetryPolicy -from cassandra.pool import NoConnectionsAvailable +from cassandra.policies import RetryPolicy, ExponentialBackoffRetryPolicy, SimpleConvictionPolicy +from cassandra.pool import Host, NoConnectionsAvailable from cassandra.query import SimpleStatement from tests.util import assertEqual, assertIsInstance import pytest @@ -37,6 +38,24 @@ class ResponseFutureTests(unittest.TestCase): + class _EndpointSwapOnExitLock(object): + + def __init__(self, host, new_endpoint): + self._lock = RLock() + self._host = host + self._new_endpoint = new_endpoint + self._exits = 0 + + def __enter__(self): + self._lock.acquire() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self._lock.release() + self._exits += 1 + if self._exits == 1: + self._host.endpoint = self._new_endpoint + def make_basic_session(self): s = Mock(spec=Session) s.row_factory = lambda col_names, rows: [(col_names, rows)] @@ -52,7 +71,7 @@ def make_pool(self): def make_session(self): session = self.make_basic_session() session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1', 'ip2'] - session._pools.get.return_value = self.make_pool() + session._get_pool_by_host_identity.return_value = self.make_pool() return session def make_response_future(self, session): @@ -66,7 +85,7 @@ def make_mock_response(self, col_names, rows): def test_result_message(self): session = self.make_basic_session() session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1', 'ip2'] - pool = session._pools.get.return_value + pool = session._get_pool_by_host_identity.return_value pool.is_shutdown = False connection = Mock(spec=Connection) @@ -75,7 +94,7 @@ def test_result_message(self): rf = self.make_response_future(session) rf.send_request() - rf.session._pools.get.assert_called_once_with('ip1') + rf.session._get_pool_by_host_identity.assert_called_once_with('ip1') pool.borrow_connection.assert_called_once_with(timeout=ANY, routing_key=ANY, keyspace=ANY, table=ANY) connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[]) @@ -87,7 +106,7 @@ def test_result_message(self): def test_unknown_result_class(self): session = self.make_session() - pool = session._pools.get.return_value + pool = session._get_pool_by_host_identity.return_value connection = Mock(spec=Connection) pool.borrow_connection.return_value = (connection, 1) @@ -151,7 +170,7 @@ def test_heartbeat_defunct_deadlock(self): session = self.make_basic_session() session.cluster._default_load_balancing_policy.make_query_plan.return_value = [Mock(), Mock()] - session._pools.get.return_value = pool + session._get_pool_by_host_identity.return_value = pool query = SimpleStatement("SELECT * FROM foo") message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) @@ -252,7 +271,7 @@ def test_retry_policy_says_ignore(self): def test_retry_policy_says_retry(self): session = self.make_session() - pool = session._pools.get.return_value + pool = session._get_pool_by_host_identity.return_value query = SimpleStatement("INSERT INFO foo (a, b) VALUES (1, 2)") message = QueryMessage(query=query, consistency_level=ConsistencyLevel.QUORUM) @@ -266,7 +285,7 @@ def test_retry_policy_says_retry(self): rf = ResponseFuture(session, message, query, 1, retry_policy=retry_policy) rf.send_request() - rf.session._pools.get.assert_called_once_with('ip1') + rf.session._get_pool_by_host_identity.assert_called_once_with('ip1') pool.borrow_connection.assert_called_once_with(timeout=ANY, routing_key=ANY, keyspace=ANY, table=ANY) connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[]) @@ -285,13 +304,13 @@ def test_retry_policy_says_retry(self): # it should try again with the same host since this was # an UnavailableException - rf.session._pools.get.assert_called_with(host) + rf.session._get_pool_by_host_identity.assert_called_with(host) pool.borrow_connection.assert_called_with(timeout=ANY, routing_key=ANY, keyspace=ANY, table=ANY) connection.send_msg.assert_called_with(rf.message, 2, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[]) def test_retry_with_different_host(self): session = self.make_session() - pool = session._pools.get.return_value + pool = session._get_pool_by_host_identity.return_value connection = Mock(spec=Connection) pool.borrow_connection.return_value = (connection, 1) @@ -300,7 +319,7 @@ def test_retry_with_different_host(self): rf.message.consistency_level = ConsistencyLevel.QUORUM rf.send_request() - rf.session._pools.get.assert_called_once_with('ip1') + rf.session._get_pool_by_host_identity.assert_called_once_with('ip1') pool.borrow_connection.assert_called_once_with(timeout=ANY, routing_key=ANY, keyspace=ANY, table=ANY) connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[]) assert ConsistencyLevel.QUORUM == rf.message.consistency_level @@ -319,7 +338,7 @@ def test_retry_with_different_host(self): rf._retry_task(False, host) # it should try with a different host - rf.session._pools.get.assert_called_with('ip2') + rf.session._get_pool_by_host_identity.assert_called_with('ip2') pool.borrow_connection.assert_called_with(timeout=ANY, routing_key=ANY, keyspace=ANY, table=ANY) connection.send_msg.assert_called_with(rf.message, 2, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[]) @@ -328,13 +347,13 @@ def test_retry_with_different_host(self): def test_all_retries_fail(self): session = self.make_session() - pool = session._pools.get.return_value + pool = session._get_pool_by_host_identity.return_value connection = Mock(spec=Connection) pool.borrow_connection.return_value = (connection, 1) rf = self.make_response_future(session) rf.send_request() - rf.session._pools.get.assert_called_once_with('ip1') + rf.session._get_pool_by_host_identity.assert_called_once_with('ip1') result = Mock(spec=IsBootstrappingErrorMessage, info={}) host = Mock() @@ -346,7 +365,7 @@ def test_all_retries_fail(self): rf._retry_task(False, host) # it should try with a different host - rf.session._pools.get.assert_called_with('ip2') + rf.session._get_pool_by_host_identity.assert_called_with('ip2') result = Mock(spec=IsBootstrappingErrorMessage, info={}) rf._set_result(host, None, None, result) @@ -360,7 +379,7 @@ def test_all_retries_fail(self): def test_exponential_retry_policy_fail(self): session = self.make_session() - pool = session._pools.get.return_value + pool = session._get_pool_by_host_identity.return_value connection = Mock(spec=Connection) pool.borrow_connection.return_value = (connection, 1) @@ -368,7 +387,7 @@ def test_exponential_retry_policy_fail(self): message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) rf = ResponseFuture(session, message, query, 1, retry_policy=ExponentialBackoffRetryPolicy(2)) rf.send_request() - rf.session._pools.get.assert_called_once_with('ip1') + rf.session._get_pool_by_host_identity.assert_called_once_with('ip1') result = Mock(spec=IsBootstrappingErrorMessage, info={}) host = Mock() @@ -384,7 +403,7 @@ def test_exponential_retry_policy_fail(self): def test_all_pools_shutdown(self): session = self.make_basic_session() session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1', 'ip2'] - session._pools.get.return_value.is_shutdown = True + session._get_pool_by_host_identity.return_value.is_shutdown = True rf = ResponseFuture(session, Mock(), Mock(), 1) rf.send_request() @@ -399,7 +418,7 @@ def test_first_pool_shutdown(self): pool_shutdown.is_shutdown = True pool_ok = self.make_pool() pool_ok.is_shutdown = True - session._pools.get.side_effect = [pool_shutdown, pool_ok] + session._get_pool_by_host_identity.side_effect = [pool_shutdown, pool_ok] rf = self.make_response_future(session) rf.send_request() @@ -424,7 +443,7 @@ def test_timeout_getting_connection_from_pool(self): connection = Mock(spec=Connection) second_pool.borrow_connection.return_value = (connection, 1) - session._pools.get.side_effect = [first_pool, second_pool] + session._get_pool_by_host_identity.side_effect = [first_pool, second_pool] rf = self.make_response_future(session) rf.send_request() @@ -459,7 +478,7 @@ def test_callback(self): def test_errback(self): session = self.make_session() - pool = session._pools.get.return_value + pool = session._get_pool_by_host_identity.return_value connection = Mock(spec=Connection) pool.borrow_connection.return_value = (connection, 1) @@ -508,7 +527,7 @@ def test_multiple_callbacks(self): def test_multiple_errbacks(self): session = self.make_session() - pool = session._pools.get.return_value + pool = session._get_pool_by_host_identity.return_value connection = Mock(spec=Connection) pool.borrow_connection.return_value = (connection, 1) @@ -581,7 +600,7 @@ def test_add_callbacks(self): def test_prepared_query_not_found(self): session = self.make_session() - pool = session._pools.get.return_value + pool = session._get_pool_by_host_identity.return_value connection = Mock(spec=Connection) pool.borrow_connection.return_value = (connection, 1) @@ -606,7 +625,7 @@ def test_prepared_query_not_found(self): def test_prepared_query_not_found_bad_keyspace(self): session = self.make_session() - pool = session._pools.get.return_value + pool = session._get_pool_by_host_identity.return_value connection = Mock(spec=Connection) pool.borrow_connection.return_value = (connection, 1) @@ -655,7 +674,7 @@ def test_timeout_does_not_release_stream_id(self): session = self.make_basic_session() session.cluster._default_load_balancing_policy.make_query_plan.return_value = [Mock(endpoint='ip1'), Mock(endpoint='ip2')] pool = self.make_pool() - session._pools.get.return_value = pool + session._get_pool_by_host_identity.return_value = pool connection = Mock(spec=Connection, lock=RLock(), _requests={}, request_ids=deque(), orphaned_request_ids=set(), orphaned_threshold=256, in_flight=3) pool.borrow_connection.return_value = (connection, 1) @@ -675,6 +694,127 @@ def test_timeout_does_not_release_stream_id(self): assert len(connection.request_ids) == 0, \ "Request IDs should be empty but it's not: {}".format(connection.request_ids) + def test_timeout_returns_orphan_to_original_pool_after_endpoint_swap(self): + session = self.make_basic_session() + host = Host(DefaultEndPoint('127.0.0.1'), SimpleConvictionPolicy, + host_id=uuid.uuid4()) + session.cluster._default_load_balancing_policy.make_query_plan.return_value = [host] + old_pool = self.make_pool() + replacement_pool = self.make_pool() + connection = Mock(spec=Connection, lock=RLock(), _requests={}, request_ids=deque(), + orphaned_request_ids=set(), orphaned_threshold=256, in_flight=1) + old_pool.borrow_connection.return_value = (connection, 1) + session._get_pool_by_host_identity.side_effect = [old_pool, replacement_pool] + + rf = self.make_response_future(session) + rf.send_request() + connection._requests[1] = (connection._handle_options_response, + ProtocolHandler.decode_message, []) + host.endpoint = DefaultEndPoint('127.0.0.2') + + rf._on_timeout() + + replacement_pool.return_connection.assert_not_called() + old_pool.return_connection.assert_called_once_with( + connection, stream_was_orphaned=True) + + def test_query_does_not_borrow_stale_pool_after_endpoint_swap(self): + session = self.make_basic_session() + host = Host(DefaultEndPoint('127.0.0.1'), SimpleConvictionPolicy, + host_id=uuid.uuid4()) + stale_pool = self.make_pool() + stale_pool.host = host + stale_pool.endpoint = host.endpoint + stale_pool.is_shutdown = False + connection = Mock(spec=Connection) + stale_pool.borrow_connection.return_value = (connection, 1) + + session._lock = RLock() + session._pools = {host: stale_pool} + session._endpoints_match = Session._endpoints_match.__get__(session, Session) + session._pool_matches_expected = Session._pool_matches_expected.__get__(session, Session) + session._get_pool_by_host_identity = Session._get_pool_by_host_identity.__get__(session, Session) + session.cluster._endpoints_match.side_effect = Cluster._endpoints_match + session.cluster._default_load_balancing_policy.make_query_plan.return_value = [host] + host.endpoint = DefaultEndPoint('127.0.0.2') + + rf = self.make_response_future(session) + + assert not rf.send_request() + stale_pool.borrow_connection.assert_not_called() + assert isinstance(rf._errors[host], ConnectionException) + + def test_query_rechecks_endpoint_after_pool_lookup_race(self): + session = self.make_basic_session() + host = Host(DefaultEndPoint('127.0.0.1'), SimpleConvictionPolicy, + host_id=uuid.uuid4()) + stale_pool = self.make_pool() + stale_pool.host = host + stale_pool.endpoint = host.endpoint + stale_pool.is_shutdown = False + connection = Mock(spec=Connection) + stale_pool.borrow_connection.return_value = (connection, 1) + + session._lock = RLock() + session._pools = {host: stale_pool} + session._endpoints_match = Session._endpoints_match.__get__(session, Session) + session._pool_matches_expected = Session._pool_matches_expected.__get__(session, Session) + session._get_pool_by_host_identity = Session._get_pool_by_host_identity.__get__(session, Session) + session.cluster._endpoints_match.side_effect = Cluster._endpoints_match + session.cluster._default_load_balancing_policy.make_query_plan.return_value = [host] + host.lock = self._EndpointSwapOnExitLock( + host, DefaultEndPoint('127.0.0.2')) + + rf = self.make_response_future(session) + + assert not rf.send_request() + stale_pool.borrow_connection.assert_not_called() + assert isinstance(rf._errors[host], ConnectionException) + + def test_query_releases_request_id_after_post_borrow_endpoint_swap(self): + session = self.make_basic_session() + host = Host(DefaultEndPoint('127.0.0.1'), SimpleConvictionPolicy, + host_id=uuid.uuid4()) + old_endpoint = host.endpoint + new_endpoint = DefaultEndPoint('127.0.0.2') + stale_pool = self.make_pool() + stale_pool.host = host + stale_pool.endpoint = old_endpoint + stale_pool.is_shutdown = False + connection = Mock(spec=Connection, lock=RLock(), request_ids=deque([1]), + in_flight=0) + + def borrow_connection(**kwargs): + with connection.lock: + connection.in_flight += 1 + request_id = connection.request_ids.popleft() + host.endpoint = new_endpoint + return connection, request_id + + def return_connection(returned_connection): + with returned_connection.lock: + returned_connection.in_flight -= 1 + + stale_pool.borrow_connection.side_effect = borrow_connection + stale_pool.return_connection.side_effect = return_connection + + session._lock = RLock() + session._pools = {host: stale_pool} + session._endpoints_match = Session._endpoints_match.__get__(session, Session) + session._pool_matches_expected = Session._pool_matches_expected.__get__(session, Session) + session._get_pool_by_host_identity = Session._get_pool_by_host_identity.__get__(session, Session) + session.cluster._endpoints_match.side_effect = Cluster._endpoints_match + session.cluster._default_load_balancing_policy.make_query_plan.return_value = [host] + + rf = self.make_response_future(session) + + assert not rf.send_request() + stale_pool.return_connection.assert_called_once_with(connection) + connection.send_msg.assert_not_called() + assert list(connection.request_ids) == [1] + assert connection.in_flight == 0 + assert isinstance(rf._errors[host], ConnectionException) + def test_single_host_query_plan_exhausted_after_one_retry(self): """ Test that when a specific host is provided, the query plan is properly @@ -686,7 +826,7 @@ def test_single_host_query_plan_exhausted_after_one_retry(self): """ session = self.make_basic_session() pool = self.make_pool() - session._pools.get.return_value = pool + session._get_pool_by_host_identity.return_value = pool # Create a specific host specific_host = Mock() @@ -702,7 +842,7 @@ def test_single_host_query_plan_exhausted_after_one_retry(self): rf.send_request() # Verify initial request was sent - rf.session._pools.get.assert_called_once_with(specific_host) + rf.session._get_pool_by_host_identity.assert_called_once_with(specific_host) pool.borrow_connection.assert_called_once_with(timeout=ANY, routing_key=ANY, keyspace=ANY, table=ANY) connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[])