Skip to content

fix gemma4 num attention head bugs#7975

Open
mingxiang1006 wants to merge 14 commits intodeepspeedai:masterfrom
mingxiang1006:master
Open

fix gemma4 num attention head bugs#7975
mingxiang1006 wants to merge 14 commits intodeepspeedai:masterfrom
mingxiang1006:master

Conversation

@mingxiang1006
Copy link
Copy Markdown

Error as there is module under Gemma4Config, either Gemma4 text config, visual or audio config, to grab the num attention head. This will cause run time error during deepspeed launch.

@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Credits must be used to enable repository wide code reviews.

@delock
Copy link
Copy Markdown
Collaborator

delock commented Apr 15, 2026

Hi @mingxiang1006 , thanks for the fix. I have two questions:

  1. Should text_config be used as a fallback when hf_model_config does not have the keys needed, instead of going there when it first appears?
  2. I saw vision_config and text_config has different number of attention heads, if text_config is picked, does it mean during training/inference only text related weights are used?

@mingxiang1006
Copy link
Copy Markdown
Author

Hi @mingxiang1006 , thanks for the fix. I have two questions:

  1. Should text_config be used as a fallback when hf_model_config does not have the keys needed, instead of going there when it first appears?
  2. I saw vision_config and text_config has different number of attention heads, if text_config is picked, does it mean during training/inference only text related weights are used?

HI @delock --> it was a temporary fix . agree with your suggestion, should fall back to text config, if hf_model_config does not have the key.

Yes it need further thought into this, when to trigger text config and vision.

@delock
Copy link
Copy Markdown
Collaborator

delock commented Apr 15, 2026

Hi @mingxiang1006 , thanks for the fix. I have two questions:

  1. Should text_config be used as a fallback when hf_model_config does not have the keys needed, instead of going there when it first appears?
  2. I saw vision_config and text_config has different number of attention heads, if text_config is picked, does it mean during training/inference only text related weights are used?

HI @delock --> it was a temporary fix . agree with your suggestion, should fall back to text config, if hf_model_config does not have the key.

Yes it need further thought into this, when to trigger text config and vision.

Hi, we can start from making text_config a fallback path.

For pick between text config and vision, does the modeling know which one is being used? It might be okay to stay with text_config for the time being, because Ulysses SP are likely work on text than vision, but I want to have better understanding of the mechanism behind Gemma4.

@sfc-gh-truwase sfc-gh-truwase requested a review from delock April 20, 2026 10:44
@delock
Copy link
Copy Markdown
Collaborator

delock commented Apr 21, 2026

Hi @mingxiang1006 , can you update the PR to use text_config as fallback path? I'll start merge process if this is done. Also remember to signoff each of your commits for DCO check. Thanks!

