guard fuser grad checks on non-leaf nodes#2919
guard fuser grad checks on non-leaf nodes#2919CarlosGomes98 wants to merge 8 commits intoNVIDIA:mainfrom
Conversation
There was a problem hiding this comment.
I'm not sure if this is really addressing the root cause of the issue. Two problems:
- We aren't actually protecting against setting
requires_gradon non-leaf nodes. We're just skippingrequires_gradlogic whentorch.is_grad_enabled() == True. - Do we even want to skip setting
requires_gradon non-leaf nodes? The backward expects grads from each of the outputs, so we needrequires_gradfor autograd to do the right thing.
I think the right solution is smarter logic when setting requires_grad_. Maybe something like:
x_requires_grad = fuser.first_op_requiring_backward < fuser._num_basic_ops
if x_requires_grad != x.requires_grad:
x = x.detach()
if x_requires_grad:
x.requires_grad_()
# Or maybe only detach if x is a non-leaf node?
# Need to check if the CPU overhead of checking
# is worth saving the CPU overhead of detaching.
...
return xAnother approach would be changing our ops to always return leaf nodes. For example, here is the forward pass of MakeExtraOutput:
This would be changed to:
out = input_.detach()
return out, [(out,)] Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt>
Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt>
Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt>
for more information, see https://pre-commit.ci Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt>
Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt>
Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt>
55f1e55 to
a67b89c
Compare
for more information, see https://pre-commit.ci
Greptile SummaryThis bug fix guards the two Confidence Score: 5/5Safe to merge — targeted, minimal bug fix with no regressions identified. The change is a small, well-scoped fix that correctly guards the two No files require special attention. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["OperationFuser.__call__()"] --> B{"is_grad_enabled?"}
B -- "True" --> C["args = (..., set_output_requires_grad=True, ...)"]
B -- "False" --> D["args = (..., set_output_requires_grad=False, ...)"]
C --> E["_OperationFuserAutogradFunction.apply(*args)"]
D --> F["_OperationFuserAutogradFunction.forward(None, *args)"]
E --> G["forward() — tensors detached by autograd machinery\nset_output_requires_grad=True\nrequires_grad_() calls are SAFE"]
F --> H["forward() — inner ops may produce non-leaf tensors\nset_output_requires_grad=False\nrequires_grad_() calls are SKIPPED"]
G --> I["return output tensor(s)"]
H --> I
Reviews (2): Last reviewed commit: "Revert cudnn-frontend submodule bump" | Re-trigger Greptile |
| input, | ||
| self, | ||
| basic_op_kwargs, | ||
| is_grad_enabled, # set_output_requires_grad |
There was a problem hiding this comment.
Implicit coupling between
is_grad_enabled and set_output_requires_grad
is_grad_enabled is passed directly as set_output_requires_grad, hardcoding an equivalence between the two concepts. The guard in forward is really about "are the output tensors leaf nodes (safe to mutate requires_grad_) or not", which happens to correlate with grad being enabled today. If a future caller ever needs grad enabled but the outputs are already non-leaf (or vice versa), the coupling breaks silently. A dedicated flag computed from actual tensor leaf-ness, or at minimum a local variable with an explanatory name, would make the intent more resilient:
# Outputs produced by inner ops in the no-grad path may be non-leaf
# tensors; setting requires_grad_ on non-leaf tensors raises a RuntimeError.
set_output_requires_grad = is_grad_enabledThis is already implicit in the existing comment, so at a minimum a short inline note at the call site explaining why is_grad_enabled serves as the proxy would help future maintainers.
Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt> Made-with: Cursor
|
/te-ci pytorch |
Description
Pass an explicit flag that controls whether the fuser forward pass sets requires_grad on outputs. This is required so that in no_grad mode we dont try to mutate this information on non-leaf nodes (which is not allowed by torch)
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: