From e89340d3613d538cf9a35972f2ecc8188392728e Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Tue, 5 May 2026 17:03:59 -0400 Subject: [PATCH] control-connection: reconnect after discounted down signal --- cassandra/cluster.py | 16 ++++++++------- tests/unit/test_control_connection.py | 29 +++++++++++++++++++++++++-- 2 files changed, 36 insertions(+), 9 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 5e7a68bc1c..b35b437af4 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -2020,7 +2020,7 @@ def on_down(self, host, is_host_addition, expect_host_to_be_down=False): Intended for internal use only. """ if self.is_shutdown: - return + return False with host.lock: was_up = host.is_up @@ -2035,14 +2035,15 @@ def on_down(self, host, is_host_addition, expect_host_to_be_down=False): if pool_state: connected |= pool_state['open_count'] > 0 if connected: - return + return False host.set_down() if (not was_up and not expect_host_to_be_down) or host.is_currently_reconnecting(): - return + return False log.warning("Host %s has been marked down", host) self.on_down_potentially_blocking(host, is_host_addition) + return True def on_add(self, host, refresh_nodes=True): if self.is_shutdown: @@ -2134,8 +2135,8 @@ def on_remove(self, host): def signal_connection_failure(self, host, connection_exc, is_host_addition, expect_host_to_be_down=False): is_down = host.signal_connection_failure(connection_exc) if is_down: - self.on_down(host, is_host_addition, expect_host_to_be_down) - return is_down + return self.on_down(host, is_host_addition, expect_host_to_be_down) + return False def add_host(self, endpoint, datacenter=None, rack=None, signal=True, refresh_nodes=True, host_id=None): """ @@ -4226,9 +4227,10 @@ def _signal_error(self): # host may be None if it's already been removed, but that indicates # that errors have already been reported, so we're fine if host: - self._cluster.signal_connection_failure( + is_down = self._cluster.signal_connection_failure( host, self._connection.last_error, is_host_addition=False) - return + if is_down: + return # if the connection is not defunct or the host already left, reconnect # manually diff --git a/tests/unit/test_control_connection.py b/tests/unit/test_control_connection.py index 037d4a8888..1cc7d7a19c 100644 --- a/tests/unit/test_control_connection.py +++ b/tests/unit/test_control_connection.py @@ -13,15 +13,16 @@ # limitations under the License. import unittest +import uuid from concurrent.futures import ThreadPoolExecutor from unittest.mock import Mock, ANY, call from cassandra import OperationTimedOut, SchemaTargetType, SchemaChangeType from cassandra.protocol import ResultMessage, RESULT_KIND_ROWS -from cassandra.cluster import ControlConnection, _Scheduler, ProfileManager, EXEC_PROFILE_DEFAULT, ExecutionProfile +from cassandra.cluster import Cluster, ControlConnection, _Scheduler, ProfileManager, EXEC_PROFILE_DEFAULT, ExecutionProfile from cassandra.pool import Host -from cassandra.connection import EndPoint, DefaultEndPoint, DefaultEndPointFactory +from cassandra.connection import EndPoint, DefaultEndPoint, DefaultEndPointFactory, ConnectionException from cassandra.policies import (SimpleConvictionPolicy, RoundRobinPolicy, ConstantReconnectionPolicy, IdentityTranslator) @@ -301,6 +302,30 @@ def test_wait_for_schema_agreement_none_timeout(self): cc._time = self.time assert cc.wait_for_schema_agreement() + def test_signal_error_reconnects_when_host_down_signal_is_discounted(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) + host.set_up() + cluster.metadata.add_or_return_host(host) + + session = Mock() + session.get_pool_state.return_value = {host: {"open_count": 1}} + cluster.sessions.add(session) + + connection_error = ConnectionException("control connection failed", endpoint=host.endpoint) + cluster.control_connection._connection = Mock( + endpoint=host.endpoint, + is_defunct=True, + last_error=connection_error) + cluster.control_connection.reconnect = Mock() + + cluster.control_connection._signal_error() + + assert host.is_up is True + cluster.control_connection.reconnect.assert_called_once_with() + def test_refresh_nodes_and_tokens(self): self.control_connection.refresh_node_list_and_token_map() meta = self.cluster.metadata