Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/CN/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ Lightllm 整合了众多的开源方案的优点,包括但不限于 FasterTran
:caption: 部署教程

DeepSeek R1 部署 <tutorial/deepseek_deployment>
FP8 KV 量化与校准 <tutorial/fp8_kv_quantization>
多级缓存部署 <tutorial/multi_level_cache_deployment>
多模态部署 <tutorial/multimodal>
奖励模型部署 <tutorial/reward_model>
Expand Down
98 changes: 98 additions & 0 deletions docs/CN/source/tutorial/fp8_kv_quantization.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
.. _tutorial/fp8_kv_quantization_cn:

FP8 KV 量化与校准指南
======================

本章节介绍 LightLLM 中 FP8 KV 推理的使用方式,包括:

- 使用校准文件进行推理(``fp8kv``)
- FA3 与 FlashInfer 后端下的量化粒度差异
- 常见报错与排查建议

功能概览
--------

LightLLM 的 FP8 KV 推理需要准备好的校准文件(``kv_cache_calib.json``),
并通过 ``--kv_quant_calibration_config_path`` 加载。
你可以直接使用 ``test/advanced_config/`` 目录下已有的校准文件,
也可以使用 `LightCompress <https://github.com/ModelTC/LightCompress>`_ 工具导出,或使用自有兼容文件。

后端与量化粒度
--------------

当前行为如下:

- ``fa3``: 使用 ``per_head``(每个 head 独立 scale)
- ``flashinfer``: 使用 ``per_tensor``(K/V 各一个标量 scale)

因此,校准文件与后端强相关:

- ``fa3`` 对应 ``per_head`` 校准文件,应配合 ``fa3`` 推理。
- ``flashinfer`` 对应 ``per_tensor`` 校准文件,应配合 ``flashinfer`` 推理。

不建议混用不同后端的校准文件。

使用校准文件启动 FP8 推理
-------------------------

推理模式示例(FA3):

.. code-block:: console

$ python -m lightllm.server.api_server \
--model_dir /path/to/model \
--llm_kv_type fp8kv \
--llm_prefill_att_backend fa3 \
--llm_decode_att_backend fa3 \
--kv_quant_calibration_config_path /path/to/kv_cache_calib.json

推理模式示例(FlashInfer):

.. code-block:: console

$ python -m lightllm.server.api_server \
--model_dir /path/to/model \
--llm_kv_type fp8kv \
--llm_prefill_att_backend flashinfer \
--llm_decode_att_backend flashinfer \
--kv_quant_calibration_config_path /path/to/kv_cache_calib.json

说明:

- ``fp8kv`` 模式必须提供 ``--kv_quant_calibration_config_path``。
- 建议推理时的 attention backend 与校准文件要求保持一致。

校准文件格式
------------

``kv_cache_calib.json`` 主要字段包括:

- ``quant_type``: ``per_head`` 或 ``per_tensor``
- ``num_layers``: 层数
- ``num_head``: 总 head 数
- ``scales_shape``: scale 张量形状
- ``scales``: 实际 scale 数值
- ``qmin`` / ``qmax``: FP8 范围参数

加载校准文件时,会校验模型架构、层数、head 数及量化类型是否匹配。

多卡说明
--------

在多卡(TP)场景下,系统会根据当前 rank 自动切分本地需要的 head 对应 scale。
你仍然只需要提供一份全量 ``kv_cache_calib.json``。

常见问题
--------

1. 启动时报错需要 ``--kv_quant_calibration_config_path``

说明你使用了 ``--llm_kv_type fp8kv`` 但未传入校准文件路径。

2. 报错 ``quant_type not match``

通常是后端与校准文件类型不一致。例如拿 ``per_head`` 文件去跑 ``flashinfer``。

3. 切换后端后效果异常

建议使用与目标后端匹配的校准文件,不要跨后端复用不兼容文件。
1 change: 1 addition & 0 deletions docs/EN/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ Documentation List
:caption: Deployment Tutorials

DeepSeek R1 Deployment <tutorial/deepseek_deployment>
FP8 KV Quantization and Calibration <tutorial/fp8_kv_quantization>
Multi-Level Cache Deployment <tutorial/multi_level_cache_deployment>
Multimodal Deployment <tutorial/multimodal>
Reward Model Deployment <tutorial/reward_model>
Expand Down
98 changes: 98 additions & 0 deletions docs/EN/source/tutorial/fp8_kv_quantization.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
.. _tutorial/fp8_kv_quantization_en:

FP8 KV Quantization and Calibration Guide
=========================================

This chapter describes FP8 KV inference in LightLLM, including:

- Running inference with calibration data (``fp8kv``)
- Quantization granularity differences between FA3 and FlashInfer
- Common errors and troubleshooting

Overview
--------

LightLLM FP8 KV inference requires a prepared calibration file (``kv_cache_calib.json``),
which is loaded by ``--kv_quant_calibration_config_path``.
You can use calibration files provided in ``test/advanced_config/``,
export one with `LightCompress <https://github.com/ModelTC/LightCompress>`_, or use your own compatible file.

Backend and Quantization Granularity
------------------------------------

Current behavior:

- ``fa3``: ``per_head`` scales (independent scale per head)
- ``flashinfer``: ``per_tensor`` scales (one scalar for K and one scalar for V)

Calibration files are backend-dependent:

- ``per_head`` files for ``fa3`` should be used with ``fa3`` inference.
- ``per_tensor`` files for ``flashinfer`` should be used with ``flashinfer`` inference.

Avoid mixing calibration files across different backends.

Start FP8 Inference with Calibration
------------------------------------

Inference mode example (FA3):

.. code-block:: console

$ python -m lightllm.server.api_server \
--model_dir /path/to/model \
--llm_kv_type fp8kv \
--llm_prefill_att_backend fa3 \
--llm_decode_att_backend fa3 \
--kv_quant_calibration_config_path /path/to/kv_cache_calib.json

Inference mode example (FlashInfer):

.. code-block:: console

$ python -m lightllm.server.api_server \
--model_dir /path/to/model \
--llm_kv_type fp8kv \
--llm_prefill_att_backend flashinfer \
--llm_decode_att_backend flashinfer \
--kv_quant_calibration_config_path /path/to/kv_cache_calib.json

Notes:

- ``fp8kv`` requires ``--kv_quant_calibration_config_path``.
- Keep the inference backend consistent with the backend expected by the calibration file.

Calibration File Schema
-----------------------

Key fields in ``kv_cache_calib.json``:

- ``quant_type``: ``per_head`` or ``per_tensor``
- ``num_layers``: number of layers
- ``num_head``: total number of heads
- ``scales_shape``: shape of the scale tensor
- ``scales``: actual scale values
- ``qmin`` / ``qmax``: FP8 numeric range parameters

At load time, LightLLM validates architecture, layer count, head count, and quantization type.

Multi-GPU Note
--------------

In multi-GPU (TP) setups, LightLLM slices the global scales to local rank heads automatically.
You only need to provide one full ``kv_cache_calib.json`` file.

Common Issues
-------------

1. Error says ``--kv_quant_calibration_config_path`` is required

You are using ``--llm_kv_type fp8kv`` without a calibration file path.

2. ``quant_type not match`` error

Usually caused by backend/file mismatch (for example, using a ``per_head`` file with ``flashinfer``).

3. Abnormal quality after backend switch

Use a calibration file that matches the target backend instead of reusing an incompatible file.
6 changes: 6 additions & 0 deletions lightllm/common/basemodel/attention/create_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@
# "fa3": Fp8Fa3AttBackend,
# "flashinfer": Fp8FlashInferAttBackend,
},
"fp8kv_sph": {
"fa3": Fp8Fa3AttBackend,
},
"fp8kv_spt": {
"flashinfer": Fp8FlashInferAttBackend,
},
}

mla_data_type_to_backend = {
Expand Down
58 changes: 16 additions & 42 deletions lightllm/common/basemodel/attention/fa3/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,24 +45,9 @@ def init_state(self):
torch.arange(batch_size, device=device), self.infer_state.b_q_seq_len
)
# 为了减少推理计算量,在推理外部初始化k_descale和v_descale
self.k_descale = (
offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
if offline_scales is not None
else torch.ones(
(mem_manager.layer_num, batch_size, head_num),
dtype=torch.float32,
device=device,
)
)
self.v_descale = (
offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
if offline_scales is not None
else torch.ones(
(mem_manager.layer_num, batch_size, head_num),
dtype=torch.float32,
device=device,
)
)
self.k_descale = offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
self.v_descale = offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)


