@@ -57,6 +57,54 @@ def __init__(self, dim: int, eps: float = 1e-6):
5757 self .weight .requires_grad = False
5858
5959
60+ class RMSNormCoreML (torch .nn .Module ):
61+ def __init__ (self , dim : int , eps : float = 1e-6 ):
62+ """
63+ CoreML-friendly RMSNorm — uses `torch.linalg.vector_norm` so the op is
64+ preserved in the CoreML graph for numerical stability.
65+
66+ Args:
67+ dim (int): The dimension of the input tensor.
68+ eps (float, optional): Floor on the L2-norm denominator
69+ (`clamp_min(‖x‖₂, √(dim·eps))`). Prevents `0/0 = NaN` on
70+ zero-padded positions and matches standard RMSNorm's
71+ `rsqrt(mean(x²) + eps)` semantics on a zero input. Must be > 0.
72+
73+ Attributes:
74+ eps (float): Floor coefficient consumed by `_norm`.
75+ weight (nn.Parameter): Learnable scaling parameter.
76+ """
77+ super ().__init__ ()
78+ assert eps > 0 , (
79+ "RMSNormCoreML requires eps > 0; eps=0 collapses the denominator "
80+ "floor and produces NaN on zero-padded positions"
81+ )
82+ self .dim = dim
83+ self .eps = eps
84+ self .weight = nn .Parameter (torch .ones (dim ))
85+
86+ def _norm (self , x ):
87+ # Floor the denominator to avoid 0 / 0 = NaN on zero-padded positions
88+ # (chunked prefill in StaticAttentionIOManager pads each chunk to
89+ # input_len with zeros). Use sqrt(dim * eps) so the floor matches
90+ # standard RMSNorm's eps semantics (`rsqrt(mean(x²) + eps)`) and is
91+ # large enough to survive fp16 (1e-6 alone underflows in fp16).
92+ floor_val = torch .sqrt (torch .tensor (self .dim * self .eps , dtype = x .dtype ))
93+ norm_val = torch .clamp_min (
94+ torch .linalg .vector_norm (x , dim = - 1 , keepdim = True ), floor_val
95+ )
96+ rms_norm_eps0 = (
97+ x
98+ * torch .sqrt (torch .tensor (self .dim , dtype = x .dtype ))
99+ * torch .reciprocal (norm_val )
100+ )
101+ return rms_norm_eps0
102+
103+ def forward (self , x ):
104+ output = self ._norm (x )
105+ return output * self .weight
106+
107+
60108class RMSNormWithInputScale (torch .nn .Module ):
61109 def __init__ (self , dim : int , eps : float = 1e-5 ):
62110 super ().__init__ ()
@@ -83,3 +131,37 @@ def forward(self, hidden_states: torch.Tensor, gate: torch.Tensor) -> torch.Tens
83131 hidden_states = self .weight * hidden_states .to (input_dtype )
84132 hidden_states = hidden_states * F .silu (gate .to (torch .float32 ))
85133 return hidden_states .to (input_dtype )
134+
135+
136+ def replace_rms_norm_for_coreml_ (model : torch .nn .Module ) -> torch .nn .Module :
137+ """In-place: walk `model` and swap every RMSNorm-family module for RMSNormCoreML.
138+
139+ Mirrors the post-construction transform pattern used by torchao's
140+ `quantize_(model, config)`: instead of threading a `use_coreml_norm` flag
141+ through every norm construction site, build the model with the standard
142+ norms and then call this once before CoreML export. Trained scale weights
143+ are preserved.
144+
145+ Swaps these classes (everything else is left alone):
146+ * `RMSNorm` (this module)
147+ * `ScalelessRMSNorm` (this module — no-op weight)
148+ * `torch.nn.RMSNorm` (used for affine q_norm/k_norm in StaticAttention)
149+ """
150+ for name , mod in list (model .named_modules ()):
151+ if not isinstance (mod , (RMSNorm , ScalelessRMSNorm , torch .nn .RMSNorm )):
152+ continue
153+ # All three carry the normalized dim either as `dim` or in `normalized_shape[-1]`.
154+ dim = getattr (mod , "dim" , None ) or mod .normalized_shape [- 1 ]
155+ eps = getattr (mod , "eps" , 1e-6 ) or 1e-6
156+ new = RMSNormCoreML (dim , eps = eps )
157+ # Preserve trained scale (no-op for ScalelessRMSNorm).
158+ if getattr (mod , "weight" , None ) is not None :
159+ new .weight = mod .weight
160+ # Locate parent module via the dotted name and rebind the attribute.
161+ if "." in name :
162+ parent_name , attr = name .rsplit ("." , 1 )
163+ parent = model .get_submodule (parent_name )
164+ else :
165+ parent , attr = model , name
166+ setattr (parent , attr , new )
167+ return model
0 commit comments