diff --git a/cassandra/client_routes.py b/cassandra/client_routes.py index 80b2477a6d..e447e37df2 100644 --- a/cassandra/client_routes.py +++ b/cassandra/client_routes.py @@ -294,7 +294,7 @@ def handle_client_routes_change(self, connection: 'Connection', timeout: float, return routes = self._query_routes_for_change_event(connection, timeout, pairs) - self._routes.merge(routes, affected_host_ids=set(host_uuids)) + self._routes.merge(routes, affected_host_ids={host_id for _, host_id in pairs}) def _query_all_routes_for_connections(self, connection: 'Connection', timeout: float, connection_ids: Set[str]) -> List[_Route]: @@ -322,27 +322,25 @@ def _query_all_routes_for_connections(self, connection: 'Connection', timeout: f def _query_routes_for_change_event(self, connection: 'Connection', timeout: float, route_pairs: List[Tuple[str, uuid.UUID]]) -> List[_Route]: """ - Query specific routes affected by a CLIENT_ROUTES_CHANGE event. + Query current routes for hosts affected by a CLIENT_ROUTES_CHANGE event. - Takes a list of (connection_id, host_id) pairs that represent the exact - routes affected by an operation. This provides precise updates without - fetching unrelated routes. - - If the pairs list is empty or None, falls back to a complete refresh - of all routes for safety. + The in-memory route store keeps a single preferred route per host. When + any configured connection_id changes for a host, fetch all configured + connection_ids for that host so the existing preferred route can be + retained if it is still present. :param connection: Connection to execute query on :param timeout: Query timeout in seconds - :param route_pairs: List of (connection_id, host_id) tuples + :param route_pairs: List of affected (connection_id, host_id) tuples :return: List of _Route """ unique_pairs = list(dict.fromkeys(route_pairs)) - conn_ids = list(dict.fromkeys(cid for cid, _ in unique_pairs)) + conn_ids = sorted(self._connection_ids) host_ids = list(dict.fromkeys(hid for _, hid in unique_pairs)) - log.debug("[client routes] Querying route pairs from CLIENT_ROUTES_CHANGE " - "(first 5 of %d): %s", len(unique_pairs), unique_pairs[:5]) + log.debug("[client routes] Querying routes from CLIENT_ROUTES_CHANGE " + "for host_ids (first 5 of %d): %s", len(host_ids), host_ids[:5]) conn_ph = ', '.join('?' for _ in conn_ids) host_ph = ', '.join('?' for _ in host_ids) diff --git a/cassandra/connection.py b/cassandra/connection.py index 08501d0a2b..c129bfb3a5 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -468,21 +468,25 @@ def resolve(self) -> Tuple[str, int]: def __eq__(self, other): return (isinstance(other, ClientRoutesEndPoint) and self._host_id == other._host_id and - self._original_address == other._original_address) + self._original_address == other._original_address and + self._original_port == other._original_port) def __hash__(self): - return hash((self._host_id, self._original_address)) + return hash((self._host_id, self._original_address, self._original_port)) + + def _comparison_key(self): + return (self._host_id, self._original_address, + self._original_port is None, self._original_port) def __lt__(self, other): - return ((self._host_id, self._original_address) < - (other._host_id, other._original_address)) + return self._comparison_key() < other._comparison_key() def __str__(self): return str("%s (host_id=%s)" % (self._original_address, self._host_id)) def __repr__(self): - return "<%s: host_id=%s, original_addr=%s>" % ( - self.__class__.__name__, self._host_id, self._original_address) + return "<%s: host_id=%s, original_addr=%s, original_port=%s>" % ( + self.__class__.__name__, self._host_id, self._original_address, self._original_port) class _Frame(object): diff --git a/tests/unit/test_client_routes.py b/tests/unit/test_client_routes.py index 0aa82fc76a..bca430c628 100644 --- a/tests/unit/test_client_routes.py +++ b/tests/unit/test_client_routes.py @@ -233,6 +233,92 @@ def test_handle_change_merges_when_host_ids_present(self, mock_query): self.assertIsNotNone(handler._routes.get_by_host_id(existing_host)) self.assertIsNotNone(handler._routes.get_by_host_id(new_host)) + @patch.object(_ClientRoutesHandler, '_query_routes_for_change_event') + def test_handle_change_preserves_routes_for_unrelated_connection_ids(self, mock_query): + """Routes for unrelated connection_ids in mixed events should not be removed.""" + handler = _ClientRoutesHandler(self.config) + mock_conn = Mock() + + conn_id = str(self.conn_id) + changed_host = uuid.uuid4() + unrelated_host = uuid.uuid4() + + handler._routes.update([ + _Route(connection_id=conn_id, host_id=changed_host, address="old.com", port=9042), + _Route(connection_id=conn_id, host_id=unrelated_host, address="keep.com", port=9042), + ]) + + mock_query.return_value = [ + _Route(connection_id=conn_id, host_id=changed_host, address="new.com", port=9042), + ] + + handler.handle_client_routes_change( + mock_conn, 5.0, + ClientRoutesChangeType.UPDATE_NODES, + connection_ids=[conn_id, "unrelated-conn-id"], + host_ids=[str(changed_host), str(unrelated_host)], + ) + + self.assertEqual(handler._routes.get_by_host_id(changed_host).address, "new.com") + self.assertEqual(handler._routes.get_by_host_id(unrelated_host).address, "keep.com") + + def test_handle_change_preserves_preferred_route_for_same_host(self): + conn_a = str(uuid.uuid4()) + conn_b = str(uuid.uuid4()) + host_id = uuid.uuid4() + config = ClientRoutesConfig([ + ClientRouteProxy(conn_a), + ClientRouteProxy(conn_b), + ]) + handler = _ClientRoutesHandler(config) + handler._routes.update([ + _Route(connection_id=conn_b, host_id=host_id, + address="current.example.com", port=9042), + ]) + + table_routes = [ + _Route(connection_id=conn_a, host_id=host_id, + address="changed.example.com", port=9042), + _Route(connection_id=conn_b, host_id=host_id, + address="current.example.com", port=9042), + ] + + def wait_for_response(query_msg, timeout): + conn_placeholders = query_msg.query.split( + "connection_id IN (", 1)[1].split(")", 1)[0].count("?") + conn_ids = { + param.decode("utf-8") + for param in query_msg.query_params[:conn_placeholders] + } + host_ids = { + uuid.UUID(bytes=param) + for param in query_msg.query_params[conn_placeholders:] + } + rows = [ + (route.connection_id, route.host_id, route.address, + route.port, route.port) + for route in table_routes + if route.connection_id in conn_ids and route.host_id in host_ids + ] + return Mock( + column_names=["connection_id", "host_id", "address", "port", "tls_port"], + parsed_rows=rows, + ) + + mock_conn = Mock() + mock_conn.wait_for_response.side_effect = wait_for_response + + handler.handle_client_routes_change( + mock_conn, 5.0, + ClientRoutesChangeType.UPDATE_NODES, + connection_ids=[conn_a], + host_ids=[str(host_id)], + ) + + route = handler._routes.get_by_host_id(host_id) + self.assertEqual(route.connection_id, conn_b) + self.assertEqual(route.address, "current.example.com") + @patch.object(_ClientRoutesHandler, '_query_all_routes_for_connections') def test_handle_change_updates_when_no_host_ids(self, mock_query): """When no host_ids are provided, routes should be fully replaced.""" @@ -388,6 +474,42 @@ def test_resolve_host_missing_port_raises(self): with self.assertRaises(ValueError): self.handler.resolve_host(host_id) + def test_endpoint_identity_includes_original_port(self): + host_id = uuid.uuid4() + first = ClientRoutesEndPoint( + host_id=host_id, + handler=self.handler, + original_address="10.0.0.1", + original_port=9042, + ) + second = ClientRoutesEndPoint( + host_id=host_id, + handler=self.handler, + original_address="10.0.0.1", + original_port=9142, + ) + + self.assertNotEqual(first, second) + self.assertEqual(len({first, second}), 2) + + def test_endpoint_ordering_handles_missing_original_port(self): + host_id = uuid.uuid4() + without_port = ClientRoutesEndPoint( + host_id=host_id, + handler=self.handler, + original_address="10.0.0.1", + original_port=None, + ) + with_port = ClientRoutesEndPoint( + host_id=host_id, + handler=self.handler, + original_address="10.0.0.1", + original_port=9042, + ) + + self.assertCountEqual( + sorted([without_port, with_port]), [without_port, with_port]) + class TestClientRoutesEndPointFactory(unittest.TestCase):