Skip to content

Adding the PolarGrad optimizer implementation#100

Open
timlautk wants to merge 11 commits intoNVIDIA-NeMo:mainfrom
timlautk:origin/timlautk/polargrad
Open

Adding the PolarGrad optimizer implementation#100
timlautk wants to merge 11 commits intoNVIDIA-NeMo:mainfrom
timlautk:origin/timlautk/polargrad

Conversation

@timlautk
Copy link

@timlautk timlautk commented Feb 7, 2026

Adding the implementation of PolarGrad (arXiv:2505.21799), based on the existing NS iteration implementation in Muon.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Feb 7, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@greptile-apps
Copy link

greptile-apps bot commented Feb 7, 2026

Greptile Overview

Greptile Summary

This PR adds a new PolarGrad optimizer under emerging_optimizers/orthogonalized_optimizers, registering it as "polargrad" and re-exporting it in the package __init__.py.

PolarGrad builds on the existing OrthogonalizedOptimizer base: it runs an internal SGD-momentum update (optionally Nesterov), then orthogonalizes each 2D update via Newton–Schulz iteration (muon_utils.newton_schulz) and applies a configurable scaling (Muon-style scaling modes or a "nuclear_norm"-style inner-product scale), with an optional Triton SYRK path when supported.

Confidence Score: 4/5

  • This PR is reasonably safe to merge, with changes scoped to a new optimizer implementation and a single package re-export.
  • The implementation follows existing OrthogonalizedOptimizer/Muon patterns and only touches two files. I did not find additional deterministic runtime-breaking issues introduced in this commit beyond the already-discussed review threads; however, the new optimizer path has limited coverage in this PR set, so downstream integration should still be exercised in CI/training configs.
  • emerging_optimizers/orthogonalized_optimizers/polargrad.py

Important Files Changed

Filename Overview
emerging_optimizers/orthogonalized_optimizers/init.py Adds PolarGrad re-export so it becomes importable via the orthogonalized_optimizers package.
emerging_optimizers/orthogonalized_optimizers/polargrad.py Introduces the PolarGrad optimizer with Newton–Schulz orthogonalization, scaling options, and optional Triton SYRK path; no new merge-blocking bugs found beyond already-discussed threads.

Sequence Diagram

sequenceDiagram
    participant User as Training loop
    participant PG as PolarGrad.step()
    participant OO as OrthogonalizedOptimizer
    participant MU as muon_utils.newton_schulz
    participant TK as triton_kernels (optional)

    User->>PG: optimizer.step()
    PG->>OO: super().__init__(..., scaled_orthogonalize_fn)
    Note over OO: step() uses SGD-momentum (+ optional Nesterov)
    OO->>OO: compute momentum update per param
    OO->>MU: scaled_orthogonalize_fn(grad)
    MU->>TK: (optional) use_syrk kernel path
    TK-->>MU: orthogonalized update
    MU-->>OO: orth_grad
    OO-->>User: apply scaled orthogonalized update

Loading

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@timlautk timlautk marked this pull request as draft February 7, 2026 06:34
@timlautk timlautk marked this pull request as ready for review February 7, 2026 06:34
@timlautk timlautk force-pushed the origin/timlautk/polargrad branch from 67cb337 to 9246029 Compare February 7, 2026 06:35
Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines +111 to +113
if scale_mode != "nuclear_norm":
scale_factor = muon.get_muon_scale_factor(grad.size(-2), grad.size(-1), mode=scale_mode)
else:
scale_factor = (orth_grad * grad).sum()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nuclear norm scaling bug

In the scale_mode == "nuclear_norm" branch, scale_factor = (orth_grad * grad).sum() is the Frobenius inner product ⟨UVᵀ, G⟩, which equals tr(S) (nuclear norm) only if orth_grad is exactly UVᵀ from the SVD of grad. However muon_utils.newton_schulz() explicitly returns an approximate US'Vᵀ (see muon_utils.py:130-132), so this scale will generally be wrong (and can even change sign/magnitude vs the intended nuclear norm). You likely need to compute the scale from grad directly (or adjust the orthogonalization routine to produce a true polar factor if you want to use ⟨orth_grad, grad⟩).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We assume the approximation is accurate enough so that the nuclear norm can also be approximated this way. Computing the nuclear norm of grad is too expensive using SVD.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have the SVD path in mop, it is usually not the bottleneck at scale. Doesn't hurt to have both. may need to consolidate naming convention at some point though.

@timlautk timlautk force-pushed the origin/timlautk/polargrad branch from 9246029 to a874d23 Compare February 7, 2026 06:43
Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines +104 to +113
orth_grad = muon_utils.newton_schulz(
grad,
steps=num_ns_steps,
coefficient_type=coefficient_type,
use_syrk=use_syrk,
)
scale_factor: float | torch.Tensor
if scale_mode != "nuclear_norm":
scale_factor = muon.get_muon_scale_factor(grad.size(-2), grad.size(-1), mode=scale_mode)
else:
scale_factor = (orth_grad * grad).sum()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incorrect nuclear-norm scaling

scale_factor = (orth_grad * grad).sum() (line ~114) only equals the nuclear norm if orth_grad is the true polar factor UVᵀ from the SVD of grad. But muon_utils.newton_schulz() explicitly returns an approximation closer to US'Vᵀ (see muon_utils.py:130-132), so this scale will be systematically wrong (including wrong magnitude/sign), changing the intended PolarGrad update. Compute the nuclear-norm scale directly from grad (e.g., via singular values) or adjust the orthogonalization routine to produce the true polar factor before using this inner product.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above.

timlautk and others added 5 commits February 6, 2026 22:49
Signed-off-by: Tim Lau <timlautk@gmail.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Tim Lau <timlautk@gmail.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Tim Lau <timlautk@gmail.com>
* make a coefficient iterator for NS

Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Tim Lau <timlautk@gmail.com>
Signed-off-by: Phoenix <861062923@qq.com>
Signed-off-by: Tim Lau <timlautk@gmail.com>
@timlautk timlautk force-pushed the origin/timlautk/polargrad branch from a874d23 to 330d624 Compare February 7, 2026 06:52
Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

22 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +96 to +98
)
use_syrk = False

def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nuclear-norm scaling incorrect

In scale_mode == "nuclear_norm", scale_factor = (orth_grad * grad).sum() only equals ||grad||_* if orth_grad is the true polar factor UVᵀ from an SVD of grad. Here orth_grad comes from muon_utils.newton_schulz(), which (per its implementation) is not guaranteed to return the exact polar factor, so this dot product can have incorrect magnitude/sign and will change the intended PolarGrad update. Compute the nuclear norm directly from grad (e.g., via singular values) or change the orthogonalization routine to produce the true polar factor before using this inner product.

@greptile-apps
Copy link

greptile-apps bot commented Feb 7, 2026

Additional Comments (1)

emerging_optimizers/orthogonalized_optimizers/muon_utils.py
Backward-incompatible steps constraint

newton_schulz() used to raise if steps wasn’t a multiple of len(coefficient_sets), but this PR silently changes behavior to cycling coefficients (or repeat-last for polar_express). That changes the algorithm for existing callers that relied on the old constraint (and makes misconfigured steps hard to notice). If cycle is intended as the default, consider preserving the validation for non-polar_express (or making the new behavior opt-in via an explicit flag/mode) so existing Muon/Scion behavior can’t change just by passing a different steps value.

mkhona-nvidia and others added 3 commits February 6, 2026 23:00
Signed-off-by: mikail <mkhona@nvidia.com>
Signed-off-by: Tim Lau <timlautk@gmail.com>
* add a registry

Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Tim Lau <timlautk@gmail.com>
* update docstring for MuonHyperball

Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Tim Lau <timlautk@gmail.com>
@timlautk timlautk force-pushed the origin/timlautk/polargrad branch from e76dfd5 to af13a74 Compare February 7, 2026 07:02
Signed-off-by: Tim Lau <timlautk@gmail.com>
Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

11 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

Comment on lines 89 to 91
logging.info(f"Iterating through {steps} steps with {mode} mode.")
logging.debug(f"Coefficient sets: {coefficient_sets}")

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noisy logging in hot path

get_coefficient_iterator() logs at INFO on every newton_schulz() call (logging.info(f"Iterating through...")). Since Newton–Schulz is invoked per-parameter per-step, this will spam logs and materially slow training in normal runs. This should be downgraded to debug (or removed) so default INFO logging stays usable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think AI missed this in previous PR. No need to address it here, @timlautk

I'll do another clean up.

@greptile-apps
Copy link

greptile-apps bot commented Feb 7, 2026

Additional Comments (1)

docs/apidocs/orthogonalized-optimizers.md
PolarGrad missing from API docs

docs/apidocs/orthogonalized-optimizers.md documents Muon/Scion/MOP/MuonHyperball but doesn’t include the new PolarGrad class, so it won’t appear in the generated API docs even though it’s part of emerging_optimizers.orthogonalized_optimizers.

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Tim Lau <timlautk@gmail.com>
@timlautk
Copy link
Author

timlautk commented Feb 8, 2026

I think a test for PolarGrad should be added to tests/test_orthogonalized_optimizer.py. I will leave it to admin.

@greptile-apps
Copy link

greptile-apps bot commented Feb 8, 2026

Greptile encountered an error while reviewing this PR. Please reach out to support@greptile.com for assistance.

@skyw
Copy link
Contributor

skyw commented Feb 8, 2026

Is the only difference from Muon the scale factor? If so, it would be cleaner to add one more option to get_muon_scale_factor() function, which will avoid duplicating a lot of code with Muon class. For example, code like the copied use_syrk part is usually not accepted.

A PolarGrad can be partial from Muon class to fix some argument.

@skyw
Copy link
Contributor

skyw commented Feb 8, 2026

I think a test for PolarGrad should be added to tests/test_orthogonalized_optimizer.py. I will leave it to admin.

test_orthogonalized_optimizer.py has grown too much. adding polar grad test into it or starting new trend like test_polar_grad.py both would be fine.

Copy link
Contributor

@mkhona-nvidia mkhona-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@timlautk

Do you have a heuristic for how to transfer learning rate using polargrad?

One of the big wins of Kimi's Muon was the update-rms matching LR parametrization. Is there a corresponding heuristic here? I am assuming since nuclear norm is now an "effective learning rate" scaling factor which varies iteration to iteration, there isn't one

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants