Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions lmdeploy/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,12 @@ def model_format(parser, default: str = None):
return parser.add_argument('--model-format',
type=str,
default=default,
choices=['hf', 'awq', 'gptq', 'fp8', 'mxfp4'],
choices=['hf', 'awq', 'gptq', 'compressed-tensors', 'fp8', 'mxfp4'],
help='The format of input model. `hf` means `hf_llama`, '
'`awq` represents the quantized model by AWQ,'
' and `gptq` refers to the quantized model by GPTQ')
'`awq` and `gptq` refer to 4-bit grouped quantization, '
'`compressed-tensors` refers to pack-quantized grouped int4 checkpoints and is '
'usually auto-detected from the model config, `fp8` refers to blocked fp8 '
'checkpoints, and `mxfp4` refers to MXFP4 expert weights.')

@staticmethod
def revision(parser, default: str = None):
Expand Down
13 changes: 9 additions & 4 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,15 @@ class TurbomindEngineConfig:
The `auto` option will use FP16 precision for FP32 and FP16
models, and BF16 precision for BF16 models.
model_format: the layout of the deployed model. It can be one
of the following values [hf, awq, gptq],`hf` meaning
huggingface model(.bin, .safetensors), `awq` and `gptq` meaning
the quantized model by AWQ and GPTQ, respectively. If it is not
specified, i.e. None, it will be extracted from the input model
of the following values [hf, awq, gptq, compressed-tensors,
fp8, mxfp4]. `hf` means a Hugging Face model (.bin,
.safetensors), `awq` and `gptq` mean grouped 4-bit
weight-only checkpoints, `compressed-tensors` means
pack-quantized grouped int4 checkpoints and is usually
auto-detected from the input model config, `fp8` means
blocked fp8 checkpoints, and `mxfp4` means MXFP4 expert
weights. If it is not specified, i.e. None, it will be
extracted from the input model
tp: the number of GPU cards used in tensor parallelism,
default to 1
session_len: the max session length of a sequence, default to
Expand Down
59 changes: 44 additions & 15 deletions lmdeploy/turbomind/deploy/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,43 @@
from .source_model.base import INPUT_MODELS
from .target_model.base import OUTPUT_MODELS, BaseOutputModel

SUPPORTED_FORMATS = ['hf', 'awq', 'gptq', 'fp8', None]
SUPPORTED_FORMATS = ['hf', 'awq', 'gptq', 'compressed-tensors', 'fp8', 'mxfp4', None]
logger = get_logger('lmdeploy')

_DEFAULT_GROUP_SIZES = {
'awq': 128,
'gptq': 128,
'compressed-tensors': 128,
'fp8': 128,
'mxfp4': 32,
}

_SUPPORTED_GROUP_SIZES = {
'awq': frozenset({128}),
'gptq': frozenset({128}),
'compressed-tensors': frozenset({32, 128}),
'fp8': frozenset({128}),
'mxfp4': frozenset({32}),
}


def _validate_quant_group_size(model_format: str | None, group_size: int | None) -> int | None:
"""Normalize and validate quantized group sizes.

The low-level int4 kernels can be shared across formats, but we only expose the format/group-size combinations that
are verified end to end.
"""
if group_size in (None, 0):
group_size = _DEFAULT_GROUP_SIZES.get(model_format, group_size)

supported_group_sizes = _SUPPORTED_GROUP_SIZES.get(model_format)
if supported_group_sizes is not None and group_size not in supported_group_sizes:
supported = ', '.join(map(str, sorted(supported_group_sizes)))
raise ValueError(f'Unsupported group_size={group_size} for model_format="{model_format}". '
f'Supported group_size values: {supported}.')

return group_size


def get_input_model_registered_name(model_path: str, model_format: str):
"""Get the registered name of a model. The name will be used to access the
Expand All @@ -24,7 +58,7 @@ def get_input_model_registered_name(model_path: str, model_format: str):
Args:
model_path (str): the path of the input model
model_format (str): the format of the model, which can be one of
['hf', 'awq', 'gptq']
['hf', 'awq', 'gptq', 'compressed-tensors', 'fp8', 'mxfp4']
"""
arch = get_model_arch(model_path)[0]
register_name = SUPPORTED_ARCHS[arch]
Expand All @@ -39,9 +73,9 @@ def get_output_model_registered_name_and_config(model_path: str, model_format: s
Args:
model_path (str): the path of the input model
model_format (str): the format of the model, which can be one of
['hf', 'awq', 'gptq']
['hf', 'awq', 'gptq', 'compressed-tensors', 'fp8', 'mxfp4']
dtype (str): the data type of the model's weights and activations
group_size (int): the size of group used by awq model
group_size (int): the quantization group size used by grouped formats
"""
register_name = 'tm'

Expand Down Expand Up @@ -74,18 +108,19 @@ def get_output_model_registered_name_and_config(model_path: str, model_format: s

session_len = _get_and_verify_max_len(model_config, None)

group_size = _validate_quant_group_size(model_format, group_size)

if model_format in ['awq', 'gptq', 'compressed-tensors']:
weight_type = 'int4'
dtype = 'float16' # force float16 for int4 quantized weights
group_size = 128 if group_size == 0 else group_size
if model_format == 'compressed-tensors':
# TurboMind reuses the AWQ int4 export path for pack-quantized
# compressed-tensors weights after the format-specific checks above.
model_format = 'awq'
elif model_format == 'fp8':
weight_type = 'fp8'
group_size = 128
elif model_format == 'mxfp4':
weight_type = 'e2m1'
group_size = 32

expert_weight_type = weight_type

Expand Down Expand Up @@ -165,7 +200,7 @@ def get_tm_model(model_path,
the input model
engine_config(TurbomindEngineConfig): user input engine config
group_size(int): refers to the group_size if the input model
is a w4a16(awq or gptq) quantized model
is a grouped quantized model
out_dir(str): the output directory where to save to turbomind model.
If it is None, the turbomind model won't be saved
"""
Expand Down Expand Up @@ -210,13 +245,7 @@ def get_tm_model(model_path,
engine_config.model_format = quant_method
group_size = _group_size

if engine_config.model_format in ['awq', 'gptq', 'compressed-tensors']:
# Compatible to awq models that are quantized by lmdeploy (<=v0.3.0)
if not group_size:
group_size = 128
assert group_size == 128, (f'model format is "{engine_config.model_format}" '
f'but group_size is {group_size}. Currently, only 128 '
'is supported')
group_size = _validate_quant_group_size(engine_config.model_format, group_size)

input_model_name = get_input_model_registered_name(model_path, engine_config.model_format)

Expand Down
83 changes: 49 additions & 34 deletions lmdeploy/turbomind/deploy/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,9 @@ def pack_u4_row(x: torch.Tensor) -> torch.Tensor:
return a.squeeze(dim=-1)


def generate_zero_point(g):
weight_shapes = g('weight_shape')
result = []
for weight_shape in weight_shapes:
row, col = weight_shape
tensor = torch.full((row, col // 128), 8, dtype=torch.uint8)
result.append(tensor)
return (*result, )
def generate_zero_point(scales):
"""Synthesize symmetric int4 zero-points from exported scale shapes."""
return tuple(torch.full(s.shape, 8, dtype=torch.uint8) for s in scales)


class Parameter:
Expand All @@ -61,12 +56,51 @@ def __call__(cls, f, g, i):


class QuantWeightOnly(Parameter):
KEYS = '.qweight', '.scales', '.qzeros'
AWQ_KEYS = '.qweight', '.scales', '.qzeros'
COMPRESSED_KEYS = '.weight_packed', '.weight_scale', '.weight_zero_point'
KEYS = AWQ_KEYS + COMPRESSED_KEYS

@classmethod
def take(cls, keys: list[str]):
if any(k.endswith(cls.AWQ_KEYS[0]) for k in keys):
suffixes = cls.AWQ_KEYS
elif any(k.endswith(cls.COMPRESSED_KEYS[0]) for k in keys):
suffixes = cls.COMPRESSED_KEYS
else:
return False

xs = []
for k in keys:
if any(k.endswith(p) for p in suffixes):
xs.append(k)
for x in xs:
keys.remove(x)
return xs

def __init__(self, xs):
self.compressed_tensors = any(key.endswith(self.COMPRESSED_KEYS[0]) for key in xs)
self.has_zero_point = any(key.endswith(self.COMPRESSED_KEYS[2]) for key in xs)

def _get(self, g, kind: str):
if not self.compressed_tensors:
return g(kind)

mapping = {
'qweight': 'weight_packed',
'scales': 'weight_scale',
'qzeros': 'weight_zero_point',
}
return g(mapping[kind])

def __call__(self, f, g, i):
f(i, g('qweight'), 'qweight', pack_u4_row)
f(i, g('scales'), 'scales', to_half, apply_gs=['w2'])
f(i, g('qzeros'), 'zeros', to_half, apply_gs=['w2'])
f(i, self._get(g, 'qweight'), 'qweight', pack_u4_row)
scales = self._get(g, 'scales')
f(i, scales, 'scales', to_half, apply_gs=['w2'])
if self.compressed_tensors and not self.has_zero_point:
zeros = generate_zero_point(scales)
else:
zeros = self._get(g, 'qzeros')
f(i, zeros, 'zeros', to_half, apply_gs=['w2'])


class WeightScaleInv(Parameter):
Expand All @@ -78,23 +112,6 @@ def __call__(self, f, g, i):
f(i, g('weight'), 'weight', identity)


class CompressedWeight(Parameter):
KEYS = '.weight_packed', '.weight_scale', '.weight_zero_point'

def __init__(self, xs):
self.has_zero_point = False
if any(key.endswith(self.KEYS[2]) for key in xs):
self.has_zero_point = True

def __call__(self, f, g, i):
f(i, g('weight_packed'), 'qweight', pack_u4_row)
f(i, g('weight_scale'), 'scales', to_half, apply_gs=['w2'])
if self.has_zero_point:
f(i, g('weight_zero_point'), 'zeros', to_half, apply_gs=['w2'])
else:
f(i, generate_zero_point(g), 'zeros', to_half, apply_gs=['w2'])


class Mxfp4Weight(Parameter):
KEYS = '.blocks', '.scales'

Expand Down Expand Up @@ -129,13 +146,11 @@ def get_params(keys: list[str], bias=0):
ps = []
if PLora.take(keys):
ps.append(PLora())
if QuantWeightOnly.take(keys):
ps.append(QuantWeightOnly())
xs = QuantWeightOnly.take(keys)
if xs:
ps.append(QuantWeightOnly(xs))
if WeightScaleInv.take(keys):
ps.append(WeightScaleInv())
xs = CompressedWeight.take(keys)
if xs:
ps.append(CompressedWeight(xs))
if Mxfp4Weight.take(keys):
ps.append(Mxfp4Weight())
if Weight.take(keys):
Expand Down
32 changes: 32 additions & 0 deletions lmdeploy/turbomind/deploy/source_model/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,35 @@ def _awq_dequant(self, prefix: str):
w = dequantize_gemm(qweight, qzeros, scales, 4, group_size)
return w.t() # [in, out] → [out, in] (PyTorch convention)

@staticmethod
def _compressed_tensors_dequant(weight_packed, weight_scale):
"""Dequantize a compressed-tensors (pack-quantized, symmetric int4)
weight to fp16.

Args:
weight_packed: int32 tensor of shape (out_features, in_features//8).
weight_scale: bf16/fp16 tensor of shape (out_features, in_features//group_size).
Returns:
fp16 tensor of shape (out_features, in_features).
"""
out_features = weight_packed.shape[0]
num_groups = weight_scale.shape[1]
in_features = weight_packed.shape[1] * 8
group_size = in_features // num_groups

# Reinterpret the packed int32 buffer as bytes and unpack two nibbles
# per byte directly into the final fp16 tensor. This avoids creating
# eight temporary fp16 tensors before applying scales.
packed_bytes = weight_packed.contiguous().view(torch.uint8).reshape(out_features, -1)
weight = torch.empty((out_features, in_features), device=weight_packed.device, dtype=torch.float16)
weight[:, 0::2] = (packed_bytes & 0xF).to(torch.float16)
weight[:, 1::2] = (packed_bytes >> 4).to(torch.float16)

scales = weight_scale.to(torch.float16).unsqueeze(-1)
weight = weight.view(out_features, num_groups, group_size)
weight.sub_(8.0).mul_(scales)
return weight.reshape(out_features, in_features)

def linear_attn(self, i: int, kind: str):
if not kind:
return self.filter(r'linear_attn\.', i)
Expand All @@ -329,6 +358,9 @@ def linear_attn(self, i: int, kind: str):
if tensor is None and kind == 'weight':
if f'{prefix}.qweight' in self.params:
tensor = self._awq_dequant(prefix)
elif f'{prefix}.weight_packed' in self.params:
tensor = self._compressed_tensors_dequant(self.params[f'{prefix}.weight_packed'],
self.params[f'{prefix}.weight_scale'])
if tensor is not None:
tensor = self.transform(tensor, kind)
result.append(tensor) # keep None to preserve alignment
Expand Down
52 changes: 51 additions & 1 deletion src/turbomind/kernels/gemm/kernel/sm70_884_4.cu
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,61 @@ void Registry::sm70_884_4()
Add<C::Type< 16, 256, 64, 1, 4, 1, D, S, 2, true, 1, 128>>();
Add<C::Type< 16, 256, 32, 1, 4, 1, D, S, 2, true, 1, 128>>();
Add<C::Type< 16, 128, 32, 1, 4, 1, D, S, 2, true, 1, 128>>();
Add<C::Type< 16, 256, 32, 1, 4, 1, D, S, 2, true, 1, 128>>();
Add<C::Type< 8, 128, 64, 1, 4, 1, D, S, 2, true, 1, 128>>();
// clang-format on
}

if constexpr (1) {
// clang-format off
using C = Config_U4_d<kColMajor>;
Add<C::Type<128, 256, 16, 2, 4, 1, D, D, 2, true, 1, 32, 128, 128>>();
Add<C::Type<128, 128, 16, 2, 2, 1, D, D, 2, true, 1, 32, 64, 128>>();
Add<C::Type<128, 128, 16, 2, 2, 1, D, S, 2, true, 1, 32, 64, 128>>();
Add<C::Type< 96, 128, 32, 2, 2, 1, D, S, 2, true, 1, 32, 48, 128>>();
Add<C::Type< 64, 128, 32, 2, 2, 1, D, D, 2, true, 1, 32, 32, 128>>();
Add<C::Type< 64, 128, 32, 2, 2, 1, D, S, 2, true, 1, 32, 32, 128>>();
Add<C::Type< 64, 128, 16, 1, 4, 1, D, S, 2, true, 1, 32, 32, 128>>();
Add<C::Type< 64, 256, 16, 1, 4, 1, D, S, 2, true, 1, 32, 64, 128>>();
Add<C::Type< 32, 128, 32, 1, 4, 1, D, S, 2, true, 1, 32>>();
Add<C::Type< 32, 256, 32, 1, 4, 1, D, S, 2, true, 1, 32, 32, 128>>();
Add<C::Type< 16, 128, 32, 1, 4, 1, D, S, 2, true, 1, 32>>();
Add<C::Type< 16, 256, 32, 1, 4, 1, D, S, 2, true, 1, 32>>();
Add<C::Type< 8, 128, 64, 1, 4, 1, D, S, 2, true, 1, 32>>();
Add<C::Type< 8, 128, 32, 1, 4, 1, D, S, 2, true, 1, 32>>();
Add<C::Type< 8, 256, 64, 1, 4, 1, D, S, 2, true, 1, 32>>();
Add<C::Type< 48, 128, 32, 1, 4, 1, D, S, 2, true, 1, 32>>();
Add<C::Type< 16, 256, 64, 1, 4, 1, D, S, 2, true, 1, 32>>();
Add<C::Type< 16, 128, 64, 1, 4, 1, D, S, 2, true, 1, 32>>();
Add<C::Type< 8, 256, 32, 1, 4, 1, D, S, 2, true, 1, 32>>();
Add<C::Type< 32, 128, 64, 1, 4, 1, D, S, 2, true, 1, 32>>();
Add<C::Type< 64, 256, 32, 1, 4, 1, D, S, 2, true, 1, 32, 64, 128>>();
// clang-format on
}

if constexpr (1) {
// clang-format off
using C = Config_U4_g<kColMajor>;
Add<C::Type<128, 256, 16, 2, 4, 1, D, D, 2, 0 , 1, 32, 128, 128>>();
Add<C::Type<128, 128, 16, 2, 2, 1, D, D, 2, true, 1, 32, 64, 128>>();
Add<C::Type< 64, 128, 32, 1, 4, 1, D, S, 2, true, 1, 32, 32, 128>>();
Add<C::Type< 64, 256, 16, 1, 4, 1, D, S, 2, true, 1, 32, 64, 128>>();
Add<C::Type< 32, 128, 32, 1, 4, 1, D, S, 2, true, 1, 32>>();
Add<C::Type< 32, 256, 32, 1, 4, 1, D, S, 2, true, 1, 32>>();
Add<C::Type< 16, 256, 64, 1, 4, 1, D, S, 2, true, 1, 32>>();
Add<C::Type< 16, 256, 32, 1, 4, 1, D, S, 2, true, 1, 32>>();
Add<C::Type< 16, 128, 32, 1, 4, 1, D, S, 2, true, 1, 32>>();
Add<C::Type< 8, 128, 64, 1, 4, 1, D, S, 2, true, 1, 32>>();
Add<C::Type< 8, 128, 32, 1, 4, 1, D, S, 2, true, 1, 32>>();
Add<C::Type< 8, 128, 128, 1, 4, 1, D, S, 2, true, 1, 32>>();
Add<C::Type< 48, 128, 32, 1, 4, 1, D, S, 2, true, 1, 32>>();
Add<C::Type< 16, 128, 64, 1, 4, 1, D, S, 2, true, 1, 32>>();
Add<C::Type< 8, 256, 64, 1, 4, 1, D, S, 2, true, 1, 32>>();
Add<C::Type< 8, 256, 32, 1, 4, 1, D, S, 2, true, 1, 32>>();
Add<C::Type< 32, 256, 64, 1, 4, 1, D, S, 2, true, 1, 32>>();
Add<C::Type< 64, 256, 32, 1, 4, 1, D, S, 2, true, 1, 32, 64, 128>>();
// clang-format on
}

if constexpr (1) {
// clang-format off
using C = Config_MXF4<kColMajor, 0>;
Expand Down
Loading
Loading