14
14
15
15
16
16
def precond_update_prob_schedule (
17
- max_prob = 1.0 , min_prob = 0.03 , decay = 0.001 , flat_start = 250
17
+ max_prob = 1.0 , min_prob = 0.03 , decay = 0.001 , flat_start = 500
18
18
):
19
19
"""Anneal preconditioner update probability during beginning of training.
20
20
@@ -45,6 +45,7 @@ class Kron(torch.optim.Optimizer):
45
45
parameter groups.
46
46
lr (float): Learning rate.
47
47
b1 (float): Momentum parameter.
48
+ normalize_grads (bool): Whether to normalize incoming gradients layer-wise.
48
49
weight_decay (float): Weight decay (L2 penalty).
49
50
preconditioner_update_probability (callable or float, optional): Probability of
50
51
updating the preconditioner. If None, defaults to a schedule that anneals
@@ -67,6 +68,7 @@ def __init__(
67
68
params ,
68
69
lr = 0.001 ,
69
70
b1 = 0.9 ,
71
+ normalize_grads = False ,
70
72
weight_decay = 0.0 ,
71
73
preconditioner_update_probability = None ,
72
74
max_size_triangular = 8192 ,
@@ -89,6 +91,7 @@ def __init__(
89
91
defaults = dict (
90
92
lr = lr ,
91
93
b1 = b1 ,
94
+ normalize_grads = normalize_grads ,
92
95
weight_decay = weight_decay ,
93
96
preconditioner_update_probability = preconditioner_update_probability ,
94
97
max_size_triangular = max_size_triangular ,
@@ -104,6 +107,7 @@ def __init__(
104
107
105
108
self ._tiny = torch .finfo (torch .bfloat16 ).tiny
106
109
self ._prob_step = 0
110
+ self ._update_counter = 0
107
111
self .rng = random .Random (5318008 )
108
112
109
113
@torch .no_grad ()
@@ -118,13 +122,17 @@ def step(self, closure=None):
118
122
total_precond_size = 0
119
123
total_precond_mb = 0
120
124
121
- # update preconditioners all together
125
+ # update preconditioners all together deterministically
122
126
update_prob = self .param_groups [0 ]["preconditioner_update_probability" ]
123
127
if callable (update_prob ):
124
128
update_prob = update_prob (self ._prob_step )
125
- do_update = self .rng .random () < update_prob
129
+ self ._update_counter += 1
130
+ do_update = self ._update_counter >= 1 / update_prob
131
+ if do_update :
132
+ self ._update_counter = 0
126
133
self ._prob_step += 1
127
134
135
+ # balance preconditioners roughly every 100 updates
128
136
balance = self .rng .random () < 0.01 and do_update
129
137
130
138
for group in self .param_groups :
@@ -174,6 +182,9 @@ def step(self, closure=None):
174
182
175
183
state ["step" ] += 1
176
184
185
+ if group ["normalize_grads" ]:
186
+ grad /= torch .norm (grad ) + 1e-12
187
+
177
188
# Update momentum buffer
178
189
beta = group ["b1" ]
179
190
bias_correction = 1 - beta ** state ["step" ]
@@ -209,13 +220,6 @@ def step(self, closure=None):
209
220
state ["Q" ], state ["exprs" ], debiased_momentum
210
221
).to (dtype = p .dtype , non_blocking = True )
211
222
212
- trust_region_fn = lambda x : 0.1 * torch .sign (x ) * torch .log (
213
- torch .abs (x ) + 1
214
- ) + 0.9 * torch .tanh (x )
215
- pre_grad = torch .clip (
216
- trust_region_fn (pre_grad / 1.5 ) * 1.5 , min = - 2 , max = 2
217
- )
218
-
219
223
# Apply weight decay and update parameters
220
224
if group ["weight_decay" ] != 0 and p .dim () >= 2 :
221
225
pre_grad .add_ (p , alpha = group ["weight_decay" ])
0 commit comments