Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions cassandra/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,10 @@ def _connection_reduce_fn(val,import_fn):

_NOT_SET = object()

# TLS session cache defaults
_DEFAULT_TLS_SESSION_CACHE_SIZE = 100
_DEFAULT_TLS_SESSION_CACHE_TTL = 3600 # 1 hour in seconds


class NoHostAvailable(Exception):
"""
Expand Down Expand Up @@ -875,6 +879,72 @@ def default_retry_policy(self, policy):
.. versionadded:: 3.17.0
"""

tls_session_cache_enabled = True
"""
Enable or disable TLS session caching for faster reconnections.
When enabled (default), TLS sessions are cached and reused for subsequent
connections to the same endpoint, reducing handshake latency.

Set to False to disable session caching entirely.

.. versionadded:: 3.30.0
"""

tls_session_cache_size = _DEFAULT_TLS_SESSION_CACHE_SIZE
"""
Maximum number of TLS sessions to cache. Default is 100.
When the cache is full, the least recently used session is evicted.

.. versionadded:: 3.30.0
"""

tls_session_cache_ttl = _DEFAULT_TLS_SESSION_CACHE_TTL
"""
Time-to-live for cached TLS sessions in seconds. Default is 3600 (1 hour).
Sessions older than this value will not be reused.

.. versionadded:: 3.30.0
"""

tls_session_cache_options = None
"""
Advanced TLS session cache configuration. Can be set to:

- An instance of :class:`~cassandra.tls.TLSSessionCacheOptions` for
fine-grained control over session caching behavior (e.g., cache_by_host_only option).
- An instance of :class:`~cassandra.tls.TLSSessionCache` (or a custom subclass)
for complete control over session caching implementation.

If None (default), a cache is created using :attr:`~.tls_session_cache_size`
and :attr:`~.tls_session_cache_ttl` when SSL/TLS is enabled.

This option takes precedence over the individual tls_session_cache_* parameters.

Example with options::

from cassandra.tls import TLSSessionCacheOptions

# Cache by host only (ignoring port)
options = TLSSessionCacheOptions(
max_size=200,
ttl=7200,
cache_by_host_only=True
)
cluster = Cluster(ssl_context=ssl_context, tls_session_cache_options=options)

Example with custom cache::

from cassandra.tls import TLSSessionCache

class MyCustomCache(TLSSessionCache):
# Custom implementation
pass

cluster = Cluster(ssl_context=ssl_context, tls_session_cache_options=MyCustomCache())

.. versionadded:: 3.30.0
"""

sockopts = None
"""
An optional list of tuples which will be used as arguments to
Expand Down Expand Up @@ -1204,6 +1274,10 @@ def __init__(self,
idle_heartbeat_timeout=30,
no_compact=False,
ssl_context=None,
tls_session_cache_enabled=True,
tls_session_cache_size=_DEFAULT_TLS_SESSION_CACHE_SIZE,
tls_session_cache_ttl=_DEFAULT_TLS_SESSION_CACHE_TTL,
tls_session_cache_options=None,
endpoint_factory=None,
application_name=None,
application_version=None,
Expand Down Expand Up @@ -1420,6 +1494,33 @@ def __init__(self,

self.ssl_options = ssl_options
self.ssl_context = ssl_context
self.tls_session_cache_enabled = tls_session_cache_enabled
self.tls_session_cache_size = tls_session_cache_size
self.tls_session_cache_ttl = tls_session_cache_ttl
self.tls_session_cache_options = tls_session_cache_options

# Initialize TLS session cache if SSL is enabled and caching is enabled
self._tls_session_cache = None
if (ssl_context or ssl_options) and tls_session_cache_enabled:
from cassandra.tls import TLSSessionCache, TLSSessionCacheOptions

if tls_session_cache_options is not None:
# Check if it's a TLSSessionCache instance (use directly)
# or TLSSessionCacheOptions (use create_cache())
if isinstance(tls_session_cache_options, TLSSessionCache):
self._tls_session_cache = tls_session_cache_options
else:
# Assume it's TLSSessionCacheOptions
self._tls_session_cache = tls_session_cache_options.create_cache()
else:
# Create default cache from individual parameters
cache_options = TLSSessionCacheOptions(
max_size=tls_session_cache_size,
ttl=tls_session_cache_ttl,
cache_by_host_only=False
)
self._tls_session_cache = cache_options.create_cache()

self.sockopts = sockopts
self.cql_version = cql_version
self.max_schema_agreement_wait = max_schema_agreement_wait
Expand Down Expand Up @@ -1661,6 +1762,7 @@ def _make_connection_kwargs(self, endpoint, kwargs_dict):
kwargs_dict.setdefault('sockopts', self.sockopts)
kwargs_dict.setdefault('ssl_options', self.ssl_options)
kwargs_dict.setdefault('ssl_context', self.ssl_context)
kwargs_dict.setdefault('tls_session_cache', self._tls_session_cache)
kwargs_dict.setdefault('cql_version', self.cql_version)
kwargs_dict.setdefault('protocol_version', self.protocol_version)
kwargs_dict.setdefault('user_type_map', self._user_types)
Expand Down
54 changes: 52 additions & 2 deletions cassandra/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,15 @@ def socket_family(self):
"""
return socket.AF_UNSPEC

