Skip to content

Problem with memory offloading for AWQ quantizations #795

@DataSnake

Description

@DataSnake

Before submitting an issue, please make sure it hasn't been already addressed by searching through the existing and past issues.

Describe the bug

  • When quantizing a model with the quantizations nvfp4_awq, int4_awq, or w4a8_awq on a GPU that can't fit the entire thing in VRAM at once, quantization will finish correctly (as shown by sample outputs generated by the script before and after quantization), but the export process will fail because some tensors are still on the meta device.
  • This does not happen with quantization set to regular nvfp4 or w4a8_nvfp4_fp8, even if all other settings are identical
  • It also doesn't happen if sufficient VRAM for the whole model is available, only when it's partially offloaded
  • Traceback from attempting to quantize a model to nvfp4_awq:
Traceback (most recent call last):
  File "/app/Model-Optimizer/examples/llm_ptq/hf_ptq.py", line 1025, in <module>
    main(args)
  File "/app/Model-Optimizer/examples/llm_ptq/hf_ptq.py", line 1004, in main
    quantize_main(
  File "/app/Model-Optimizer/examples/llm_ptq/hf_ptq.py", line 808, in quantize_main
    export_quantized(args, full_model, language_model, model_type, tokenizer, default_padding_side)
  File "/app/Model-Optimizer/examples/llm_ptq/hf_ptq.py", line 538, in export_quantized
    export_hf_checkpoint(
  File "/app/Model-Optimizer/modelopt/torch/export/unified_export_hf.py", line 780, in export_hf_checkpoint
    raise e
  File "/app/Model-Optimizer/modelopt/torch/export/unified_export_hf.py", line 749, in export_hf_checkpoint
    post_state_dict, hf_quant_config = _export_transformers_checkpoint(model, dtype)
                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/app/Model-Optimizer/modelopt/torch/export/unified_export_hf.py", line 638, in _export_transformers_checkpoint
    requantize_resmooth_fused_llm_layers(model)
  File "/app/Model-Optimizer/modelopt/torch/export/unified_export_hf.py", line 241, in requantize_resmooth_fused_llm_layers
    fuse_prequant_to_linear(model)
  File "/app/Model-Optimizer/modelopt/torch/export/quant_utils.py", line 1087, in fuse_prequant_to_linear
    linear_fuse_into.weight * pre_quant_scale.view(-1, 1)
    ~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  File "/usr/local/lib/python3.12/dist-packages/torch/_prims_common/wrappers.py", line 309, in _fn
    result = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_compile.py", line 53, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 1044, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_prims_common/wrappers.py", line 149, in _fn
    result = fn(**bound.arguments)
             ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_refs/__init__.py", line 1107, in _ref
    output = prim(a, b)
             ^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_refs/__init__.py", line 1714, in mul
    return prims.mul(a, b)
           ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 841, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_library/fake_impl.py", line 109, in meta_kernel
    return fake_impl_holder.kernel(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_library/utils.py", line 22, in __call__
    return self.func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/library.py", line 1425, in inner
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py", line 627, in fake_impl
    return self._abstract_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_prims/__init__.py", line 404, in _prim_elementwise_meta
    utils.check_same_device(*args_, allow_cpu_scalar_tensors=True)
  File "/usr/local/lib/python3.12/dist-packages/torch/_prims_common/__init__.py", line 867, in check_same_device
    raise RuntimeError(msg)
RuntimeError: Tensor on device cuda:0 is not on the expected device meta!
  • While the device mismatch can be worked around by editing modelopt/torch/export/quant_utils.py to change line 1087 to linear_fuse_into.weight * pre_quant_scale.to(linear_fuse_into.weight.device).view(-1, 1) and line 1120 to layernorm_module.weight * getattr(modules[0].input_quantizer.to(layernorm_module.weight.device), "_pre_quant_scale") the script will still crash with a different error that's also caused by tensors being meta when they shouldn't be:
Traceback (most recent call last):
  File "/app/Model-Optimizer/examples/llm_ptq/hf_ptq.py", line 1025, in <module>
    main(args)
  File "/app/Model-Optimizer/examples/llm_ptq/hf_ptq.py", line 1004, in main
    quantize_main(
  File "/app/Model-Optimizer/examples/llm_ptq/hf_ptq.py", line 808, in quantize_main
    export_quantized(args, full_model, language_model, model_type, tokenizer, default_padding_side)
  File "/app/Model-Optimizer/examples/llm_ptq/hf_ptq.py", line 538, in export_quantized
    export_hf_checkpoint(
  File "/app/Model-Optimizer/modelopt/torch/export/unified_export_hf.py", line 780, in export_hf_checkpoint
    raise e
  File "/app/Model-Optimizer/modelopt/torch/export/unified_export_hf.py", line 749, in export_hf_checkpoint
    post_state_dict, hf_quant_config = _export_transformers_checkpoint(model, dtype)
                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/app/Model-Optimizer/modelopt/torch/export/unified_export_hf.py", line 664, in _export_transformers_checkpoint
    _process_quantized_modules(model, dtype, is_modelopt_qlora)
  File "/app/Model-Optimizer/modelopt/torch/export/unified_export_hf.py", line 534, in _process_quantized_modules
    _export_quantized_weight(sub_module, dtype)
  File "/app/Model-Optimizer/modelopt/torch/export/unified_export_hf.py", line 403, in _export_quantized_weight
    quantizer_attrs.weight_scale, get_weight_scaling_factor(sub_module, weight_name)
                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/app/Model-Optimizer/modelopt/torch/export/quant_utils.py", line 286, in get_weight_scaling_factor
    weight_scaling_factor_2.to(weight.device),
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
NotImplementedError: Cannot copy out of meta tensor; no data!

Steps/Code to reproduce bug

  • run examples/llm_ptq/hf_ptq.py to quantize a model too large to fit into the available VRAM into an AWQ format. This can also be done with a small model by setting gpu_max_mem_percentage to a small value, if you want to test it quickly.
  • Example command: python hf_ptq.py --pyt_ckpt_path Qwen/Qwen3-0.6B --qformat nvfp4_awq --dataset cnn_dailymail --calib_size 64 --export_path ~/Qwen3-0.6B-NVFP4-AWQ --use_seq_device_map --gpu_max_mem_percentage 0.05
  • This is on a 5060 Ti 16GB. To reproduce the error on a card with more VRAM, use either a larger model or a smaller value for gpu_max_mem_percentage

Expected behavior

The model should be quantized, with some of the tensors offloaded to system RAM because there isn't enough VRAM to hold everything, then exported in huggingface format to the specified directory.

Who can help?

  • ?

System information

  • Container used (if applicable): Docker container using the image nvcr.io/nvidia/tensorrt-llm/release:1.2.0rc4
  • OS (e.g., Ubuntu 22.04, CentOS 7, Windows 10): Ubuntu 24.04.3 LTS
  • CPU architecture (x86_64, aarch64): x86_64
  • GPU name (e.g. H100, A100, L40S): NVIDIA GeForce RTX 5060 Ti
  • GPU memory size: 15.9 GB
  • Number of GPUs: 2
  • Library versions (if applicable):
    • Python: 3.12.3
    • ModelOpt version or commit hash: 0.41.0rc2.dev30+gc1956b8e2
    • CUDA: 13.0
    • PyTorch: 2.9.0a0+145a3a7bda.nv25.10
    • Transformers: 4.56.0
    • TensorRT-LLM: 1.2.0rc4
    • ONNXRuntime: 1.22.0
    • TensorRT: 10.13.3.9
  • Any other details that may help: while the number of GPUs is listed as 2, the second is a GT 1030 that just runs my monitor. I have CUDA_DEVICE_ORDER set to PCI_BUS_ID and CUDA_VISIBLE_DEVICES set to 0 so that the only GPU it actually tries to use is my 5060 Ti.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions