Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions cassandra/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -1975,6 +1975,10 @@ def on_up(self, host):
with host.lock:
host.set_up()
host._currently_handling_node_up = False
for listener in self.listeners:
listener.on_up(host)
for session in tuple(self.sessions):
session.update_created_pools()

# for testing purposes
return futures
Expand Down Expand Up @@ -2020,7 +2024,7 @@ def on_down(self, host, is_host_addition, expect_host_to_be_down=False):
Intended for internal use only.
"""
if self.is_shutdown:
return
return False

with host.lock:
was_up = host.is_up
Expand All @@ -2035,14 +2039,15 @@ def on_down(self, host, is_host_addition, expect_host_to_be_down=False):
if pool_state:
connected |= pool_state['open_count'] > 0
if connected:
return
return False

host.set_down()
if (not was_up and not expect_host_to_be_down) or host.is_currently_reconnecting():
return
return False
log.warning("Host %s has been marked down", host)

self.on_down_potentially_blocking(host, is_host_addition)
return True

def on_add(self, host, refresh_nodes=True):
if self.is_shutdown:
Expand Down Expand Up @@ -2134,8 +2139,8 @@ def on_remove(self, host):
def signal_connection_failure(self, host, connection_exc, is_host_addition, expect_host_to_be_down=False):
is_down = host.signal_connection_failure(connection_exc)
if is_down:
self.on_down(host, is_host_addition, expect_host_to_be_down)
return is_down
return self.on_down(host, is_host_addition, expect_host_to_be_down)
return False

def add_host(self, endpoint, datacenter=None, rack=None, signal=True, refresh_nodes=True, host_id=None):
"""
Expand Down Expand Up @@ -3315,7 +3320,9 @@ def update_created_pools(self):
# we don't eagerly set is_up on previously ignored hosts. None is included here
# to allow us to attempt connections to hosts that have gone from ignored to something
# else.
if distance != HostDistance.IGNORED and host.is_up in (True, None):
if (distance != HostDistance.IGNORED and
host.is_up in (True, None) and
not getattr(host, '_currently_handling_node_up', False)):
future = self.add_or_renew_pool(host, False)
elif distance != pool.host_distance:
# the distance has changed
Expand Down Expand Up @@ -4226,9 +4233,10 @@ def _signal_error(self):
# host may be None if it's already been removed, but that indicates
# that errors have already been reported, so we're fine
if host:
self._cluster.signal_connection_failure(
is_down = self._cluster.signal_connection_failure(
host, self._connection.last_error, is_host_addition=False)
return
if is_down:
return

# if the connection is not defunct or the host already left, reconnect
# manually
Expand Down
44 changes: 44 additions & 0 deletions tests/unit/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import logging
import socket
from concurrent.futures import Future

from unittest.mock import patch, Mock
import uuid
Expand Down Expand Up @@ -229,6 +230,27 @@ def test_connection_factory_passes_compression_kwarg(self):
assert factory.call_args.kwargs['compression'] == expected
assert cluster.compression == expected

def test_on_up_without_pool_futures_notifies_listeners(self):
cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4)
self.addCleanup(cluster.shutdown)

host = Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4())
host.set_down()
cluster.metadata.add_or_return_host(host)

session = Mock()
session.add_or_renew_pool.return_value = None
cluster.sessions.add(session)

listener = Mock()
cluster.register_listener(listener)

cluster.on_up(host)

assert host.is_up is True
listener.on_up.assert_called_once_with(host)
session.update_created_pools.assert_called_once_with()


class SchedulerTest(unittest.TestCase):
# TODO: this suite could be expanded; for now just adding a test covering a ticket
Expand Down Expand Up @@ -339,6 +361,28 @@ def test_set_keyspace_escapes_quotes(self, *_):
assert query == 'USE simple_ks', (
"Simple keyspace names should not be quoted, got: %r" % query)

def test_update_created_pools_skips_host_with_node_up_in_progress(self):
cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4)
self.addCleanup(cluster.shutdown)

host = Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4())
cluster.metadata.add_or_return_host(host)
cluster.profile_manager.populate(cluster, [host])
cluster.profile_manager.on_up(host)

completed = Future()
completed.set_result(True)

with patch.object(Session, "add_or_renew_pool", return_value=completed) as add_or_renew_pool:
session = Session(cluster, [host])
add_or_renew_pool.reset_mock()

session._pools = {}
host._currently_handling_node_up = True

assert session.update_created_pools() == set()
add_or_renew_pool.assert_not_called()

class ProtocolVersionTests(unittest.TestCase):

def test_protocol_downgrade_test(self):
Expand Down
29 changes: 27 additions & 2 deletions tests/unit/test_control_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@
# limitations under the License.

import unittest
import uuid

from concurrent.futures import ThreadPoolExecutor
from unittest.mock import Mock, ANY, call

from cassandra import OperationTimedOut, SchemaTargetType, SchemaChangeType
from cassandra.protocol import ResultMessage, RESULT_KIND_ROWS
from cassandra.cluster import ControlConnection, _Scheduler, ProfileManager, EXEC_PROFILE_DEFAULT, ExecutionProfile
from cassandra.cluster import Cluster, ControlConnection, _Scheduler, ProfileManager, EXEC_PROFILE_DEFAULT, ExecutionProfile
from cassandra.pool import Host
from cassandra.connection import EndPoint, DefaultEndPoint, DefaultEndPointFactory
from cassandra.connection import EndPoint, DefaultEndPoint, DefaultEndPointFactory, ConnectionException
from cassandra.policies import (SimpleConvictionPolicy, RoundRobinPolicy,
ConstantReconnectionPolicy, IdentityTranslator)

Expand Down Expand Up @@ -301,6 +302,30 @@ def test_wait_for_schema_agreement_none_timeout(self):
cc._time = self.time
assert cc.wait_for_schema_agreement()

def test_signal_error_reconnects_when_host_down_signal_is_discounted(self):
cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4)
self.addCleanup(cluster.shutdown)

host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=uuid.uuid4())
host.set_up()
cluster.metadata.add_or_return_host(host)

session = Mock()
session.get_pool_state.return_value = {host: {"open_count": 1}}
cluster.sessions.add(session)

connection_error = ConnectionException("control connection failed", endpoint=host.endpoint)
cluster.control_connection._connection = Mock(
endpoint=host.endpoint,
is_defunct=True,
last_error=connection_error)
cluster.control_connection.reconnect = Mock()

cluster.control_connection._signal_error()

assert host.is_up is True
cluster.control_connection.reconnect.assert_called_once_with()

def test_refresh_nodes_and_tokens(self):
self.control_connection.refresh_node_list_and_token_map()
meta = self.cluster.metadata
Expand Down
Loading