Skip to content

Commit f42ce2c

Browse files
Add Muon optimizer implementation
- Implements Muon optimizer for hidden layer weight matrices - Uses Newton-Schulz orthogonalization iterations - Provides matrix-aware gradient updates with spectral constraints - Includes comprehensive docstrings and type hints - Adds doctests for validation - Provides usage example demonstrating optimization - Follows PEP8 coding standards - Pure NumPy implementation without frameworks - Part of issue #13662
1 parent 76f8f40 commit f42ce2c

File tree

1 file changed

+117
-0
lines changed

1 file changed

+117
-0
lines changed

neural_network/optimizers/muon.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
"""
2+
Muon Optimizer
3+
4+
Implements Muon optimizer for neural network hidden layers using NumPy.
5+
Muon uses Newton-Schulz orthogonalization iterations for improved convergence.
6+
7+
Reference: https://kellerjordan.github.io/posts/muon/
8+
Author: Adhithya Laxman Ravi Shankar Geetha
9+
Date: 2025.10.21
10+
"""
11+
12+
import numpy as np
13+
14+
15+
class Muon:
16+
"""
17+
Muon optimizer for hidden layer weight matrices.
18+
19+
Applies Newton-Schulz orthogonalization to gradients before updates.
20+
"""
21+
22+
def __init__(
23+
self, learning_rate: float = 0.02, momentum: float = 0.95, ns_steps: int = 5
24+
) -> None:
25+
"""
26+
Initialize Muon optimizer.
27+
28+
Args:
29+
learning_rate (float): Learning rate for updates.
30+
momentum (float): Momentum factor.
31+
ns_steps (int): Number of Newton-Schulz iteration steps.
32+
33+
>>> optimizer = Muon(learning_rate=0.02, momentum=0.95, ns_steps=5)
34+
>>> optimizer.momentum
35+
0.95
36+
"""
37+
self.learning_rate = learning_rate
38+
self.momentum = momentum
39+
self.ns_steps = ns_steps
40+
self.velocity: dict[int, np.ndarray] = {}
41+
42+
def newton_schulz_orthogonalize(self, matrix: np.ndarray) -> np.ndarray:
43+
"""
44+
Orthogonalize matrix using Newton-Schulz iterations.
45+
46+
Args:
47+
matrix (np.ndarray): Input matrix.
48+
49+
Returns:
50+
np.ndarray: Orthogonalized matrix.
51+
52+
>>> optimizer = Muon()
53+
>>> mat = np.array([[1.0, 0.5], [0.5, 1.0]])
54+
>>> orth = optimizer.newton_schulz_orthogonalize(mat)
55+
>>> orth.shape
56+
(2, 2)
57+
"""
58+
if matrix.shape[0] < matrix.shape[1]:
59+
matrix = matrix.T
60+
transposed = True
61+
else:
62+
transposed = False
63+
64+
a = matrix.copy()
65+
for _ in range(self.ns_steps):
66+
a = 1.5 * a - 0.5 * a @ (a.T @ a)
67+
68+
return a.T if transposed else a
69+
70+
def update(
71+
self, param_id: int, params: np.ndarray, gradients: np.ndarray
72+
) -> np.ndarray:
73+
"""
74+
Update parameters using Muon.
75+
76+
Args:
77+
param_id (int): Unique identifier for parameter group.
78+
params (np.ndarray): Current parameters.
79+
gradients (np.ndarray): Gradients of parameters.
80+
81+
Returns:
82+
np.ndarray: Updated parameters.
83+
84+
>>> optimizer = Muon(learning_rate=0.1, momentum=0.9)
85+
>>> params = np.array([[1.0, 2.0], [3.0, 4.0]])
86+
>>> grads = np.array([[0.1, 0.2], [0.3, 0.4]])
87+
>>> updated = optimizer.update(0, params, grads)
88+
>>> updated.shape
89+
(2, 2)
90+
"""
91+
if param_id not in self.velocity:
92+
self.velocity[param_id] = np.zeros_like(params)
93+
94+
ortho_grad = self.newton_schulz_orthogonalize(gradients)
95+
self.velocity[param_id] = self.momentum * self.velocity[param_id] + ortho_grad
96+
97+
return params - self.learning_rate * self.velocity[param_id]
98+
99+
100+
# Usage example
101+
if __name__ == "__main__":
102+
import doctest
103+
104+
doctest.testmod()
105+
106+
print("Muon Example: Optimizing a 2x2 matrix")
107+
108+
optimizer = Muon(learning_rate=0.05, momentum=0.9)
109+
weights = np.array([[1.0, 2.0], [3.0, 4.0]])
110+
111+
for step in range(10):
112+
gradients = 0.1 * weights # Simplified gradient
113+
weights = optimizer.update(0, weights, gradients)
114+
if step % 3 == 0:
115+
print(f"Step {step}: weights =\n{weights}")
116+
117+
print(f"Final weights:\n{weights}")

0 commit comments

Comments
 (0)