Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
322 changes: 318 additions & 4 deletions tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2329,13 +2329,13 @@ def test_backward_activation_bias(
backward_ops = model._module_groups[0]._backward_ops
if with_quantization:
assert len(backward_ops) == 2
assert isinstance(backward_ops[0][0], BackwardActivationBias)
assert isinstance(backward_ops[1][0], te_ops.Quantize)
assert isinstance(backward_ops[0][0], te_ops.Quantize)
assert isinstance(backward_ops[1][0], BackwardActivationBias)
else:
assert len(backward_ops) == 3
assert isinstance(backward_ops[0][0], act_type)
assert isinstance(backward_ops[0][0], te_ops.Quantize)
assert isinstance(backward_ops[1][0], te_ops.Bias)
assert isinstance(backward_ops[2][0], te_ops.Quantize)
assert isinstance(backward_ops[2][0], act_type)

# Expected numerical error
tols = dtype_tols(dtype)
Expand Down Expand Up @@ -2930,3 +2930,317 @@ def to_cpu(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
if bias:
torch.testing.assert_close(to_cpu(ffn1.bias.grad), b1_ref.grad, **tols)
torch.testing.assert_close(to_cpu(ffn2.bias.grad), b2_ref.grad, **tols)


class TestCustomOps:
"""Test with ops that are defined externally"""

def test_custom_basic_op(
self,
*,
shape: Iterable[int] = (7, 5),
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
) -> None:
"""Custom basic op"""

class CustomScaleOp(te.ops.BasicOperation):
"""Custom op that applies a learnable scale"""

def __init__(self) -> None:
super().__init__()
self.scale: torch.nn.Parameter
scale = torch.ones((), dtype=dtype, device=device)
scale = torch.nn.Parameter(scale)
self.register_parameter("scale", scale)

def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
ctx.save_for_backward(self.scale, input_)
return self.scale * input_

def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> torch.Tensor:
(
scale,
input_,
) = ctx.saved_tensors
grad_scale = torch.inner(input_.reshape(-1), grad_output.reshape(-1))
grad_scale = grad_scale.reshape(())
grad_input = scale * grad_output
return grad_input, (grad_scale,)

# Random data
x_ref, x_test = make_reference_and_test_tensors(
shape,
test_dtype=dtype,
test_device=device,
)
w_ref, w_test = make_reference_and_test_tensors(
(),
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)

# Plain PyTorch implementation
y_ref = w_ref * x_ref
y_ref.backward(dy_ref)

# Implementation with fusible operation
op = CustomScaleOp()
forward = te.ops.Sequential(te.ops.Identity(), op, te.ops.Identity())
with torch.no_grad():
op.scale.copy_(w_test)
del w_test
y_test = forward(x_test)
y_test.backward(dy_test)

# Check results
tols = dtype_tols(dtype)
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
dw_test = op.scale.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(dw_test, w_ref.grad, **tols)

def test_custom_forward_fused_op(
self,
*,
shape: Iterable[int] = (7, 11),
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
):
"""Custom fused op in forward pass"""

class CustomForwardLinearSiLU(te.ops.FusedOperation):
"""Custom fused op for GEMM + SiLU"""

_enabled = True

def __init__(self, *, linear, silu) -> None:
super().__init__((linear, silu))

def fuser_forward(
self,
basic_op_ctxs: list[OperationContext],
input_: torch.Tensor,
**unused,
) -> torch.Tensor:
weight = self.basic_ops[0].weight
dtype = weight.dtype
device = weight.device

# Perform compute on CPU, because why not?
x = input_.cpu()
w = weight.cpu()
y = torch.matmul(x, w.T)
z = torch.nn.functional.silu(y)
out = z.to(device=device)

# Save state for linear backward
linear_op_ctx = basic_op_ctxs[0]
linear_op_ctx.save_for_backward(input_, weight)
linear_op_ctx.with_quantized_compute = False
linear_op_ctx.input_quantizer = None
linear_op_ctx.weight_quantizer = None
linear_op_ctx.grad_output_quantizer = None
linear_op_ctx.grad_input_quantizer = None
linear_op_ctx.dtype = dtype
linear_op_ctx.input_requires_grad = True
linear_op_ctx.weight_requires_grad = True

# Save state for SiLU backward
silu_op_ctx = basic_op_ctxs[1]
silu_op_ctx.save_for_backward(y.to(device=device))
silu_op_ctx.dtype = dtype
silu_op_ctx.prev_op_grad_output_quantizer = None

return out, [(), ()]

@staticmethod
def fuse_ops(
ops: list[FusibleOperation],
**unused,
) -> list[FusibleOperation]:
"""Apply fusion the first time this function is called"""
if CustomForwardLinearSiLU._enabled:
CustomForwardLinearSiLU._enabled = False
op = CustomForwardLinearSiLU(linear=ops[0], silu=ops[1])
return [op] + ops[2:]
return ops

# Random data
x_ref, x_test = make_reference_and_test_tensors(
shape,
test_dtype=dtype,
test_device=device,
)
w_ref, w_test = make_reference_and_test_tensors(
(shape[-1], shape[-1]),
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)

# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref)
y_ref = torch.nn.functional.silu(y_ref)
y_ref.backward(dy_ref)

# Implementation with fusible operation
te.ops.register_forward_fusion(CustomForwardLinearSiLU.fuse_ops)
model = te.ops.Sequential(
te.ops.Linear(shape[-1], shape[-1], bias=False),
te.ops.SiLU(),
)
with torch.no_grad():
model[0].weight.copy_(w_test)
del w_test
y_test = model(x_test)
y_test.backward(dy_test)

# Check that forward operations have been fused
forward_ops = model._module_groups[0]._forward_ops
assert len(forward_ops) == 1
assert isinstance(forward_ops[0][0], CustomForwardLinearSiLU)

# Expected numerical error
tols = dtype_tols(dtype)
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM

# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
dw_test = model[0].weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(dw_test, w_ref.grad, **tols)

def test_custom_backward_fused_op(
self,
*,
shape: Iterable[int] = (13, 5),
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
):
"""Custom fused op in backward pass"""

class CustomBackwardLinearScale(te.ops.FusedOperation):
"""Custom fused op for backward linear + scale"""

_enabled: bool = True

def __init__(self, *, scale, linear) -> None:
super().__init__((scale, linear))

def fuser_backward(
self,
basic_op_ctxs: list[OperationContext],
grad_output: torch.Tensor,
**unused,
) -> torch.Tensor:

# Load state from linear forward
linear_op_ctx = basic_op_ctxs[1]
x, w = linear_op_ctx.saved_tensors
dtype = linear_op_ctx.dtype
device = w.device

# Perform compute in FP64 and apply scale before dgrad
# GEMM instead of after
scale = self.basic_ops[0].scale
dy = grad_output.double()
x = x.double()
w = w.double()
dx = torch.matmul(dy, scale * w)
dw = torch.matmul(dy.T, x)
dx = dx.to(dtype=dtype)
dw = dw.to(dtype=dtype)

return dx, [(), (dw,)], [(), ()]

@staticmethod
def fuse_ops(
ops: list[FusibleOperation],
**unused,
) -> list[FusibleOperation]:
"""Apply fusion the first time this function is called"""
if CustomBackwardLinearScale._enabled:
CustomBackwardLinearScale._enabled = False
op = CustomBackwardLinearScale(scale=ops[0], linear=ops[1])
return [op] + ops[2:]
return ops

# Random data
x_ref, x_test = make_reference_and_test_tensors(
shape,
test_dtype=dtype,
test_device=device,
)
w_ref, w_test = make_reference_and_test_tensors(
(shape[-1], shape[-1]),
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
scale = 1.234

# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(scale * x_ref, w_ref)
y_ref.backward(dy_ref)

# Implementation with fusible operation
te.ops.register_backward_fusion(CustomBackwardLinearScale.fuse_ops, prepend=True)
model = te.ops.Sequential(
te.ops.ConstantScale(scale),
te.ops.Linear(shape[-1], shape[-1], bias=False),
)
with torch.no_grad():
model[1].weight.copy_(w_test)
del w_test
y_test = model(x_test)
y_test.backward(dy_test)

# Check that forward operations have been fused
backward_ops = model._module_groups[0]._backward_ops
assert len(backward_ops) == 1
assert isinstance(backward_ops[0][0], CustomBackwardLinearScale)

# Expected numerical error
tols = dtype_tols(dtype)
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM

# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
dw_test = model[1].weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(dw_test, w_ref.grad, **tols)
10 changes: 6 additions & 4 deletions transformer_engine/pytorch/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
"""

from transformer_engine.pytorch.ops.basic import *
from transformer_engine.pytorch.ops.linear import Linear
from transformer_engine.pytorch.ops.op import FusibleOperation
from transformer_engine.pytorch.ops.sequential import Sequential
from .basic import *
from .fuser import register_backward_fusion, register_forward_fusion
from .linear import Linear
from .op import BasicOperation, FusedOperation, FusibleOperation
from .sequential import Sequential
from . import fused
Loading
Loading