From d30b6d4a5b223279ad5d23f2e2e8e398435a8281 Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Tue, 5 May 2026 17:02:34 -0400 Subject: [PATCH] client-routes: preserve partial route state Partial CLIENT_ROUTES_CHANGE handling must not treat filtered event entries as affected route state. Limit merge invalidation to configured event pairs so unrelated connection IDs cannot drop cached proxy routes. For same-host partial updates, fetch all configured connection IDs for affected hosts. This lets the route store keep the currently preferred proxy route when it is still present instead of switching because the partial event omitted it. Also keep ClientRoutesEndPoint identity port-aware and sortable when original_port is missing. Fixes #846 Refs #813 --- cassandra/client_routes.py | 22 +++--- cassandra/connection.py | 16 ++-- tests/unit/test_client_routes.py | 122 +++++++++++++++++++++++++++++++ 3 files changed, 142 insertions(+), 18 deletions(-) diff --git a/cassandra/client_routes.py b/cassandra/client_routes.py index 80b2477a6d..e447e37df2 100644 --- a/cassandra/client_routes.py +++ b/cassandra/client_routes.py @@ -294,7 +294,7 @@ def handle_client_routes_change(self, connection: 'Connection', timeout: float, return routes = self._query_routes_for_change_event(connection, timeout, pairs) - self._routes.merge(routes, affected_host_ids=set(host_uuids)) + self._routes.merge(routes, affected_host_ids={host_id for _, host_id in pairs}) def _query_all_routes_for_connections(self, connection: 'Connection', timeout: float, connection_ids: Set[str]) -> List[_Route]: @@ -322,27 +322,25 @@ def _query_all_routes_for_connections(self, connection: 'Connection', timeout: f def _query_routes_for_change_event(self, connection: 'Connection', timeout: float, route_pairs: List[Tuple[str, uuid.UUID]]) -> List[_Route]: """ - Query specific routes affected by a CLIENT_ROUTES_CHANGE event. + Query current routes for hosts affected by a CLIENT_ROUTES_CHANGE event. - Takes a list of (connection_id, host_id) pairs that represent the exact - routes affected by an operation. This provides precise updates without - fetching unrelated routes. - - If the pairs list is empty or None, falls back to a complete refresh - of all routes for safety. + The in-memory route store keeps a single preferred route per host. When + any configured connection_id changes for a host, fetch all configured + connection_ids for that host so the existing preferred route can be + retained if it is still present. :param connection: Connection to execute query on :param timeout: Query timeout in seconds - :param route_pairs: List of (connection_id, host_id) tuples + :param route_pairs: List of affected (connection_id, host_id) tuples :return: List of _Route """ unique_pairs = list(dict.fromkeys(route_pairs)) - conn_ids = list(dict.fromkeys(cid for cid, _ in unique_pairs)) + conn_ids = sorted(self._connection_ids) host_ids = list(dict.fromkeys(hid for _, hid in unique_pairs)) - log.debug("[client routes] Querying route pairs from CLIENT_ROUTES_CHANGE " - "(first 5 of %d): %s", len(unique_pairs), unique_pairs[:5]) + log.debug("[client routes] Querying routes from CLIENT_ROUTES_CHANGE " + "for host_ids (first 5 of %d): %s", len(host_ids), host_ids[:5]) conn_ph = ', '.join('?' for _ in conn_ids) host_ph = ', '.join('?' for _ in host_ids) diff --git a/cassandra/connection.py b/cassandra/connection.py index 08501d0a2b..c129bfb3a5 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -468,21 +468,25 @@ def resolve(self) -> Tuple[str, int]: def __eq__(self, other): return (isinstance(other, ClientRoutesEndPoint) and self._host_id == other._host_id and - self._original_address == other._original_address) + self._original_address == other._original_address and + self._original_port == other._original_port) def __hash__(self): - return hash((self._host_id, self._original_address)) + return hash((self._host_id, self._original_address, self._original_port)) + + def _comparison_key(self): + return (self._host_id, self._original_address, + self._original_port is None, self._original_port) def __lt__(self, other): - return ((self._host_id, self._original_address) < - (other._host_id, other._original_address)) + return self._comparison_key() < other._comparison_key() def __str__(self): return str("%s (host_id=%s)" % (self._original_address, self._host_id)) def __repr__(self): - return "<%s: host_id=%s, original_addr=%s>" % ( - self.__class__.__name__, self._host_id, self._original_address) + return "<%s: host_id=%s, original_addr=%s, original_port=%s>" % ( + self.__class__.__name__, self._host_id, self._original_address, self._original_port) class _Frame(object): diff --git a/tests/unit/test_client_routes.py b/tests/unit/test_client_routes.py index 0aa82fc76a..bca430c628 100644 --- a/tests/unit/test_client_routes.py +++ b/tests/unit/test_client_routes.py @@ -233,6 +233,92 @@ def test_handle_change_merges_when_host_ids_present(self, mock_query): self.assertIsNotNone(handler._routes.get_by_host_id(existing_host)) self.assertIsNotNone(handler._routes.get_by_host_id(new_host)) + @patch.object(_ClientRoutesHandler, '_query_routes_for_change_event') + def test_handle_change_preserves_routes_for_unrelated_connection_ids(self, mock_query): + """Routes for unrelated connection_ids in mixed events should not be removed.""" + handler = _ClientRoutesHandler(self.config) + mock_conn = Mock() + + conn_id = str(self.conn_id) + changed_host = uuid.uuid4() + unrelated_host = uuid.uuid4() + + handler._routes.update([ + _Route(connection_id=conn_id, host_id=changed_host, address="old.com", port=9042), + _Route(connection_id=conn_id, host_id=unrelated_host, address="keep.com", port=9042), + ]) + + mock_query.return_value = [ + _Route(connection_id=conn_id, host_id=changed_host, address="new.com", port=9042), + ] + + handler.handle_client_routes_change( + mock_conn, 5.0, + ClientRoutesChangeType.UPDATE_NODES, + connection_ids=[conn_id, "unrelated-conn-id"], + host_ids=[str(changed_host), str(unrelated_host)], + ) + + self.assertEqual(handler._routes.get_by_host_id(changed_host).address, "new.com") + self.assertEqual(handler._routes.get_by_host_id(unrelated_host).address, "keep.com") + + def test_handle_change_preserves_preferred_route_for_same_host(self): + conn_a = str(uuid.uuid4()) + conn_b = str(uuid.uuid4()) + host_id = uuid.uuid4() + config = ClientRoutesConfig([ + ClientRouteProxy(conn_a), + ClientRouteProxy(conn_b), + ]) + handler = _ClientRoutesHandler(config) + handler._routes.update([ + _Route(connection_id=conn_b, host_id=host_id, + address="current.example.com", port=9042), + ]) + + table_routes = [ + _Route(connection_id=conn_a, host_id=host_id, + address="changed.example.com", port=9042), + _Route(connection_id=conn_b, host_id=host_id, + address="current.example.com", port=9042), + ] + + def wait_for_response(query_msg, timeout): + conn_placeholders = query_msg.query.split( + "connection_id IN (", 1)[1].split(")", 1)[0].count("?") + conn_ids = { + param.decode("utf-8") + for param in query_msg.query_params[:conn_placeholders] + } + host_ids = { + uuid.UUID(bytes=param) + for param in query_msg.query_params[conn_placeholders:] + } + rows = [ + (route.connection_id, route.host_id, route.address, + route.port, route.port) + for route in table_routes + if route.connection_id in conn_ids and route.host_id in host_ids + ] + return Mock( + column_names=["connection_id", "host_id", "address", "port", "tls_port"], + parsed_rows=rows, + ) + + mock_conn = Mock() + mock_conn.wait_for_response.side_effect = wait_for_response + + handler.handle_client_routes_change( + mock_conn, 5.0, + ClientRoutesChangeType.UPDATE_NODES, + connection_ids=[conn_a], + host_ids=[str(host_id)], + ) + + route = handler._routes.get_by_host_id(host_id) + self.assertEqual(route.connection_id, conn_b) + self.assertEqual(route.address, "current.example.com") + @patch.object(_ClientRoutesHandler, '_query_all_routes_for_connections') def test_handle_change_updates_when_no_host_ids(self, mock_query): """When no host_ids are provided, routes should be fully replaced.""" @@ -388,6 +474,42 @@ def test_resolve_host_missing_port_raises(self): with self.assertRaises(ValueError): self.handler.resolve_host(host_id) + def test_endpoint_identity_includes_original_port(self): + host_id = uuid.uuid4() + first = ClientRoutesEndPoint( + host_id=host_id, + handler=self.handler, + original_address="10.0.0.1", + original_port=9042, + ) + second = ClientRoutesEndPoint( + host_id=host_id, + handler=self.handler, + original_address="10.0.0.1", + original_port=9142, + ) + + self.assertNotEqual(first, second) + self.assertEqual(len({first, second}), 2) + + def test_endpoint_ordering_handles_missing_original_port(self): + host_id = uuid.uuid4() + without_port = ClientRoutesEndPoint( + host_id=host_id, + handler=self.handler, + original_address="10.0.0.1", + original_port=None, + ) + with_port = ClientRoutesEndPoint( + host_id=host_id, + handler=self.handler, + original_address="10.0.0.1", + original_port=9042, + ) + + self.assertCountEqual( + sorted([without_port, with_port]), [without_port, with_port]) + class TestClientRoutesEndPointFactory(unittest.TestCase):