Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
443 changes: 443 additions & 0 deletions tests/pytorch/test_torch_compile.py

Large diffs are not rendered by default.

113 changes: 111 additions & 2 deletions transformer_engine/pytorch/ops/basic/basic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@
from collections.abc import Callable, Iterable
import contextlib
import math
from typing import Any, Optional
from typing import Any, Optional, TYPE_CHECKING

import torch

if TYPE_CHECKING:
from ..compile_compat.tensor_info import TensorInfo, PseudoForwardResult

from ...cpp_extensions import general_gemm
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...distributed import (
Expand Down Expand Up @@ -1057,7 +1060,7 @@ def op_backward(
)

# Clear input tensor if possible
clear_tensor_data(x_local)
# clear_tensor_data(x_local)

# Megatron-LM wgrad fusion
# Note: Return dummy tensor for grad weight if needed.
Expand All @@ -1073,3 +1076,109 @@ def op_backward(
)

return grad_input, [grad_weight]

def pseudo_forward(
self,
input_info: "TensorInfo",
extra_inputs_info: tuple["TensorInfo", ...],
**kwargs,
) -> "PseudoForwardResult":
"""Compute forward metadata for BasicLinear.

Output shape: input shape with last dim changed from in_features to out_features.
Saves: x_local (input), w (weight) for backward.
"""
from ..compile_compat.tensor_info import TensorInfo, PseudoForwardResult

# Output shape: replace last dim with out_features
output_shape = input_info.shape[:-1] + (self.local_out_features,)
output_info = TensorInfo(
shape=output_shape,
dtype=input_info.dtype,
requires_grad=input_info.requires_grad,
)

# Tensors to save for backward: x_local and w
# x_local has same shape as input (possibly after sequence parallel gather)
# w has shape (out_features, in_features)
# NOTE: We always return the same structure regardless of requires_grad
# because torch.compile requires consistent output shapes from custom ops.
# At runtime, if requires_grad=False, we may save empty/dummy tensors.
tensors_to_save_info = []
tensor_sources = [] # -1 = new tensor, 0 = x, 1+ = params/extra_inputs

# Always add tensors to save (structure must be consistent for torch.compile)
if True: # was: if input_info.requires_grad:
# x_local shape depends on tensor parallel mode
if self.tensor_parallel_mode == "column" and self.sequence_parallel:
# After all-gather, first dim is multiplied by TP size
x_local_shape = (
input_info.shape[0] * self.tensor_parallel_size,
) + input_info.shape[1:]
else:
x_local_shape = input_info.shape

tensors_to_save_info.append(
TensorInfo(
shape=x_local_shape,
dtype=input_info.dtype,
requires_grad=False, # Saved tensors don't track grad
)
)
# x_local aliases the input only for the FIRST op in the pipeline.
# For subsequent ops, x_local is an intermediate output (new tensor).
# The fuser passes _is_first_op in kwargs to indicate this.
is_first_op = kwargs.get("_is_first_op", True)
if is_first_op:
# x_local may or may not alias input depending on quantization.
# In non-FP8 mode, x_local = input (alias). In FP8 mode, x_local is quantized (new).
# Conservatively mark as aliasing input (source=0) to avoid saving twice.
tensor_sources.append(0)
else:
# Intermediate op - x_local is previous op's output (new tensor, not alias)
tensor_sources.append(-1)

# Weight shape - this IS the weight parameter (params[0] for this op)
tensors_to_save_info.append(
TensorInfo(
shape=(self.local_out_features, self.local_in_features),
dtype=self.weight.dtype,
requires_grad=False,
)
)
# w is params[0], so source = 1 (offset by 1 because 0 = input x)
tensor_sources.append(1)

# Context data for backward reconstruction
# Note: dtype and quantizers are determined at backward time from op state
from ...quantization import FP8GlobalStateManager

# Get dtype (same logic as op_forward)
if torch.is_autocast_enabled():
dtype = torch.get_autocast_dtype("cuda")
else:
dtype = self.weight.dtype

ctx_data = {
"input_requires_grad": input_info.requires_grad,
"weight_requires_grad": (
self.weight.requires_grad if input_info.requires_grad else False
),
"num_saved_tensors": len(tensors_to_save_info),
"dtype": dtype,
"with_quantized_compute": FP8GlobalStateManager.is_fp8_enabled(),
# Note: Quantizers are NOT stored here (they're not picklable for torch.compile).
# They will be set in run_backward from the operation directly.
"input_quantizer": None,
"weight_quantizer": None,
"grad_output_quantizer": None,
"grad_input_quantizer": None,
}

return PseudoForwardResult(
output_info=output_info,
tensors_to_save_info=tensors_to_save_info,
extra_outputs_info=[],
ctx_data=ctx_data,
tensor_sources=tensor_sources,
)
38 changes: 37 additions & 1 deletion transformer_engine/pytorch/ops/basic/bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
"""Fusible operation for bias."""

from __future__ import annotations
from typing import Optional
from typing import Optional, TYPE_CHECKING

import torch

if TYPE_CHECKING:
from ..compile_compat.tensor_info import TensorInfo, PseudoForwardResult
Comment on lines +12 to +13
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not import directly? compile_compat doesn't depend on anything in te.ops.basic, so if we import in the correct order then we shouldn't have circular dependencies.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a lot of poor-quality ai-generated code here, so it's not worth to review this PR in detail. I think we need to have some agreement on high-level design and I will reimplement it from scratch. Maybe I will elaborate more why this PR works the way it works tomorrow.


import transformer_engine_torch as tex
from ..op import BasicOperation, OperationContext
from ...utils import canonicalize_device, canonicalize_dtype
Expand Down Expand Up @@ -142,3 +145,36 @@ def op_backward(
else:
db = dy
return dy, (db,)

def pseudo_forward(
self,
input_info: "TensorInfo",
extra_inputs_info: tuple["TensorInfo", ...],
**kwargs,
) -> "PseudoForwardResult":
"""Compute forward metadata for Bias.

Bias is element-wise addition - output shape equals input shape.
No tensors saved for backward.
"""
from ..compile_compat.tensor_info import TensorInfo, PseudoForwardResult

# Output has same shape as input
output_info = TensorInfo(
shape=input_info.shape,
dtype=input_info.dtype,
requires_grad=input_info.requires_grad,
)

# Note: grad_input_quantizer is set by the fuser based on prev_op
# For now, set to None (will be overwritten by run_backward if needed)
return PseudoForwardResult(
output_info=output_info,
tensors_to_save_info=[],
extra_outputs_info=[],
ctx_data={
"num_saved_tensors": 0,
"grad_input_quantizer": None,
},
tensor_sources=[],
)
5 changes: 4 additions & 1 deletion transformer_engine/pytorch/ops/basic/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
"""Fusible operation for identity."""

from __future__ import annotations
from typing import Optional
from typing import Optional, TYPE_CHECKING

import torch

if TYPE_CHECKING:
from ..compile_compat.tensor_info import TensorInfo, PseudoForwardResult

from transformer_engine.pytorch.ops.op import (
BasicOperation,
OperationContext,
Expand Down
43 changes: 43 additions & 0 deletions transformer_engine/pytorch/ops/compile_compat/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""torch.compile compatibility module for Transformer Engine operations.

This module provides components to make te.ops work with
torch.compile(fullgraph=True) by using custom operators that wrap fusion logic.

Usage:
from transformer_engine.pytorch.ops.compile_compat import TorchCompileCompatibleFuser

# Create fuser OUTSIDE compiled region
fuser = TorchCompileCompatibleFuser([op1, op2, op3])

@torch.compile(fullgraph=True)
def forward(x):
return fuser(x)
"""

# Import and re-export public API
# Note: NoneRecipe is used as sentinel when recipe is None (FP8 disabled)
from .tensor_info import TensorInfo, TensorInfoList, PseudoForwardResult
from .opaque_kwargs import OpaqueKwargs
from .ops_container import OpsContainer
from .operators import fused_forward_impl, fused_backward_impl, NoneRecipe, NONE_RECIPE
from .fuser import TorchCompileCompatibleFuser

__all__ = [
# Main API
"TorchCompileCompatibleFuser",
# Supporting classes
"TensorInfo",
"TensorInfoList",
"PseudoForwardResult",
"OpaqueKwargs",
"OpsContainer",
"NoneRecipe",
"NONE_RECIPE",
# Custom operators (for advanced usage)
"fused_forward_impl",
"fused_backward_impl",
]
137 changes: 137 additions & 0 deletions transformer_engine/pytorch/ops/compile_compat/fuser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""torch.compile compatible fuser for Transformer Engine operations."""

from __future__ import annotations
from typing import Any, Optional

import torch

from ...quantization import FP8GlobalStateManager
from ..op import BasicOperation, FusibleOperation

from .opaque_kwargs import opaque_kwargs_from_dicts
from .ops_container import OpsContainer
from .operators import fused_forward_impl, NONE_RECIPE


class TorchCompileCompatibleFuser:
"""Fuser for torch.compile(fullgraph=True) compatibility.

This class wraps a sequence of FusibleOperations and provides a callable
that works with torch.compile without graph breaks. The fusion logic
is hidden inside custom operators.

Usage:
ops = [LinearOp(...), BiasOp(...), ActivationOp(...)]
fuser = TorchCompileCompatibleFuser(ops)

@torch.compile(fullgraph=True)
def forward(x):
return fuser(x)

Note: The fuser must be created OUTSIDE the compiled region, as OpsContainer
is a reference-type opaque object.
"""

def __init__(self, ops: list[FusibleOperation]) -> None:
"""Initialize the fuser with a list of operations.

Args:
ops: List of FusibleOperation instances (can include FusedOperations)
"""
# Flatten to basic operations
basic_ops: list[BasicOperation] = []
for op in ops:
if op.is_fused_op:
basic_ops.extend(op.basic_ops)
else:
basic_ops.append(op)

# Create OpsContainer (outside compiled region)
self.ops_container = OpsContainer(basic_ops)

# Cache num_ops and default kwargs (avoid accessing these in compiled region)
self._num_ops = len(basic_ops)
self._default_kwargs_opaque = opaque_kwargs_from_dicts([{}] * len(basic_ops))

# Flatten parameters for autograd tracking
self._flat_params = [p for op in basic_ops for p in op.parameters()]

# Track extra inputs/outputs
self.num_extra_inputs = sum(op.num_extra_inputs for op in basic_ops)
self.num_extra_outputs = sum(op.num_extra_outputs for op in basic_ops)

# Keep reference to basic ops for module compatibility
self._basic_ops = basic_ops

def __call__(
self,
input: torch.Tensor,
*extra_inputs: torch.Tensor,
basic_op_kwargs: Optional[list[dict[str, Any]]] = None,
) -> torch.Tensor | tuple[torch.Tensor, ...]:
"""Apply the fused operations to input.

Args:
input: Input tensor
*extra_inputs: Extra tensor inputs for operations that need them
basic_op_kwargs: Optional list of kwargs dicts, one per basic operation

Returns:
Output tensor, or tuple of (output, *extra_outputs) if any operation
produces extra outputs.
"""
# Get recipe from global state
# Use NONE_RECIPE singleton when FP8 is disabled (cannot pass None to custom_op)
if FP8GlobalStateManager.is_fp8_enabled():
recipe = FP8GlobalStateManager.get_fp8_recipe()
else:
recipe = NONE_RECIPE

# Create OpaqueKwargs
# Use cached default kwargs to avoid accessing self._num_ops in compiled region
if basic_op_kwargs is None:
kwargs_opaque = self._default_kwargs_opaque
else:
kwargs_opaque = opaque_kwargs_from_dicts(basic_op_kwargs)

# Verify extra inputs count
if len(extra_inputs) != self.num_extra_inputs:
raise ValueError(
f"Expected {self.num_extra_inputs} extra inputs, got {len(extra_inputs)}"
)

# Call the custom op - returns [output, *non_aliased_tensors_to_save, *extra_outputs]
# Aliased tensors are NOT included (reconstructed in backward)
flat_result = fused_forward_impl(
input,
self.ops_container,
recipe,
kwargs_opaque,
self._flat_params,
list(extra_inputs),
)

# Parse flat result
output = flat_result[0]
# non_aliased_tensors_to_save are in the middle (handled by autograd), we skip them
# extra_outputs are at the end
num_extra_outputs = self.num_extra_outputs
if num_extra_outputs > 0:
extra_outputs = flat_result[-num_extra_outputs:]
return (output, *extra_outputs)
return output

def parameters(self):
"""Iterate over all parameters in the fused operations."""
return iter(self._flat_params)

def named_parameters(self, prefix: str = "", recurse: bool = True):
"""Iterate over named parameters."""
for idx, op in enumerate(self._basic_ops):
op_prefix = f"{prefix}op_{idx}." if prefix else f"op_{idx}."
for name, param in op.named_parameters(prefix="", recurse=recurse):
yield op_prefix + name, param
Loading