@@ -56,11 +56,10 @@ class Kron(torch.optim.Optimizer):
56
56
to set all preconditioners to be triangular, 'one_diag' sets the largest
57
57
or last dim to be diagonal per layer, and 'all_diag' sets all preconditioners
58
58
to be diagonal.
59
+ momentum_into_precond_update: (bool), whether to send momentum into preconditioner
60
+ update instead of raw gradients.
59
61
mu_dtype (torch.dtype, optional): Dtype of the momentum accumulator.
60
62
precond_dtype (torch.dtype, optional): Dtype of the preconditioner.
61
- trust_region_scale (float): Trust region on preconditioned grads. Normally this
62
- doesn't need to be changed but if things seem unstable you can try reducing
63
- this to 1.5.
64
63
"""
65
64
66
65
def __init__ (
@@ -73,9 +72,9 @@ def __init__(
73
72
max_size_triangular = 8192 ,
74
73
min_ndim_triangular = 2 ,
75
74
memory_save_mode = None ,
75
+ momentum_into_precond_update = True ,
76
76
mu_dtype = None ,
77
77
precond_dtype = None ,
78
- trust_region_scale = 1.5 ,
79
78
):
80
79
if not 0.0 <= lr :
81
80
raise ValueError (f"Invalid learning rate: { lr } " )
@@ -95,11 +94,11 @@ def __init__(
95
94
max_size_triangular = max_size_triangular ,
96
95
min_ndim_triangular = min_ndim_triangular ,
97
96
memory_save_mode = memory_save_mode ,
97
+ momentum_into_precond_update = momentum_into_precond_update ,
98
98
precond_lr = 0.1 , # precond lr hardcoded to 0.1
99
99
precond_init_scale = 1.0 , # precond init scale hardcoded to 1.0
100
100
mu_dtype = mu_dtype ,
101
101
precond_dtype = precond_dtype ,
102
- trust_region_scale = trust_region_scale ,
103
102
)
104
103
super (Kron , self ).__init__ (params , defaults )
105
104
@@ -129,8 +128,11 @@ def step(self, closure=None):
129
128
balance = self .rng .random () < 0.01 and do_update
130
129
131
130
for group in self .param_groups :
132
- precond_dtype = group .get ("precond_dtype" , torch .float32 )
133
131
mu_dtype = group .get ("mu_dtype" )
132
+ precond_dtype = group .get ("precond_dtype" , torch .float32 )
133
+ momentum_into_precond_update = group .get (
134
+ "momentum_into_precond_update" , True
135
+ )
134
136
135
137
for p in group ["params" ]:
136
138
if p .grad is None :
@@ -197,7 +199,7 @@ def step(self, closure=None):
197
199
state ["Q" ],
198
200
state ["exprs" ],
199
201
torch .randn_like (debiased_momentum , dtype = precond_dtype ),
200
- debiased_momentum ,
202
+ debiased_momentum if momentum_into_precond_update else grad ,
201
203
group ["precond_lr" ],
202
204
self ._tiny ,
203
205
)
@@ -210,9 +212,8 @@ def step(self, closure=None):
210
212
trust_region_fn = lambda x : 0.1 * torch .sign (x ) * torch .log (
211
213
torch .abs (x ) + 1
212
214
) + 0.9 * torch .tanh (x )
213
- pre_grad = (
214
- trust_region_fn (pre_grad / group ["trust_region_scale" ])
215
- * group ["trust_region_scale" ]
215
+ pre_grad = torch .clip (
216
+ trust_region_fn (pre_grad / 1.5 ) * 1.5 , min = - 2 , max = 2
216
217
)
217
218
218
219
# Apply weight decay and update parameters
0 commit comments