Adding the PolarGrad optimizer implementation#100
Adding the PolarGrad optimizer implementation#100timlautk wants to merge 11 commits intoNVIDIA-NeMo:mainfrom
Conversation
Greptile OverviewGreptile SummaryThis PR adds a new
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User as Training loop
participant PG as PolarGrad.step()
participant OO as OrthogonalizedOptimizer
participant MU as muon_utils.newton_schulz
participant TK as triton_kernels (optional)
User->>PG: optimizer.step()
PG->>OO: super().__init__(..., scaled_orthogonalize_fn)
Note over OO: step() uses SGD-momentum (+ optional Nesterov)
OO->>OO: compute momentum update per param
OO->>MU: scaled_orthogonalize_fn(grad)
MU->>TK: (optional) use_syrk kernel path
TK-->>MU: orthogonalized update
MU-->>OO: orth_grad
OO-->>User: apply scaled orthogonalized update
|
67cb337 to
9246029
Compare
| if scale_mode != "nuclear_norm": | ||
| scale_factor = muon.get_muon_scale_factor(grad.size(-2), grad.size(-1), mode=scale_mode) | ||
| else: | ||
| scale_factor = (orth_grad * grad).sum() |
There was a problem hiding this comment.
Nuclear norm scaling bug
In the scale_mode == "nuclear_norm" branch, scale_factor = (orth_grad * grad).sum() is the Frobenius inner product ⟨UVᵀ, G⟩, which equals tr(S) (nuclear norm) only if orth_grad is exactly UVᵀ from the SVD of grad. However muon_utils.newton_schulz() explicitly returns an approximate US'Vᵀ (see muon_utils.py:130-132), so this scale will generally be wrong (and can even change sign/magnitude vs the intended nuclear norm). You likely need to compute the scale from grad directly (or adjust the orthogonalization routine to produce a true polar factor if you want to use ⟨orth_grad, grad⟩).
There was a problem hiding this comment.
We assume the approximation is accurate enough so that the nuclear norm can also be approximated this way. Computing the nuclear norm of grad is too expensive using SVD.
There was a problem hiding this comment.
We have the SVD path in mop, it is usually not the bottleneck at scale. Doesn't hurt to have both. may need to consolidate naming convention at some point though.
9246029 to
a874d23
Compare
| orth_grad = muon_utils.newton_schulz( | ||
| grad, | ||
| steps=num_ns_steps, | ||
| coefficient_type=coefficient_type, | ||
| use_syrk=use_syrk, | ||
| ) | ||
| scale_factor: float | torch.Tensor | ||
| if scale_mode != "nuclear_norm": | ||
| scale_factor = muon.get_muon_scale_factor(grad.size(-2), grad.size(-1), mode=scale_mode) | ||
| else: | ||
| scale_factor = (orth_grad * grad).sum() |
There was a problem hiding this comment.
Incorrect nuclear-norm scaling
scale_factor = (orth_grad * grad).sum() (line ~114) only equals the nuclear norm if orth_grad is the true polar factor UVᵀ from the SVD of grad. But muon_utils.newton_schulz() explicitly returns an approximation closer to US'Vᵀ (see muon_utils.py:130-132), so this scale will be systematically wrong (including wrong magnitude/sign), changing the intended PolarGrad update. Compute the nuclear-norm scale directly from grad (e.g., via singular values) or adjust the orthogonalization routine to produce the true polar factor before using this inner product.
Signed-off-by: Tim Lau <timlautk@gmail.com>
Signed-off-by: Hao Wu <skyw@nvidia.com> Signed-off-by: Tim Lau <timlautk@gmail.com>
Signed-off-by: Hao Wu <skyw@nvidia.com> Signed-off-by: Tim Lau <timlautk@gmail.com>
* make a coefficient iterator for NS Signed-off-by: Hao Wu <skyw@nvidia.com> Signed-off-by: Tim Lau <timlautk@gmail.com>
Signed-off-by: Phoenix <861062923@qq.com> Signed-off-by: Tim Lau <timlautk@gmail.com>
a874d23 to
330d624
Compare
| ) | ||
| use_syrk = False | ||
|
|
||
| def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
Nuclear-norm scaling incorrect
In scale_mode == "nuclear_norm", scale_factor = (orth_grad * grad).sum() only equals ||grad||_* if orth_grad is the true polar factor UVᵀ from an SVD of grad. Here orth_grad comes from muon_utils.newton_schulz(), which (per its implementation) is not guaranteed to return the exact polar factor, so this dot product can have incorrect magnitude/sign and will change the intended PolarGrad update. Compute the nuclear norm directly from grad (e.g., via singular values) or change the orthogonalization routine to produce the true polar factor before using this inner product.
Additional Comments (1)
|
Signed-off-by: mikail <mkhona@nvidia.com> Signed-off-by: Tim Lau <timlautk@gmail.com>
* add a registry Signed-off-by: Hao Wu <skyw@nvidia.com> Signed-off-by: Tim Lau <timlautk@gmail.com>
* update docstring for MuonHyperball Signed-off-by: Hao Wu <skyw@nvidia.com> Signed-off-by: Tim Lau <timlautk@gmail.com>
e76dfd5 to
af13a74
Compare
Signed-off-by: Tim Lau <timlautk@gmail.com>
| logging.info(f"Iterating through {steps} steps with {mode} mode.") | ||
| logging.debug(f"Coefficient sets: {coefficient_sets}") | ||
|
|
There was a problem hiding this comment.
Noisy logging in hot path
get_coefficient_iterator() logs at INFO on every newton_schulz() call (logging.info(f"Iterating through...")). Since Newton–Schulz is invoked per-parameter per-step, this will spam logs and materially slow training in normal runs. This should be downgraded to debug (or removed) so default INFO logging stays usable.
There was a problem hiding this comment.
I think AI missed this in previous PR. No need to address it here, @timlautk
I'll do another clean up.
Additional Comments (1)
|
Signed-off-by: Tim Lau <timlautk@gmail.com>
|
I think a test for PolarGrad should be added to |
|
Greptile encountered an error while reviewing this PR. Please reach out to support@greptile.com for assistance. |
|
Is the only difference from Muon the scale factor? If so, it would be cleaner to add one more option to get_muon_scale_factor() function, which will avoid duplicating a lot of code with Muon class. For example, code like the copied use_syrk part is usually not accepted. A PolarGrad can be partial from Muon class to fix some argument. |
test_orthogonalized_optimizer.py has grown too much. adding polar grad test into it or starting new trend like test_polar_grad.py both would be fine. |
There was a problem hiding this comment.
Do you have a heuristic for how to transfer learning rate using polargrad?
One of the big wins of Kimi's Muon was the update-rms matching LR parametrization. Is there a corresponding heuristic here? I am assuming since nuclear norm is now an "effective learning rate" scaling factor which varies iteration to iteration, there isn't one
Adding the implementation of PolarGrad (arXiv:2505.21799), based on the existing NS iteration implementation in Muon.