Skip to content

Conversation

@timmoon10
Copy link
Collaborator

Description

This PR adds the register_forward_fusion and register_backward_fusion functions to the op fuser API, allowing users to register custom fusions.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Add function to register custom op fusions
  • Refactor op fuser to have consistent op order in forward and backward pass
  • Refactor op fusion functions to avoid index bookkeeping
  • Add tests for user-defined ops

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Refactor fusion functions to remove index bookkeeping. Refactor fused ops to use consistent operation order.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10 timmoon10 requested review from ksivaman and pggPL January 14, 2026 08:28
@timmoon10 timmoon10 added the enhancement New feature or request label Jan 14, 2026
@timmoon10

This comment was marked as outdated.

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 14, 2026

Greptile Summary

This PR adds support for user-defined operation fusions by introducing register_forward_fusion and register_backward_fusion functions to the public API. The implementation includes a significant refactoring of the op fusion system to improve consistency and maintainability.

Key Changes:

  • Added register_forward_fusion() and register_backward_fusion() functions in fuser.py to allow users to register custom fusion functions
  • Refactored all fusion functions from standalone functions to static methods on the fused operation classes (e.g., fuse_forward_linear_bias_activation()ForwardLinearBiasActivation.fuse_forward_ops())
  • Simplified fusion function signatures to accept list[FusibleOperation] instead of list[tuple[FusibleOperation, list[int]]], removing index bookkeeping complexity
  • Fixed indexing bugs in backward fusion operations (backward_linear_add.py, backward_linear_scale.py) where context indices were incorrect
  • Ensured consistent operation order between forward and backward passes by having both work with the same basic operation sequence
  • Added comprehensive tests for custom basic operations, custom forward fusions, and custom backward fusions

Impact:
The refactoring provides a cleaner, more maintainable API for operation fusion while fixing several indexing bugs. Users can now define and register their own fusion patterns, extending TransformerEngine's optimization capabilities.

Confidence Score: 5/5

  • This PR is safe to merge with no critical issues found
  • The PR implements a well-designed feature with comprehensive testing. The refactoring fixes existing indexing bugs in backward operations and significantly improves code maintainability by removing complex index bookkeeping. The new API is clean and follows good design patterns. All changes are backwards compatible since the old fusion functions were internal implementation details.
  • No files require special attention

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/fuser.py Added register_forward_fusion and register_backward_fusion functions to allow users to register custom op fusions. Refactored _fuse_ops method to have cleaner signature.
transformer_engine/pytorch/ops/init.py Exposed register_forward_fusion and register_backward_fusion APIs and additional operation classes for public use.
transformer_engine/pytorch/ops/fused/init.py Refactored to register fusion functions using new register_forward_fusion and register_backward_fusion APIs instead of direct function exports.
transformer_engine/pytorch/ops/fused/backward_linear_add.py Fixed context indexing bug (line 48), converted fusion function to static method, removed index bookkeeping from fusion logic.
tests/pytorch/test_fusible_ops.py Added comprehensive tests for custom user-defined ops including basic ops, forward fusions, and backward fusions. Updated existing test expectations for op ordering.

Sequence Diagram

sequenceDiagram
    participant User
    participant RegisterAPI as register_forward_fusion/<br/>register_backward_fusion
    participant OperationFuser
    participant FusionFunc as Custom Fusion Function
    participant FusedOp as Fused Operation

    User->>RegisterAPI: register custom fusion function
    RegisterAPI->>OperationFuser: append to forward_fusion_functions<br/>or backward_fusion_functions

    Note over User,FusedOp: During model execution

    User->>OperationFuser: __call__(input, ...)
    OperationFuser->>OperationFuser: maybe_fuse_ops()
    
    loop For each fusion function
        OperationFuser->>FusionFunc: func(ops, recipe=recipe)
        FusionFunc->>FusionFunc: scan ops with sliding window
        alt Pattern matches
            FusionFunc->>FusedOp: create fused op
            FusedOp-->>FusionFunc: fused operation
        else No match
            FusionFunc->>FusionFunc: shift window
        end
        FusionFunc-->>OperationFuser: updated ops list
    end

    OperationFuser->>OperationFuser: store _forward_ops and _backward_ops
    
    Note over OperationFuser,FusedOp: Forward pass
    
    loop For each forward op
        OperationFuser->>FusedOp: fuser_forward(ctxs, input, ...)
        FusedOp->>FusedOp: execute fused computation
        FusedOp->>FusedOp: save state to contexts
        FusedOp-->>OperationFuser: output
    end
    
    Note over OperationFuser,FusedOp: Backward pass
    
    loop For each backward op (reversed)
        OperationFuser->>FusedOp: fuser_backward(ctxs, grad_output, ...)
        FusedOp->>FusedOp: restore from contexts
        FusedOp->>FusedOp: execute fused backward
        FusedOp-->>OperationFuser: grad_input, grad_params
    end
    
    OperationFuser-->>User: outputs and gradients
Loading

greptile-apps[bot]

This comment was marked as outdated.

@timmoon10

This comment was marked as outdated.

greptile-apps[bot]

This comment was marked as outdated.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10

This comment was marked as outdated.

@timmoon10 timmoon10 closed this Jan 15, 2026
@timmoon10 timmoon10 reopened this Jan 15, 2026
@timmoon10

This comment was marked as outdated.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10

This comment was marked as outdated.

1 similar comment
@timmoon10
Copy link
Collaborator Author

/te-ci pytorch L1

@timmoon10 timmoon10 merged commit 7259276 into NVIDIA:main Jan 25, 2026
27 of 32 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant