Skip to content

[BUG] Incorrect zero-point packing shape for GPTQ MatMulNBits #20

@AyoubMDL

Description

@AyoubMDL

Hello, when quantizing a model with 4bits GPTQ group-wise, the zero point packing shape is not the expected from ORT MatMulNBits

Error

2026-01-25 22:06:17.750132480 [E:onnxruntime:, sequential_executor.cc:572 ExecuteKernel] Non-zero status code returned while running MatMulNBits node. Name:'' Status Message: Input 'zero_points' is expected to have shape {256} or {128,2}, got {512,1}

To reproduce

import numpy as np
import onnx
import onnxruntime
from onnxruntime.quantization import CalibrationDataReader
from quark.onnx import Config, ModelQuantizer
from quark.onnx.quantization.config import get_default_config


def get_matmul_model(path):
    model = onnx.parser.parse_model("""
                < ir_version: 10, opset_import: ["" : 21] >
                test_model (float[N, 512] X) => (float [N, ?] Y)
                <float[512, 128] W1, float[128, 32] W2>
                {
                    x1 = MatMul(X, W1)
                    Y = MatMul(x1, W2)
                }
            """)
    rng = np.random.default_rng(0)
    W1 = onnx.numpy_helper.from_array(rng.normal(size=(512, 128)).astype(np.float32), name="W1")
    W2 = onnx.numpy_helper.from_array(rng.normal(size=(128, 32)).astype(np.float32), name="W2")
    model.graph.initializer.extend([W1, W2])
    onnx.checker.check_model(model, full_check=True)
    onnx.save(model, path)


class DataReader(CalibrationDataReader):
    def __init__(self, input_tensor):
        self.data = [input_tensor]
        self.input_name = "X"
        self.index = 0

    def get_next(self):
        if self.index < len(self.data):
            input_dict = {self.input_name: self.data[self.index]}
            self.index += 1
            return input_dict
        else:
            return None

    def __len__(self):
        return len(self.data)

    def rewind(self):
        self.index = 0


def prepare_config():
    config = get_default_config("MATMUL_NBITS")
    config.extra_options["MatMulNBitsParams"]["Algorithm"] = "GPTQ"
    config.extra_options["GPTQParams"] = {"GroupSize": 128, "Bits": 4, "PerChannel": True}
    quant_config = Config(global_quant_config=config)

    return quant_config


def prepare_data(input_tensor):
    data_reader = DataReader(input_tensor)
    return data_reader


def prepare_quantizer(quant_config):
    quantizer = ModelQuantizer(quant_config)
    return quantizer


def quantize_static(quantizer, input_model_path, output_model_path, data_reader):
    quantizer.quantize_model(input_model_path, output_model_path, data_reader)
    print("Quantized the ONNX model and saved it at:", output_model_path)

    return output_model_path


def infer_quantized_model(quantized_model_path, input_tensor):
    sess = onnxruntime.InferenceSession(quantized_model_path)

    input_name = sess.get_inputs()[0].name
    output_name = sess.get_outputs()[0].name

    input_data = input_tensor
    output = sess.run([output_name], {input_name: input_data})

    return output


def tensor_quantize_gptq():
    # Create and save the model
    input_model_path = "matmul_model.onnx"
    get_matmul_model(input_model_path)

    input_tensor = np.random.rand(4, 512).astype(np.float32)
    data_reader = prepare_data(input_tensor)
    quant_config = prepare_config()
    quantizer = prepare_quantizer(quant_config)
    quantized_model_path = quantize_static(
        quantizer, input_model_path, "matmul_model_quantized.onnx", data_reader
    )

    # Check inference
    output = infer_quantized_model(quantized_model_path, input_tensor)
    print(output)

tensor_quantize_gptq()

I can provide a fix that corrects zero point shape (Reference: https://github.com/AyoubMDL/onnx_quantize/blob/aa7664446b25ce7ae9874a754cb491dc1f2b1a02/src/onnx_quantize/qrules/_common.py#L155)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions