Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def setup(options):
try:
session.execute("""
CREATE KEYSPACE %s
WITH replication = { 'class': 'SimpleStrategy', 'replication_factor': '2' }
WITH replication = { 'class': 'NetworkTopologyStrategy', 'replication_factor': '2' }
""" % options.keyspace)

log.debug("Setting keyspace...")
Expand Down
6 changes: 3 additions & 3 deletions cassandra/cqlengine/management.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _get_context(keyspaces, connections):

def create_keyspace_simple(name, replication_factor, durable_writes=True, connections=None):
"""
Creates a keyspace with SimpleStrategy for replica placement
Creates a keyspace with NetworkTopologyStrategy for replica placement

If the keyspace already exists, it will not be modified.

Expand All @@ -66,11 +66,11 @@ def create_keyspace_simple(name, replication_factor, durable_writes=True, connec
*There are plans to guard schema-modifying functions with an environment-driven conditional.*

:param str name: name of keyspace to create
:param int replication_factor: keyspace replication factor, used with :attr:`~.SimpleStrategy`
:param int replication_factor: keyspace replication factor, used with :attr:`~.NetworkTopologyStrategy`
:param bool durable_writes: Write log is bypassed if set to False
:param list connections: List of connection names
"""
_create_keyspace(name, durable_writes, 'SimpleStrategy',
_create_keyspace(name, durable_writes, 'NetworkTopologyStrategy',
{'replication_factor': replication_factor}, connections=connections)


Expand Down
179 changes: 148 additions & 31 deletions cassandra/io/asyncioreactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
asyncio.run_coroutine_threadsafe
except AttributeError:
raise ImportError(
'Cannot use asyncioreactor without access to '
'asyncio.run_coroutine_threadsafe (added in 3.4.6 and 3.5.1)'
"Cannot use asyncioreactor without access to "
"asyncio.run_coroutine_threadsafe (added in 3.4.6 and 3.5.1)"
)


Expand All @@ -38,12 +38,12 @@ class AsyncioTimer(object):

@property
def end(self):
raise NotImplementedError('{} is not compatible with TimerManager and '
'does not implement .end()')
raise NotImplementedError(
"{} is not compatible with TimerManager and does not implement .end()"
)

def __init__(self, timeout, callback, loop):
delayed = self._call_delayed_coro(timeout=timeout,
callback=callback)
delayed = self._call_delayed_coro(timeout=timeout, callback=callback)
self._handle = asyncio.run_coroutine_threadsafe(delayed, loop=loop)

@staticmethod
Expand All @@ -63,17 +63,51 @@ def cancel(self):
def finish(self):
# connection.Timer method not implemented here because we can't inspect
# the Handle returned from call_later
raise NotImplementedError('{} is not compatible with TimerManager and '
'does not implement .finish()')
raise NotImplementedError(
"{} is not compatible with TimerManager and does not implement .finish()"
)


class _AsyncioProtocol(asyncio.Protocol):
"""
Protocol adapter for asyncio SSL connections. Bridges asyncio's
transport/protocol API back to AsyncioConnection's buffer processing.
"""

def __init__(self, connection):
self._connection = connection
self.transport = None

def connection_made(self, transport):
self.transport = transport

def data_received(self, data):
conn = self._connection
conn._iobuf.write(data)
if conn._iobuf.tell():
conn.process_io_buffer()

def connection_lost(self, exc):
conn = self._connection
if exc:
log.debug("Connection %s lost: %s", conn, exc)
conn.defunct(exc)
else:
log.debug("Connection %s closed by server", conn)
conn.close()

def eof_received(self):
return False


class AsyncioConnection(Connection):
"""
An experimental implementation of :class:`.Connection` that uses the
``asyncio`` module in the Python standard library for its event loop.
An implementation of :class:`.Connection` that uses the ``asyncio``
module in the Python standard library for its event loop.

Note that it requires ``asyncio`` features that were only introduced in the
3.4 line in 3.4.6, and in the 3.5 line in 3.5.1.
Supports SSL connections via asyncio's native TLS transport, which
avoids the incompatibility between ``ssl.SSLSocket`` and asyncio's
low-level socket methods (``sock_sendall``, ``sock_recv``).
"""

_loop = None
Expand All @@ -88,26 +122,102 @@ class AsyncioConnection(Connection):
def __init__(self, *args, **kwargs):
Connection.__init__(self, *args, **kwargs)
self._background_tasks = set()
self._transport = None
self._protocol = _AsyncioProtocol(self) if self.ssl_context else None
self._using_ssl = bool(self.ssl_context)
self._ssl_ready = asyncio.Event() if self.ssl_context else None

self._connect_socket()
self._socket.setblocking(0)
loop_args = dict()
if sys.version_info[0] == 3 and sys.version_info[1] < 10:
loop_args['loop'] = self._loop
loop_args["loop"] = self._loop
self._write_queue = asyncio.Queue(**loop_args)
self._write_queue_lock = asyncio.Lock(**loop_args)

# see initialize_reactor -- loop is running in a separate thread, so we
# have to use a threadsafe call
self._read_watcher = asyncio.run_coroutine_threadsafe(
self.handle_read(), loop=self._loop
)
if self._using_ssl:
# For SSL: set up asyncio transport/protocol, then start writer
self._read_watcher = asyncio.run_coroutine_threadsafe(
self._setup_ssl_and_run(), loop=self._loop
)
else:
# For non-SSL: use low-level sock_sendall/sock_recv as before
self._read_watcher = asyncio.run_coroutine_threadsafe(
self.handle_read(), loop=self._loop
)
self._write_watcher = asyncio.run_coroutine_threadsafe(
self.handle_write(), loop=self._loop
)
self._send_options_message()

def _connect_socket(self):
"""
Override base class to skip SSL wrapping of the socket.
For SSL connections, the plain TCP socket is connected here, and TLS
is set up later via asyncio's native SSL transport in _setup_ssl_and_run().
"""
sockerr = None
addresses = self._get_socket_addresses()
for af, socktype, proto, _, sockaddr in addresses:
try:
self._socket = self._socket_impl.socket(af, socktype, proto)
# Do NOT wrap with ssl_context here -- asyncio will handle TLS
self._socket.settimeout(self.connect_timeout)
self._initiate_connection(sockaddr)
self._socket.settimeout(None)

local_addr = self._socket.getsockname()
log.debug("Connection %s: '%s' -> '%s'", id(self), local_addr, sockaddr)
sockerr = None
break
except socket.error as err:
if self._socket:
self._socket.close()
self._socket = None
sockerr = err

if sockerr:
raise socket.error(
sockerr.errno,
"Tried connecting to %s. Last error: %s"
% ([a[4] for a in addresses], sockerr.strerror or sockerr),
)

async def _setup_ssl_and_run(self):
"""
Upgrade the plain TCP connection to TLS using asyncio's native SSL
transport, then continuously read data via the protocol callbacks.
"""
try:
ssl_context = self.ssl_context
server_hostname = None
if self.ssl_options:
server_hostname = self.ssl_options.get("server_hostname", None)
if not server_hostname:
# asyncio's create_connection requires server_hostname when
# ssl= is set, even if check_hostname is False
server_hostname = self.endpoint.address

transport, protocol = await self._loop.create_connection(
lambda: self._protocol,
sock=self._socket,
ssl=ssl_context,
server_hostname=server_hostname,
)
self._transport = transport

if self._check_hostname:
self._validate_hostname()

self._ssl_ready.set()
except Exception as exc:
log.debug("SSL setup failed for %s: %s", self, exc)
self.defunct(exc)
# Unblock handle_write so it can observe the defunct state and exit
self._ssl_ready.set()
return

@classmethod
def initialize_reactor(cls):
Expand All @@ -126,8 +236,9 @@ def initialize_reactor(cls):
cls._loop = asyncio.new_event_loop()
# daemonize so the loop will be shut down on interpreter
# shutdown
cls._loop_thread = Thread(target=cls._loop.run_forever,
daemon=True, name="asyncio_thread")
cls._loop_thread = Thread(
target=cls._loop.run_forever, daemon=True, name="asyncio_thread"
)
cls._loop_thread.start()

@classmethod
Expand All @@ -142,17 +253,18 @@ def close(self):

# close from the loop thread to avoid races when removing file
# descriptors
asyncio.run_coroutine_threadsafe(
self._close(), loop=self._loop
)
asyncio.run_coroutine_threadsafe(self._close(), loop=self._loop)

async def _close(self):
log.debug("Closing connection (%s) to %s" % (id(self), self.endpoint))
if self._write_watcher:
self._write_watcher.cancel()
if self._read_watcher:
self._read_watcher.cancel()
if self._socket:
if self._transport:
self._transport.close()
self._transport = None
elif self._socket:
self._loop.remove_writer(self._socket.fileno())
self._loop.remove_reader(self._socket.fileno())
self._socket.close()
Expand All @@ -172,15 +284,12 @@ def push(self, data):
if len(data) > buff_size:
chunks = []
for i in range(0, len(data), buff_size):
chunks.append(data[i:i + buff_size])
chunks.append(data[i : i + buff_size])
else:
chunks = [data]

if self._loop_thread != threading.current_thread():
asyncio.run_coroutine_threadsafe(
self._push_msg(chunks),
loop=self._loop
)
asyncio.run_coroutine_threadsafe(self._push_msg(chunks), loop=self._loop)
else:
# avoid races/hangs by just scheduling this, not using threadsafe
task = self._loop.create_task(self._push_msg(chunks))
Expand All @@ -194,13 +303,22 @@ async def _push_msg(self, chunks):
for chunk in chunks:
self._write_queue.put_nowait(chunk)


async def handle_write(self):
# For SSL connections, wait until the TLS handshake completes
if self._ssl_ready:
await self._ssl_ready.wait()
if self.is_defunct:
return
while True:
try:
next_msg = await self._write_queue.get()
if next_msg:
await self._loop.sock_sendall(self._socket, next_msg)
if self._transport:
# SSL: use asyncio transport (handles TLS transparently)
self._transport.write(next_msg)
else:
# Non-SSL: use low-level socket API
await self._loop.sock_sendall(self._socket, next_msg)
except socket.error as err:
log.debug("Exception in send for %s: %s", self, err)
self.defunct(err)
Expand All @@ -223,8 +341,7 @@ async def handle_read(self):
await asyncio.sleep(0)
continue
except socket.error as err:
log.debug("Exception during socket recv for %s: %s",
self, err)
log.debug("Exception during socket recv for %s: %s", self, err)
self.defunct(err)
return # leave the read loop
except asyncio.CancelledError:
Expand Down
2 changes: 1 addition & 1 deletion docs/scylla-specific.rst
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ New Error Types
session = cluster.connect()
session.execute("""
CREATE KEYSPACE IF NOT EXISTS keyspace1
WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'}
WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '1'}
""")
session.execute("USE keyspace1")
Expand Down
2 changes: 1 addition & 1 deletion examples/concurrent_executions/execute_async_with_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
session = cluster.connect()

session.execute(("CREATE KEYSPACE IF NOT EXISTS examples "
"WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1' }"))
"WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '1' }"))
session.execute("USE examples")
session.execute("CREATE TABLE IF NOT EXISTS tbl_sample_kv (id uuid, value text, PRIMARY KEY (id))")
prepared_insert = session.prepare("INSERT INTO tbl_sample_kv (id, value) VALUES (?, ?)")
Expand Down
2 changes: 1 addition & 1 deletion examples/concurrent_executions/execute_with_threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
session = cluster.connect()

session.execute(("CREATE KEYSPACE IF NOT EXISTS examples "
"WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1' }"))
"WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '1' }"))
session.execute("USE examples")
session.execute("CREATE TABLE IF NOT EXISTS tbl_sample_kv (id uuid, value text, PRIMARY KEY (id))")
prepared_insert = session.prepare("INSERT INTO tbl_sample_kv (id, value) VALUES (?, ?)")
Expand Down
2 changes: 1 addition & 1 deletion examples/example_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def main():
log.info("creating keyspace...")
session.execute("""
CREATE KEYSPACE IF NOT EXISTS %s
WITH replication = { 'class': 'SimpleStrategy', 'replication_factor': '2' }
WITH replication = { 'class': 'NetworkTopologyStrategy', 'replication_factor': '2' }
""" % KEYSPACE)

log.info("setting keyspace...")
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/cqlengine/connections/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ def setUpClass(cls):
super(SeveralConnectionsTest, cls).setUpClass()
cls.setup_cluster = TestCluster()
cls.setup_session = cls.setup_cluster.connect()
ddl = "CREATE KEYSPACE {0} WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': '{1}'}}".format(cls.keyspace1, 1)
ddl = "CREATE KEYSPACE {0} WITH replication = {{'class': 'NetworkTopologyStrategy', 'replication_factor': '{1}'}}".format(cls.keyspace1, 1)
execute_with_long_wait_retry(cls.setup_session, ddl)
ddl = "CREATE KEYSPACE {0} WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': '{1}'}}".format(cls.keyspace2, 1)
ddl = "CREATE KEYSPACE {0} WITH replication = {{'class': 'NetworkTopologyStrategy', 'replication_factor': '{1}'}}".format(cls.keyspace2, 1)
execute_with_long_wait_retry(cls.setup_session, ddl)

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/long/test_failure_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def test_write_failures_from_coordinator(self):
self._perform_cql_statement(
"""
CREATE KEYSPACE testksfail
WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '3'}
WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '3'}
""", consistency_level=ConsistencyLevel.ALL, expected_exception=None)

# create table
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/long/test_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_should_rethrow_on_unvailable_with_default_policy_if_cas(self):
cluster = TestCluster(execution_profiles={EXEC_PROFILE_DEFAULT: ep})
session = cluster.connect()

session.execute("CREATE KEYSPACE test_retry_policy_cas WITH replication = {'class':'SimpleStrategy','replication_factor': 3};")
session.execute("CREATE KEYSPACE test_retry_policy_cas WITH replication = {'class':'NetworkTopologyStrategy','replication_factor': 3};")
session.execute("CREATE TABLE test_retry_policy_cas.t (id int PRIMARY KEY, data text);")
session.execute('INSERT INTO test_retry_policy_cas.t ("id", "data") VALUES (%(0)s, %(1)s)', {'0': 42, '1': 'testing'})

Expand Down
Loading
Loading