Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 66 additions & 7 deletions transformer_engine/common/recipe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,14 @@ class MMParams:

use_split_accumulator: bool = True

def __repr__(self) -> str:
cached = self.__dict__.get("_cached_repr")
if cached is not None:
return cached
result = f"MMParams(use_split_accumulator={self.use_split_accumulator})"
object.__setattr__(self, "_cached_repr", result)
return result
Comment on lines +63 to +69
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 _cached_repr stored outside declared dataclass fields

MMParams is @dataclass(frozen=True). Storing _cached_repr via object.__setattr__ bypasses the frozen guard correctly in CPython, but _cached_repr is not a declared dataclass field — it won't appear in dataclasses.fields(), dataclasses.asdict(), dataclasses.astuple(), or copy.replace(). If downstream code serializes or copies an MMParams instance, the cached repr would be lost silently. Documenting this with a comment or declaring it as field(init=False, repr=False, compare=False) would make the intent clearer. The same applies to QParams.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that this is why we're doing the funny accesses with __dict__. I agree that bypassing frozen=True is iffy, so I wonder if we could set _cached_repr in __post_init__? If the class is frozen, its repr must also be frozen and I don't see a benefit in lazy evaluation.



@dataclass(frozen=True)
class QParams:
Expand All @@ -77,20 +85,41 @@ class QParams:
fp4_2d_quantization: bool = False

def __repr__(self) -> str:
return (
cached = self.__dict__.get("_cached_repr")
if cached is not None:
return cached
result = (
f"Qparams(\npower_2_scale={self.power_2_scale},\n"
f"amax_epsilon={self.amax_epsilon},\n"
f"random_hadamard_transform={self.random_hadamard_transform},\n"
f"stochastic_rounding={self.stochastic_rounding},\n"
f"fp4_2d_quantization={self.fp4_2d_quantization}\n)"
)
object.__setattr__(self, "_cached_repr", result)
return result


class Recipe:
"""
Base recipe class.
"""

# Cached string representation. Lazily populated by ``__repr__`` in
# subclasses and invalidated by ``__setattr__`` whenever any attribute
# changes. This makes repeated ``str(recipe)`` calls (e.g. on the hot
# path in ``FP8GlobalStateManager.get_unique_autocast_key``) essentially
# free after the first call.
_cached_repr: Optional[str] = None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Three problems:

  • _cached_repr is being set as a class attr, not an instance attr.
  • Accessing _cached_repr via __dict__ is non-standard and bug-prone.
  • Splitting the cache logic between the base class and child classes results in code duplication and more risk of bugs, especially if it involves non-standard __dict__ accesses.

What if we concentrated the caching logic in the base class:

class Recipe:

    def __init__(self) -> None:
        self._cached_repr: Optional[str] = None

    @abc.abstractmethod
    def _make_repr(self) -> str:
        ...

    def __repr__(self) -> str:
        if self._cached_repr is None:
            self._cached_repr = self._make_repr()
        return self._cached_repr

    ...

class DelayedScaling(Recipe):

    def _make_repr(self) -> str:
        return f"..."


def __setattr__(self, name: str, value: Any) -> None:
# Invalidate the cached repr on any attribute mutation. We avoid
# recursion by checking the name and always routing the actual
# assignment through ``object.__setattr__`` (which also works for
# pydantic frozen dataclasses that override ``__setattr__``).
if name != "_cached_repr":
object.__setattr__(self, "_cached_repr", None)
object.__setattr__(self, name, value)

@classmethod
def nvfp4(cls):
"""Whether the given recipe is NVFP4 1D block scaling."""
Expand Down Expand Up @@ -228,7 +257,10 @@ def __post_init__(self) -> None:
), "Delayed scaling only supports backward_override=None."

def __repr__(self) -> str:
return (
cached = self.__dict__.get("_cached_repr")
if cached is not None:
return cached
result = (
f"recipe_type={self.__class__.__name__}, "
f"margin={self.margin}, "
f"format={str(self.fp8_format).split('.')[1]}, "
Expand All @@ -238,6 +270,8 @@ def __repr__(self) -> str:
f"fp8_mha={self.fp8_mha}, "
f"backward_override={self.backward_override}"
)
object.__setattr__(self, "_cached_repr", result)
return result


@dataclass()
Expand Down Expand Up @@ -276,7 +310,10 @@ def __post_init__(self) -> None:
), "NVTE_BACKWARD_OVERRIDE must be unset or one of: 'high_precision', 'dequantized'."

