-
Notifications
You must be signed in to change notification settings - Fork 718
CPU overhead optimizations for te autocast #2957
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -60,6 +60,14 @@ class MMParams: | |
|
|
||
| use_split_accumulator: bool = True | ||
|
|
||
| def __repr__(self) -> str: | ||
| cached = self.__dict__.get("_cached_repr") | ||
| if cached is not None: | ||
| return cached | ||
| result = f"MMParams(use_split_accumulator={self.use_split_accumulator})" | ||
| object.__setattr__(self, "_cached_repr", result) | ||
| return result | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class QParams: | ||
|
|
@@ -77,20 +85,41 @@ class QParams: | |
| fp4_2d_quantization: bool = False | ||
|
|
||
| def __repr__(self) -> str: | ||
| return ( | ||
| cached = self.__dict__.get("_cached_repr") | ||
| if cached is not None: | ||
| return cached | ||
| result = ( | ||
| f"Qparams(\npower_2_scale={self.power_2_scale},\n" | ||
| f"amax_epsilon={self.amax_epsilon},\n" | ||
| f"random_hadamard_transform={self.random_hadamard_transform},\n" | ||
| f"stochastic_rounding={self.stochastic_rounding},\n" | ||
| f"fp4_2d_quantization={self.fp4_2d_quantization}\n)" | ||
| ) | ||
| object.__setattr__(self, "_cached_repr", result) | ||
| return result | ||
|
|
||
|
|
||
| class Recipe: | ||
| """ | ||
| Base recipe class. | ||
| """ | ||
|
|
||
| # Cached string representation. Lazily populated by ``__repr__`` in | ||
| # subclasses and invalidated by ``__setattr__`` whenever any attribute | ||
| # changes. This makes repeated ``str(recipe)`` calls (e.g. on the hot | ||
| # path in ``FP8GlobalStateManager.get_unique_autocast_key``) essentially | ||
| # free after the first call. | ||
| _cached_repr: Optional[str] = None | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Three problems:
What if we concentrated the caching logic in the base class: class Recipe:
def __init__(self) -> None:
self._cached_repr: Optional[str] = None
@abc.abstractmethod
def _make_repr(self) -> str:
...
def __repr__(self) -> str:
if self._cached_repr is None:
self._cached_repr = self._make_repr()
return self._cached_repr
...
class DelayedScaling(Recipe):
def _make_repr(self) -> str:
return f"..." |
||
|
|
||
| def __setattr__(self, name: str, value: Any) -> None: | ||
| # Invalidate the cached repr on any attribute mutation. We avoid | ||
| # recursion by checking the name and always routing the actual | ||
| # assignment through ``object.__setattr__`` (which also works for | ||
| # pydantic frozen dataclasses that override ``__setattr__``). | ||
| if name != "_cached_repr": | ||
| object.__setattr__(self, "_cached_repr", None) | ||
| object.__setattr__(self, name, value) | ||
|
|
||
| @classmethod | ||
| def nvfp4(cls): | ||
| """Whether the given recipe is NVFP4 1D block scaling.""" | ||
|
|
@@ -228,7 +257,10 @@ def __post_init__(self) -> None: | |
| ), "Delayed scaling only supports backward_override=None." | ||
|
|
||
| def __repr__(self) -> str: | ||
| return ( | ||
| cached = self.__dict__.get("_cached_repr") | ||
| if cached is not None: | ||
| return cached | ||
| result = ( | ||
| f"recipe_type={self.__class__.__name__}, " | ||
| f"margin={self.margin}, " | ||
| f"format={str(self.fp8_format).split('.')[1]}, " | ||
|
|
@@ -238,6 +270,8 @@ def __repr__(self) -> str: | |
| f"fp8_mha={self.fp8_mha}, " | ||
| f"backward_override={self.backward_override}" | ||
| ) | ||
| object.__setattr__(self, "_cached_repr", result) | ||
| return result | ||
|
|
||
|
|
||
| @dataclass() | ||
|
|
@@ -276,7 +310,10 @@ def __post_init__(self) -> None: | |
| ), "NVTE_BACKWARD_OVERRIDE must be unset or one of: 'high_precision', 'dequantized'." | ||
|
|
||
| def __repr__(self) -> str: | ||
| return ( | ||
| cached = self.__dict__.get("_cached_repr") | ||
| if cached is not None: | ||
| return cached | ||
| result = ( | ||
| f"recipe_type={self.__class__.__name__}, " | ||
| f"format={str(self.fp8_format).split('.')[1]}, " | ||
| f"fp8_quant_fwd_inp={self.fp8_quant_fwd_inp}, " | ||
|
|
@@ -289,6 +326,8 @@ def __repr__(self) -> str: | |
| f"fp8_mha={self.fp8_mha}, " | ||
| f"backward_override={self.backward_override}" | ||
| ) | ||
| object.__setattr__(self, "_cached_repr", result) | ||
| return result | ||
|
|
||
|
|
||
| @dataclass() | ||
|
|
@@ -334,12 +373,17 @@ def __post_init__(self) -> None: | |
| ), "NVTE_BACKWARD_OVERRIDE must be unset or one of: 'high_precision', 'dequantized'." | ||
|
|
||
| def __repr__(self) -> str: | ||
| return ( | ||
| cached = self.__dict__.get("_cached_repr") | ||
| if cached is not None: | ||
| return cached | ||
| result = ( | ||
| f"recipe_type={self.__class__.__name__}, " | ||
| f"margin={self.margin}, " | ||
| f"format={str(self.fp8_format).split('.')[1]}, " | ||
| f"backward_override={self.backward_override}" | ||
| ) | ||
| object.__setattr__(self, "_cached_repr", result) | ||
| return result | ||
|
|
||
|
|
||
| @dataclass() | ||
|
|
@@ -415,7 +459,10 @@ def __post_init__(self) -> None: | |
| ), "NVTE_BACKWARD_OVERRIDE must be unset or one of: 'high_precision', 'dequantized'." | ||
|
|
||
| def __repr__(self) -> str: | ||
| return ( | ||
| cached = self.__dict__.get("_cached_repr") | ||
| if cached is not None: | ||
| return cached | ||
| result = ( | ||
| f"recipe_type={self.__class__.__name__}, " | ||
| f"format={str(self.fp8_format).split('.')[1]}, " | ||
| f"fp8_quant_fwd_inp={self.fp8_quant_fwd_inp}, " | ||
|
|
@@ -431,6 +478,8 @@ def __repr__(self) -> str: | |
| f"fp8_mha={self.fp8_mha}, " | ||
| f"backward_override={self.backward_override}" | ||
| ) | ||
| object.__setattr__(self, "_cached_repr", result) | ||
| return result | ||
|
|
||
|
|
||
| @dataclass() | ||
|
|
@@ -527,7 +576,10 @@ def __post_init__(self) -> None: | |
| ) | ||
|
|
||
| def __repr__(self) -> str: | ||
| return ( | ||
| cached = self.__dict__.get("_cached_repr") | ||
| if cached is not None: | ||
| return cached | ||
| result = ( | ||
| f"recipe_type={self.__class__.__name__}, " | ||
| f"fp4_format={str(self.fp4_format).split('.')[1]}, " | ||
| f"fp8_format={str(self.fp8_format).split('.')[1]}, " | ||
|
|
@@ -538,6 +590,8 @@ def __repr__(self) -> str: | |
| f"fp4_quant_fwd_weight={self.fp4_quant_fwd_weight}, " | ||
| f"fp4_quant_bwd_grad={self.fp4_quant_bwd_grad}, " | ||
| ) | ||
| object.__setattr__(self, "_cached_repr", result) | ||
| return result | ||
|
|
||
|
|
||
| @dataclass() | ||
|
|
@@ -584,8 +638,13 @@ def __post_init__(self) -> None: | |
| ), "NVTE_BACKWARD_OVERRIDE must be unset or one of: 'high_precision', 'dequantized'." | ||
|
|
||
| def __repr__(self) -> str: | ||
| return ( | ||
| cached = self.__dict__.get("_cached_repr") | ||
| if cached is not None: | ||
| return cached | ||
| result = ( | ||
| f"recipe_type={self.__class__.__name__}, " | ||
| f"qfactory={self.qfactory}, " | ||
| f"backward_override={self.backward_override}" | ||
| ) | ||
| object.__setattr__(self, "_cached_repr", result) | ||
| return result | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -580,9 +580,8 @@ def reduce_and_update_fp8_tensors( | |||||||||||||||
| amax_history, scale, get_fp8_max(recipe, forward), recipe | ||||||||||||||||
| ) | ||||||||||||||||
|
|
||||||||||||||||
| @classmethod | ||||||||||||||||
| @staticmethod | ||||||||||||||||
| def get_unique_autocast_key( | ||||||||||||||||
| cls, | ||||||||||||||||
| recipe: Optional[Recipe] = None, | ||||||||||||||||
| group: Optional[dist_group_type] = None, | ||||||||||||||||
| ): | ||||||||||||||||
|
|
@@ -591,7 +590,13 @@ def get_unique_autocast_key( | |||||||||||||||
| Object identity is sufficient since autocast contexts never outlive a single | ||||||||||||||||
| training session. | ||||||||||||||||
| """ | ||||||||||||||||
| return str((str(recipe), id(group) if group is not None else None)) | ||||||||||||||||
| # directly getting the cached repr is about 40 ns faster than str(recipe) | ||||||||||||||||
| # on grace systems. | ||||||||||||||||
|
Comment on lines
+593
to
+594
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is good to mention in the PR description, but not that useful in the code itself. Profiling becomes outdated once we move on to the next architecture. |
||||||||||||||||
| recipe_repr = recipe.__dict__.get("_cached_repr") if recipe is not None else None | ||||||||||||||||
| if recipe_repr is None: | ||||||||||||||||
| recipe_repr = str(recipe) | ||||||||||||||||
| group_id = id(group) if group is not None else 0 | ||||||||||||||||
| return f"{recipe_repr}|{group_id}" | ||||||||||||||||
|
Comment on lines
+595
to
+599
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The new key format
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||
|
|
||||||||||||||||
| @classmethod | ||||||||||||||||
| def autocast_enter( | ||||||||||||||||
|
|
@@ -805,14 +810,13 @@ def quantized_model_init( | |||||||||||||||
| qstate.high_precision_init_val = _high_precision_init_val | ||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
| @contextmanager | ||||||||||||||||
| def fp8_autocast( | ||||||||||||||||
| enabled: bool = True, | ||||||||||||||||
| calibrating: bool = False, | ||||||||||||||||
| fp8_recipe: Optional[Recipe] = None, | ||||||||||||||||
| fp8_group: Optional[dist_group_type] = None, | ||||||||||||||||
| _graph: bool = False, | ||||||||||||||||
| ) -> None: | ||||||||||||||||
| ) -> "autocast": | ||||||||||||||||
| """ | ||||||||||||||||
| .. warning:: | ||||||||||||||||
|
|
||||||||||||||||
|
|
@@ -828,25 +832,16 @@ def fp8_autocast( | |||||||||||||||
| stacklevel=2, | ||||||||||||||||
| ) | ||||||||||||||||
|
|
||||||||||||||||
| # Call new implementation. | ||||||||||||||||
| with autocast( | ||||||||||||||||
| return autocast( | ||||||||||||||||
| enabled=enabled, | ||||||||||||||||
| calibrating=calibrating, | ||||||||||||||||
| recipe=fp8_recipe, | ||||||||||||||||
| amax_reduction_group=fp8_group, | ||||||||||||||||
| _graph=_graph, | ||||||||||||||||
| ): | ||||||||||||||||
| yield | ||||||||||||||||
| ) | ||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
| @contextmanager | ||||||||||||||||
| def autocast( | ||||||||||||||||
| enabled: bool = True, | ||||||||||||||||
| calibrating: bool = False, | ||||||||||||||||
| recipe: Optional["Recipe"] = None, | ||||||||||||||||
| amax_reduction_group: Optional["dist_group_type"] = None, | ||||||||||||||||
| _graph: bool = False, | ||||||||||||||||
| ) -> None: | ||||||||||||||||
| class autocast: | ||||||||||||||||
| """ | ||||||||||||||||
| Context manager for quantization schemes like FP8 or FP4. | ||||||||||||||||
|
|
||||||||||||||||
|
|
@@ -885,24 +880,53 @@ def autocast( | |||||||||||||||
| are reduced at the end of each training step. | ||||||||||||||||
| """ | ||||||||||||||||
|
|
||||||||||||||||
| if enabled: | ||||||||||||||||
| check_recipe_support(recipe) | ||||||||||||||||
| # Class-based context manager (instead of ``@contextmanager`` from contextlib) | ||||||||||||||||
| # to avoid the ~0.5us / invocation overhead of contextlib's generator-driven | ||||||||||||||||
| # ``GeneratorContextManager``. ``__slots__`` further avoids per-instance | ||||||||||||||||
| # dict allocation. | ||||||||||||||||
|
Comment on lines
+883
to
+886
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are we mentioning the context manager here? It makes sense for this PR, but once the code is merged it will be completely random. This comment should explain what we are doing with |
||||||||||||||||
| __slots__ = ( | ||||||||||||||||
| "_enabled", | ||||||||||||||||
| "_calibrating", | ||||||||||||||||
| "_recipe", | ||||||||||||||||
| "_amax_reduction_group", | ||||||||||||||||
| "_graph", | ||||||||||||||||
| "_fp8_state", | ||||||||||||||||
| ) | ||||||||||||||||
|
|
||||||||||||||||
| # Save current state so we always restore it on exit. | ||||||||||||||||
| fp8_state = FP8GlobalStateManager.get_autocast_state() | ||||||||||||||||
| def __init__( | ||||||||||||||||
| self, | ||||||||||||||||
| enabled: bool = True, | ||||||||||||||||
| calibrating: bool = False, | ||||||||||||||||
| recipe: Optional["Recipe"] = None, | ||||||||||||||||
| amax_reduction_group: Optional["dist_group_type"] = None, | ||||||||||||||||
| _graph: bool = False, | ||||||||||||||||
| ) -> None: | ||||||||||||||||
| self._enabled = enabled | ||||||||||||||||
| self._calibrating = calibrating | ||||||||||||||||
| self._recipe = recipe | ||||||||||||||||
| self._amax_reduction_group = amax_reduction_group | ||||||||||||||||
| self._graph = _graph | ||||||||||||||||
| self._fp8_state = None | ||||||||||||||||
|
|
||||||||||||||||
| def __enter__(self) -> "autocast": | ||||||||||||||||
| if self._enabled: | ||||||||||||||||
| check_recipe_support(self._recipe) | ||||||||||||||||
| # Save current state so we always restore it on exit. | ||||||||||||||||
| self._fp8_state = FP8GlobalStateManager.get_autocast_state() | ||||||||||||||||
| FP8GlobalStateManager.autocast_enter( | ||||||||||||||||
| enabled=self._enabled, | ||||||||||||||||
| calibrating=self._calibrating, | ||||||||||||||||
| fp8_recipe=self._recipe, | ||||||||||||||||
| fp8_group=self._amax_reduction_group, | ||||||||||||||||
| _graph=self._graph, | ||||||||||||||||
| ) | ||||||||||||||||
| return self | ||||||||||||||||
|
|
||||||||||||||||
| FP8GlobalStateManager.autocast_enter( | ||||||||||||||||
| enabled=enabled, | ||||||||||||||||
| calibrating=calibrating, | ||||||||||||||||
| fp8_recipe=recipe, | ||||||||||||||||
| fp8_group=amax_reduction_group, | ||||||||||||||||
| _graph=_graph, | ||||||||||||||||
| ) | ||||||||||||||||
| try: | ||||||||||||||||
| yield | ||||||||||||||||
| finally: | ||||||||||||||||
| FP8GlobalStateManager.set_autocast_state(fp8_state) | ||||||||||||||||
| FP8GlobalStateManager.autocast_exit(enabled, _graph=_graph) | ||||||||||||||||
| def __exit__(self, exc_type, exc_val, exc_tb) -> None: | ||||||||||||||||
| FP8GlobalStateManager.set_autocast_state(self._fp8_state) | ||||||||||||||||
| FP8GlobalStateManager.autocast_exit(self._enabled, _graph=self._graph) | ||||||||||||||||
| # Do not suppress exceptions. | ||||||||||||||||
| return None | ||||||||||||||||
|
Comment on lines
+911
to
+929
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The old generator-based implementation raised ctx = autocast(enabled=True, recipe=recipe)
with ctx: # _fp8_state = pre_context_state
with ctx: # _fp8_state = state_inside_first_block ← overwrites!
pass # __exit__: restores state_inside_first_block
# _fp8_state is now state_inside_first_block
# __exit__: restores state_inside_first_block, NOT pre_context_state ← bugAdding a guard in def __enter__(self) -> "autocast":
if self._fp8_state is not None:
raise RuntimeError("autocast context manager cannot be entered more than once concurrently")
...
Comment on lines
+928
to
+929
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: The function already returns
Suggested change
|
||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
| def _update_amax_history(amax_history: torch.Tensor) -> torch.Tensor: | ||||||||||||||||
|
|
||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_cached_reprstored outside declared dataclass fieldsMMParamsis@dataclass(frozen=True). Storing_cached_reprviaobject.__setattr__bypasses the frozen guard correctly in CPython, but_cached_repris not a declared dataclass field — it won't appear indataclasses.fields(),dataclasses.asdict(),dataclasses.astuple(), orcopy.replace(). If downstream code serializes or copies anMMParamsinstance, the cached repr would be lost silently. Documenting this with a comment or declaring it asfield(init=False, repr=False, compare=False)would make the intent clearer. The same applies toQParams.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see that this is why we're doing the funny accesses with
__dict__. I agree that bypassingfrozen=Trueis iffy, so I wonder if we could set_cached_reprin__post_init__? If the class is frozen, its repr must also be frozen and I don't see a benefit in lazy evaluation.