-
Notifications
You must be signed in to change notification settings - Fork 613
[PyTorch] Support user-defined op fusions #2597
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Refactor fusion functions to remove index bookkeeping. Refactor fused ops to use consistent operation order. Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
This comment was marked as outdated.
This comment was marked as outdated.
Greptile SummaryThis PR adds support for user-defined operation fusions by introducing Key Changes:
Impact: Confidence Score: 5/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant RegisterAPI as register_forward_fusion/<br/>register_backward_fusion
participant OperationFuser
participant FusionFunc as Custom Fusion Function
participant FusedOp as Fused Operation
User->>RegisterAPI: register custom fusion function
RegisterAPI->>OperationFuser: append to forward_fusion_functions<br/>or backward_fusion_functions
Note over User,FusedOp: During model execution
User->>OperationFuser: __call__(input, ...)
OperationFuser->>OperationFuser: maybe_fuse_ops()
loop For each fusion function
OperationFuser->>FusionFunc: func(ops, recipe=recipe)
FusionFunc->>FusionFunc: scan ops with sliding window
alt Pattern matches
FusionFunc->>FusedOp: create fused op
FusedOp-->>FusionFunc: fused operation
else No match
FusionFunc->>FusionFunc: shift window
end
FusionFunc-->>OperationFuser: updated ops list
end
OperationFuser->>OperationFuser: store _forward_ops and _backward_ops
Note over OperationFuser,FusedOp: Forward pass
loop For each forward op
OperationFuser->>FusedOp: fuser_forward(ctxs, input, ...)
FusedOp->>FusedOp: execute fused computation
FusedOp->>FusedOp: save state to contexts
FusedOp-->>OperationFuser: output
end
Note over OperationFuser,FusedOp: Backward pass
loop For each backward op (reversed)
OperationFuser->>FusedOp: fuser_backward(ctxs, grad_output, ...)
FusedOp->>FusedOp: restore from contexts
FusedOp->>FusedOp: execute fused backward
FusedOp-->>OperationFuser: grad_input, grad_params
end
OperationFuser-->>User: outputs and gradients
|
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
This comment was marked as outdated.
This comment was marked as outdated.
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
This comment was marked as outdated.
This comment was marked as outdated.
This comment was marked as outdated.
This comment was marked as outdated.
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
This comment was marked as outdated.
This comment was marked as outdated.
1 similar comment
|
/te-ci pytorch L1 |
Description
This PR adds the
register_forward_fusionandregister_backward_fusionfunctions to the op fuser API, allowing users to register custom fusions.Type of change
Changes
Checklist: