Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions tests/pytorch/distributed/run_numerics_exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,31 +60,34 @@ def get_nvfp4_quantizer_factory():
"""

def factory(role):
if role == "linear_input":
if ":" not in role:
raise ValueError(f"Invalid role: {role}, expected format: '<bucket>:<scope>'")
bucket, _ = role.split(":", 1)
if bucket == "input":
return quantization_nvfp4.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(1, 16),
pow_2_scales=False,
with_rht=True, # RHT enabled for input
)
elif role == "linear_weight":
elif bucket == "weight":
return quantization_nvfp4.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(16, 16), # 2D quantization for weight
pow_2_scales=False,
with_rht=False,
)
elif role == "linear_output":
elif bucket == "output":
# Output quantization not used
return None
elif role == "linear_grad_output":
elif bucket == "grad_output":
return quantization_nvfp4.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(1, 16),
pow_2_scales=False,
with_rht=True, # RHT enabled for grad_output
)
elif role == "linear_grad_input":
elif bucket == "grad_input":
# Grad input quantization not used
return None
else:
Expand Down
13 changes: 8 additions & 5 deletions tests/pytorch/nvfp4/test_nvfp4_module_exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,31 +80,34 @@ def get_nvfp4_quantizer_factory(with_rht: bool = False, with_2d_quantization: bo
"""

def factory(role):
if role == "linear_input":
if ":" not in role:
raise ValueError(f"Invalid role: {role}, expected format: '<bucket>:<scope>'")
bucket, _ = role.split(":", 1)
if bucket == "input":
return quantization_nvfp4.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(1, 16),
pow_2_scales=False,
with_rht=with_rht,
)
elif role == "linear_weight":
elif bucket == "weight":
return quantization_nvfp4.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(16, 16) if with_2d_quantization else (1, 16),
pow_2_scales=False,
with_rht=False,
)
elif role == "linear_output":
elif bucket == "output":
# Output quantization not used
return None
elif role == "linear_grad_output":
elif bucket == "grad_output":
return quantization_nvfp4.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(1, 16),
pow_2_scales=False,
with_rht=with_rht,
)
elif role == "linear_grad_input":
elif bucket == "grad_input":
# Grad input quantization not used
return None
else:
Expand Down
55 changes: 35 additions & 20 deletions tests/pytorch/test_custom_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,12 @@ def test_custom_recipe_sanity(module_type):

# Single factory: map roles to quantizers
def quantizer_factory(role):
if role in ("linear_input", "linear_weight", "linear_output"):
if ":" not in role:
raise ValueError(f"Invalid role: {role}, expected format: '<bucket>:<scope>'")
bucket, _ = role.split(":", 1)
if bucket in ("input", "weight", "output"):
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda")
if role in ("linear_grad_output", "linear_grad_input"):
if bucket in ("grad_output", "grad_input"):
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda")
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda")

Expand Down Expand Up @@ -127,9 +130,12 @@ def test_custom_recipe_grouped_linear_sanity():
inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True)

def quantizer_factory(role):
if role in ("linear_input", "linear_weight", "linear_output"):
if ":" not in role:
raise ValueError(f"Invalid role: {role}, expected format: '<bucket>:<scope>'")
bucket, _ = role.split(":", 1)
if bucket in ("input", "weight", "output"):
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda")
if role in ("linear_grad_output", "linear_grad_input"):
if bucket in ("grad_output", "grad_input"):
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda")
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda")

Expand Down Expand Up @@ -189,9 +195,12 @@ def test_custom_recipe_matches_current_scaling():

# Custom: single factory returning quantizers per role to match Float8CurrentScaling
def quantizer_factory(role):
if role in ("linear_input", "linear_weight", "linear_output"):
if ":" not in role:
raise ValueError(f"Invalid role: {role}, expected format: '<bucket>:<scope>'")
bucket, _ = role.split(":", 1)
if bucket in ("input", "weight", "output"):
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda")
if role in ("linear_grad_output", "linear_grad_input"):
if bucket in ("grad_output", "grad_input"):
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda")
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda")

Expand Down Expand Up @@ -246,9 +255,12 @@ def test_custom_recipe_ops_linear_2_1_layout():
inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True)