@property
def tls_session_cache_key(self):
"""
Returns the cache key components for TLS session caching.
This is a tuple that uniquely identifies this endpoint for TLS session purposes.
Subclasses may override this to include additional components (e.g., SNI server name).
"""
return (self.address, self.port)

def resolve(self):
"""
Resolve the endpoint to an address/port. This is called
Expand Down Expand Up @@ -275,6 +284,14 @@ def port(self):
def ssl_options(self):
return self._ssl_options

@property
def tls_session_cache_key(self):
"""
Returns the cache key including server_name for SNI endpoints.
This prevents cache collisions when multiple SNI endpoints use the same proxy.
"""
return (self.address, self.port, self._server_name)

def resolve(self):
try:
resolved_addresses = socket.getaddrinfo(self._proxy_address, self._port,
Expand Down Expand Up @@ -349,6 +366,14 @@ def port(self):
def socket_family(self):
return socket.AF_UNIX

@property
def tls_session_cache_key(self):
"""
Returns the cache key for Unix socket endpoints.
Since Unix sockets don't have a port, only the path is used.
"""
return (self._unix_socket_path,)

def resolve(self):
return self.address, None

Expand Down Expand Up @@ -687,6 +712,7 @@ class Connection(object):
endpoint = None
ssl_options = None
ssl_context = None
tls_session_cache = None
last_error = None

# The current number of operations that are in flight. More precisely,
Expand Down Expand Up @@ -763,14 +789,15 @@ def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
ssl_options=None, sockopts=None, compression: Union[bool, str] = True,
cql_version=None, protocol_version=ProtocolVersion.MAX_SUPPORTED, is_control_connection=False,
user_type_map=None, connect_timeout=None, allow_beta_protocol_version=False, no_compact=False,
ssl_context=None, owning_pool=None, shard_id=None, total_shards=None,
ssl_context=None, tls_session_cache=None, owning_pool=None, shard_id=None, total_shards=None,
on_orphaned_stream_released=None, application_info: Optional[ApplicationInfoBase] = None):
# TODO next major rename host to endpoint and remove port kwarg.
self.endpoint = host if isinstance(host, EndPoint) else DefaultEndPoint(host, port)

self.authenticator = authenticator
self.ssl_options = ssl_options.copy() if ssl_options else {}
self.ssl_context = ssl_context
self.tls_session_cache = tls_session_cache
self.sockopts = sockopts
self.compression = compression
self.cql_version = cql_version
Expand Down Expand Up @@ -913,7 +940,21 @@ def _wrap_socket_from_context(self):
server_hostname = self.endpoint.address
opts['server_hostname'] = server_hostname

return self.ssl_context.wrap_socket(self._socket, **opts)
# Try to get a cached TLS session for resumption
# Note: Session resumption works with both TLS 1.2 and TLS 1.3
# Python's ssl module handles both transparently via SSLSession objects
if self.tls_session_cache:
cached_session = self.tls_session_cache.get_session(self.endpoint)
if cached_session:
opts['session'] = cached_session
log.debug("Using cached TLS session for %s", self.endpoint)

ssl_socket = self.ssl_context.wrap_socket(self._socket, **opts)

# Note: Session is NOT stored here - it will be stored after successful connection
# in _connect_socket() to ensure we only cache sessions for successful connections

return ssl_socket

def _initiate_connection(self, sockaddr):
if self.features.shard_id is not None:
Expand Down Expand Up @@ -968,6 +1009,15 @@ def _connect_socket(self):
# run that here.
if self._check_hostname:
self._validate_hostname()

# Store the TLS session after successful connection
# This ensures we only cache sessions for connections that actually succeeded
if self.tls_session_cache and self.ssl_context and hasattr(self._socket, 'session'):
if self._socket.session:
self.tls_session_cache.set_session(self.endpoint, self._socket.session)
if hasattr(self._socket, 'session_reused') and self._socket.session_reused:
log.debug("TLS session was reused for %s", self.endpoint)

sockerr = None
break
except socket.error as err:
Expand Down
14 changes: 14 additions & 0 deletions cassandra/io/eventletreactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,27 @@ def _wrap_socket_from_context(self):
# This is necessary for SNI
self._socket.set_tlsext_host_name(self.ssl_options['server_hostname'].encode('ascii'))

# Apply cached TLS session for resumption (PyOpenSSL)
if self.tls_session_cache:
cached_session = self.tls_session_cache.get_session(self.endpoint)
if cached_session:
self._socket.set_session(cached_session)
log.debug("Using cached TLS session for %s", self.endpoint)

def _initiate_connection(self, sockaddr):
if self.uses_legacy_ssl_options:
super(EventletConnection, self)._initiate_connection(sockaddr)
else:
self._socket.connect(sockaddr)
if self.ssl_context or self.ssl_options:
self._socket.do_handshake()
# Store TLS session after successful handshake (PyOpenSSL)
if self.tls_session_cache:
session = self._socket.get_session()
if session:
self.tls_session_cache.set_session(self.endpoint, session)
if self._socket.session_reused():
log.debug("TLS session was reused for %s", self.endpoint)

def _match_hostname(self):
if self.uses_legacy_ssl_options:
Expand Down
20 changes: 19 additions & 1 deletion cassandra/io/twistedreactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,12 @@ def _on_loop_timer(self):

@implementer(IOpenSSLClientConnectionCreator)
class _SSLCreator(object):
def __init__(self, endpoint, ssl_context, ssl_options, check_hostname, timeout):
def __init__(self, endpoint, ssl_context, ssl_options, check_hostname, timeout, tls_session_cache=None):
self.endpoint = endpoint
self.ssl_options = ssl_options
self.check_hostname = check_hostname
self.timeout = timeout
self.tls_session_cache = tls_session_cache

if ssl_context:
self.context = ssl_context
Expand Down Expand Up @@ -171,11 +172,27 @@ def info_callback(self, connection, where, ret):
transport = connection.get_app_data()
transport.failVerification(Failure(ConnectionException("Hostname verification failed", self.endpoint)))

# Store TLS session after successful handshake (PyOpenSSL)
if self.tls_session_cache:
session = connection.get_session()
if session:
self.tls_session_cache.set_session(self.endpoint, session)
if connection.session_reused():
log.debug("TLS session was reused for %s", self.endpoint)

def clientConnectionForTLS(self, tlsProtocol):
connection = SSL.Connection(self.context, None)
connection.set_app_data(tlsProtocol)
if self.ssl_options and "server_hostname" in self.ssl_options:
connection.set_tlsext_host_name(self.ssl_options['server_hostname'].encode('ascii'))

# Apply cached TLS session for resumption (PyOpenSSL)
if self.tls_session_cache:
cached_session = self.tls_session_cache.get_session(self.endpoint)
if cached_session:
connection.set_session(cached_session)
log.debug("Using cached TLS session for %s", self.endpoint)

return connection


Expand Down Expand Up @@ -241,6 +258,7 @@ def add_connection(self):
self.ssl_options,
self._check_hostname,
self.connect_timeout,
tls_session_cache=self.tls_session_cache,
)

endpoint = SSL4ClientEndpoint(
Expand Down
Loading
Loading