Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 75 additions & 13 deletions cassandra/cqltypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@
varint_pack, varint_unpack, point_be, point_le,
vints_pack, vints_unpack, uvint_unpack, uvint_pack)
from cassandra import util
from cassandra.cython_deps import HAVE_NUMPY

if HAVE_NUMPY:
import numpy as np

Comment on lines +56 to 57
_little_endian_flag = 1 # we always serialize LE
import ipaddress
Expand Down Expand Up @@ -1430,6 +1434,9 @@ class VectorType(_CassandraType):
typename = 'org.apache.cassandra.db.marshal.VectorType'
vector_size = 0
subtype = None
_vector_struct = None # Cached struct.Struct for bulk deserialization
_struct_format_map = {} # Populated after FloatType etc. are defined
_numpy_dtype = None # Cached numpy dtype string for large vector deserialization

@classmethod
def serial_size(cls):
Expand All @@ -1441,7 +1448,16 @@ def apply_parameters(cls, params, names):
assert len(params) == 2
subtype = lookup_casstype(params[0])
vsize = params[1]
return type('%s(%s)' % (cls.cass_parameterized_type_with([]), vsize), (cls,), {'vector_size': vsize, 'subtype': subtype})
# Cache a struct.Struct for bulk deserialization of known numeric types
vector_struct = None
numpy_dtype = None
for base_type, fmt_char in cls._struct_format_map.items():
if subtype is base_type or (isinstance(subtype, type) and issubclass(subtype, base_type)):
vector_struct = struct.Struct(f'>{vsize}{fmt_char}')
numpy_dtype = cls._numpy_dtype_map.get(fmt_char)
break
return type('%s(%s)' % (cls.cass_parameterized_type_with([]), vsize), (cls,),
{'vector_size': vsize, 'subtype': subtype, '_vector_struct': vector_struct, '_numpy_dtype': numpy_dtype})

@classmethod
def deserialize(cls, byts, protocol_version):
Expand All @@ -1452,25 +1468,55 @@ def deserialize(cls, byts, protocol_version):
raise ValueError(
"Expected vector of type {0} and dimension {1} to have serialized size {2}; observed serialized size of {3} instead"\
.format(cls.subtype.typename, cls.vector_size, expected_byte_size, len(byts)))
indexes = (serialized_size * x for x in range(0, cls.vector_size))
return [cls.subtype.deserialize(byts[idx:idx + serialized_size], protocol_version) for idx in indexes]

# Optimization: bulk deserialization for common numeric types
# For small vectors: use cached struct.Struct (avoids per-call format string allocation)
# For large vectors with numpy: use numpy.frombuffer (1.3-1.5x faster for 128+ elements)
# Threshold at 32 elements balances simplicity with performance
if cls._vector_struct is not None:
if HAVE_NUMPY and cls.vector_size >= 32 and cls._numpy_dtype is not None:
return np.frombuffer(byts, dtype=cls._numpy_dtype, count=cls.vector_size).tolist()
return list(cls._vector_struct.unpack(byts))
Comment on lines +1476 to +1479
# Fallback: element-by-element deserialization for other fixed-size types
result = [None] * cls.vector_size
subtype_deserialize = cls.subtype.deserialize
offset = 0
for i in range(cls.vector_size):
result[i] = subtype_deserialize(byts[offset:offset + serialized_size], protocol_version)
offset += serialized_size
return result

# Variable-size subtype path
result = [None] * cls.vector_size
idx = 0
rv = []
while (len(rv) < cls.vector_size):
byts_len = len(byts)
subtype_deserialize = cls.subtype.deserialize

for i in range(cls.vector_size):
if idx >= byts_len:
raise ValueError("Error reading additional data during vector deserialization after successfully adding {} elements"\
.format(i))

try:
size, bytes_read = uvint_unpack(byts[idx:])
idx += bytes_read
rv.append(cls.subtype.deserialize(byts[idx:idx + size], protocol_version))
idx += size
except:
except (IndexError, KeyError):
raise ValueError("Error reading additional data during vector deserialization after successfully adding {} elements"\
.format(len(rv)))
.format(i))
Comment on lines 1500 to +1504

idx += bytes_read

# If we have any additional data in the serialized vector treat that as an error as well
if idx < len(byts):
if idx + size > byts_len:
raise ValueError("Error reading additional data during vector deserialization after successfully adding {} elements"\
.format(i))

result[i] = subtype_deserialize(byts[idx:idx + size], protocol_version)
idx += size

# Check for additional data
if idx < byts_len:
raise ValueError("Additional bytes remaining after vector deserialization completed")
return rv

return result

@classmethod
def serialize(cls, v, protocol_version):
Expand All @@ -1481,6 +1527,9 @@ def serialize(cls, v, protocol_version):
.format(cls.vector_size, cls.subtype.typename, v_length))

serialized_size = cls.subtype.serial_size()
# Bulk serialization for known numeric types (symmetric with struct.unpack in deserialize)
if cls._vector_struct is not None and serialized_size is not None:
return cls._vector_struct.pack(*v)
buf = io.BytesIO()
for item in v:
item_bytes = cls.subtype.serialize(item, protocol_version)
Expand All @@ -1492,3 +1541,16 @@ def serialize(cls, v, protocol_version):
@classmethod
def cql_parameterized_type(cls):
return "%s<%s, %s>" % (cls.typename, cls.subtype.cql_parameterized_type(), cls.vector_size)


# Populate VectorType._struct_format_map now that all types are defined
VectorType._struct_format_map = {
FloatType: 'f',
DoubleType: 'd',
Int32Type: 'i',
LongType: 'q',
ShortType: 'h',
}

# Map struct format chars to numpy dtype strings for large vector deserialization
VectorType._numpy_dtype_map = {'f': '>f4', 'd': '>f8', 'i': '>i4', 'q': '>i8', 'h': '>i2'}
Loading