Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions examples/pytorch/quantized_model_init/fully_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import transformer_engine.pytorch as te
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
from transformer_engine.pytorch.quantized_tensor import QuantizedTensor
from transformer_engine.common.recipe import MXFP8BlockScaling

# ── Configuration ────────────────────────────────────────────────────
HIDDEN_SIZE = 256
Expand Down Expand Up @@ -85,7 +86,9 @@ def main():
# avoiding the precision loss of dequantizing from FP8.
# We set DTYPE to float32 since these weights will actually be initialized as FP8,
# but we want to seed the optimizer states (which will be in FP32) with the FP32 values.
with te.quantized_model_init(enabled=True, preserve_high_precision_init_val=True):
with te.quantized_model_init(
recipe=MXFP8BlockScaling(), enabled=True, preserve_high_precision_init_val=True
):
model = torch.nn.Sequential(
*[
te.TransformerLayer(
Expand Down Expand Up @@ -154,7 +157,7 @@ def main():
for step in range(NUM_STEPS):
optimizer.zero_grad(set_to_none=True)

with te.autocast(enabled=True):
with te.autocast(enabled=True, recipe=MXFP8BlockScaling()):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Repeated recipe instantiation per call

MXFP8BlockScaling() is instantiated three separate times (once in quantized_model_init, once here in the training loop, and once in the post-checkpoint step). While MXFP8BlockScaling is currently stateless so this is functionally safe, the conventional pattern in TE examples is to create a single recipe object at the top and reuse it. That avoids any risk if the recipe gains per-instance state in the future and keeps the configuration change localized to one place.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

output = model(x)

loss = F.mse_loss(output, target)
Expand Down Expand Up @@ -187,7 +190,7 @@ def main():

# Verify training continues after checkpoint load.
optimizer.zero_grad(set_to_none=True)
with te.autocast(enabled=True):
with te.autocast(enabled=True, recipe=MXFP8BlockScaling()):
output = model(x)
loss = F.mse_loss(output, target)
loss.backward()
Expand Down