def prefill_att(
self,
Expand All @@ -89,19 +74,21 @@ def _fp8_prefill_att(
) -> torch.Tensor:
self.backend: Fp8Fa3AttBackend = self.backend # for typing

q_head_num = q.shape[1]
q_head_dim = q.shape[2]
k_head_num = k.shape[1]
q, q_scale = q_per_head_fp8_quant(
q,
q.reshape(q.shape[0], k_head_num, -1),
self.infer_state.b_seq_len,
self.cu_seqlens_q,
self.mid_token_batch_ids,
token_batch_ids=self.mid_token_batch_ids,
)
k_head_num = k.shape[1]
k_head_dim = k.shape[2]
cache_k = k.view(-1, 1, k_head_num, k_head_dim).view(torch.float8_e4m3fn)
cache_v = v.view(-1, 1, k_head_num, k_head_dim).view(torch.float8_e4m3fn)
layer_index = self.backend._find_layer_index(k=cache_k, v=cache_v, att_state=self)
o = flash_attn_with_kvcache(
q=q,
q=q.reshape(-1, q_head_num, q_head_dim),
k_cache=cache_k,
v_cache=cache_v,
page_table=self.page_table,
Expand Down Expand Up @@ -141,24 +128,9 @@ def init_state(self):
head_num = mem_manager.head_num

# 为了减少推理计算量,在推理外部初始化k_descale和v_descale
self.k_descale = (
offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
if offline_scales is not None
else torch.ones(
(mem_manager.layer_num, batch_size, head_num),
dtype=torch.float32,
device=device,
)
)
self.v_descale = (
offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
if offline_scales is not None
else torch.ones(
(mem_manager.layer_num, batch_size, head_num),
dtype=torch.float32,
device=device,
)
)
self.k_descale = offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
self.v_descale = offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)

return

def copy_for_decode_cuda_graph(self, new_state: "Fp8Fa3DecodeAttState"):
Expand Down Expand Up @@ -200,9 +172,11 @@ def _fp8_decode_att(
layer_index = self.backend._find_layer_index(k=cache_k, v=cache_v, att_state=self)

q_head_num = q.shape[1]
q, q_scale = scaled_fp8_quant(q.view(q.shape[0] * k_head_num, -1), use_per_token_if_dynamic=True)
if scaled_fp8_quant is None:
raise ImportError("scaled_fp8_quant is unavailable. Please install vllm to enable FP8 decode attention.")
q, q_scale = scaled_fp8_quant(q.reshape(q.shape[0] * k_head_num, -1), use_per_token_if_dynamic=True)
o = flash_attn_with_kvcache(
q=q.view(-1, q_head_num, k_head_dim),
q=q.reshape(-1, q_head_num, k_head_dim),
k_cache=cache_k,
v_cache=cache_v,
page_table=self.page_table,
Expand Down
18 changes: 8 additions & 10 deletions lightllm/common/basemodel/attention/flashinfer/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ def create_att_decode_state(self, infer_state) -> "Fp8FlashInferDecodeAttState":

@dataclasses.dataclass
class Fp8FlashInferPrefillAttState(FlashInferPrefillAttState):
offline_scales: torch.Tensor = None
scales: torch.Tensor = None

def init_state(self):
super().init_state()
self.offline_scales = self.infer_state.mem_manager.scales_list
self.scales = self.infer_state.mem_manager.scales

def prefill_att(
self,
Expand Down Expand Up @@ -53,9 +53,8 @@ def _fp8_prefill_att(
k = k.unsqueeze(1).view(torch.float8_e4m3fn)
v = v.unsqueeze(1).view(torch.float8_e4m3fn)
layer_index = self.backend._find_layer_index(k=k, v=v, att_state=self)
offline_scales = self.offline_scales
k_descale = offline_scales[layer_index][0] if offline_scales is not None else None
v_descale = offline_scales[layer_index][1] if offline_scales is not None else None
k_descale = self.scales[layer_index][0]
v_descale = self.scales[layer_index][1]
self.prefill_wrapper.run(
q,
(k, v),
Expand All @@ -68,11 +67,11 @@ def _fp8_prefill_att(

@dataclasses.dataclass
class Fp8FlashInferDecodeAttState(FlashInferDecodeAttState):
offline_scales: torch.Tensor = None
scales: torch.Tensor = None

def init_state(self):
super().init_state()
self.offline_scales = self.infer_state.mem_manager.scales_list
self.scales = self.infer_state.mem_manager.scales

def copy_for_decode_cuda_graph(self, new_state):
return super().copy_for_decode_cuda_graph(new_state)
Expand Down Expand Up @@ -108,11 +107,10 @@ def _fp8_decode_att(

k = k.unsqueeze(1).view(torch.float8_e4m3fn)
v = v.unsqueeze(1).view(torch.float8_e4m3fn)
offline_scales = self.offline_scales
layer_index = self.backend._find_layer_index(k=k, v=v, att_state=self)

k_descale = offline_scales[layer_index][0] if offline_scales is not None else None
v_descale = offline_scales[layer_index][1] if offline_scales is not None else None
k_descale = self.scales[layer_index][0]
v_descale = self.scales[layer_index][1]
self.decode_wrapper.run(
q,
(k, v),
Expand Down
6 changes: 4 additions & 2 deletions lightllm/common/kv_cache_mem_manager/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
from .mem_manager import MemoryManager, ReadOnlyStaticsMemoryManager
from .calibration_fp8kv_mem_manager import CalibrationFP8KVMemoryManager
from .export_calibration_mem_manager import ExportCalibrationMemoryManager
from .ppl_int8kv_mem_manager import PPLINT8KVMemoryManager
from .ppl_int4kv_mem_manager import PPLINT4KVMemoryManager
from .deepseek2_mem_manager import Deepseek2MemoryManager
from .deepseek3_2mem_manager import Deepseek3_2MemoryManager
from .fp8_static_per_head_quant_mem_manager import FP8StaticPerHeadQuantMemManager
from .fp8_static_per_tensor_quant_mem_manager import FP8StaticPerTensorQuantMemManager

__all__ = [
"MemoryManager",
"ReadOnlyStaticsMemoryManager",
"CalibrationFP8KVMemoryManager",
"ExportCalibrationMemoryManager",
"PPLINT4KVMemoryManager",
"PPLINT8KVMemoryManager",
"Deepseek2MemoryManager",
"Deepseek3_2MemoryManager",
"FP8StaticPerHeadQuantMemManager",
"FP8StaticPerTensorQuantMemManager",
]
Loading
Loading