def quantizer_factory(role):
if role in ("linear_input", "linear_weight", "linear_output"):
if ":" not in role:
raise ValueError(f"Invalid role: {role}, expected format: '<bucket>:<scope>'")
bucket, _ = role.split(":", 1)
if bucket in ("input", "weight", "output"):
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda")
if role in ("linear_grad_output", "linear_grad_input"):
if bucket in ("grad_output", "grad_input"):
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda")
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda")

Expand Down Expand Up @@ -278,19 +290,22 @@ def test_custom_recipe_factory_invocation_counts_and_cycling():

# Counters per role
counts = {
"linear_input": 0,
"linear_weight": 0,
"linear_output": 0,
"linear_grad_output": 0,
"linear_grad_input": 0,
"input:linear": 0,
"weight:linear": 0,
"output:linear": 0,
"grad_output:linear": 0,
"grad_input:linear": 0,
}

def quantizer_factory(role):
if role in counts:
counts[role] += 1
if role in ("linear_input", "linear_weight", "linear_output"):
if ":" not in role:
raise ValueError(f"Invalid role: {role}, expected format: '<bucket>:<scope>'")
bucket, _ = role.split(":", 1)
if bucket in ("input", "weight", "output"):
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device=torch.device("cuda"))
if role in ("linear_grad_output", "linear_grad_input"):
if bucket in ("grad_output", "grad_input"):
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device=torch.device("cuda"))
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device=torch.device("cuda"))

Expand All @@ -304,11 +319,11 @@ def quantizer_factory(role):
loss.backward()

# Single GEMM: forward should request input, weight, output; backward grad_output, grad_input
assert counts["linear_input"] == 1
assert counts["linear_weight"] == 1
assert counts["linear_output"] == 1
assert counts["linear_grad_output"] == 1
assert counts["linear_grad_input"] == 1
assert counts["input:linear"] == 1
assert counts["weight:linear"] == 1
assert counts["output:linear"] == 1
assert counts["grad_output:linear"] == 1
assert counts["grad_input:linear"] == 1


def test_factories_return_distinct_instances_and_buffers():
Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/common/recipe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,8 +503,8 @@ class CustomRecipe(Recipe):
Where `role` is one of the following strings for e.g. te.Linear
(stable public contract):

- forward: "linear_input", "linear_weight", "linear_output"
- backward: "linear_grad_output", "linear_grad_input"
- forward: "input:linear", "weight:linear", "output:linear"
- backward: "grad_output:linear", "grad_input:linear"
"""

qfactory: Callable[..., Any]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,13 @@ def current_scaling_ref_quantizer_factory(role):
with autocast(recipe=custom_recipe):
output = model(input)
"""
if role in ("linear_input", "linear_weight"):
if ":" not in role:
raise ValueError(f"Invalid role: {role}, expected format: '<bucket>:<scope>'")
bucket, _ = role.split(":", 1)

if bucket in ("input", "weight"):
dtype = torch.float8_e4m3fn
elif role in ("linear_output", "linear_grad_output"):
elif bucket in ("output", "grad_output"):
dtype = torch.float8_e5m2
else:
return None
Expand Down
10 changes: 7 additions & 3 deletions transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,25 @@ def nvfp4_ref_rht_2d_quantizer_factory(role):
with autocast(fp8_recipe=custom_recipe):
output = model(input)
"""
if role == "linear_input":
if ":" not in role:
raise ValueError(f"Invalid role: {role}, expected format: '<bucket>:<scope>'")
bucket, _ = role.split(":", 1)

if bucket == "input":
return NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(1, 16),
pow_2_scales=False,
with_rht=True,
)
if role == "linear_weight":
if bucket == "weight":
return NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(16, 16),
pow_2_scales=False,
with_rht=False,
)
if role == "linear_grad_output":
if bucket == "grad_output":
return NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(1, 16),
Expand Down
19 changes: 19 additions & 0 deletions transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,15 +732,34 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
num_fp8_tensors = self.fp8_meta["num_gemms"] * 3 if fwd else self.fp8_meta["num_gemms"] * 2

# Initialize recipe state and quantizers
roles = self.get_quantizer_roles(fwd=fwd, num_quantizers=num_fp8_tensors)
if roles is not None:
assert (
len(roles) == num_fp8_tensors
), f"Recipe roles must match number of quantizers ({len(roles)=} vs {num_fp8_tensors=})"
recipe_state = RecipeState.create(
recipe,
mode=("forward" if fwd else "backward"),
num_quantizers=num_fp8_tensors,
roles=roles,
)

self.fp8_meta[fp8_meta_tensor_key] = recipe_state
self.quantizers[fp8_meta_tensor_key] = recipe_state.make_quantizers()

def get_quantizer_roles(
self,
*,
fwd: bool,
num_quantizers: int,
) -> Optional[List[str]]:
"""Return an ordered list of role strings for quantizers.

The returned list must have length `num_quantizers`.
Returning `None` means "no explicit roles".
"""
return None

def _update_weight_quantizers(self) -> None:
"""Update the quantizers for the weight tensors."""
weight_tensors = self._get_weight_tensors()
Expand Down
16 changes: 16 additions & 0 deletions transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,22 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
if recipe.float8_current_scaling():
self._customize_quantizers_float8_current_scaling(fwd, recipe)

def get_quantizer_roles(
self,
*,
fwd: bool,
num_quantizers: int,
) -> Optional[List[str]]:
"""Role strings for quantizers used by `GroupedLinear`.

For grouped GEMMs we repeat the same pattern for each GEMM in order.
"""
if fwd:
base = ("input:grouped_linear", "weight:grouped_linear", "output:grouped_linear")
else:
base = ("grad_output:grouped_linear", "grad_input:grouped_linear")
return [base[i % len(base)] for i in range(num_quantizers)]

def reset_parameters(self, defer_init=False):
super().reset_parameters(defer_init=defer_init)

Expand Down
20 changes: 20 additions & 0 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -1412,6 +1412,26 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
elif recipe.nvfp4():
self._customize_quantizers_nvfp4(fwd, recipe)

def get_quantizer_roles(
self,
*,
fwd: bool,
num_quantizers: int,
) -> Optional[List[str]]:
"""Role strings for quantizers used by `LayerNormLinear`."""
if fwd:
base = (
"input:layernorm_linear",
"weight:layernorm_linear",
"output:layernorm_linear",
)
else:
base = (
"grad_output:layernorm_linear",
"grad_input:layernorm_linear",
)
return [base[i % len(base)] for i in range(num_quantizers)]

def reset_layer_norm_parameters(self) -> None:
"""Init LN params"""
warnings.warn(
Expand Down
13 changes: 13 additions & 0 deletions transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1968,6 +1968,19 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
elif recipe.nvfp4():
self._customize_quantizers_nvfp4(fwd, recipe)

def get_quantizer_roles(
self,
*,
fwd: bool,
num_quantizers: int,
) -> Optional[List[str]]:
"""Role strings for quantizers used by `LayerNormMLP`."""
if fwd:
base = ("input:layernorm_mlp", "weight:layernorm_mlp", "output:layernorm_mlp")
else:
base = ("grad_output:layernorm_mlp", "grad_input:layernorm_mlp")
return [base[i % len(base)] for i in range(num_quantizers)]

def reset_layer_norm_parameters(self) -> None:
"""Init LN params"""
warnings.warn(
Expand Down
13 changes: 13 additions & 0 deletions transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -1308,6 +1308,19 @@ def __init__(
if name in self.weight_names or name in self.bias_names:
param.skip_backward_post_hook = True

def get_quantizer_roles(
self,
*,
fwd: bool,
num_quantizers: int,
) -> Optional[List[str]]:
"""Role strings for quantizers used by `Linear`."""
if fwd:
base = ("input:linear", "weight:linear", "output:linear")
else:
base = ("grad_output:linear", "grad_input:linear")
return [base[i % len(base)] for i in range(num_quantizers)]

def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
"""Init scales and amaxes for fwd | bwd."""
super().set_meta_tensor(fwd, recipe)
Expand Down
11 changes: 11 additions & 0 deletions transformer_engine/pytorch/ops/basic/basic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,17 @@ def num_quantizers(self, mode: str) -> int:
return 1
return 0

def get_quantizer_roles(self, mode: str) -> Optional[list[str]]:
if mode == "forward":
# BasicLinear owns input and weight quantizers.
# Output quantizer is provided by the next op (as its input quantizer).
return ["input:linear", "weight:linear"]
if mode == "backward":
# BasicLinear owns grad_output quantizer.
# Grad_input quantizer is provided by the previous op (as its grad_output quantizer).
return ["grad_output:linear"]
return None

def reset_parameters(self) -> None:
"""Initialize parameter buffers and values"""

Expand Down
Loading
Loading