Skip to content

[minor] Refactor TE fused-norm handling in GPTModelExporter#1061

Open
yueshen2016 wants to merge 1 commit intomainfrom
yueshen/refactor-te-fused-norm
Open

[minor] Refactor TE fused-norm handling in GPTModelExporter#1061
yueshen2016 wants to merge 1 commit intomainfrom
yueshen/refactor-te-fused-norm

Conversation

@yueshen2016
Copy link
Contributor

@yueshen2016 yueshen2016 commented Mar 17, 2026

What does this PR do?

Type of change: Refactor (no behavior change)

Extract a _get_fused_norm_weight helper method in GPTModelExporter to consolidate the repeated
TE fused-norm detection logic that previously appeared as three separate multi-condition elif blocks
in _get_transformer_layer_state_dict (attention and MLP paths) and _get_mamba_layer_state_dict.

Changes:

  • Add _get_fused_norm_weight(module) that checks "fused_norm" in self.rules and returns getattr(module, "layer_norm_weight", None)
  • Replace double hasattr chains with getattr(..., None) chaining — getattr already returns None for missing attributes
  • Remove redundant isinstance(layer.norm, IdentityOp) in the Mamba elif (guaranteed by being an elif branch)
  • Use walrus operator (:=) to capture norm_weight without repeating the attribute traversal on the call line

Before / After

Before (one of three nearly-identical blocks):

elif (
    hasattr(layer.self_attention, "linear_qkv")
    and hasattr(layer.self_attention.linear_qkv, "layer_norm_weight")
    and layer.self_attention.linear_qkv.layer_norm_weight is not None
    and "fused_norm" in self.rules
):
    self.rules["fused_norm"](layer.self_attention.linear_qkv.layer_norm_weight, layer_id)

After:

elif (
    norm_weight := self._get_fused_norm_weight(
        getattr(layer.self_attention, "linear_qkv", None)
    )
) is not None:
    self.rules["fused_norm"](norm_weight, layer_id)

Testing

No behavior change — existing tests cover all paths.

  • Is this change backward compatible?: ✅ Pure refactor, no API or logic change
  • Did you write any new necessary tests?: N/A
  • Did you update Changelog?: N/A (minor refactor)

🤖 Generated with Claude Code

Summary by CodeRabbit

  • Refactor
    • Centralized and simplified fused-normalization handling in model export, reducing duplicated checks and streamlining control flow while preserving existing behavior and compatibility. Improved maintainability and consistency across export paths.

@yueshen2016 yueshen2016 requested a review from a team as a code owner March 17, 2026 19:05
@yueshen2016 yueshen2016 requested a review from cjluo-nv March 17, 2026 19:05
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 17, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: ab51eae5-2996-4115-b17a-01bac4a4b285

📥 Commits

Reviewing files that changed from the base of the PR and between 6ac5a43 and 8783a4f.

📒 Files selected for processing (1)
  • modelopt/torch/export/unified_export_megatron.py

📝 Walkthrough

Walkthrough

Adds a private helper _get_fused_norm_weight(self, module) and refactors transformer, MLP, MTP/MoE, and Mamba code paths to use it (via a walrus-style assignment) for fused-norm detection, replacing scattered hasattr/None checks while preserving prior behavior.

Changes

Cohort / File(s) Summary
Norm-fusion helper consolidation
modelopt/torch/export/unified_export_megatron.py
Added _get_fused_norm_weight() to centralize retrieval of layer_norm_weight. Refactored transformer, MLP, MTP/MoE, and Mamba branches to use the helper with a walrus-assignment pattern, replacing inline hasattr/None checks and conditionally applying fused_norm when weight is present.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 25.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the main refactoring: consolidating fused-norm handling logic in GPTModelExporter through a new helper method.
Security Anti-Patterns ✅ Passed Pure refactor adding helper method with safe operations, no security-sensitive patterns introduced.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch yueshen/refactor-te-fused-norm
📝 Coding Plan
  • Generate coding plan for human review comments

Comment @coderabbitai help to get the list of available commands and usage tips.

Tip

CodeRabbit can scan for known vulnerabilities in your dependencies using OSV Scanner.

OSV Scanner will automatically detect and report security vulnerabilities in your project's dependencies. No additional configuration is required.

@yueshen2016 yueshen2016 requested a review from ChenhanYu March 17, 2026 19:06
@codecov
Copy link

codecov bot commented Mar 17, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 70.31%. Comparing base (00fa5bd) to head (8783a4f).

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1061      +/-   ##
==========================================
+ Coverage   70.30%   70.31%   +0.01%     
==========================================
  Files         227      227              
  Lines       25847    25847              
==========================================
+ Hits        18172    18175       +3     
+ Misses       7675     7672       -3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Extract `_get_fused_norm_weight` helper to consolidate the repeated
three-condition check (`"fused_norm" in self.rules`, attribute
navigation, `layer_norm_weight is not None`) that appeared in
`_get_transformer_layer_state_dict` (attention and MLP paths) and
`_get_mamba_layer_state_dict`.

- Replace double `hasattr` chains with `getattr(..., None)` chaining
- Remove redundant `isinstance(layer.norm, IdentityOp)` in Mamba elif
  (already guaranteed by being in the elif branch)
- Use walrus operator to capture `norm_weight` without repeating the
  attribute access on the call site

No behavior change.

Signed-off-by: James Shen <yueshen@nvidia.com>
@yueshen2016 yueshen2016 force-pushed the yueshen/refactor-te-fused-norm branch from 6ac5a43 to 8783a4f Compare March 17, 2026 21:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant