Skip to content

fix: update FP8 syntax for custom torch._scaled_mm on CPU#201

Merged
chichun-charlie-liu merged 7 commits intomainfrom
fp8_cpu_op_fix
Mar 30, 2026
Merged

fix: update FP8 syntax for custom torch._scaled_mm on CPU#201
chichun-charlie-liu merged 7 commits intomainfrom
fp8_cpu_op_fix

Conversation

@andrea-fasoli
Copy link
Copy Markdown
Collaborator

@andrea-fasoli andrea-fasoli commented Mar 26, 2026

Description of the change

This PR supersedes an earlier fix to the handling of FP8 operations on CPU (#200). This earlier fix resulted in an error when the model was compiled (fms-mo tests did not pick up on this problem).

In PyTorch 2.10, torch._scaled_mm does not support FP8 matmul on CPU with per-channel or per-token quantization (only per-tensor is supported). A fallback to a custom operation already existed at fms_mo/aiu_addons/fp8/fp8_spyre_op.py, but the registration of this fallback operation did not work in PyTorch 2.10, most likely due to an update to allowed signatures for op override.

This PR updates the registration signature to a modern syntax.

Related issues or PRs

[internal issue]

How to verify the PR

Example of a test that should pass, ran on a pod with 4 AIUs, in PF mode, in PyTorch 2.10 env (set up env vars according to your case; AFTU = aiu-fms-testing-utils repo):

torchrun --nproc-per-node 4 ${AFTU_PATH}/scripts/drive_paged_programs.py --model_variant ${FP8_MODEL_PATH} --max_new_tokens 128 --timing per-token --dataset_type sharegpt --dataset_path ${DATASET_PATH} --test_type metrics --program_criteria_json_path ${PROGRAMS_FILE} --programs ${SELECTED_PROGRAM} --attention_type paged_fp8 --save_validation_info_outputs --validation_info_outputs_dir ${OUTPUT_DIR} --prefill_chunk_size 1024 --cross_entropy_threshold 2.6 --failure_rate_threshold 0.1 --prioritize_large_batch_sizes --enforce_homogeneous_prompt_programs --distributed

Was the PR tested

  • I have ensured all unit tests pass
  • I verified compilation is successful
  • I verified inference is functional on CPU with both a compiled and non-compiled model
  • I verified inference is functional on Spyre with a compiled model

Checklist for passing CI/CD:

  • All commits are signed showing "Signed-off-by: Name <email@domain.com>" with git commit -signoff or equivalent
  • PR title and commit messages adhere to Conventional Commits
  • Contribution is formatted with tox -e fix
  • Contribution passes linting with tox -e lint
  • Contribution passes spellcheck with tox -e spellcheck
  • Contribution passes all unit tests with tox -e unit

Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Copy link
Copy Markdown
Collaborator

@chichun-charlie-liu chichun-charlie-liu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@chichun-charlie-liu chichun-charlie-liu merged commit d498851 into main Mar 30, 2026
14 checks passed
@chichun-charlie-liu chichun-charlie-liu deleted the fp8_cpu_op_fix branch March 30, 2026 17:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants