Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 212 additions & 0 deletions tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from collections.abc import Iterable
import io
import math
import random
from typing import Optional

import pytest
Expand Down Expand Up @@ -1924,6 +1925,217 @@ def test_dropout(
abs(z_score) < 2.5758
), f"Number of zeros is outside 99% confidence interval ({prob=}, {prob_observed=})"

@pytest.mark.parametrize("bias", (False, True))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("quantized_compute", (False, True))
@pytest.mark.parametrize("quantized_weight", (False, True))
@pytest.mark.parametrize("input_requires_grad", (False, True))
@pytest.mark.parametrize("weight_requires_grad", (False, True))
def test_grouped_linear(
self,
*,
group_size: int = 4,
bias: bool,
weight_shape: tuple[int, int] = (128, 128),
split_alignment: int = 128,
dtype: torch.dtype,
device: torch.device = "cuda",
quantization: Optional[str],
quantized_compute: bool,
quantized_weight: bool,
input_requires_grad: bool,
weight_requires_grad: bool,
) -> None:
"""Grouped GEMM"""

# Split sizes
split_sizes = [split_alignment * i for i in range(group_size)]
random.shuffle(split_sizes)
split_sizes = torch.tensor(split_sizes, dtype=torch.int, device="cpu")

# Make input and weight shapes consistent
out_features, in_features = weight_shape
in_shape = (split_sizes.sum().item(), in_features)
out_shape = (in_shape[0], out_features)

# Skip invalid configurations
maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
maybe_skip_quantization(quantization, dims=out_shape)
if quantization is None and (quantized_compute or quantized_weight):
pytest.skip("Quantization scheme is not specified")
if quantization is not None and not (quantized_compute or quantized_weight):
pytest.skip("Quantization scheme is not used")
if quantization is not None and dtype not in (torch.bfloat16, torch.float16):
pytest.skip("Quantized group GEMM is only supported with BF16/FP16")

# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
requires_grad=input_requires_grad,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
ws_ref, ws_test = [], []
bs_ref, bs_test = [], []
for _ in range(group_size):
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
quantization=quantization,
test_dtype=dtype,
test_device=device,
requires_grad=weight_requires_grad,
)
b_ref, b_test = None, None
if bias:
b_ref, b_test = make_reference_and_test_tensors(
out_features,
test_dtype=dtype,
test_device=device,
requires_grad=weight_requires_grad,
)
ws_ref.append(w_ref)
ws_test.append(w_test)
bs_ref.append(b_ref)
bs_test.append(b_test)

# Plain PyTorch implementation
xs_ref = torch.split(x_ref, split_sizes.tolist())
ys_ref = []
for x, w, b in zip(xs_ref, ws_ref, bs_ref):
ys_ref.append(torch.nn.functional.linear(x, w, bias=b))
y_ref = torch.cat(ys_ref)
if input_requires_grad or weight_requires_grad:
y_ref.backward(dy_ref)

# Construct fusible operation
recipe = make_recipe(quantization)
with te.quantized_model_init(enabled=quantized_weight, recipe=recipe):
op = te_ops.GroupedLinear(
group_size,
in_features,
out_features,
bias=bias,
device=device,
dtype=dtype,
)
with torch.no_grad():
for group_idx in range(group_size):
getattr(op, f"weight{group_idx}").copy_(ws_test[group_idx])
if bias:
getattr(op, f"bias{group_idx}").copy_(bs_test[group_idx])
del ws_test, bs_test
for param in op.parameters():
param.requires_grad_(requires_grad=weight_requires_grad)

# Forward and backward pass with op
with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = op(x_test, split_sizes)
if input_requires_grad or weight_requires_grad:
y_test.backward(dy_test)

# Expected numerical error
tols = dtype_tols(dtype)
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute:
tols = quantization_tols(quantization)

# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
if input_requires_grad:
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
else:
assert x_test.grad is None
for group_idx in range(group_size):
w_test = getattr(op, f"weight{group_idx}")
if weight_requires_grad:
dw_test = w_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(dw_test, ws_ref[group_idx].grad, **tols)
else:
assert w_test.grad is None
if bias:
b_test = getattr(op, f"bias{group_idx}")
if weight_requires_grad:
db_test = b_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(db_test, bs_ref[group_idx].grad, **tols)
else:
assert b_test.grad is None

@pytest.mark.parametrize(
"input_shape,extra_input_shape",
(
((3, 4, 5), (3, 4, 5)),
((6, 7), ()),
((), (8, 9)),
((10, 11, 12), (11, 1)),
((1, 15), (13, 14, 15)),
),
)
@pytest.mark.parametrize("input_requires_grad", (False, True))
@pytest.mark.parametrize("extra_input_requires_grad", (False, True))
def test_multiply_extra_input(
self,
*,
input_shape: Iterable[int],
extra_input_shape: Iterable[int],
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
input_requires_grad: bool,
extra_input_requires_grad: bool,
) -> None:
"""Multiply two tensors"""

# Random data
x1_ref, x1_test = make_reference_and_test_tensors(
input_shape,
test_dtype=dtype,
test_device=device,
requires_grad=input_requires_grad,
)
x2_ref, x2_test = make_reference_and_test_tensors(
extra_input_shape,
test_dtype=dtype,
test_device=device,
requires_grad=extra_input_requires_grad,
)

# Plain PyTorch implementation
y_ref = x1_ref * x2_ref
if input_requires_grad or extra_input_requires_grad:
torch.square(y_ref).sum().backward()

# Implementation with fusible operation
op = te_ops.MultiplyExtraInput()
y_test = op(x1_test, x2_test)
if input_requires_grad or extra_input_requires_grad:
torch.square(y_test).sum().backward()

# Check results
tols = dtype_tols(dtype)
y_test = y_test.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
if input_requires_grad:
dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(dx1_test, x1_ref.grad, **tols)
else:
assert x1_test.grad is None
if extra_input_requires_grad:
dx2_test = x2_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(dx2_test, x2_ref.grad, **tols)
else:
assert x2_test.grad is None


class TestFusedOps:
"""Tests for fused operations"""
Expand Down
2 changes: 2 additions & 0 deletions transformer_engine/pytorch/ops/basic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@
from .bias import Bias
from .constant_scale import ConstantScale
from .dropout import Dropout
from .grouped_linear import GroupedLinear
from .identity import Identity
from .l2normalization import L2Normalization
from .layer_norm import LayerNorm
from .make_extra_output import MakeExtraOutput
from .multiply_extra_input import MultiplyExtraInput
from .quantize import Quantize
from .reduce_scatter import ReduceScatter
from .reshape import Reshape
Expand Down
Loading
Loading