def __repr__(self) -> str:
return (
cached = self.__dict__.get("_cached_repr")
if cached is not None:
return cached
result = (
f"recipe_type={self.__class__.__name__}, "
f"format={str(self.fp8_format).split('.')[1]}, "
f"fp8_quant_fwd_inp={self.fp8_quant_fwd_inp}, "
Expand All @@ -289,6 +326,8 @@ def __repr__(self) -> str:
f"fp8_mha={self.fp8_mha}, "
f"backward_override={self.backward_override}"
)
object.__setattr__(self, "_cached_repr", result)
return result


@dataclass()
Expand Down Expand Up @@ -334,12 +373,17 @@ def __post_init__(self) -> None:
), "NVTE_BACKWARD_OVERRIDE must be unset or one of: 'high_precision', 'dequantized'."

def __repr__(self) -> str:
return (
cached = self.__dict__.get("_cached_repr")
if cached is not None:
return cached
result = (
f"recipe_type={self.__class__.__name__}, "
f"margin={self.margin}, "
f"format={str(self.fp8_format).split('.')[1]}, "
f"backward_override={self.backward_override}"
)
object.__setattr__(self, "_cached_repr", result)
return result


@dataclass()
Expand Down Expand Up @@ -415,7 +459,10 @@ def __post_init__(self) -> None:
), "NVTE_BACKWARD_OVERRIDE must be unset or one of: 'high_precision', 'dequantized'."

def __repr__(self) -> str:
return (
cached = self.__dict__.get("_cached_repr")
if cached is not None:
return cached
result = (
f"recipe_type={self.__class__.__name__}, "
f"format={str(self.fp8_format).split('.')[1]}, "
f"fp8_quant_fwd_inp={self.fp8_quant_fwd_inp}, "
Expand All @@ -431,6 +478,8 @@ def __repr__(self) -> str:
f"fp8_mha={self.fp8_mha}, "
f"backward_override={self.backward_override}"
)
object.__setattr__(self, "_cached_repr", result)
return result


@dataclass()
Expand Down Expand Up @@ -527,7 +576,10 @@ def __post_init__(self) -> None:
)

def __repr__(self) -> str:
return (
cached = self.__dict__.get("_cached_repr")
if cached is not None:
return cached
result = (
f"recipe_type={self.__class__.__name__}, "
f"fp4_format={str(self.fp4_format).split('.')[1]}, "
f"fp8_format={str(self.fp8_format).split('.')[1]}, "
Expand All @@ -538,6 +590,8 @@ def __repr__(self) -> str:
f"fp4_quant_fwd_weight={self.fp4_quant_fwd_weight}, "
f"fp4_quant_bwd_grad={self.fp4_quant_bwd_grad}, "
)
object.__setattr__(self, "_cached_repr", result)
return result


@dataclass()
Expand Down Expand Up @@ -584,8 +638,13 @@ def __post_init__(self) -> None:
), "NVTE_BACKWARD_OVERRIDE must be unset or one of: 'high_precision', 'dequantized'."

def __repr__(self) -> str:
return (
cached = self.__dict__.get("_cached_repr")
if cached is not None:
return cached
result = (
f"recipe_type={self.__class__.__name__}, "
f"qfactory={self.qfactory}, "
f"backward_override={self.backward_override}"
)
object.__setattr__(self, "_cached_repr", result)
return result
90 changes: 57 additions & 33 deletions transformer_engine/pytorch/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,9 +580,8 @@ def reduce_and_update_fp8_tensors(
amax_history, scale, get_fp8_max(recipe, forward), recipe
)

@classmethod
@staticmethod
def get_unique_autocast_key(
cls,
recipe: Optional[Recipe] = None,
group: Optional[dist_group_type] = None,
):
Expand All @@ -591,7 +590,13 @@ def get_unique_autocast_key(
Object identity is sufficient since autocast contexts never outlive a single
training session.
"""
return str((str(recipe), id(group) if group is not None else None))
# directly getting the cached repr is about 40 ns faster than str(recipe)
# on grace systems.
Comment on lines +593 to +594
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is good to mention in the PR description, but not that useful in the code itself. Profiling becomes outdated once we move on to the next architecture.

recipe_repr = recipe.__dict__.get("_cached_repr") if recipe is not None else None
if recipe_repr is None:
recipe_repr = str(recipe)
group_id = id(group) if group is not None else 0
return f"{recipe_repr}|{group_id}"
Comment on lines +595 to +599
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Key format change could produce ambiguous keys

The new key format f"{recipe_repr}|{group_id}" uses | as a separator without escaping. If a future recipe's __repr__ ever emits a | character, two distinct (recipe, group) pairs could map to the same string. The old str(tuple) format was unambiguous because it quoted the recipe repr. A safer pattern uses a separator that cannot appear in repr output, or encodes the parts deterministically.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
recipe_repr = recipe.__dict__.get("_cached_repr") if recipe is not None else None
if recipe_repr is None:
recipe_repr = str(recipe)
group_id = id(group) if group is not None else 0
return f"{recipe_repr}|{group_id}"
group_id = id(group) if group is not None else None
return f"recipe=({str(recipe)}),group={group_id}"


@classmethod
def autocast_enter(
Expand Down Expand Up @@ -805,14 +810,13 @@ def quantized_model_init(
qstate.high_precision_init_val = _high_precision_init_val


@contextmanager
def fp8_autocast(
enabled: bool = True,
calibrating: bool = False,
fp8_recipe: Optional[Recipe] = None,
fp8_group: Optional[dist_group_type] = None,
_graph: bool = False,
) -> None:
) -> "autocast":
"""
.. warning::

Expand All @@ -828,25 +832,16 @@ def fp8_autocast(
stacklevel=2,
)

# Call new implementation.
with autocast(
return autocast(
enabled=enabled,
calibrating=calibrating,
recipe=fp8_recipe,
amax_reduction_group=fp8_group,
_graph=_graph,
):
yield
)


@contextmanager
def autocast(
enabled: bool = True,
calibrating: bool = False,
recipe: Optional["Recipe"] = None,
amax_reduction_group: Optional["dist_group_type"] = None,
_graph: bool = False,
) -> None:
class autocast:
"""
Context manager for quantization schemes like FP8 or FP4.

Expand Down Expand Up @@ -885,24 +880,53 @@ def autocast(
are reduced at the end of each training step.
"""

if enabled:
check_recipe_support(recipe)
# Class-based context manager (instead of ``@contextmanager`` from contextlib)
# to avoid the ~0.5us / invocation overhead of contextlib's generator-driven
# ``GeneratorContextManager``. ``__slots__`` further avoids per-instance
# dict allocation.
Comment on lines +883 to +886
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we mentioning the context manager here? It makes sense for this PR, but once the code is merged it will be completely random. This comment should explain what we are doing with __slots__, and we should explain the custom context manager logic in __enter__ and __exit__.

__slots__ = (
"_enabled",
"_calibrating",
"_recipe",
"_amax_reduction_group",
"_graph",
"_fp8_state",
)

# Save current state so we always restore it on exit.
fp8_state = FP8GlobalStateManager.get_autocast_state()
def __init__(
self,
enabled: bool = True,
calibrating: bool = False,
recipe: Optional["Recipe"] = None,
amax_reduction_group: Optional["dist_group_type"] = None,
_graph: bool = False,
) -> None:
self._enabled = enabled
self._calibrating = calibrating
self._recipe = recipe
self._amax_reduction_group = amax_reduction_group
self._graph = _graph
self._fp8_state = None

def __enter__(self) -> "autocast":
if self._enabled:
check_recipe_support(self._recipe)
# Save current state so we always restore it on exit.
self._fp8_state = FP8GlobalStateManager.get_autocast_state()
FP8GlobalStateManager.autocast_enter(
enabled=self._enabled,
calibrating=self._calibrating,
fp8_recipe=self._recipe,
fp8_group=self._amax_reduction_group,
_graph=self._graph,
)
return self

FP8GlobalStateManager.autocast_enter(
enabled=enabled,
calibrating=calibrating,
fp8_recipe=recipe,
fp8_group=amax_reduction_group,
_graph=_graph,
)
try:
yield
finally:
FP8GlobalStateManager.set_autocast_state(fp8_state)
FP8GlobalStateManager.autocast_exit(enabled, _graph=_graph)
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
FP8GlobalStateManager.set_autocast_state(self._fp8_state)
FP8GlobalStateManager.autocast_exit(self._enabled, _graph=self._graph)
# Do not suppress exceptions.
return None
Comment on lines +911 to +929
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Nested reuse of the same instance silently corrupts state

The old generator-based implementation raised RuntimeError: generator already executing if you tried to enter the same context manager object twice concurrently. The new class-based implementation silently accepts nested reuse, but the second __enter__ call overwrites self._fp8_state with the state captured inside the first context, so the outer __exit__ restores the wrong state permanently.

ctx = autocast(enabled=True, recipe=recipe)
with ctx:           # _fp8_state = pre_context_state
    with ctx:       # _fp8_state = state_inside_first_block  ← overwrites!
        pass        # __exit__: restores state_inside_first_block
    # _fp8_state is now state_inside_first_block
# __exit__: restores state_inside_first_block, NOT pre_context_state  ← bug

Adding a guard in __enter__ would preserve the old safety behavior:

def __enter__(self) -> "autocast":
    if self._fp8_state is not None:
        raise RuntimeError("autocast context manager cannot be entered more than once concurrently")
    ...

Comment on lines +928 to +929
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: The function already returns None and the comment is trivially true (all Python outside of a try statement is not suppressing exceptions).

Suggested change
# Do not suppress exceptions.
return None



def _update_amax_history(amax_history: torch.Tensor) -> torch.Tensor:
Expand Down
Loading