|
| 1 | +""" |
| 2 | +Muon Optimizer |
| 3 | +
|
| 4 | +Implements Muon optimizer for neural network hidden layers using NumPy. |
| 5 | +Muon uses Newton-Schulz orthogonalization iterations for improved convergence. |
| 6 | +
|
| 7 | +Reference: https://kellerjordan.github.io/posts/muon/ |
| 8 | +Author: Adhithya Laxman Ravi Shankar Geetha |
| 9 | +Date: 2025.10.21 |
| 10 | +""" |
| 11 | + |
| 12 | +import numpy as np |
| 13 | + |
| 14 | + |
| 15 | +class Muon: |
| 16 | + """ |
| 17 | + Muon optimizer for hidden layer weight matrices. |
| 18 | +
|
| 19 | + Applies Newton-Schulz orthogonalization to gradients before updates. |
| 20 | + """ |
| 21 | + |
| 22 | + def __init__( |
| 23 | + self, learning_rate: float = 0.02, momentum: float = 0.95, ns_steps: int = 5 |
| 24 | + ) -> None: |
| 25 | + """ |
| 26 | + Initialize Muon optimizer. |
| 27 | +
|
| 28 | + Args: |
| 29 | + learning_rate (float): Learning rate for updates. |
| 30 | + momentum (float): Momentum factor. |
| 31 | + ns_steps (int): Number of Newton-Schulz iteration steps. |
| 32 | +
|
| 33 | + >>> optimizer = Muon(learning_rate=0.02, momentum=0.95, ns_steps=5) |
| 34 | + >>> optimizer.momentum |
| 35 | + 0.95 |
| 36 | + """ |
| 37 | + self.learning_rate = learning_rate |
| 38 | + self.momentum = momentum |
| 39 | + self.ns_steps = ns_steps |
| 40 | + self.velocity: dict[int, np.ndarray] = {} |
| 41 | + |
| 42 | + def newton_schulz_orthogonalize(self, matrix: np.ndarray) -> np.ndarray: |
| 43 | + """ |
| 44 | + Orthogonalize matrix using Newton-Schulz iterations. |
| 45 | +
|
| 46 | + Args: |
| 47 | + matrix (np.ndarray): Input matrix. |
| 48 | +
|
| 49 | + Returns: |
| 50 | + np.ndarray: Orthogonalized matrix. |
| 51 | +
|
| 52 | + >>> optimizer = Muon() |
| 53 | + >>> mat = np.array([[1.0, 0.5], [0.5, 1.0]]) |
| 54 | + >>> orth = optimizer.newton_schulz_orthogonalize(mat) |
| 55 | + >>> orth.shape |
| 56 | + (2, 2) |
| 57 | + """ |
| 58 | + if matrix.shape[0] < matrix.shape[1]: |
| 59 | + matrix = matrix.T |
| 60 | + transposed = True |
| 61 | + else: |
| 62 | + transposed = False |
| 63 | + |
| 64 | + a = matrix.copy() |
| 65 | + for _ in range(self.ns_steps): |
| 66 | + a = 1.5 * a - 0.5 * a @ (a.T @ a) |
| 67 | + |
| 68 | + return a.T if transposed else a |
| 69 | + |
| 70 | + def update( |
| 71 | + self, param_id: int, params: np.ndarray, gradients: np.ndarray |
| 72 | + ) -> np.ndarray: |
| 73 | + """ |
| 74 | + Update parameters using Muon. |
| 75 | +
|
| 76 | + Args: |
| 77 | + param_id (int): Unique identifier for parameter group. |
| 78 | + params (np.ndarray): Current parameters. |
| 79 | + gradients (np.ndarray): Gradients of parameters. |
| 80 | +
|
| 81 | + Returns: |
| 82 | + np.ndarray: Updated parameters. |
| 83 | +
|
| 84 | + >>> optimizer = Muon(learning_rate=0.1, momentum=0.9) |
| 85 | + >>> params = np.array([[1.0, 2.0], [3.0, 4.0]]) |
| 86 | + >>> grads = np.array([[0.1, 0.2], [0.3, 0.4]]) |
| 87 | + >>> updated = optimizer.update(0, params, grads) |
| 88 | + >>> updated.shape |
| 89 | + (2, 2) |
| 90 | + """ |
| 91 | + if param_id not in self.velocity: |
| 92 | + self.velocity[param_id] = np.zeros_like(params) |
| 93 | + |
| 94 | + ortho_grad = self.newton_schulz_orthogonalize(gradients) |
| 95 | + self.velocity[param_id] = self.momentum * self.velocity[param_id] + ortho_grad |
| 96 | + |
| 97 | + return params - self.learning_rate * self.velocity[param_id] |
| 98 | + |
| 99 | + |
| 100 | +# Usage example |
| 101 | +if __name__ == "__main__": |
| 102 | + import doctest |
| 103 | + |
| 104 | + doctest.testmod() |
| 105 | + |
| 106 | + print("Muon Example: Optimizing a 2x2 matrix") |
| 107 | + |
| 108 | + optimizer = Muon(learning_rate=0.05, momentum=0.9) |
| 109 | + weights = np.array([[1.0, 2.0], [3.0, 4.0]]) |
| 110 | + |
| 111 | + for step in range(10): |
| 112 | + gradients = 0.1 * weights # Simplified gradient |
| 113 | + weights = optimizer.update(0, weights, gradients) |
| 114 | + if step % 3 == 0: |
| 115 | + print(f"Step {step}: weights =\n{weights}") |
| 116 | + |
| 117 | + print(f"Final weights:\n{weights}") |
0 commit comments