Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions examples/deepseek/ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,20 +311,40 @@ def calibrate_loop(model):
# disable head that corresponds to lm_head (for the huggingface checkpoint)
mtq_cfg["quant_cfg"]["*head*"] = {"enable": False}

allowed_mla_quant = [None, "per_tensor_fp8"]
allowed_mla_quant = [None, "per_tensor_fp8", "nvfp4"]
assert mla_quant in allowed_mla_quant, f"mla_quant must be {allowed_mla_quant}"

if not mla_quant:
mtq_cfg["quant_cfg"]["*attn*"] = {"enable": False}
elif mla_quant == "per_tensor_fp8":
mtq_cfg["quant_cfg"]["*attn*weight_quantizer"] = {"num_bits": (4, 3), "axis": None}
mtq_cfg["quant_cfg"]["*attn*input_quantizer"] = {"num_bits": (4, 3), "axis": None}
elif mla_quant == "nvfp4": # for DeepSeek-R1-0528-NVFP4-Turbo
mla_linear_layers = ["*wq_a*", "*wq_b*", "*wkv_a*", "*wkv_b*", "*wo*"]
mla_nvfp4_linear_layers = ["*wq_a*", "*wkv_a*", "*wq_b*", "*wo*"]
for layer in mla_linear_layers:
if layer in mla_nvfp4_linear_layers:
# wq_a, wkv_a, wq_b, wo use NVFP4 quantization
mtq_cfg["quant_cfg"][layer + "_quantizer"] = {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
}
else:
mtq_cfg["quant_cfg"][layer + "_quantizer"] = {"enable": False}

# Disable BMM quantizers
mtq_cfg["quant_cfg"]["*attn.kv_bmm_quantizer*"] = {"enable": False}
mtq_cfg["quant_cfg"]["*attn.pe_bmm_quantizer*"] = {"enable": False}

if not args.disable_wo_quant and "FP4" in quant_cfg:
mtq_cfg["quant_cfg"]["*wo*weight_quantizer"] = mtq_cfg["quant_cfg"]["*input_quantizer"]
mtq_cfg["quant_cfg"]["*wo*input_quantizer"] = mtq_cfg["quant_cfg"]["*weight_quantizer"]

## ptq
transformer = mtq.quantize(transformer, mtq_cfg, calibrate_loop)

if int(os.environ["LOCAL_RANK"]) == 0:
mtq.print_quant_summary(transformer)

Expand Down Expand Up @@ -407,11 +427,17 @@ def state_dict_filter(state_dict):
parser.add_argument("--disable_fp8_kvcache", action="store_true", help="disable fp8 kvcache.")
parser.add_argument("--disable_wo_quant", action="store_true", help="disable MLA wo quant.")
parser.add_argument("--trust_remote_code", action="store_true", help="trust remote code.")
parser.add_argument(
"--mla_quant",
type=str,
default=None,
help="MLA quantization type: None (disable), per_tensor_fp8, nvfp4",
)

args = parser.parse_args()
model = load_deepseek_model(args.config, args.model_path, args.batch_size)
tokenizer = AutoTokenizer.from_pretrained(
args.model_path, trust_remote_code=args.trust_remote_code
)
model = ptq(model, tokenizer, args.quant_cfg, args.batch_size, args.calib_size)
model = ptq(model, tokenizer, args.quant_cfg, args.batch_size, args.calib_size, args.mla_quant)
save_amax_and_quant_config(model, args.output_path, not args.disable_fp8_kvcache)