Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 212 additions & 0 deletions benchmarks/test_deserializer_cache_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
# Copyright ScyllaDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Benchmarks for find_deserializer / make_deserializers with and without caching.

Run with: pytest benchmarks/test_deserializer_cache_benchmark.py -v

Requires the ``pytest-benchmark`` plugin and Cython extensions to be built.
Skipped automatically when either dependency is unavailable.
"""

import pytest

pytest.importorskip("pytest_benchmark")
pytest.importorskip("cassandra.deserializers")

from cassandra import cqltypes
from cassandra.deserializers import (
find_deserializer,
make_deserializers,
)
Comment on lines +24 to +33
Comment on lines +29 to +33


# ---------------------------------------------------------------------------
# Reference: original uncached implementations (copied from master)
# ---------------------------------------------------------------------------

_classes = {}


def _init_classes():
"""Lazily initialize the class lookup dict from deserializers module."""
if not _classes:
from cassandra import deserializers as mod

for name in dir(mod):
obj = getattr(mod, name)
if isinstance(obj, type):
_classes[name] = obj


def find_deserializer_uncached(cqltype):
"""Original implementation without caching."""
_init_classes()

name = "Des" + cqltype.__name__
if name in _classes:
cls = _classes[name]
elif issubclass(cqltype, cqltypes.ListType):
from cassandra.deserializers import DesListType

cls = DesListType
elif issubclass(cqltype, cqltypes.SetType):
from cassandra.deserializers import DesSetType

cls = DesSetType
elif issubclass(cqltype, cqltypes.MapType):
from cassandra.deserializers import DesMapType

cls = DesMapType
elif issubclass(cqltype, cqltypes.UserType):
from cassandra.deserializers import DesUserType

cls = DesUserType
elif issubclass(cqltype, cqltypes.TupleType):
from cassandra.deserializers import DesTupleType

cls = DesTupleType
elif issubclass(cqltype, cqltypes.DynamicCompositeType):
from cassandra.deserializers import DesDynamicCompositeType

cls = DesDynamicCompositeType
elif issubclass(cqltype, cqltypes.CompositeType):
from cassandra.deserializers import DesCompositeType

cls = DesCompositeType
elif issubclass(cqltype, cqltypes.ReversedType):
from cassandra.deserializers import DesReversedType

cls = DesReversedType
elif issubclass(cqltype, cqltypes.FrozenType):
from cassandra.deserializers import DesFrozenType

cls = DesFrozenType
else:
from cassandra.deserializers import GenericDeserializer

cls = GenericDeserializer

return cls(cqltype)


def make_deserializers_uncached(ctypes):
"""Original implementation without caching."""
from cassandra.deserializers import obj_array

return obj_array([find_deserializer_uncached(ct) for ct in ctypes])


# ---------------------------------------------------------------------------
# Test type sets
# ---------------------------------------------------------------------------

SIMPLE_TYPES = [
cqltypes.Int32Type,
cqltypes.UTF8Type,
cqltypes.BooleanType,
cqltypes.DoubleType,
cqltypes.LongType,
]

MIXED_TYPES = [
cqltypes.Int32Type,
cqltypes.UTF8Type,
cqltypes.BooleanType,
cqltypes.DoubleType,
cqltypes.LongType,
cqltypes.FloatType,
cqltypes.TimestampType,
cqltypes.UUIDType,
cqltypes.InetAddressType,
cqltypes.DecimalType,
]


# ---------------------------------------------------------------------------
# Correctness tests
# ---------------------------------------------------------------------------


class TestDeserializerCacheCorrectness:
"""Verify the cached implementation returns equivalent deserializers."""

@pytest.mark.parametrize("cqltype", SIMPLE_TYPES + MIXED_TYPES)
def test_find_deserializer_returns_correct_type(self, cqltype):
cached = find_deserializer(cqltype)
uncached = find_deserializer_uncached(cqltype)
assert type(cached).__name__ == type(uncached).__name__

def test_find_deserializer_cache_hit_same_object(self):
d1 = find_deserializer(cqltypes.Int32Type)
d2 = find_deserializer(cqltypes.Int32Type)
assert d1 is d2

def test_make_deserializers_returns_correct_length(self):
result = make_deserializers(SIMPLE_TYPES)
assert len(result) == len(SIMPLE_TYPES)

def test_make_deserializers_cache_hit_same_object(self):
r1 = make_deserializers(SIMPLE_TYPES)
r2 = make_deserializers(SIMPLE_TYPES)
# Should be the exact same cached object
assert r1 is r2


# ---------------------------------------------------------------------------
# Benchmarks
# ---------------------------------------------------------------------------


class TestFindDeserializerBenchmark:
"""Benchmark find_deserializer cached vs uncached."""

# --- Single simple type ---

@pytest.mark.benchmark(group="find_deser_simple")
def test_uncached_simple(self, benchmark):
benchmark(find_deserializer_uncached, cqltypes.Int32Type)

@pytest.mark.benchmark(group="find_deser_simple")
def test_cached_simple(self, benchmark):
# Cache is already warm from correctness tests or previous iterations
find_deserializer(cqltypes.Int32Type) # ensure warm
benchmark(find_deserializer, cqltypes.Int32Type)


class TestMakeDeserializersBenchmark:
"""Benchmark make_deserializers cached vs uncached."""

# --- 5 simple types ---

@pytest.mark.benchmark(group="make_deser_5types")
def test_uncached_5types(self, benchmark):
benchmark(make_deserializers_uncached, SIMPLE_TYPES)

@pytest.mark.benchmark(group="make_deser_5types")
def test_cached_5types(self, benchmark):
make_deserializers(SIMPLE_TYPES) # ensure warm
benchmark(make_deserializers, SIMPLE_TYPES)

# --- 10 mixed types ---

@pytest.mark.benchmark(group="make_deser_10types")
def test_uncached_10types(self, benchmark):
benchmark(make_deserializers_uncached, MIXED_TYPES)

@pytest.mark.benchmark(group="make_deser_10types")
def test_cached_10types(self, benchmark):
make_deserializers(MIXED_TYPES) # ensure warm
benchmark(make_deserializers, MIXED_TYPES)
65 changes: 62 additions & 3 deletions cassandra/deserializers.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -440,16 +440,54 @@ cdef class GenericDeserializer(Deserializer):
#--------------------------------------------------------------------------
# Helper utilities

# Maximum number of entries in each deserializer cache. In practice the
# caches are bounded by the number of distinct column-type signatures in
# the schema (typically dozens to low hundreds), but parameterized types
# created via apply_parameters() for unprepared queries are *not*
# interned, so repeated simple queries could accumulate entries. The cap
# prevents unbounded growth in such edge cases.
cdef int _CACHE_MAX_SIZE = 256

# Cache make_deserializers results keyed on the tuple of cqltype objects.
# Using the cqltype objects themselves (rather than id()) as keys ensures
# the dict holds strong references, preventing GC and id() reuse issues
# with non-singleton parameterized types.
cdef dict _make_deserializers_cache = {}
Comment on lines +451 to +455

def make_deserializers(cqltypes):
"""Create an array of Deserializers for each given cqltype in cqltypes"""
cdef Deserializer[::1] deserializers
return obj_array([find_deserializer(ct) for ct in cqltypes])
cdef tuple key = tuple(cqltypes)
try:
return _make_deserializers_cache[key]
except KeyError:
pass
result = obj_array([find_deserializer(ct) for ct in cqltypes])
if len(_make_deserializers_cache) >= _CACHE_MAX_SIZE:
_make_deserializers_cache.clear()
_make_deserializers_cache[key] = result
return result
Comment on lines +459 to +468


cdef dict classes = globals()

# Cache deserializer instances keyed on the cqltype object itself to avoid
# repeated class lookups and object creation on every result set.
# Using the object as key (rather than id()) holds a strong reference,
# preventing GC and id() reuse issues with parameterized types.
#
# Note: if a Des* class is overridden at runtime (e.g. DesBytesType =
# DesBytesTypeByteArray for cqlsh), callers must invoke
# clear_deserializer_caches() to flush stale entries so that subsequent
# find_deserializer() calls pick up the new class.
cdef dict _deserializer_cache = {}

cpdef Deserializer find_deserializer(cqltype):
"""Find a deserializer for a cqltype"""
try:
return <Deserializer>_deserializer_cache[cqltype]
except KeyError:
pass
Comment on lines +486 to +489

name = 'Des' + cqltype.__name__

if name in globals():
Expand Down Expand Up @@ -477,7 +515,28 @@ cpdef Deserializer find_deserializer(cqltype):
else:
cls = GenericDeserializer

return cls(cqltype)
cdef Deserializer result = cls(cqltype)
if len(_deserializer_cache) >= _CACHE_MAX_SIZE:
_deserializer_cache.clear()
_deserializer_cache[cqltype] = result
return result


def clear_deserializer_caches():
"""Clear the find_deserializer and make_deserializers caches.

Call this after overriding a Des* class at runtime (e.g.
``deserializers.DesBytesType = deserializers.DesBytesTypeByteArray``)
so that subsequent lookups pick up the new class instead of returning
stale cached instances.
"""
_deserializer_cache.clear()
_make_deserializers_cache.clear()


def get_deserializer_cache_sizes():
"""Return ``(find_cache_size, make_cache_size)`` for diagnostic use."""
return len(_deserializer_cache), len(_make_deserializers_cache)


def obj_array(list objs):
Expand Down
Loading
Loading