Skip to content

[bugfix] Fix abnormal grad_norm under GRPO LoRA + DeepSpeed ZeRO-0 (fix #6815)#8341

Open
alphadl wants to merge 1 commit intomodelscope:mainfrom
alphadl:fix/6815-grad-norm-zero0-reduce
Open

[bugfix] Fix abnormal grad_norm under GRPO LoRA + DeepSpeed ZeRO-0 (fix #6815)#8341
alphadl wants to merge 1 commit intomodelscope:mainfrom
alphadl:fix/6815-grad-norm-zero0-reduce

Conversation

@alphadl
Copy link

@alphadl alphadl commented Mar 14, 2026

#6815

Under DeepSpeed ZeRO-0 (or plain DDP), the logged grad_norm was abnormally large (e.g. ~1656) compared to ZeRO-3 (e.g. ~0.025) with the same loss. This made monitoring and debugging difficult and was reported by multiple users.

Root cause: With ZeRO-0/DDP, each rank may report a different gradient norm (e.g. local view before or after reduce). ZeRO-3 reports a global norm. The trainer was logging the per-rank value without reduction, so the logged grad_norm was inconsistent across ZeRO stages.

Changes:

  • swift/trainers/mixin.py:
    • Add _get_reduced_grad_norm_for_logging(grad_norm) to reduce grad_norm across processes when not ZeRO-3 and world_size > 1 (all-reduce with ReduceOp.AVG), so that the logged value is consistent and comparable across ZeRO stages.
    • In _maybe_log_save_evaluate, use _get_reduced_grad_norm_for_logging(args[0]) when building logs['grad_norm'] (transformers >= 4.38).
  • tests/train/test_grad_norm_reduce.py: Add unit tests for _get_reduced_grad_norm_for_logging (None/float/tensor, single-process, ZeRO-3 no reduce, ZeRO-0 with all_reduce).

Behavior:

  • Single process or ZeRO-3: no change (existing behavior).
  • Multi-process + ZeRO-0/1/2: grad_norm is all-reduced (average) before logging, so logs show a single consistent value.

…odelscope#6815) - Add _get_reduced_grad_norm_for_logging to reduce grad_norm across processes when not ZeRO-3 and world_size > 1 (AVG) - Use it in _maybe_log_save_evaluate for logs['grad_norm'] - Add tests/train/test_grad_norm_reduce.py
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses an issue where grad_norm logging was inconsistent and abnormally high under DeepSpeed ZeRO-0 or plain DDP, making monitoring difficult. The core change introduces a new utility to average the grad_norm across all processes in distributed training environments (except ZeRO-3, which already handles global norms), ensuring that the logged value is consistent and comparable regardless of the DeepSpeed configuration. This significantly improves the reliability of gradient norm monitoring for users.

Highlights

  • Consistent Gradient Norm Logging: Implemented a mechanism to ensure grad_norm is consistently logged across different DeepSpeed ZeRO stages and DDP setups by performing an all-reduce (average) operation.
  • New Utility Method: Introduced _get_reduced_grad_norm_for_logging in swift/trainers/mixin.py to conditionally reduce gradient norms across processes.
  • Logging Integration: Modified the _maybe_log_save_evaluate method to utilize the new gradient norm reduction logic for accurate logging when transformers version is 4.38 or higher.
  • Comprehensive Unit Tests: Added a dedicated test file tests/train/test_grad_norm_reduce.py to validate the behavior of the new gradient norm reduction utility under various scenarios.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • swift/trainers/mixin.py
    • Imported get_dist_setting for distributed environment checks.
    • Added a new method _get_reduced_grad_norm_for_logging to perform an all-reduce (average) on grad_norm across processes for consistent logging, specifically for non-ZeRO-3 distributed setups.
    • Updated the _maybe_log_save_evaluate method to use _get_reduced_grad_norm_for_logging when logging grad_norm for transformers versions 4.38 and above.
  • tests/train/test_grad_norm_reduce.py
    • Added a new test file containing unit tests for the _get_reduced_grad_norm_for_logging method, covering cases like None input, float input, single-process tensor, distributed ZeRO-3 (no reduce), and distributed ZeRO-0 (with all-reduce).
Activity
  • No human activity has been recorded for this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@alphadl
Copy link
Author

alphadl commented Mar 14, 2026

Please take a look when you have time. This fixes abnormal grad_norm logging under GRPO LoRA + DeepSpeed ZeRO-0 by all-reducing grad_norm across processes when not ZeRO-3 (fixes #6815). cc @Jintao-Huang @hjh0119

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request fixes an issue with abnormally large grad_norm logging when using DeepSpeed ZeRO-0 by introducing a reduction step. The changes look solid, with a new utility function _get_reduced_grad_norm_for_logging to average the gradient norm across processes and comprehensive unit tests to cover various scenarios. My only suggestion is to improve the exception handling in the new function to log any errors during reduction, which will help in debugging potential issues instead of silently falling back to the old behavior.

Comment on lines +686 to +687
except Exception:
return grad_norm.item()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The broad except Exception: without logging can hide issues during gradient norm reduction. If an error occurs, it will silently fall back to using the un-reduced gradient norm, which could be misleading for monitoring. It's better to log the exception to make debugging easier.

Suggested change
except Exception:
return grad_norm.item()
except Exception as e:
logger.warning(f'Failed to reduce grad_norm for logging: {e}. Returning un-reduced value.')
return grad_norm.item()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant