fix(zero): guard ds_grads_remaining#7904
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: d997d509c8
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| @torch.compiler.disable | ||
| def _post_backward_module_hook(module, inputs): | ||
| module.ds_grads_remaining = 0 | ||
| _ensure_ds_grads_remaining(module) |
There was a problem hiding this comment.
Reset ds_grads_remaining at each forward-pre hook
In _post_backward_module_hook, replacing the per-forward reset with _ensure_ds_grads_remaining(module) keeps any stale ds_grads_remaining value from previous iterations instead of clearing it. This counter is decremented in PostBackwardFunctionModule.backward, and the nearby comment already notes some backward hooks may not trigger in edge cases; when that happens, a leftover positive count now carries into the next forward, so remaining can stop reaching 0 and post_sub_module_backward_function is skipped indefinitely, which can block parameter release and lead to persistent memory growth.
Useful? React with 👍 / 👎.
|
@ailuntz thanks for the PR. Given that large number of modifications this flag has had in the past, can you please explain the issue that triggered this PR? |
Summary
ds_grads_remainingto avoid unexpected errors in edge cases.Testing