Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions fastdeploy/cache_manager/v1/cache_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,17 +282,21 @@ def initialize_kv_cache(
logger.info(f"Initializing kv cache for all layers. num_layers={self._num_layers}")
cache_kvs_list = []

# Quantized KV cache (int8/fp8/etc.) uses uint8 storage (1 byte per element).
# Non-quantized cache uses the model's compute dtype (e.g., bfloat16).
cache_dtype = "uint8" if kv_cache_quant_type is not None else self.model_config.dtype

for i in range(self._num_layers):
# Generate cache names
cache_names = self._get_cache_names(i)

logger.info(f"..creating kv cache for layer {i}: key:{key_cache_shape}, value:{value_cache_shape}")

# Create key cache and value cache
key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=self.model_config.dtype)
key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=cache_dtype)
self.cache_kvs_map[cache_names["key"]] = key_cache

val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=self.model_config.dtype)
val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=cache_dtype)
self.cache_kvs_map[cache_names["value"]] = val_cache
cache_kvs_list.extend([key_cache, val_cache])

Expand Down Expand Up @@ -360,13 +364,16 @@ def initialize_mtp_kv_cache(
)
cache_kvs_list = []

# Quantized KV cache uses uint8 storage; non-quantized uses model compute dtype.
cache_dtype = "uint8" if kv_cache_quant_type is not None else self.model_config.dtype

for i in range(layer_offset, layer_offset + num_mtp_layers):
cache_names = self._get_cache_names(i)

key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=self.model_config.dtype)
key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=cache_dtype)
self.cache_kvs_map[cache_names["key"]] = key_cache

val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=self.model_config.dtype)
val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=cache_dtype)
self.cache_kvs_map[cache_names["value"]] = val_cache
cache_kvs_list.extend([key_cache, val_cache])

Expand Down
Loading