mingxiang1006 and others added 4 commits April 21, 2026 09:51
Signed-off-by: ming.lee <ming.lee@inceptionai.ai>
Signed-off-by: ming.lee <ming.lee@inceptionai.ai>
…impl.cu (deepspeedai#7973)

## Summary

- Fix `warning deepspeedai#68-D: integer conversion resulted in a change of sign`
by using unsigned literal `1U` for all bit-shift expressions storing to
unsigned types (`csrc/fp_quantizer/fp_quantize_impl.cu`)
- Fix `warning deepspeedai#62-D: shift count is negative` by removing unused
`mantisa_mask` dead code in `apply_dequantization` and
`apply_selective_dequantization`

The `_sign_mask` computation `1 << (_mantisa_bits + _exponent_bits)` in
`apply_quantization` shifts a signed `int` literal by 31 bits
(`_mantisa_bits=23, _exponent_bits=8`), which is undefined behavior in
C++. Using `1U` makes the shift well-defined. For consistency and
defensive programming, the same `1` → `1U` change is applied to all
similar patterns in `round()`, `apply_dequantization`, and
`apply_selective_dequantization`.

The `mantisa_mask` variable in both dequantization functions was
copy-pasted from the quantization function but **never used** in the
dequantization code paths. Its initialization `mantisa_mask <<=
(_mantisa_bits - q_mantisa_bits)` always produces a negative shift count
because in these functions `_mantisa_bits` (quantized format, small:
1-7) < `q_mantisa_bits` (output format, large: 7 or 10).

> **Note:** The issue suggested the template argument order in
`launch_dequantization` / `launch_selective_dequantization` might be
wrong, but analysis of the function body confirms the original order is
correct — `quantized_bits = _mantisa_bits + _exponent_bits + 1`
correctly computes the quantized format total bits, and `dst_mantisa <<
(q_mantisa_bits - _mantisa_bits)` correctly left-shifts the quantized
mantissa into the output format's mantissa field. The warnings came
solely from the unused dead code.

Fixes deepspeedai#7971

## Before / After

<details>
<summary>Before (18+ warnings)</summary>

```
$ nvcc -c fp_quantize_impl.cu -DBF16_AVAILABLE --expt-relaxed-constexpr \
       -gencode arch=compute_86,code=sm_86 -std=c++17

fp_quantize_impl.cu(82):  warning deepspeedai#68-D: integer conversion resulted in a change of sign
fp_quantize_impl.cu(244): warning deepspeedai#62-D: shift count is negative   (x9, one per template instantiation)
fp_quantize_impl.cu(426): warning deepspeedai#62-D: shift count is negative   (x9, one per template instantiation)
```

The `mantisa_mask` variable causing the shift warnings is declared but
never used in either dequantization function.

</details>

<details>
<summary>After (0 warnings from this file)</summary>

```
$ nvcc -c fp_quantize_impl.cu -DBF16_AVAILABLE --expt-relaxed-constexpr \
       -gencode arch=compute_86,code=sm_86 -std=c++17 2>&1 | grep -E 'deepspeedai#62-D|deepspeedai#68-D'

(no output — all warnings eliminated)
```

Compilation succeeds with exit code 0. Only unrelated `deepspeedai#821-D` warnings
remain from `memory_access_utils.h`.

</details>

## Changes

### `csrc/fp_quantizer/fp_quantize_impl.cu`

1. **Line 38, 40, 42** (`round()`) — `1` → `1U`: Consistent unsigned
shifts in `mantisa_mask`, `offset`, and exponent overflow check. Not UB
today (`1 << 23` fits in `int`), but prevents future issues and silences
potential sign-conversion warnings.

2. **Line 82** (`apply_quantization`) — `1` → `1U`: Fix actual UB — `1
<< 31` on signed `int` is undefined behavior.

3. **Line 237** (`apply_dequantization`) — `1` → `1U` in `_sign_mask`:
Consistent with `apply_quantization`. Not UB with current template args
(`1 << 7`), but defensive.

4. **Line 416** (`apply_selective_dequantization`) — `1` → `1U` in
`_sign_mask`: Same as above.

5. **Lines 243-244** — Remove unused `mantisa_mask` in
`apply_dequantization`: Copy-pasted from the quantization function but
never referenced in the dequantization code path.

6. **Lines 425-426** — Remove unused `mantisa_mask` in
`apply_selective_dequantization`: Same dead code as above.

## Test plan

- [x] `nvcc` compilation with `-DBF16_AVAILABLE` — 0 `deepspeedai#62-D` / `deepspeedai#68-D`
warnings (was 18+)
- [x] Verified `mantisa_mask` (no underscore prefix) is unused in both
dequantization functions by grepping all occurrences — only used in
`apply_quantization` (line 142) and `round` (lines 38-42)
- [x] Verified template parameter order in `launch_dequantization` and
`launch_selective_dequantization` is correct by tracing all usages of
`_mantisa_bits`, `_exponent_bits`, `q_mantisa_bits` in function bodies
- [x] All `1` → `1U` changes are semantically identical for non-negative
shift counts; the only behavioral fix is line 82 where `1 << 31` was UB

Signed-off-by: Cursx <674760201@qq.com>
Signed-off-by: ming.lee <ming.lee@inceptionai.ai>
Copy link
Copy Markdown
Author

@mingxiang1006 mingxiang1006 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fallback to text config

Cursx and others added 4 commits April 21, 2026 05:58
## Summary

- Fix duplicate/wrong `-gencode=` flags in both JIT and non-JIT
compilation paths (`op_builder/builder.py`)
- Fix `TORCH_CUDA_ARCH_LIST` env-var restore logic in
`OpBuilder.jit_load()`

DeepSpeed's `compute_capability_args()` generates its own `-gencode`
flags, but PyTorch (`load()` in JIT mode, `BuildExtension` in non-JIT
mode) *also* reads `TORCH_CUDA_ARCH_LIST` and generates `-gencode`
flags. This causes two problems:

1. **JIT mode**: `jit_load()` set `TORCH_CUDA_ARCH_LIST=""`, which
PyTorch treats as *unset* and falls back to auto-detection — resulting
in every flag appearing **twice**.
2. **Non-JIT mode**: subclasses that override `filter_ccs()` (e.g.
`FPQuantizerBuilder`, `EvoformerAttnBuilder`) remove certain archs, but
`BuildExtension` re-reads the **unfiltered** `TORCH_CUDA_ARCH_LIST` and
adds them back — **undermining the filter**.

The fix synchronises `TORCH_CUDA_ARCH_LIST` with the filtered arch list
in `compute_capability_args()`, for both JIT and non-JIT paths.

Fixes deepspeedai#7972

## Before / After

<details>
<summary>Before (buggy behavior)</summary>

**JIT mode** — `TORCH_CUDA_ARCH_LIST` cleared to `""`, PyTorch
auto-detects and adds flags, DeepSpeed also adds the same flags:

```
nvcc ... -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_80,code=sm_80
     ... -gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_80,code=compute_80
```

Plus a spurious warning:
```
UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
```

**Non-JIT mode** — `FPQuantizerBuilder.filter_ccs()` removes `< 8.0`,
but `BuildExtension` re-adds them from the unfiltered env var:

```
# FPQuantizer compiled for sm_70 even though filter_ccs() removed it
nvcc ... -gencode=arch=compute_80,code=sm_80   # from DeepSpeed (correct)
     ... -gencode=arch=compute_70,code=sm_70   # from BuildExtension (wrong!)
```

</details>

<details>
<summary>After (fixed behavior)</summary>

**JIT mode** — `TORCH_CUDA_ARCH_LIST` is set to the detected
architectures, PyTorch generates flags once, no duplicates:

```
nvcc ... -gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_80,code=compute_80
```

No spurious warning. Env var is properly restored/removed after build.

**Non-JIT mode** — `TORCH_CUDA_ARCH_LIST` is updated to the filtered
list. Each extension keeps its own `-gencode` flags, and
`BuildExtension` reads the filtered env var:

```
# FPQuantizer: only sm_80+ as intended
nvcc ... -gencode=arch=compute_80,code=sm_80   # from DeepSpeed
     ... -gencode=arch=compute_80,code=sm_80   # from BuildExtension (harmless dup)
```

> **Note:** in multi-builder `setup.py` builds, the last builder's
filtered arch list wins for `TORCH_CUDA_ARCH_LIST`. This may cause
harmless duplicates for some extensions, but will never reintroduce
archs that any builder's `filter_ccs()` removed — a strict improvement
over the current behavior where the unfiltered original is always used.

</details>

## Changes

- `op_builder/builder.py`
  - `CUDAOpBuilder.compute_capability_args()`:
    - Always sync `TORCH_CUDA_ARCH_LIST` with the filtered arch list
    - JIT mode: return `[]` (PyTorch generates flags via `load()`)
- Non-JIT mode: return `-gencode` args as before (per-builder flags in
`extra_compile_args`)
- `OpBuilder.jit_load()`: simplified stash/restore — properly `del` the
env var if it was not originally set

---------

Signed-off-by: Cursx <674760201@qq.com>
Signed-off-by: ming.lee <ming.lee@inceptionai.ai>
Followup to deepspeedai#7973 for deepspeedai#7971

The naming of q_mantisa_bits and mantisa_bits was swapped. The
invocation set:

```
q_mantisa_bits = mantisa
_mantisa_bits = CONST_Q_MANTISA_BITS
_exponent_bits = CONST_Q_EXPONENT_BITS
```

so correct them by swapping the names back.

I noticed that the code needs a thorough review because multiple places
look suspicious:
```
	// Why the default args? They seem to not even be matching (16 != 3+4+1)
          int total_q_bits = 16,
          int q_mantisa_bits = 3,
          int q_exponent_bits = 4>

	// Why recompute if there is a total_q_bits template?
    constexpr int quantized_bits = q_mantisa_bits + q_exponent_bits + 1;
    // Likely wrong: total_q_bits < mantisa_bits --> negative bits? Likely caused by wrong naming
    constexpr int q_exponent_bits = total_q_bits - mantisa_bits - 1;

    // should likey use a `q_` prefix not `_`
    constexpr uint16_t _mantisa_mask = (1 << q_mantisa_bits) - 1;
    constexpr uint16_t _exponent_mask = ((1 << q_exponent_bits) - 1) << q_mantisa_bits;
    constexpr uint16_t _sign_mask = 1U << (q_mantisa_bits + q_exponent_bits);
```

cc @Cursx

Signed-off-by: Alexander Grund <alexander.grund@tu-dresden.de>
Signed-off-by: ming.lee <ming.lee@inceptionai.ai>
Currently the CI full test shows a [CUDA reinit
error](https://github.com/deepspeedai/DeepSpeed/actions/runs/24444633640/job/71417719445).
This PR includes the following fixes:

- Fix `compute_capability_args()` in JIT mode to read
`TORCH_CUDA_ARCH_LIST` before calling
`torch.cuda.get_device_capability()` and restores JIT builder state
after `jit_load()`. It also adds regression tests for the explicit-arch,
bad-fork, and restore paths.
- Delay initialization of CUDA streams in DeepCompile

After this fix, the full test
[passed](https://github.com/deepspeedai/DeepSpeed/actions/runs/24508304055/job/71632434455)
again.

---------

Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: ming.lee <ming.lee@inceptionai.ai>
@delock
Copy link
Copy Markdown
Collaborator

delock commented Apr 22, 2026

fallback to text config

Hi @mingxiang1006 Following this document, the following line should be the cleanest way to pick text config, can you try if it works for your case?
arch_cfg = hf_model_config.get_text_config()

Also can you fix DCO CI test error by signoff your commits? Thanks!

local_seq_length = seq_length // mpu.get_sequence_parallel_world_size()
global_seq_length = seq_length

## Fix here
Copy link
Copy Markdown
Collaborator

@delock delock Apr 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also remove ## Fix here since this comment looks vague?

Signed-off-by: ming.lee <ming.lee@inceptionai.ai>
@delock
Copy link
Copy Markdown
Collaborator

delock commented Apr 28, 2026

fallback to text config

Hi @mingxiang1006 Following this document, the following line should be the cleanest way to pick text config, can you try if it works for your case? arch_cfg = hf_model_config.get_text_config()

Also can you fix DCO CI test error by signoff your commits? Thanks!

Hi @mingxiang1006 can you signoff your commits to fix the DCO CI error, and try whether hf_model_config.get_text_config() works? Thanks!

Signed-off-by: ming.lee <ming.lee@inceptionai.ai>
@delock
Copy link
Copy Markdown
Collaborator

delock commented Apr 29, 2026

Hi @mingxiang1006, DCO check is necessary for contributors to certificate their contributions to opensource projects. To pass DCO check, each commit from contributor needs to be sign-off. The following link contains description how to resolve DCO check error for this specific PR, can you follow the steps in the link? Otherwise the merge process is blocked. Thanks for your help!

https://github.com/deepspeedai/DeepSpeed/pull/7975/checks?check_run_id=73365745940

@mingxiang1006
Copy link
Copy Markdown
Author

fallback to text config

Hi @mingxiang1006 Following this document, the following line should be the cleanest way to pick text config, can you try if it works for your case? arch_cfg = hf_model_config.get_text_config()
Also can you fix DCO CI test error by signoff your commits? Thanks!

Hi @mingxiang1006 can you signoff your commits to fix the DCO CI error, and try whether hf_model_config.get_text_config() works? Thanks!

Hi @delock , this version is cleaner indeed, but that mean it will only grab the text config, and ignore others. Appreciate for your feedback.

Signed-off-by: ming.lee <ming.lee@inceptionai.ai>
@delock
Copy link
Copy Markdown
Collaborator

delock commented May 1, 2026

Hi @mingxiang1006 thanks for sign-off your latest commit. I believe the past commit also needs sign-off followed by a force push. The purpose of DCO is to make sure every one of your commits are sign-off to the opensource software.

If you follow the DCO check link (https://github.com/deepspeedai/DeepSpeed/pull/7975/checks?check_run_id=73523193924) You can find the following instructions. Follow it will get the DCO check pass.
"To add your Signed-off-by line to every commit in this branch:

Ensure you have a local copy of your branch by checking out the pull request locally via command line.
In your local branch, run: git rebase HEAD~14 --signoff
Force push your changes to overwrite the branch: git push --force-with-lease origin master"

Thanks if you can fix the DCO error following the instructions above. You will almost immediately see the DCO check become green when its done. Thanks!

@mingxiang1006
Copy link
Copy Markdown
Author

Hi @mingxiang1006 thanks for sign-off your latest commit. I believe the past commit also needs sign-off followed by a force push. The purpose of DCO is to make sure every one of your commits are sign-off to the opensource software.

If you follow the DCO check link (https://github.com/deepspeedai/DeepSpeed/pull/7975/checks?check_run_id=73523193924) You can find the following instructions. Follow it will get the DCO check pass. "To add your Signed-off-by line to every commit in this branch:

Ensure you have a local copy of your branch by checking out the pull request locally via command line. In your local branch, run: git rebase HEAD~14 --signoff Force push your changes to overwrite the branch: git push --force-with-lease origin master"

Thanks if you can fix the DCO error following the instructions above. You will almost immediately see the DCO check become green when its done. Thanks!

Hi @delock , I just did that, hope this is resolved. Thank you.

@delock
Copy link
Copy Markdown
Collaborator

delock commented May 2, 2026

Hi Mingxiang, seems there are other test errors as well. I'll open a new branch and invite you to test on this new branch, hope this works!

@delock
Copy link
Copy Markdown
Collaborator

delock commented May 6, 2026

@mingxiang1006 the new PR is created, #7990, this PR has exact your change along with a UT with multimodality and DCO issue fixed. Appreciate if you can tell me whether this new PR branch fixes Gemma4 issue!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants