5
5
6
6
Implementation adapted from https://github.com/sail-sg/Adan
7
7
"""
8
+ # Copyright 2022 Garena Online Private Limited
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
8
21
9
22
import math
23
+ from typing import List , Tuple
10
24
11
25
import torch
26
+ from torch import Tensor
27
+ from torch .optim .optimizer import Optimizer
12
28
13
- from torch .optim import Optimizer
29
+
30
+ class MultiTensorApply (object ):
31
+ available = False
32
+ warned = False
33
+
34
+ def __init__ (self , chunk_size ):
35
+ try :
36
+ MultiTensorApply .available = True
37
+ self .chunk_size = chunk_size
38
+ except ImportError as err :
39
+ MultiTensorApply .available = False
40
+ MultiTensorApply .import_err = err
41
+
42
+ def __call__ (self , op , noop_flag_buffer , tensor_lists , * args ):
43
+ return op (self .chunk_size , noop_flag_buffer , tensor_lists , * args )
14
44
15
45
16
46
class Adan (Optimizer ):
17
- """
18
- Implements a pytorch variant of Adan
19
- Adan was proposed in
20
- Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models[J]. arXiv preprint arXiv:2208.06677, 2022.
47
+ """ Implements a pytorch variant of Adan.
48
+
49
+ Adan was proposed in Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models
21
50
https://arxiv.org/abs/2208.06677
51
+
22
52
Arguments:
23
- params (iterable): iterable of parameters to optimize or dicts defining parameter groups.
24
- lr (float, optional): learning rate. (default: 1e-3)
25
- betas (Tuple[float, float, flot], optional): coefficients used for computing
26
- running averages of gradient and its norm. (default: (0.98, 0.92, 0.99))
27
- eps (float, optional): term added to the denominator to improve
28
- numerical stability. (default: 1e-8)
29
- weight_decay (float, optional): decoupled weight decay (L2 penalty) (default: 0)
30
- no_prox (bool): how to perform the decoupled weight decay (default: False)
53
+ params: Iterable of parameters to optimize or dicts defining parameter groups.
54
+ lr: Learning rate.
55
+ betas: Coefficients used for first- and second-order moments.
56
+ eps: Term added to the denominator to improve numerical stability.
57
+ weight_decay: Decoupled weight decay (L2 penalty)
58
+ no_prox: How to perform the weight decay
59
+ foreach: If True would use torch._foreach implementation. Faster but uses slightly more memory.
31
60
"""
32
61
33
- def __init__ (
34
- self ,
62
+ def __init__ (self ,
35
63
params ,
36
- lr = 1e-3 ,
37
- betas = (0.98 , 0.92 , 0.99 ),
38
- eps = 1e-8 ,
39
- weight_decay = 0.0 ,
40
- no_prox = False ,
64
+ lr : float = 1e-3 ,
65
+ betas : Tuple [float , float , float ] = (0.98 , 0.92 , 0.99 ),
66
+ eps : float = 1e-8 ,
67
+ weight_decay : float = 0.0 ,
68
+ no_prox : bool = False ,
69
+ foreach : bool = True ,
41
70
):
42
71
if not 0.0 <= lr :
43
- raise ValueError (" Invalid learning rate: {}" .format (lr ))
72
+ raise ValueError (' Invalid learning rate: {}' .format (lr ))
44
73
if not 0.0 <= eps :
45
- raise ValueError (" Invalid epsilon value: {}" .format (eps ))
74
+ raise ValueError (' Invalid epsilon value: {}' .format (eps ))
46
75
if not 0.0 <= betas [0 ] < 1.0 :
47
- raise ValueError (" Invalid beta parameter at index 0: {}" .format (betas [0 ]))
76
+ raise ValueError (' Invalid beta parameter at index 0: {}' .format (betas [0 ]))
48
77
if not 0.0 <= betas [1 ] < 1.0 :
49
- raise ValueError (" Invalid beta parameter at index 1: {}" .format (betas [1 ]))
78
+ raise ValueError (' Invalid beta parameter at index 1: {}' .format (betas [1 ]))
50
79
if not 0.0 <= betas [2 ] < 1.0 :
51
- raise ValueError ("Invalid beta parameter at index 2: {}" .format (betas [2 ]))
52
- defaults = dict (lr = lr , betas = betas , eps = eps , weight_decay = weight_decay , no_prox = no_prox )
53
- super (Adan , self ).__init__ (params , defaults )
80
+ raise ValueError ('Invalid beta parameter at index 2: {}' .format (betas [2 ]))
81
+
82
+ defaults = dict (
83
+ lr = lr ,
84
+ betas = betas ,
85
+ eps = eps ,
86
+ weight_decay = weight_decay ,
87
+ no_prox = no_prox ,
88
+ foreach = foreach ,
89
+ )
90
+ super ().__init__ (params , defaults )
91
+
92
+ def __setstate__ (self , state ):
93
+ super (Adan , self ).__setstate__ (state )
94
+ for group in self .param_groups :
95
+ group .setdefault ('no_prox' , False )
54
96
55
97
@torch .no_grad ()
56
98
def restart_opt (self ):
@@ -70,17 +112,23 @@ def restart_opt(self):
70
112
71
113
@torch .no_grad ()
72
114
def step (self , closure = None ):
73
- """ Performs a single optimization step.
74
- """
115
+ """Performs a single optimization step."""
75
116
loss = None
76
117
if closure is not None :
77
118
with torch .enable_grad ():
78
119
loss = closure ()
79
120
80
121
for group in self .param_groups :
122
+ params_with_grad = []
123
+ grads = []
124
+ exp_avgs = []
125
+ exp_avg_sqs = []
126
+ exp_avg_diffs = []
127
+ neg_pre_grads = []
128
+
81
129
beta1 , beta2 , beta3 = group ['betas' ]
82
130
# assume same step across group now to simplify things
83
- # per parameter step can be easily support by making it tensor, or pass list into kernel
131
+ # per parameter step can be easily supported by making it a tensor, or pass list into kernel
84
132
if 'step' in group :
85
133
group ['step' ] += 1
86
134
else :
@@ -93,32 +141,155 @@ def step(self, closure=None):
93
141
for p in group ['params' ]:
94
142
if p .grad is None :
95
143
continue
96
- grad = p .grad
144
+ params_with_grad .append (p )
145
+ grads .append (p .grad )
97
146
98
147
state = self .state [p ]
99
148
if len (state ) == 0 :
100
149
state ['exp_avg' ] = torch .zeros_like (p )
101
- state ['exp_avg_diff' ] = torch .zeros_like (p )
102
150
state ['exp_avg_sq' ] = torch .zeros_like (p )
103
- state ['pre_grad ' ] = grad . clone ( )
151
+ state ['exp_avg_diff ' ] = torch . zeros_like ( p )
104
152
105
- exp_avg , exp_avg_sq , exp_avg_diff = state [ 'exp_avg' ], state [ 'exp_avg_diff' ], state [ 'exp_avg_sq' ]
106
- grad_diff = grad - state ['pre_grad' ]
153
+ if 'neg_pre_grad' not in state or group [ 'step' ] == 1 :
154
+ state ['neg_pre_grad' ] = - p . grad . clone ()
107
155
108
- exp_avg . lerp_ ( grad , 1. - beta1 ) # m_t
109
- exp_avg_diff . lerp_ ( grad_diff , 1. - beta2 ) # diff_t (v )
110
- update = grad + beta2 * grad_diff
111
- exp_avg_sq . mul_ ( beta3 ). addcmul_ ( update , update , value = 1. - beta3 ) # n_t
156
+ exp_avgs . append ( state [ 'exp_avg' ])
157
+ exp_avg_sqs . append ( state [ 'exp_avg_sq' ] )
158
+ exp_avg_diffs . append ( state [ 'exp_avg_diff' ])
159
+ neg_pre_grads . append ( state [ 'neg_pre_grad' ])
112
160
113
- denom = (exp_avg_sq .sqrt () / math .sqrt (bias_correction3 )).add_ (group ['eps' ])
114
- update = (exp_avg / bias_correction1 + beta2 * exp_avg_diff / bias_correction2 ).div_ (denom )
115
- if group ['no_prox' ]:
116
- p .data .mul_ (1 - group ['lr' ] * group ['weight_decay' ])
117
- p .add_ (update , alpha = - group ['lr' ])
118
- else :
119
- p .add_ (update , alpha = - group ['lr' ])
120
- p .data .div_ (1 + group ['lr' ] * group ['weight_decay' ])
161
+ if not params_with_grad :
162
+ continue
121
163
122
- state ['pre_grad' ].copy_ (grad )
164
+ kwargs = dict (
165
+ params = params_with_grad ,
166
+ grads = grads ,
167
+ exp_avgs = exp_avgs ,
168
+ exp_avg_sqs = exp_avg_sqs ,
169
+ exp_avg_diffs = exp_avg_diffs ,
170
+ neg_pre_grads = neg_pre_grads ,
171
+ beta1 = beta1 ,
172
+ beta2 = beta2 ,
173
+ beta3 = beta3 ,
174
+ bias_correction1 = bias_correction1 ,
175
+ bias_correction2 = bias_correction2 ,
176
+ bias_correction3_sqrt = math .sqrt (bias_correction3 ),
177
+ lr = group ['lr' ],
178
+ weight_decay = group ['weight_decay' ],
179
+ eps = group ['eps' ],
180
+ no_prox = group ['no_prox' ],
181
+ )
182
+
183
+ if group ['foreach' ]:
184
+ _multi_tensor_adan (** kwargs )
185
+ else :
186
+ _single_tensor_adan (** kwargs )
123
187
124
188
return loss
189
+
190
+
191
+ def _single_tensor_adan (
192
+ params : List [Tensor ],
193
+ grads : List [Tensor ],
194
+ exp_avgs : List [Tensor ],
195
+ exp_avg_sqs : List [Tensor ],
196
+ exp_avg_diffs : List [Tensor ],
197
+ neg_pre_grads : List [Tensor ],
198
+ * ,
199
+ beta1 : float ,
200
+ beta2 : float ,
201
+ beta3 : float ,
202
+ bias_correction1 : float ,
203
+ bias_correction2 : float ,
204
+ bias_correction3_sqrt : float ,
205
+ lr : float ,
206
+ weight_decay : float ,
207
+ eps : float ,
208
+ no_prox : bool ,
209
+ ):
210
+ for i , param in enumerate (params ):
211
+ grad = grads [i ]
212
+ exp_avg = exp_avgs [i ]
213
+ exp_avg_sq = exp_avg_sqs [i ]
214
+ exp_avg_diff = exp_avg_diffs [i ]
215
+ neg_grad_or_diff = neg_pre_grads [i ]
216
+
217
+ # for memory saving, we use `neg_grad_or_diff` to get some temp variable in an inplace way
218
+ neg_grad_or_diff .add_ (grad )
219
+
220
+ exp_avg .mul_ (beta1 ).add_ (grad , alpha = 1 - beta1 ) # m_t
221
+ exp_avg_diff .mul_ (beta2 ).add_ (neg_grad_or_diff , alpha = 1 - beta2 ) # diff_t
222
+
223
+ neg_grad_or_diff .mul_ (beta2 ).add_ (grad )
224
+ exp_avg_sq .mul_ (beta3 ).addcmul_ (neg_grad_or_diff , neg_grad_or_diff , value = 1 - beta3 ) # n_t
225
+
226
+ denom = (exp_avg_sq .sqrt () / bias_correction3_sqrt ).add_ (eps )
227
+ step_size_diff = lr * beta2 / bias_correction2
228
+ step_size = lr / bias_correction1
229
+
230
+ if no_prox :
231
+ param .mul_ (1 - lr * weight_decay )
232
+ param .addcdiv_ (exp_avg , denom , value = - step_size )
233
+ param .addcdiv_ (exp_avg_diff , denom , value = - step_size_diff )
234
+ else :
235
+ param .addcdiv_ (exp_avg , denom , value = - step_size )
236
+ param .addcdiv_ (exp_avg_diff , denom , value = - step_size_diff )
237
+ param .div_ (1 + lr * weight_decay )
238
+
239
+ neg_grad_or_diff .zero_ ().add_ (grad , alpha = - 1.0 )
240
+
241
+
242
+ def _multi_tensor_adan (
243
+ params : List [Tensor ],
244
+ grads : List [Tensor ],
245
+ exp_avgs : List [Tensor ],
246
+ exp_avg_sqs : List [Tensor ],
247
+ exp_avg_diffs : List [Tensor ],
248
+ neg_pre_grads : List [Tensor ],
249
+ * ,
250
+ beta1 : float ,
251
+ beta2 : float ,
252
+ beta3 : float ,
253
+ bias_correction1 : float ,
254
+ bias_correction2 : float ,
255
+ bias_correction3_sqrt : float ,
256
+ lr : float ,
257
+ weight_decay : float ,
258
+ eps : float ,
259
+ no_prox : bool ,
260
+ ):
261
+ if len (params ) == 0 :
262
+ return
263
+
264
+ # for memory saving, we use `neg_pre_grads` to get some temp variable in a inplace way
265
+ torch ._foreach_add_ (neg_pre_grads , grads )
266
+
267
+ torch ._foreach_mul_ (exp_avgs , beta1 )
268
+ torch ._foreach_add_ (exp_avgs , grads , alpha = 1 - beta1 ) # m_t
269
+
270
+ torch ._foreach_mul_ (exp_avg_diffs , beta2 )
271
+ torch ._foreach_add_ (exp_avg_diffs , neg_pre_grads , alpha = 1 - beta2 ) # diff_t
272
+
273
+ torch ._foreach_mul_ (neg_pre_grads , beta2 )
274
+ torch ._foreach_add_ (neg_pre_grads , grads )
275
+ torch ._foreach_mul_ (exp_avg_sqs , beta3 )
276
+ torch ._foreach_addcmul_ (exp_avg_sqs , neg_pre_grads , neg_pre_grads , value = 1 - beta3 ) # n_t
277
+
278
+ denom = torch ._foreach_sqrt (exp_avg_sqs )
279
+ torch ._foreach_div_ (denom , bias_correction3_sqrt )
280
+ torch ._foreach_add_ (denom , eps )
281
+
282
+ step_size_diff = lr * beta2 / bias_correction2
283
+ step_size = lr / bias_correction1
284
+
285
+ if no_prox :
286
+ torch ._foreach_mul_ (params , 1 - lr * weight_decay )
287
+ torch ._foreach_addcdiv_ (params , exp_avgs , denom , value = - step_size )
288
+ torch ._foreach_addcdiv_ (params , exp_avg_diffs , denom , value = - step_size_diff )
289
+ else :
290
+ torch ._foreach_addcdiv_ (params , exp_avgs , denom , value = - step_size )
291
+ torch ._foreach_addcdiv_ (params , exp_avg_diffs , denom , value = - step_size_diff )
292
+ torch ._foreach_div_ (params , 1 + lr * weight_decay )
293
+
294
+ torch ._foreach_zero_ (neg_pre_grads )
295
+ torch ._foreach_add_ (neg_pre_grads , grads , alpha = - 1.0 )
0 commit comments