1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ import torch
4
+ import torch .nn as nn
5
+ import math
6
+ import os
7
+ import torch .distributed as dist
8
+ import torch .nn as nn
9
+ from torch import Tensor
10
+ def zeropower_via_newtonschulz5 (G : Tensor , steps : int ) -> Tensor :
11
+ """
12
+ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
13
+ quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
14
+ of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
15
+ zero even beyond the point where the iteration no longer converges all the way to one everywhere
16
+ on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
17
+ where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
18
+ performance at all relative to UV^T, where USV^T = G is the SVD.
19
+ """
20
+ assert G .ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
21
+ a , b , c = (3.4445 , - 4.7750 , 2.0315 )
22
+ X = G .bfloat16 ()
23
+ if G .size (- 2 ) > G .size (- 1 ):
24
+ X = X .mT
25
+
26
+ # Ensure spectral norm is at most 1
27
+ X = X / (X .norm (dim = (- 2 , - 1 ), keepdim = True ) + 1e-7 )
28
+ # Perform the NS iterations
29
+ for _ in range (steps ):
30
+ A = X @ X .mT
31
+ B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
32
+ X = a * X + B @ X
33
+
34
+ if G .size (- 2 ) > G .size (- 1 ):
35
+ X = X .mT
36
+ return X
37
+ class Muon (torch .optim .Optimizer ):
38
+ """
39
+ Adam optimizer with orthogonalization step.
40
+ """
41
+ def __init__ (self , params , lr = 0.001 , betas = (0.9 , 0.999 ), eps = 1e-8 , weight_decay = 0 , ns_steps = 5 ):
42
+ defaults = dict (lr = lr , betas = betas , eps = eps , weight_decay = weight_decay , ns_steps = ns_steps )
43
+ super ().__init__ (params , defaults )
44
+
45
+ @torch .no_grad ()
46
+ def step (self , closure = None ):
47
+ """
48
+ Performs a single optimization step.
49
+
50
+ Args:
51
+ closure (callable, optional): A closure that reevaluates the model
52
+ and returns the loss.
53
+ """
54
+ loss = None
55
+ if closure is not None :
56
+ loss = closure ()
57
+
58
+ for group in self .param_groups :
59
+ for p in group ['params' ]:
60
+ if p .grad is None :
61
+ continue
62
+ grad = p .grad
63
+ state = self .state [p ]
64
+
65
+ # Initialize state
66
+ if len (state ) == 0 :
67
+ state ['step' ] = 0
68
+ state ['exp_avg' ] = torch .zeros_like (p )
69
+ state ['exp_avg_sq' ] = torch .zeros_like (p )
70
+
71
+ exp_avg , exp_avg_sq = state ['exp_avg' ], state ['exp_avg_sq' ]
72
+ beta1 , beta2 = group ['betas' ]
73
+
74
+ state ['step' ] += 1
75
+ bias_correction1 = 1 - beta1 ** state ['step' ]
76
+ bias_correction2 = 1 - beta2 ** state ['step' ]
77
+
78
+ # Update momentum and squared gradient
79
+ exp_avg .mul_ (beta1 ).add_ (grad , alpha = 1 - beta1 )
80
+ exp_avg_sq .mul_ (beta2 ).addcmul_ (grad , grad , value = 1 - beta2 )
81
+
82
+ # Compute the update
83
+ denom = (exp_avg_sq .sqrt () / math .sqrt (bias_correction2 )).add_ (group ['eps' ])
84
+ step_size = group ['lr' ] / bias_correction1
85
+
86
+ # Orthogonalize the update
87
+ update = exp_avg / denom
88
+ if update .ndim >= 2 :
89
+ update = zeropower_via_newtonschulz5 (update , steps = group ['ns_steps' ])
90
+
91
+ # Apply the update
92
+ p .add_ (update , alpha = - step_size )
93
+
94
+ # Apply weight decay
95
+ if group ['weight_decay' ] != 0 :
96
+ p .add_ (p , alpha = - group ['lr' ] * group ['weight_decay' ])
97
+
98
+ return loss
0 commit comments