Skip to content

Commit d1db6b7

Browse files
authored
Add CoreML-stable RMSNorm for llama eager paths (pytorch#19523) (pytorch#19523)
Differential Revision: D104862210 Pull Request resolved: pytorch#19523
1 parent a8cfe2b commit d1db6b7

2 files changed

Lines changed: 83 additions & 60 deletions

File tree

examples/apple/coreml/llama/llama_transformer.py

Lines changed: 1 addition & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import torch
1616
import torch.nn.functional as F
17-
from executorch.examples.models.llama.norm import RMSNorm
17+
from executorch.examples.models.llama.norm import RMSNorm, RMSNormCoreML # noqa: F401
1818

1919
from executorch.examples.models.llama.rope import (
2020
hf_apply_rotary_emb,
@@ -109,65 +109,6 @@ def __post_init__(self):
109109
self.head_dim = self.dim // self.n_heads
110110

111111

112-
class CoreMLRMSNorm(torch.nn.Module):
113-
def __init__(self, dim: int, eps: float = 1e-6):
114-
"""
115-
Initialize the RMSNorm normalization layer.
116-
117-
Args:
118-
dim (int): The dimension of the input tensor.
119-
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
120-
121-
Attributes:
122-
eps (float): A small value added to the denominator for numerical stability.
123-
weight (nn.Parameter): Learnable scaling parameter.
124-
125-
"""
126-
super().__init__()
127-
self.dim = dim
128-
self.eps = eps
129-
self.weight = nn.Parameter(torch.ones(dim))
130-
131-
def _norm(self, x):
132-
"""
133-
Apply the RMSNorm normalization to the input tensor.
134-
135-
Args:
136-
x (torch.Tensor): The input tensor.
137-
138-
Returns:
139-
torch.Tensor: The normalized tensor.
140-
141-
"""
142-
# CoreML ignores casts to FP32, so existing implementation of RMSNorm was not stable
143-
# We instead use (x * sqrt(n)) / norm(x, dim=-1)
144-
# Using torch.norm and preserving this op in CoreML improves stability
145-
# Note, we ignore eps, but could add it by using torch.norm(torch.concat(x, sqrt(n*eps))) in the denominator
146-
# In future, we want to add CoreML support for the functional RMSNorm op
147-
# We have yet to do large scale evaluations on the numeric stability of this solution, but note that
148-
# it appears better than what exists currently (removing FP32 casts and using FP16)
149-
rms_norm_eps0 = (
150-
x
151-
* torch.sqrt(torch.tensor(self.dim, dtype=x.dtype))
152-
* torch.reciprocal(torch.linalg.vector_norm(x, dim=-1, keepdim=True))
153-
)
154-
return rms_norm_eps0
155-
156-
def forward(self, x):
157-
"""
158-
Forward pass through the RMSNorm layer.
159-
160-
Args:
161-
x (torch.Tensor): The input tensor.
162-
163-
Returns:
164-
torch.Tensor: The output tensor after applying RMSNorm.
165-
166-
"""
167-
output = self._norm(x)
168-
return output * self.weight
169-
170-
171112
class Rope(torch.nn.Module):
172113
def __init__(self, params: ModelArgs):
173114
super().__init__()

examples/models/llama/norm.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,54 @@ def __init__(self, dim: int, eps: float = 1e-6):
5757
self.weight.requires_grad = False
5858

5959

60+
class RMSNormCoreML(torch.nn.Module):
61+
def __init__(self, dim: int, eps: float = 1e-6):
62+
"""
63+
CoreML-friendly RMSNorm — uses `torch.linalg.vector_norm` so the op is
64+
preserved in the CoreML graph for numerical stability.
65+
66+
Args:
67+
dim (int): The dimension of the input tensor.
68+
eps (float, optional): Floor on the L2-norm denominator
69+
(`clamp_min(‖x‖₂, √(dim·eps))`). Prevents `0/0 = NaN` on
70+
zero-padded positions and matches standard RMSNorm's
71+
`rsqrt(mean(x²) + eps)` semantics on a zero input. Must be > 0.
72+
73+
Attributes:
74+
eps (float): Floor coefficient consumed by `_norm`.
75+
weight (nn.Parameter): Learnable scaling parameter.
76+
"""
77+
super().__init__()
78+
assert eps > 0, (
79+
"RMSNormCoreML requires eps > 0; eps=0 collapses the denominator "
80+
"floor and produces NaN on zero-padded positions"
81+
)
82+
self.dim = dim
83+
self.eps = eps
84+
self.weight = nn.Parameter(torch.ones(dim))
85+
86+
def _norm(self, x):
87+
# Floor the denominator to avoid 0 / 0 = NaN on zero-padded positions
88+
# (chunked prefill in StaticAttentionIOManager pads each chunk to
89+
# input_len with zeros). Use sqrt(dim * eps) so the floor matches
90+
# standard RMSNorm's eps semantics (`rsqrt(mean(x²) + eps)`) and is
91+
# large enough to survive fp16 (1e-6 alone underflows in fp16).
92+
floor_val = torch.sqrt(torch.tensor(self.dim * self.eps, dtype=x.dtype))
93+
norm_val = torch.clamp_min(
94+
torch.linalg.vector_norm(x, dim=-1, keepdim=True), floor_val
95+
)
96+
rms_norm_eps0 = (
97+
x
98+
* torch.sqrt(torch.tensor(self.dim, dtype=x.dtype))
99+
* torch.reciprocal(norm_val)
100+
)
101+
return rms_norm_eps0
102+
103+
def forward(self, x):
104+
output = self._norm(x)
105+
return output * self.weight
106+
107+
60108
class RMSNormWithInputScale(torch.nn.Module):
61109
def __init__(self, dim: int, eps: float = 1e-5):
62110
super().__init__()
@@ -83,3 +131,37 @@ def forward(self, hidden_states: torch.Tensor, gate: torch.Tensor) -> torch.Tens
83131
hidden_states = self.weight * hidden_states.to(input_dtype)
84132
hidden_states = hidden_states * F.silu(gate.to(torch.float32))
85133
return hidden_states.to(input_dtype)
134+
135+
136+
def replace_rms_norm_for_coreml_(model: torch.nn.Module) -> torch.nn.Module:
137+
"""In-place: walk `model` and swap every RMSNorm-family module for RMSNormCoreML.
138+
139+
Mirrors the post-construction transform pattern used by torchao's
140+
`quantize_(model, config)`: instead of threading a `use_coreml_norm` flag
141+
through every norm construction site, build the model with the standard
142+
norms and then call this once before CoreML export. Trained scale weights
143+
are preserved.
144+
145+
Swaps these classes (everything else is left alone):
146+
* `RMSNorm` (this module)
147+
* `ScalelessRMSNorm` (this module — no-op weight)
148+
* `torch.nn.RMSNorm` (used for affine q_norm/k_norm in StaticAttention)
149+
"""
150+
for name, mod in list(model.named_modules()):
151+
if not isinstance(mod, (RMSNorm, ScalelessRMSNorm, torch.nn.RMSNorm)):
152+
continue
153+
# All three carry the normalized dim either as `dim` or in `normalized_shape[-1]`.
154+
dim = getattr(mod, "dim", None) or mod.normalized_shape[-1]
155+
eps = getattr(mod, "eps", 1e-6) or 1e-6
156+
new = RMSNormCoreML(dim, eps=eps)
157+
# Preserve trained scale (no-op for ScalelessRMSNorm).
158+
if getattr(mod, "weight", None) is not None:
159+
new.weight = mod.weight
160+
# Locate parent module via the dotted name and rebind the attribute.
161+
if "." in name:
162+
parent_name, attr = name.rsplit(".", 1)
163+
parent = model.get_submodule(parent_name)
164+
else:
165+
parent, attr = model, name
166+
setattr(parent, attr, new)
167+
return model

0 commit comments

Comments
 (0)