From fa61d4db9c63c52720cc234e16a95bd9334ed28e Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Tue, 5 May 2026 13:15:18 -0400 Subject: [PATCH 1/6] control-connection: keep zero-token hosts --- cassandra/cluster.py | 3 +-- tests/unit/test_control_connection.py | 19 +++++++++++++++++-- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 8da9df6a55..e76ece316d 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -3963,9 +3963,8 @@ def _is_valid_peer(row): if "tokens" in row and not row.get("tokens"): log.debug( - "Found a zero-token node - tokens is None (broadcast_rpc: %s, host_id: %s). Ignoring host." % + "Found a zero-token node - tokens is None (broadcast_rpc: %s, host_id: %s). Adding host without tokens." % (broadcast_rpc, host_id)) - return False return True diff --git a/tests/unit/test_control_connection.py b/tests/unit/test_control_connection.py index d759e12332..631b98347f 100644 --- a/tests/unit/test_control_connection.py +++ b/tests/unit/test_control_connection.py @@ -82,6 +82,7 @@ def update_host(self, host, old_endpoint): host, created = self.add_or_return_host(host) self._host_id_by_endpoint[host.endpoint] = host.host_id self._host_id_by_endpoint.pop(old_endpoint, False) + self._host_id_by_endpoint[host.endpoint] = host.host_id def all_hosts_items(self): return list(self.hosts.items()) @@ -321,7 +322,6 @@ def refresh_and_validate_added_hosts(): [None, None, "a", "dc1", "rack1", ["1", "101", "201"], 'uuid1'], ["192.168.1.7", "10.0.0.1", "a", None, "rack1", ["1", "101", "201"], 'uuid2'], ["192.168.1.6", "10.0.0.1", "a", "dc1", None, ["1", "101", "201"], 'uuid3'], - ["192.168.1.5", "10.0.0.1", "a", "dc1", "rack1", None, 'uuid4'], ["192.168.1.4", "10.0.0.1", "a", "dc1", "rack1", ["1", "101", "201"], None]]]) refresh_and_validate_added_hosts() @@ -335,7 +335,6 @@ def refresh_and_validate_added_hosts(): [None, 9042, None, 7040, "a", "dc1", "rack1", ["2", "102", "202"], "uuid2"], ["192.168.1.5", 9042, "10.0.0.2", 7040, "a", None, "rack1", ["2", "102", "202"], "uuid2"], ["192.168.1.5", 9042, "10.0.0.2", 7040, "a", "dc1", None, ["2", "102", "202"], "uuid2"], - ["192.168.1.5", 9042, "10.0.0.2", 7040, "a", "dc1", "rack1", None, "uuid2"], ["192.168.1.5", 9042, "10.0.0.2", 7040, "a", "dc1", "rack1", ["2", "102", "202"], None]]]) refresh_and_validate_added_hosts() @@ -411,6 +410,22 @@ def test_refresh_nodes_and_tokens_add_host(self): assert self.cluster.added_hosts[0].rack == "rack1" assert self.cluster.added_hosts[0].host_id == "uuid4" + def test_refresh_nodes_and_tokens_adds_zero_token_host_without_token_map_entry(self): + # Zero-token nodes are valid topology members, but they do not own token ranges. + self.connection.peer_results[1].append( + ["192.168.1.3", "10.0.0.3", "a", "dc1", "rack1", None, "uuid4"] + ) + self.cluster.scheduler.schedule = lambda delay, f, *args, **kwargs: f(*args, **kwargs) + + self.control_connection.refresh_node_list_and_token_map() + + zero_token_host = self.cluster.metadata.get_host(DefaultEndPoint("192.168.1.3")) + assert zero_token_host is not None + assert zero_token_host.host_id == "uuid4" + assert zero_token_host.datacenter == "dc1" + assert zero_token_host.rack == "rack1" + assert zero_token_host not in self.cluster.metadata.token_map + def test_refresh_nodes_and_tokens_remove_host(self): del self.connection.peer_results[1][1] self.control_connection.refresh_node_list_and_token_map() From 5d225d68203a741ace9e37adfea3ae8aa5d4e300 Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Tue, 5 May 2026 13:53:46 -0400 Subject: [PATCH 2/6] test: expand zero-token host coverage --- tests/unit/test_control_connection.py | 67 ++++++++++++++++++++++++--- 1 file changed, 60 insertions(+), 7 deletions(-) diff --git a/tests/unit/test_control_connection.py b/tests/unit/test_control_connection.py index 631b98347f..b8be438bf4 100644 --- a/tests/unit/test_control_connection.py +++ b/tests/unit/test_control_connection.py @@ -80,7 +80,6 @@ def add_or_return_host(self, host): def update_host(self, host, old_endpoint): host, created = self.add_or_return_host(host) - self._host_id_by_endpoint[host.endpoint] = host.host_id self._host_id_by_endpoint.pop(old_endpoint, False) self._host_id_by_endpoint[host.endpoint] = host.host_id @@ -207,6 +206,15 @@ def setUp(self): self.control_connection._connection = self.connection self.control_connection._time = self.time + def _assert_zero_token_host_without_token_map_entry(self, endpoint, host_id): + zero_token_host = self.cluster.metadata.get_host(endpoint) + assert zero_token_host is not None + assert zero_token_host.host_id == host_id + assert zero_token_host.datacenter == "dc1" + assert zero_token_host.rack == "rack1" + assert zero_token_host not in self.cluster.metadata.token_map + return zero_token_host + def test_wait_for_schema_agreement(self): """ Basic test with all schema versions agreeing @@ -419,12 +427,34 @@ def test_refresh_nodes_and_tokens_adds_zero_token_host_without_token_map_entry(s self.control_connection.refresh_node_list_and_token_map() - zero_token_host = self.cluster.metadata.get_host(DefaultEndPoint("192.168.1.3")) - assert zero_token_host is not None - assert zero_token_host.host_id == "uuid4" - assert zero_token_host.datacenter == "dc1" - assert zero_token_host.rack == "rack1" - assert zero_token_host not in self.cluster.metadata.token_map + zero_token_host = self._assert_zero_token_host_without_token_map_entry( + DefaultEndPoint("192.168.1.3"), "uuid4") + assert 1 == len(self.cluster.added_hosts) + assert self.cluster.added_hosts[0] is zero_token_host + assert [] == self.cluster.metadata.removed_hosts + + def test_refresh_nodes_and_tokens_adds_empty_token_host_without_token_map_entry(self): + self.connection.peer_results[1].append( + ["192.168.1.3", "10.0.0.3", "a", "dc1", "rack1", [], "uuid4"] + ) + self.cluster.scheduler.schedule = lambda delay, f, *args, **kwargs: f(*args, **kwargs) + + self.control_connection.refresh_node_list_and_token_map() + + zero_token_host = self._assert_zero_token_host_without_token_map_entry( + DefaultEndPoint("192.168.1.3"), "uuid4") + assert 1 == len(self.cluster.added_hosts) + assert self.cluster.added_hosts[0] is zero_token_host + + def test_refresh_nodes_and_tokens_keeps_zero_token_local_host_without_token_map_entry(self): + self.connection.local_results[1][0][7] = None + + self.control_connection.refresh_node_list_and_token_map() + + self._assert_zero_token_host_without_token_map_entry( + DefaultEndPoint("192.168.1.0"), "uuid1") + assert [] == self.cluster.added_hosts + assert [] == self.cluster.metadata.removed_hosts def test_refresh_nodes_and_tokens_remove_host(self): del self.connection.peer_results[1][1] @@ -604,6 +634,29 @@ def test_refresh_nodes_and_tokens_add_host_detects_port(self): assert self.cluster.added_hosts[0].datacenter == "dc1" assert self.cluster.added_hosts[0].rack == "rack1" + def test_refresh_nodes_and_tokens_adds_zero_token_host_from_peers_v2_without_token_map_entry(self): + del self.connection.peer_results[:] + self.connection.peer_results.extend(self.connection.peer_results_v2) + self.connection.peer_results[1].append( + ["192.168.1.3", 555, "10.0.0.3", 666, "a", "dc1", "rack1", None, "uuid4"] + ) + self.connection.wait_for_responses = Mock(return_value=_node_meta_results( + self.connection.local_results, self.connection.peer_results)) + self.cluster.scheduler.schedule = lambda delay, f, *args, **kwargs: f(*args, **kwargs) + + self.control_connection.refresh_node_list_and_token_map() + + zero_token_host = self._assert_zero_token_host_without_token_map_entry( + DefaultEndPoint("192.168.1.3", 555), "uuid4") + assert 1 == len(self.cluster.added_hosts) + assert self.cluster.added_hosts[0] is zero_token_host + assert zero_token_host.endpoint.port == 555 + assert zero_token_host.broadcast_rpc_address == "192.168.1.3" + assert zero_token_host.broadcast_rpc_port == 555 + assert zero_token_host.broadcast_address == "10.0.0.3" + assert zero_token_host.broadcast_port == 666 + assert [] == self.cluster.metadata.removed_hosts + def test_refresh_nodes_and_tokens_add_host_detects_invalid_port(self): del self.connection.peer_results[:] self.connection.peer_results.extend(self.connection.peer_results_v2) From 11f039dd9af0a19f33c60a983a903e7be15a3323 Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Tue, 5 May 2026 17:57:09 -0400 Subject: [PATCH 3/6] Ignore zero-token hosts by default Track zero-token status during topology refreshes and rebuild routing state when nodes gain tokens. Filter zero-token hosts from default load balancing policies while preserving explicit whitelist behavior. --- cassandra/cluster.py | 182 ++++++++++++++++++-------- cassandra/policies.py | 73 ++++++++--- cassandra/pool.py | 7 +- tests/unit/test_control_connection.py | 22 +++- tests/unit/test_policies.py | 90 ++++++++++++- 5 files changed, 299 insertions(+), 75 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index e76ece316d..4e5524c293 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -1916,8 +1916,9 @@ def on_up(self, host): log.debug("Waiting to acquire lock for handling up status of node %s", host) with host.lock: - if host._currently_handling_node_up: - log.debug("Another thread is already handling up status of node %s", host) + if (host._currently_handling_node_up or + getattr(host, "_currently_handling_node_addition", False)): + log.debug("Another thread is already handling up/add status of node %s", host) return if host.is_up: @@ -2050,69 +2051,90 @@ def on_add(self, host, refresh_nodes=True): log.debug("Handling new host %r and notifying listeners", host) - self.profile_manager.on_add(host) - self.control_connection.on_add(host, refresh_nodes) + # Keep refresh-time pool rebuilds from racing this host's pool creation. + with host.lock: + if getattr(host, "_currently_handling_node_addition", False): + log.debug("Another thread is already handling add status of node %s", host) + return + host._currently_handling_node_addition = True - distance = self.profile_manager.distance(host) - if distance != HostDistance.IGNORED: - self._prepare_all_queries(host) - log.debug("Done preparing queries for new host %r", host) + have_future = False + try: + self.profile_manager.on_add(host) + self.control_connection.on_add(host, refresh_nodes) - if distance == HostDistance.IGNORED: - log.debug("Not adding connection pool for new host %r because the " - "load balancing policy has marked it as IGNORED", host) - self._finalize_add(host, set_up=False) - return + distance = self.profile_manager.distance(host) + if distance != HostDistance.IGNORED: + self._prepare_all_queries(host) + log.debug("Done preparing queries for new host %r", host) - futures_lock = Lock() - futures_results = [] - futures = set() + if distance == HostDistance.IGNORED: + log.debug("Not adding connection pool for new host %r because the " + "load balancing policy has marked it as IGNORED", host) + self._finalize_add(host, set_up=False) + return - def future_completed(future): - with futures_lock: - futures.discard(future) + futures_lock = Lock() + futures_results = [] + futures = set() - try: - futures_results.append(future.result()) - except Exception as exc: - futures_results.append(exc) + def future_completed(future): + with futures_lock: + futures.discard(future) - if futures: - return + try: + futures_results.append(future.result()) + except Exception as exc: + futures_results.append(exc) - log.debug('All futures have completed for added host %s', host) + if futures: + return - for exc in [f for f in futures_results if isinstance(f, Exception)]: - log.error("Unexpected failure while adding node %s, will not mark up:", host, exc_info=exc) - return + log.debug('All futures have completed for added host %s', host) - if not all(futures_results): - log.warning("Connection pool could not be created, not marking node %s up", host) - return + for exc in [f for f in futures_results if isinstance(f, Exception)]: + log.error("Unexpected failure while adding node %s, will not mark up:", host, exc_info=exc) + with host.lock: + host._currently_handling_node_addition = False + return - self._finalize_add(host) + if not all(futures_results): + log.warning("Connection pool could not be created, not marking node %s up", host) + with host.lock: + host._currently_handling_node_addition = False + return - have_future = False - for session in tuple(self.sessions): - future = session.add_or_renew_pool(host, is_host_addition=True) - if future is not None: - have_future = True - futures.add(future) - future.add_done_callback(future_completed) + self._finalize_add(host) + + for session in tuple(self.sessions): + future = session.add_or_renew_pool(host, is_host_addition=True) + if future is not None: + have_future = True + futures.add(future) + future.add_done_callback(future_completed) - if not have_future: - self._finalize_add(host) + if not have_future: + self._finalize_add(host) + except Exception: + if not have_future: + with host.lock: + host._currently_handling_node_addition = False + raise def _finalize_add(self, host, set_up=True): - if set_up: - host.set_up() + try: + if set_up: + host.set_up() - for listener in self.listeners: - listener.on_add(host) + for listener in self.listeners: + listener.on_add(host) - # see if there are any pools to add or remove now that the host is marked up - for session in tuple(self.sessions): - session.update_created_pools() + # see if there are any pools to add or remove now that the host is marked up + for session in tuple(self.sessions): + session.update_created_pools() + finally: + with host.lock: + host._currently_handling_node_addition = False def on_remove(self, host): if self.is_shutdown: @@ -2137,7 +2159,8 @@ def signal_connection_failure(self, host, connection_exc, is_host_addition, expe self.on_down(host, is_host_addition, expect_host_to_be_down) return is_down - def add_host(self, endpoint, datacenter=None, rack=None, signal=True, refresh_nodes=True, host_id=None): + def add_host(self, endpoint, datacenter=None, rack=None, signal=True, refresh_nodes=True, host_id=None, + is_zero_token=None): """ Called when adding initial contact points and when the control connection subsequently discovers a new node. @@ -2147,8 +2170,16 @@ def add_host(self, endpoint, datacenter=None, rack=None, signal=True, refresh_no """ with self.metadata._hosts_lock: if endpoint in self.metadata._host_id_by_endpoint: - return self.metadata._hosts[self.metadata._host_id_by_endpoint[endpoint]], False - host, new = self.metadata.add_or_return_host(Host(endpoint, self.conviction_policy_factory, datacenter, rack, host_id=host_id)) + host = self.metadata._hosts[self.metadata._host_id_by_endpoint[endpoint]] + if is_zero_token is not None: + host.is_zero_token = is_zero_token + return host, False + host = Host(endpoint, self.conviction_policy_factory, datacenter, rack, host_id=host_id) + if is_zero_token is not None: + host.is_zero_token = is_zero_token + host, new = self.metadata.add_or_return_host(host) + if not new and is_zero_token is not None: + host.is_zero_token = is_zero_token if new and signal: log.info("New Cassandra host %r discovered", host) self.on_add(host, refresh_nodes) @@ -3315,7 +3346,10 @@ def update_created_pools(self): # we don't eagerly set is_up on previously ignored hosts. None is included here # to allow us to attempt connections to hosts that have gone from ignored to something # else. - if distance != HostDistance.IGNORED and host.is_up in (True, None): + # on_up() and on_add() already own pool creation for hosts in flight. + if (distance != HostDistance.IGNORED and host.is_up in (True, None) and + not host._currently_handling_node_up and + not getattr(host, "_currently_handling_node_addition", False)): future = self.add_or_renew_pool(host, False) elif distance != pool.host_distance: # the distance has changed @@ -3864,6 +3898,8 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, # every node in the cluster was in the contact points, we won't discover # any new nodes, so we need this additional check. (See PYTHON-90) should_rebuild_token_map = force_token_rebuild or self._cluster.metadata.partitioner is None + zero_token_status_changed = False + promoted_zero_token_hosts = [] for row in peers_result: if not self._is_valid_peer(row): continue @@ -3884,10 +3920,20 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, host = self._cluster.metadata.get_host(endpoint) datacenter = row.get("data_center") rack = row.get("rack") + tokens = row.get("tokens", None) + has_token_status = "tokens" in row + is_zero_token = has_token_status and not tokens + token_status = is_zero_token if has_token_status else None if host is None: host = self._cluster.metadata.get_host_by_host_id(host_id) if host and host.endpoint != endpoint: + if has_token_status: + status_changed = self._update_zero_token_info(host, is_zero_token) + zero_token_status_changed |= status_changed + should_rebuild_token_map |= status_changed + if status_changed and not is_zero_token and host.is_up is not True: + promoted_zero_token_hosts.append(host) log.debug("[control connection] Updating host ip from %s to %s for (%s)", host.endpoint, endpoint, host_id) reconnector = host.get_and_set_reconnection_handler(None) if reconnector: @@ -3901,11 +3947,20 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, if host is None: log.debug("[control connection] Found new host to connect to: %s", endpoint) - host, _ = self._cluster.add_host(endpoint, datacenter=datacenter, rack=rack, signal=True, refresh_nodes=False, host_id=host_id) + host, _ = self._cluster.add_host(endpoint, datacenter=datacenter, rack=rack, signal=True, + refresh_nodes=False, host_id=host_id, + is_zero_token=token_status) should_rebuild_token_map = True else: should_rebuild_token_map |= self._update_location_info(host, datacenter, rack) + if has_token_status: + status_changed = self._update_zero_token_info(host, is_zero_token) + zero_token_status_changed |= status_changed + should_rebuild_token_map |= status_changed + if status_changed and not is_zero_token and host.is_up is not True: + promoted_zero_token_hosts.append(host) + host.host_id = host_id host.broadcast_address = _NodeInfo.get_broadcast_address(row) host.broadcast_port = _NodeInfo.get_broadcast_port(row) @@ -3916,7 +3971,6 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, host.dse_workload = row.get("workload") host.dse_workloads = row.get("workloads") - tokens = row.get("tokens", None) if partitioner and tokens and self._token_meta_enabled: token_map[host] = tokens self._cluster.metadata.update_host(host, old_endpoint=endpoint) @@ -3932,6 +3986,22 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, log.debug("[control connection] Rebuilding token map due to topology changes") self._cluster.metadata.rebuild_token_map(partitioner, token_map) + for host in promoted_zero_token_hosts: + self._cluster.on_up(host) + + if zero_token_status_changed: + for session in tuple(getattr(self._cluster, "sessions", ())): + session.update_created_pools() + + @staticmethod + def _update_zero_token_info(host, is_zero_token): + is_zero_token = bool(is_zero_token) + if host.is_zero_token == is_zero_token: + return False + + host.is_zero_token = is_zero_token + return True + @staticmethod def _is_valid_peer(row): broadcast_rpc = _NodeInfo.get_broadcast_rpc_address(row) @@ -3963,7 +4033,7 @@ def _is_valid_peer(row): if "tokens" in row and not row.get("tokens"): log.debug( - "Found a zero-token node - tokens is None (broadcast_rpc: %s, host_id: %s). Adding host without tokens." % + "Found a zero-token node - tokens are empty (broadcast_rpc: %s, host_id: %s). Adding host without tokens." % (broadcast_rpc, host_id)) return True diff --git a/cassandra/policies.py b/cassandra/policies.py index e742708019..eb0f016b3c 100644 --- a/cassandra/policies.py +++ b/cassandra/policies.py @@ -121,10 +121,25 @@ class LoadBalancingPolicy(HostStateListener): """ _hosts_lock = None + _ignore_zero_token_hosts = True def __init__(self): self._hosts_lock = Lock() + def _is_ignored_zero_token_host(self, host): + if getattr(host, 'is_zero_token', False) is not True: + return False + + child_policy = getattr(self, '_child_policy', None) + if child_policy is not None: + # Preserve child opt-outs through wrapper layers. + return bool(child_policy._is_ignored_zero_token_host(host)) + + return bool(self._ignore_zero_token_hosts) + + def _filter_zero_token_hosts(self, hosts): + return tuple(h for h in hosts if not self._is_ignored_zero_token_host(h)) + def distance(self, host): """ Returns a measure of how remote a :class:`~.pool.Host` is in @@ -178,10 +193,13 @@ class RoundRobinPolicy(LoadBalancingPolicy): def populate(self, cluster, hosts): self._live_hosts = frozenset(hosts) - if len(hosts) > 1: - self._position = randint(0, len(hosts) - 1) + live_hosts = self._filter_zero_token_hosts(hosts) + if len(live_hosts) > 1: + self._position = randint(0, len(live_hosts) - 1) def distance(self, host): + if self._is_ignored_zero_token_host(host): + return HostDistance.IGNORED return HostDistance.LOCAL def make_query_plan(self, working_keyspace=None, query=None): @@ -190,7 +208,7 @@ def make_query_plan(self, working_keyspace=None, query=None): pos = self._position self._position += 1 - hosts = self._live_hosts + hosts = self._filter_zero_token_hosts(self._live_hosts) length = len(hosts) if length: pos %= length @@ -257,6 +275,9 @@ def populate(self, cluster, hosts): self._position = randint(0, len(hosts) - 1) if hosts else 0 def distance(self, host): + if self._is_ignored_zero_token_host(host): + return HostDistance.IGNORED + dc = self._dc(host) if dc == self.local_dc: return HostDistance.LOCAL @@ -264,7 +285,7 @@ def distance(self, host): if not self.used_hosts_per_remote_dc: return HostDistance.IGNORED else: - dc_hosts = self._dc_live_hosts.get(dc) + dc_hosts = self._filter_zero_token_hosts(self._dc_live_hosts.get(dc, ())) if not dc_hosts: return HostDistance.IGNORED @@ -279,7 +300,7 @@ def make_query_plan(self, working_keyspace=None, query=None): pos = self._position self._position += 1 - local_live = self._dc_live_hosts.get(self.local_dc, ()) + local_live = self._filter_zero_token_hosts(self._dc_live_hosts.get(self.local_dc, ())) pos = (pos % len(local_live)) if local_live else 0 for host in islice(cycle(local_live), pos, pos + len(local_live)): yield host @@ -287,7 +308,7 @@ def make_query_plan(self, working_keyspace=None, query=None): # the dict can change, so get candidate DCs iterating over keys of a copy other_dcs = [dc for dc in self._dc_live_hosts.copy().keys() if dc != self.local_dc] for dc in other_dcs: - remote_live = self._dc_live_hosts.get(dc, ()) + remote_live = self._filter_zero_token_hosts(self._dc_live_hosts.get(dc, ())) for host in remote_live[:self.used_hosts_per_remote_dc]: yield host @@ -372,6 +393,9 @@ def populate(self, cluster, hosts): self._position = randint(0, len(hosts) - 1) if hosts else 0 def distance(self, host): + if self._is_ignored_zero_token_host(host): + return HostDistance.IGNORED + rack = self._rack(host) dc = self._dc(host) if rack == self.local_rack and dc == self.local_dc: @@ -383,7 +407,7 @@ def distance(self, host): if not self.used_hosts_per_remote_dc: return HostDistance.IGNORED - dc_hosts = self._dc_live_hosts.get(dc, ()) + dc_hosts = self._filter_zero_token_hosts(self._dc_live_hosts.get(dc, ())) if not dc_hosts: return HostDistance.IGNORED if host in dc_hosts and dc_hosts.index(host) < self.used_hosts_per_remote_dc: @@ -395,14 +419,15 @@ def make_query_plan(self, working_keyspace=None, query=None): pos = self._position self._position += 1 - local_rack_live = self._live_hosts.get((self.local_dc, self.local_rack), ()) + local_rack_live = self._filter_zero_token_hosts(self._live_hosts.get((self.local_dc, self.local_rack), ())) pos = (pos % len(local_rack_live)) if local_rack_live else 0 # Slice the cyclic iterator to start from pos and include the next len(local_live) elements # This ensures we get exactly one full cycle starting from pos for host in islice(cycle(local_rack_live), pos, pos + len(local_rack_live)): yield host - local_live = [host for host in self._dc_live_hosts.get(self.local_dc, ()) if host.rack != self.local_rack] + local_live = [host for host in self._filter_zero_token_hosts(self._dc_live_hosts.get(self.local_dc, ())) + if host.rack != self.local_rack] pos = (pos % len(local_live)) if local_live else 0 for host in islice(cycle(local_live), pos, pos + len(local_live)): yield host @@ -410,6 +435,7 @@ def make_query_plan(self, working_keyspace=None, query=None): # the dict can change, so get candidate DCs iterating over keys of a copy for dc, remote_live in self._dc_live_hosts.copy().items(): if dc != self.local_dc: + remote_live = self._filter_zero_token_hosts(remote_live) for host in remote_live[:self.used_hosts_per_remote_dc]: yield host @@ -491,6 +517,9 @@ def check_supported(self): (self.__class__.__name__, self._cluster_metadata.partitioner)) def distance(self, *args, **kwargs): + host = args[0] if args else kwargs.get('host') + if host is not None and self._is_ignored_zero_token_host(host): + return HostDistance.IGNORED return self._child_policy.distance(*args, **kwargs) def make_query_plan(self, working_keyspace=None, query=None): @@ -499,7 +528,8 @@ def make_query_plan(self, working_keyspace=None, query=None): child = self._child_policy if query is None or query.routing_key is None or keyspace is None: for host in child.make_query_plan(keyspace, query): - yield host + if not self._is_ignored_zero_token_host(host): + yield host return replicas = [] @@ -520,13 +550,15 @@ def make_query_plan(self, working_keyspace=None, query=None): def yield_in_order(hosts): for distance in [HostDistance.LOCAL_RACK, HostDistance.LOCAL, HostDistance.REMOTE]: for replica in hosts: - if replica.is_up and child.distance(replica) == distance: + if (not self._is_ignored_zero_token_host(replica) and + replica.is_up and child.distance(replica) == distance): yield replica # yield replicas: local_rack, local, remote yield from yield_in_order(replicas) # yield rest of the cluster: local_rack, local, remote - yield from yield_in_order([host for host in child.make_query_plan(keyspace, query) if host not in replicas]) + yield from yield_in_order([host for host in child.make_query_plan(keyspace, query) + if host not in replicas and not self._is_ignored_zero_token_host(host)]) def on_up(self, *args, **kwargs): return self._child_policy.on_up(*args, **kwargs) @@ -554,6 +586,8 @@ class WhiteListRoundRobinPolicy(RoundRobinPolicy): attempts are made to private IP addresses remotely """ + _ignore_zero_token_hosts = False + def __init__(self, hosts): """ The `hosts` parameter should be a sequence of hosts to permit @@ -674,6 +708,9 @@ def distance(self, host): :attr:`~HostDistance.IGNORED` if falsy, and defers to the child policy otherwise. """ + if self._is_ignored_zero_token_host(host): + return HostDistance.IGNORED + if self.predicate(host): return self._child_policy.distance(host) else: @@ -699,7 +736,7 @@ def make_query_plan(self, working_keyspace=None, query=None): working_keyspace=working_keyspace, query=query ) for host in child_qp: - if self.predicate(host): + if not self._is_ignored_zero_token_host(host) and self.predicate(host): yield host def check_supported(self): @@ -1305,6 +1342,9 @@ def __init__(self, child_policy): self._child_policy = child_policy def distance(self, *args, **kwargs): + host = args[0] if args else kwargs.get('host') + if host is not None and self._is_ignored_zero_token_host(host): + return HostDistance.IGNORED return self._child_policy.distance(*args, **kwargs) def populate(self, cluster, hosts): @@ -1347,14 +1387,15 @@ def make_query_plan(self, working_keyspace=None, query=None): target_host = self._cluster_metadata.get_host(addr) child = self._child_policy - if target_host and target_host.is_up: + if target_host and target_host.is_up and not self._is_ignored_zero_token_host(target_host): yield target_host for h in child.make_query_plan(keyspace, query): - if h != target_host: + if h != target_host and not self._is_ignored_zero_token_host(h): yield h else: for h in child.make_query_plan(keyspace, query): - yield h + if not self._is_ignored_zero_token_host(h): + yield h # TODO for backward compatibility, remove in next major diff --git a/cassandra/pool.py b/cassandra/pool.py index 2da657256f..d19844917e 100644 --- a/cassandra/pool.py +++ b/cassandra/pool.py @@ -127,6 +127,11 @@ class Host(object): up or down. """ + is_zero_token = False + """ + :const:`True` if the node has no tokens in the system topology tables. + """ + release_version = None """ release_version as queried from the control connection system tables @@ -179,6 +184,7 @@ def __init__(self, endpoint, conviction_policy_factory, datacenter=None, rack=No raise ValueError("host_id may not be None") self.host_id = host_id self.set_location_info(datacenter, rack) + self.is_zero_token = False self.lock = RLock() @property @@ -927,4 +933,3 @@ def open_count(self): def _excess_connection_limit(self): return self.host.sharding_info.shards_count * self.max_excess_connections_per_shard_multiplier - diff --git a/tests/unit/test_control_connection.py b/tests/unit/test_control_connection.py index b8be438bf4..c5a92138f8 100644 --- a/tests/unit/test_control_connection.py +++ b/tests/unit/test_control_connection.py @@ -112,8 +112,10 @@ def __init__(self): self.endpoint_factory = DefaultEndPointFactory().configure(self) self.ssl_options = None - def add_host(self, endpoint, datacenter, rack, signal=False, refresh_nodes=True, host_id=None): + def add_host(self, endpoint, datacenter, rack, signal=False, refresh_nodes=True, host_id=None, + is_zero_token=False): host = Host(endpoint, SimpleConvictionPolicy, datacenter, rack, host_id=host_id) + host.is_zero_token = is_zero_token host, _ = self.metadata.add_or_return_host(host) self.added_hosts.append(host) return host, True @@ -212,6 +214,7 @@ def _assert_zero_token_host_without_token_map_entry(self, endpoint, host_id): assert zero_token_host.host_id == host_id assert zero_token_host.datacenter == "dc1" assert zero_token_host.rack == "rack1" + assert zero_token_host.is_zero_token assert zero_token_host not in self.cluster.metadata.token_map return zero_token_host @@ -456,6 +459,23 @@ def test_refresh_nodes_and_tokens_keeps_zero_token_local_host_without_token_map_ assert [] == self.cluster.added_hosts assert [] == self.cluster.metadata.removed_hosts + def test_refresh_nodes_and_tokens_updates_zero_token_status_when_tokens_change(self): + self.connection.peer_results[1].append( + ["192.168.1.3", "10.0.0.3", "a", "dc1", "rack1", None, "uuid4"] + ) + self.cluster.scheduler.schedule = lambda delay, f, *args, **kwargs: f(*args, **kwargs) + + self.control_connection.refresh_node_list_and_token_map() + zero_token_host = self._assert_zero_token_host_without_token_map_entry( + DefaultEndPoint("192.168.1.3"), "uuid4") + + self.connection.peer_results[1][-1][5] = ["3", "103", "203"] + self.control_connection.refresh_node_list_and_token_map() + + assert not zero_token_host.is_zero_token + assert zero_token_host in self.cluster.metadata.token_map + assert self.cluster.metadata.token_map[zero_token_host] == ["3", "103", "203"] + def test_refresh_nodes_and_tokens_remove_host(self): del self.connection.peer_results[1][1] self.control_connection.refresh_node_list_and_token_map() diff --git a/tests/unit/test_policies.py b/tests/unit/test_policies.py index 6142af1aa1..d362a1effe 100644 --- a/tests/unit/test_policies.py +++ b/tests/unit/test_policies.py @@ -33,13 +33,22 @@ RetryPolicy, WriteType, DowngradingConsistencyRetryPolicy, ConstantReconnectionPolicy, LoadBalancingPolicy, ConvictionPolicy, ReconnectionPolicy, FallthroughRetryPolicy, - IdentityTranslator, EC2MultiRegionTranslator, HostFilterPolicy, ExponentialBackoffRetryPolicy) + IdentityTranslator, EC2MultiRegionTranslator, HostFilterPolicy, ExponentialBackoffRetryPolicy, + DefaultLoadBalancingPolicy) from cassandra.connection import DefaultEndPoint, UnixSocketEndPoint from cassandra.pool import Host from cassandra.query import Statement from cassandra.tablets import Tablets, Tablet +def make_host(address, datacenter="dc1", rack="rack1", is_zero_token=False): + host = Host(DefaultEndPoint(address), SimpleConvictionPolicy, host_id=uuid.uuid4()) + host.set_location_info(datacenter, rack) + host.set_up() + host.is_zero_token = is_zero_token + return host + + class LoadBalancingPolicyTest(unittest.TestCase): def test_non_implemented(self): """ @@ -187,6 +196,26 @@ def test_no_live_nodes(self): qplan = list(policy.make_query_plan()) assert qplan == [] + +@pytest.mark.parametrize("policy", [ + RoundRobinPolicy(), + DCAwareRoundRobinPolicy("dc1"), + RackAwareRoundRobinPolicy("dc1", "rack1"), +]) +def test_zero_token_hosts_ignored_by_round_robin_policies(policy): + host = make_host("127.0.0.1") + zero_token_host = make_host("127.0.0.2", is_zero_token=True) + + policy.populate(Mock(), [host, zero_token_host]) + + assert list(policy.make_query_plan()) == [host] + assert policy.distance(zero_token_host) == HostDistance.IGNORED + + zero_token_host.is_zero_token = False + + assert set(policy.make_query_plan()) == {host, zero_token_host} + assert policy.distance(zero_token_host) != HostDistance.IGNORED + @pytest.mark.parametrize("policy_specialization, constructor_args", [(DCAwareRoundRobinPolicy, ("dc1", )), (RackAwareRoundRobinPolicy, ("dc1", "rack1"))]) class TestRackOrDCAwareRoundRobinPolicy: @@ -850,6 +879,27 @@ def test_statement_keyspace(self): assert replicas + hosts[:2] == qplan cluster.metadata.get_replicas.assert_called_with(statement_keyspace, routing_key) + def test_ignores_zero_token_hosts(self): + host = make_host("127.0.0.1") + zero_token_host = make_host("127.0.0.2", is_zero_token=True) + + cluster = Mock(spec=Cluster) + cluster.metadata = Mock(spec=Metadata) + cluster.metadata._tablets = Mock(spec=Tablets) + cluster.metadata._tablets.get_tablet_for_key.return_value = None + cluster.metadata.get_replicas.return_value = [zero_token_host, host] + + child_policy = Mock() + child_policy.make_query_plan.return_value = [zero_token_host, host] + child_policy.distance.return_value = HostDistance.LOCAL + + policy = TokenAwarePolicy(child_policy, shuffle_replicas=False) + policy.populate(cluster, [host, zero_token_host]) + + query = Statement(routing_key=b"routing_key", keyspace="keyspace_name") + + assert list(policy.make_query_plan(None, query)) == [host] + def test_shuffles_if_given_keyspace_and_routing_key(self): """ Test to validate the hosts are shuffled when `shuffle_replicas` is truthy @@ -1432,6 +1482,14 @@ def test_hosts_with_socket_hostname(self): assert policy.distance(host) == HostDistance.LOCAL + def test_ignores_zero_token_status(self): + policy = WhiteListRoundRobinPolicy(["127.0.0.1"]) + host = make_host("127.0.0.1", is_zero_token=True) + policy.populate(None, [host]) + + assert list(policy.make_query_plan()) == [host] + assert policy.distance(host) == HostDistance.LOCAL + class AddressTranslatorTest(unittest.TestCase): @@ -1567,6 +1625,12 @@ def test_accepted_filter_defers_to_child_policy(self): # second call of _child_policy with count() side effect assert self.hfp.distance(self.accepted_host) == distances[1] + def test_zero_token_host_is_ignored_before_child_policy(self): + host = make_host("acceptme", is_zero_token=True) + + assert self.hfp.distance(host) == HostDistance.IGNORED + assert self.hfp._child_policy.distance.call_count == 0 + class HostFilterPolicyPopulateTest(unittest.TestCase): @@ -1618,6 +1682,30 @@ def test_query_plan_deferred_to_child(self): ) assert qp == hfp._child_policy.make_query_plan.return_value + def test_zero_token_hosts_are_filtered(self): + host = make_host("127.0.0.1") + zero_token_host = make_host("127.0.0.2", is_zero_token=True) + hfp = HostFilterPolicy( + child_policy=Mock(name='child_policy', make_query_plan=Mock(return_value=[zero_token_host, host])), + predicate=lambda host: True + ) + + assert list(hfp.make_query_plan()) == [host] + + def test_default_policy_filters_zero_token_target_and_child_hosts(self): + host = make_host("127.0.0.1") + zero_token_host = make_host("127.0.0.2", is_zero_token=True) + cluster = Mock(spec=Cluster) + cluster.metadata = Mock(spec=Metadata) + cluster.metadata.get_host.return_value = zero_token_host + + child_policy = Mock(name='child_policy', make_query_plan=Mock(return_value=[zero_token_host, host])) + policy = DefaultLoadBalancingPolicy(child_policy) + policy.populate(cluster, [host, zero_token_host]) + query = Mock(target_host=zero_token_host.address, keyspace=None) + + assert list(policy.make_query_plan(query=query)) == [host] + def test_wrap_token_aware(self): cluster = Mock(spec=Cluster) hosts = [Host(DefaultEndPoint("127.0.0.{}".format(i)), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(1, 6)] From 169fc7deb9d94753c44b9664bd345add521fd6cf Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Tue, 5 May 2026 18:27:37 -0400 Subject: [PATCH 4/6] Fix zero-token host review issues Clear host addition state when pool creation aborts after a partial async start. Keep zero-token filtering compatible with child policies that only implement the public load-balancing API. --- cassandra/cluster.py | 14 ++++++++++---- cassandra/policies.py | 5 ++++- tests/unit/test_cluster.py | 29 ++++++++++++++++++++++++++--- tests/unit/test_policies.py | 34 ++++++++++++++++++++++++++++++++++ 4 files changed, 74 insertions(+), 8 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 4e5524c293..661cdc07d3 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -2059,6 +2059,8 @@ def on_add(self, host, refresh_nodes=True): host._currently_handling_node_addition = True have_future = False + add_aborted = False + futures = set() try: self.profile_manager.on_add(host) self.control_connection.on_add(host, refresh_nodes) @@ -2076,12 +2078,14 @@ def on_add(self, host, refresh_nodes=True): futures_lock = Lock() futures_results = [] - futures = set() def future_completed(future): with futures_lock: futures.discard(future) + if add_aborted: + return + try: futures_results.append(future.result()) except Exception as exc: @@ -2116,9 +2120,11 @@ def future_completed(future): if not have_future: self._finalize_add(host) except Exception: - if not have_future: - with host.lock: - host._currently_handling_node_addition = False + add_aborted = True + for future in tuple(futures): + future.cancel() + with host.lock: + host._currently_handling_node_addition = False raise def _finalize_add(self, host, set_up=True): diff --git a/cassandra/policies.py b/cassandra/policies.py index eb0f016b3c..e472251593 100644 --- a/cassandra/policies.py +++ b/cassandra/policies.py @@ -133,7 +133,10 @@ def _is_ignored_zero_token_host(self, host): child_policy = getattr(self, '_child_policy', None) if child_policy is not None: # Preserve child opt-outs through wrapper layers. - return bool(child_policy._is_ignored_zero_token_host(host)) + child_filter = getattr(child_policy, '_is_ignored_zero_token_host', None) + if child_filter is not None: + return bool(child_filter(host)) + return bool(getattr(child_policy, '_ignore_zero_token_hosts', self._ignore_zero_token_hosts)) return bool(self._ignore_zero_token_hosts) diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index 49208ac53e..b0eff3962b 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -11,13 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import unittest - +from concurrent.futures import Future import logging import socket +import unittest +import uuid from unittest.mock import patch, Mock -import uuid from cassandra import ConsistencyLevel, DriverException, Timeout, Unavailable, RequestExecutionException, ReadTimeout, WriteTimeout, CoordinationFailure, ReadFailure, WriteFailure, FunctionFailure, AlreadyExists,\ InvalidRequest, Unauthorized, AuthenticationFailed, OperationTimedOut, UnsupportedOperation, RequestValidationException, ConfigurationException, ProtocolVersion @@ -101,6 +101,29 @@ def test_tuple_for_contact_points(self): assert cp.address == '127.0.0.3' assert cp.port == 9999 + def test_on_add_clears_in_progress_flag_when_later_session_add_fails(self): + cluster = Cluster(protocol_version=4) + host = Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4()) + successful_session = Mock() + successful_session.add_or_renew_pool.return_value = Future() + successful_session.update_created_pools.return_value = set() + failing_session = Mock() + failing_session.add_or_renew_pool.side_effect = RuntimeError("pool add failed") + cluster.sessions = [successful_session, failing_session] + + try: + with pytest.raises(RuntimeError): + cluster.on_add(host, refresh_nodes=False) + + assert not host._currently_handling_node_addition + + with pytest.raises(RuntimeError): + cluster.on_add(host, refresh_nodes=False) + + assert successful_session.add_or_renew_pool.call_count == 2 + finally: + cluster.shutdown() + def test_invalid_contact_point_types(self): with pytest.raises(ValueError): Cluster(contact_points=[None], protocol_version=4, connect_timeout=1) diff --git a/tests/unit/test_policies.py b/tests/unit/test_policies.py index d362a1effe..37577fbe2a 100644 --- a/tests/unit/test_policies.py +++ b/tests/unit/test_policies.py @@ -900,6 +900,40 @@ def test_ignores_zero_token_hosts(self): assert list(policy.make_query_plan(None, query)) == [host] + def test_ignores_zero_token_hosts_with_legacy_child_policy(self): + host = make_host("127.0.0.1") + zero_token_host = make_host("127.0.0.2", is_zero_token=True) + + class LegacyChildPolicy(object): + def populate(self, cluster, hosts): + pass + + def distance(self, host): + return HostDistance.LOCAL + + def make_query_plan(self, working_keyspace=None, query=None): + return [zero_token_host, host] + + def on_up(self, host): + pass + + def on_down(self, host): + pass + + def on_add(self, host): + pass + + def on_remove(self, host): + pass + + cluster = Mock(spec=Cluster) + cluster.metadata = Mock(spec=Metadata) + + policy = TokenAwarePolicy(LegacyChildPolicy()) + policy.populate(cluster, [host, zero_token_host]) + + assert list(policy.make_query_plan()) == [host] + def test_shuffles_if_given_keyspace_and_routing_key(self): """ Test to validate the hosts are shuffled when `shuffle_replicas` is truthy From 882c276ab2465d5940bd87e0742e32c4fcf9811a Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Tue, 5 May 2026 18:47:43 -0400 Subject: [PATCH 5/6] cluster: wait for all pool futures before marking host up --- cassandra/cluster.py | 8 +++-- tests/unit/test_cluster.py | 62 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 2 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 661cdc07d3..40daee5cdb 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -1959,8 +1959,10 @@ def on_up(self, host): future = session.add_or_renew_pool(host, is_host_addition=False) if future is not None: have_future = True - future.add_done_callback(callback) futures.add(future) + + for future in tuple(futures): + future.add_done_callback(callback) except Exception: log.exception("Unexpected failure handling node %s being marked up:", host) for future in futures: @@ -2115,7 +2117,9 @@ def future_completed(future): if future is not None: have_future = True futures.add(future) - future.add_done_callback(future_completed) + + for future in tuple(futures): + future.add_done_callback(future_completed) if not have_future: self._finalize_add(host) diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index b0eff3962b..932c25c537 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -124,6 +124,68 @@ def test_on_add_clears_in_progress_flag_when_later_session_add_fails(self): finally: cluster.shutdown() + def test_on_add_waits_for_all_session_pool_futures_before_marking_host_up(self): + cluster = Cluster(protocol_version=4) + host = Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4()) + completed_future = Future() + completed_future.set_result(True) + pending_future = Future() + first_session = Mock() + first_session.add_or_renew_pool.return_value = completed_future + second_session = Mock() + second_session.add_or_renew_pool.return_value = pending_future + listener = Mock() + cluster.sessions = [first_session, second_session] + cluster.register_listener(listener) + + try: + cluster.on_add(host, refresh_nodes=False) + + assert host.is_up is not True + listener.on_add.assert_not_called() + first_session.update_created_pools.assert_not_called() + second_session.update_created_pools.assert_not_called() + + pending_future.set_result(True) + + assert host.is_up is True + listener.on_add.assert_called_once_with(host) + first_session.update_created_pools.assert_called_once_with() + second_session.update_created_pools.assert_called_once_with() + finally: + cluster.shutdown() + + def test_on_up_waits_for_all_session_pool_futures_before_marking_host_up(self): + cluster = Cluster(protocol_version=4) + host = Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4()) + completed_future = Future() + completed_future.set_result(True) + pending_future = Future() + first_session = Mock() + first_session.add_or_renew_pool.return_value = completed_future + second_session = Mock() + second_session.add_or_renew_pool.return_value = pending_future + listener = Mock() + cluster.sessions = [first_session, second_session] + cluster.register_listener(listener) + + try: + cluster.on_up(host) + + assert host.is_up is not True + listener.on_up.assert_not_called() + first_session.update_created_pools.assert_not_called() + second_session.update_created_pools.assert_not_called() + + pending_future.set_result(True) + + assert host.is_up is True + listener.on_up.assert_called_once_with(host) + first_session.update_created_pools.assert_called_once_with() + second_session.update_created_pools.assert_called_once_with() + finally: + cluster.shutdown() + def test_invalid_contact_point_types(self): with pytest.raises(ValueError): Cluster(contact_points=[None], protocol_version=4, connect_timeout=1) From 83007265c0e6e45ae310d845ad2667bc141f898d Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Tue, 5 May 2026 19:15:27 -0400 Subject: [PATCH 6/6] policies: hide in-flight hosts from query plans --- cassandra/policies.py | 37 ++++++++++++++++++++------------ tests/unit/test_cluster.py | 43 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 13 deletions(-) diff --git a/cassandra/policies.py b/cassandra/policies.py index e472251593..72f0332a1f 100644 --- a/cassandra/policies.py +++ b/cassandra/policies.py @@ -143,6 +143,17 @@ def _is_ignored_zero_token_host(self, host): def _filter_zero_token_hosts(self, hosts): return tuple(h for h in hosts if not self._is_ignored_zero_token_host(h)) + def _is_in_flight_host(self, host): + return ((getattr(host, '_currently_handling_node_up', False) is True or + getattr(host, '_currently_handling_node_addition', False) is True) and + getattr(host, 'is_up', None) is not True) + + def _is_ignored_query_plan_host(self, host): + return self._is_ignored_zero_token_host(host) or self._is_in_flight_host(host) + + def _filter_query_plan_hosts(self, hosts): + return tuple(h for h in hosts if not self._is_ignored_query_plan_host(h)) + def distance(self, host): """ Returns a measure of how remote a :class:`~.pool.Host` is in @@ -211,7 +222,7 @@ def make_query_plan(self, working_keyspace=None, query=None): pos = self._position self._position += 1 - hosts = self._filter_zero_token_hosts(self._live_hosts) + hosts = self._filter_query_plan_hosts(self._live_hosts) length = len(hosts) if length: pos %= length @@ -303,7 +314,7 @@ def make_query_plan(self, working_keyspace=None, query=None): pos = self._position self._position += 1 - local_live = self._filter_zero_token_hosts(self._dc_live_hosts.get(self.local_dc, ())) + local_live = self._filter_query_plan_hosts(self._dc_live_hosts.get(self.local_dc, ())) pos = (pos % len(local_live)) if local_live else 0 for host in islice(cycle(local_live), pos, pos + len(local_live)): yield host @@ -311,7 +322,7 @@ def make_query_plan(self, working_keyspace=None, query=None): # the dict can change, so get candidate DCs iterating over keys of a copy other_dcs = [dc for dc in self._dc_live_hosts.copy().keys() if dc != self.local_dc] for dc in other_dcs: - remote_live = self._filter_zero_token_hosts(self._dc_live_hosts.get(dc, ())) + remote_live = self._filter_query_plan_hosts(self._dc_live_hosts.get(dc, ())) for host in remote_live[:self.used_hosts_per_remote_dc]: yield host @@ -422,14 +433,14 @@ def make_query_plan(self, working_keyspace=None, query=None): pos = self._position self._position += 1 - local_rack_live = self._filter_zero_token_hosts(self._live_hosts.get((self.local_dc, self.local_rack), ())) + local_rack_live = self._filter_query_plan_hosts(self._live_hosts.get((self.local_dc, self.local_rack), ())) pos = (pos % len(local_rack_live)) if local_rack_live else 0 # Slice the cyclic iterator to start from pos and include the next len(local_live) elements # This ensures we get exactly one full cycle starting from pos for host in islice(cycle(local_rack_live), pos, pos + len(local_rack_live)): yield host - local_live = [host for host in self._filter_zero_token_hosts(self._dc_live_hosts.get(self.local_dc, ())) + local_live = [host for host in self._filter_query_plan_hosts(self._dc_live_hosts.get(self.local_dc, ())) if host.rack != self.local_rack] pos = (pos % len(local_live)) if local_live else 0 for host in islice(cycle(local_live), pos, pos + len(local_live)): @@ -438,7 +449,7 @@ def make_query_plan(self, working_keyspace=None, query=None): # the dict can change, so get candidate DCs iterating over keys of a copy for dc, remote_live in self._dc_live_hosts.copy().items(): if dc != self.local_dc: - remote_live = self._filter_zero_token_hosts(remote_live) + remote_live = self._filter_query_plan_hosts(remote_live) for host in remote_live[:self.used_hosts_per_remote_dc]: yield host @@ -531,7 +542,7 @@ def make_query_plan(self, working_keyspace=None, query=None): child = self._child_policy if query is None or query.routing_key is None or keyspace is None: for host in child.make_query_plan(keyspace, query): - if not self._is_ignored_zero_token_host(host): + if not self._is_ignored_query_plan_host(host): yield host return @@ -553,7 +564,7 @@ def make_query_plan(self, working_keyspace=None, query=None): def yield_in_order(hosts): for distance in [HostDistance.LOCAL_RACK, HostDistance.LOCAL, HostDistance.REMOTE]: for replica in hosts: - if (not self._is_ignored_zero_token_host(replica) and + if (not self._is_ignored_query_plan_host(replica) and replica.is_up and child.distance(replica) == distance): yield replica @@ -561,7 +572,7 @@ def yield_in_order(hosts): yield from yield_in_order(replicas) # yield rest of the cluster: local_rack, local, remote yield from yield_in_order([host for host in child.make_query_plan(keyspace, query) - if host not in replicas and not self._is_ignored_zero_token_host(host)]) + if host not in replicas and not self._is_ignored_query_plan_host(host)]) def on_up(self, *args, **kwargs): return self._child_policy.on_up(*args, **kwargs) @@ -739,7 +750,7 @@ def make_query_plan(self, working_keyspace=None, query=None): working_keyspace=working_keyspace, query=query ) for host in child_qp: - if not self._is_ignored_zero_token_host(host) and self.predicate(host): + if not self._is_ignored_query_plan_host(host) and self.predicate(host): yield host def check_supported(self): @@ -1390,14 +1401,14 @@ def make_query_plan(self, working_keyspace=None, query=None): target_host = self._cluster_metadata.get_host(addr) child = self._child_policy - if target_host and target_host.is_up and not self._is_ignored_zero_token_host(target_host): + if target_host and target_host.is_up and not self._is_ignored_query_plan_host(target_host): yield target_host for h in child.make_query_plan(keyspace, query): - if h != target_host and not self._is_ignored_zero_token_host(h): + if h != target_host and not self._is_ignored_query_plan_host(h): yield h else: for h in child.make_query_plan(keyspace, query): - if not self._is_ignored_zero_token_host(h): + if not self._is_ignored_query_plan_host(h): yield h diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index 932c25c537..39df2a8bc4 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -155,6 +155,27 @@ def test_on_add_waits_for_all_session_pool_futures_before_marking_host_up(self): finally: cluster.shutdown() + def test_on_add_excludes_host_from_query_plan_until_pool_futures_complete(self): + cluster = Cluster(protocol_version=4) + host = Host("127.0.0.1", SimpleConvictionPolicy, datacenter="dc1", rack="rack1", host_id=uuid.uuid4()) + pending_future = Future() + session = Mock() + session.add_or_renew_pool.return_value = pending_future + session.update_created_pools.return_value = set() + cluster.sessions = [session] + + try: + cluster.on_add(host, refresh_nodes=False) + + load_balancer = cluster.profile_manager.default.load_balancing_policy + assert host not in list(load_balancer.make_query_plan()) + + pending_future.set_result(True) + + assert list(load_balancer.make_query_plan()) == [host] + finally: + cluster.shutdown() + def test_on_up_waits_for_all_session_pool_futures_before_marking_host_up(self): cluster = Cluster(protocol_version=4) host = Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4()) @@ -186,6 +207,28 @@ def test_on_up_waits_for_all_session_pool_futures_before_marking_host_up(self): finally: cluster.shutdown() + def test_on_up_excludes_host_from_query_plan_until_pool_futures_complete(self): + cluster = Cluster(protocol_version=4) + host = Host("127.0.0.1", SimpleConvictionPolicy, datacenter="dc1", rack="rack1", host_id=uuid.uuid4()) + host.set_down() + pending_future = Future() + session = Mock() + session.add_or_renew_pool.return_value = pending_future + session.update_created_pools.return_value = set() + cluster.sessions = [session] + + try: + cluster.on_up(host) + + load_balancer = cluster.profile_manager.default.load_balancing_policy + assert host not in list(load_balancer.make_query_plan()) + + pending_future.set_result(True) + + assert list(load_balancer.make_query_plan()) == [host] + finally: + cluster.shutdown() + def test_invalid_contact_point_types(self): with pytest.raises(ValueError): Cluster(contact_points=[None], protocol_version=4, connect_timeout=1)