13
13
14
14
"""
15
15
16
- from typing import cast , List , Optional , Tuple , Union
16
+ from typing import cast , Callable , List , Optional , Tuple , Union
17
17
18
18
import torch
19
19
from torch import Tensor
@@ -64,6 +64,7 @@ def __init__(
64
64
lr : Union [float , Tensor ] = 1e-3 ,
65
65
betas : Tuple [float , float ] = (0.9 , 0.9999 ),
66
66
eps : float = 1e-6 ,
67
+ clip_exp : Optional [float ] = 0.333 ,
67
68
weight_decay : float = 0.0 ,
68
69
decoupled : bool = False ,
69
70
* ,
@@ -95,6 +96,7 @@ def __init__(
95
96
betas = betas ,
96
97
eps = eps ,
97
98
weight_decay = weight_decay ,
99
+ clip_exp = clip_exp ,
98
100
decoupled = decoupled ,
99
101
maximize = maximize ,
100
102
foreach = foreach ,
@@ -111,6 +113,7 @@ def __setstate__(self, state):
111
113
group .setdefault ("foreach" , None )
112
114
group .setdefault ("capturable" , False )
113
115
group .setdefault ("differentiable" , False )
116
+ group .setdefault ("clip_exp" , None )
114
117
for p in group ["params" ]:
115
118
p_state = self .state .get (p , [])
116
119
if len (p_state ) != 0 and not torch .is_tensor (p_state ["step" ]):
@@ -141,9 +144,7 @@ def _init_group(
141
144
has_complex |= torch .is_complex (p )
142
145
params_with_grad .append (p )
143
146
if p .grad .is_sparse :
144
- raise RuntimeError (
145
- "ADOPT does not support sparse gradients"
146
- )
147
+ raise RuntimeError ("ADOPT does not support sparse gradients" )
147
148
grads .append (p .grad )
148
149
149
150
state = self .state [p ]
@@ -153,36 +154,24 @@ def _init_group(
153
154
# Deliberately host `step` on CPU if both capturable and fused are off.
154
155
# This is because kernel launches are costly on CUDA and XLA.
155
156
state ["step" ] = (
156
- torch .zeros (
157
- (),
158
- dtype = _get_scalar_dtype (),
159
- device = p .grad .device ,
160
- )
157
+ torch .zeros ((), dtype = _get_scalar_dtype (), device = p .grad .device )
161
158
if group ["capturable" ]
162
159
else torch .tensor (0.0 , dtype = _get_scalar_dtype ())
163
160
)
164
161
# Exponential moving average of gradient values
165
- state ["exp_avg" ] = torch .zeros_like (
166
- p .grad , memory_format = torch .preserve_format
167
- )
162
+ state ["exp_avg" ] = torch .zeros_like (p .grad , memory_format = torch .preserve_format )
168
163
# Exponential moving average of squared gradient values
169
- state ["exp_avg_sq" ] = torch .zeros_like (
170
- p .grad , memory_format = torch .preserve_format
171
- )
164
+ state ["exp_avg_sq" ] = torch .zeros_like (p .grad , memory_format = torch .preserve_format )
172
165
173
166
exp_avgs .append (state ["exp_avg" ])
174
167
exp_avg_sqs .append (state ["exp_avg_sq" ])
175
168
176
169
if group ["differentiable" ] and state ["step" ].requires_grad :
177
- raise RuntimeError (
178
- "`requires_grad` is not supported for `step` in differentiable mode"
179
- )
170
+ raise RuntimeError ("`requires_grad` is not supported for `step` in differentiable mode" )
180
171
181
172
# Foreach without capturable does not support a tensor lr
182
173
if group ["foreach" ] and torch .is_tensor (group ["lr" ]) and not group ["capturable" ]:
183
- raise RuntimeError (
184
- "lr as a Tensor is not supported for capturable=False and foreach=True"
185
- )
174
+ raise RuntimeError ("lr as a Tensor is not supported for capturable=False and foreach=True" )
186
175
187
176
state_steps .append (state ["step" ])
188
177
return has_complex
@@ -231,6 +220,7 @@ def step(self, closure=None):
231
220
beta2 = beta2 ,
232
221
lr = group ["lr" ],
233
222
weight_decay = group ["weight_decay" ],
223
+ clip_exp = group ["clip_exp" ],
234
224
decoupled = group ["decoupled" ],
235
225
eps = group ["eps" ],
236
226
maximize = group ["maximize" ],
@@ -258,6 +248,7 @@ def _single_tensor_adopt(
258
248
beta2 : float ,
259
249
lr : Union [float , Tensor ],
260
250
weight_decay : float ,
251
+ clip_exp : Optional [float ],
261
252
decoupled : bool ,
262
253
eps : float ,
263
254
maximize : bool ,
@@ -282,20 +273,12 @@ def _single_tensor_adopt(
282
273
if capturable and not _is_compiling ():
283
274
from torch .optim .optimizer import _get_capturable_supported_devices
284
275
capturable_supported_devices = _get_capturable_supported_devices ()
285
- assert (
286
- param .device .type == step_t .device .type
287
- and param .device .type in capturable_supported_devices
288
- ), f"If capturable=True, params and state_steps must be on supported devices: { capturable_supported_devices } ."
276
+ assert param .device .type == step_t .device .type and param .device .type in capturable_supported_devices ,\
277
+ f"If capturable=True, params and state_steps must be on supported devices: { capturable_supported_devices } ."
289
278
290
279
# update step
291
280
step_t += 1
292
281
293
- if weight_decay != 0 :
294
- if decoupled :
295
- param .add_ (param , alpha = - lr * weight_decay )
296
- else :
297
- grad = grad .add (param , alpha = weight_decay )
298
-
299
282
if torch .is_complex (param ):
300
283
grad = torch .view_as_real (grad )
301
284
if exp_avg is not None :
@@ -304,17 +287,25 @@ def _single_tensor_adopt(
304
287
exp_avg_sq = torch .view_as_real (exp_avg_sq )
305
288
param = torch .view_as_real (param )
306
289
290
+ if weight_decay != 0 and not decoupled :
291
+ grad = grad .add (param , alpha = weight_decay )
292
+
307
293
step = step_t if capturable or differentiable else _get_value (step_t )
308
294
if step == 1 :
309
295
exp_avg_sq .addcmul_ (grad , grad .conj ())
310
296
continue
311
297
298
+ if weight_decay != 0 and decoupled :
299
+ param .add_ (param , alpha = - lr * weight_decay )
300
+
312
301
denom = torch .clamp (exp_avg_sq .sqrt (), eps )
313
- if step == 2 :
314
- exp_avg .addcdiv_ (grad , denom )
315
- else :
316
- exp_avg .mul_ (beta1 ).addcdiv_ (grad , denom , value = 1 - beta1 )
302
+ normed_grad = grad .div (denom )
303
+
304
+ if clip_exp is not None :
305
+ clip_val = (step - 1 ) ** clip_exp
306
+ normed_grad .clamp_ (- clip_val , clip_val )
317
307
308
+ exp_avg .lerp_ (normed_grad , 1 - beta1 )
318
309
param .add_ (exp_avg , alpha = - lr )
319
310
320
311
exp_avg_sq .mul_ (beta2 ).addcmul_ (grad , grad .conj (), value = 1 - beta2 )
@@ -334,6 +325,7 @@ def _multi_tensor_adopt(
334
325
beta2 : float ,
335
326
lr : Union [float , Tensor ],
336
327
weight_decay : float ,
328
+ clip_exp : Optional [float ],
337
329
decoupled : bool ,
338
330
eps : float ,
339
331
maximize : bool ,
@@ -355,8 +347,7 @@ def _multi_tensor_adopt(
355
347
supports_xla = False
356
348
)
357
349
assert all (
358
- p .device .type == step .device .type
359
- and p .device .type in capturable_supported_devices
350
+ p .device .type == step .device .type and p .device .type in capturable_supported_devices
360
351
for p , step in zip (params , state_steps )
361
352
), f"If capturable=True, params and state_steps must be on supported devices: { capturable_supported_devices } ."
362
353
@@ -382,9 +373,7 @@ def _multi_tensor_adopt(
382
373
383
374
# Handle complex parameters
384
375
if has_complex :
385
- _view_as_real (
386
- device_params , device_grads , device_exp_avgs , device_exp_avg_sqs
387
- )
376
+ _view_as_real (device_params , device_grads , device_exp_avgs , device_exp_avg_sqs )
388
377
389
378
if maximize :
390
379
device_grads = torch ._foreach_neg (device_grads ) # type: ignore[assignment]
@@ -394,44 +383,38 @@ def _multi_tensor_adopt(
394
383
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
395
384
# wrapped it once now. The alpha is required to assure we go to the right overload.
396
385
if not _is_compiling () and device_state_steps [0 ].is_cpu :
397
- torch ._foreach_add_ (
398
- device_state_steps , torch .tensor (1.0 , device = "cpu" ), alpha = 1.0
399
- )
386
+ torch ._foreach_add_ (device_state_steps , torch .tensor (1.0 , device = "cpu" ), alpha = 1.0 )
400
387
else :
401
388
torch ._foreach_add_ (device_state_steps , 1 )
402
389
403
- if weight_decay != 0 :
404
- if decoupled :
405
- torch ._foreach_add_ (device_params , device_params , alpha = - lr * weight_decay )
390
+ if weight_decay != 0 and not decoupled :
391
+ # Re-use the intermediate memory (device_grads) already allocated for maximize
392
+ if maximize :
393
+ torch ._foreach_add_ (device_grads , device_params , alpha = weight_decay )
406
394
else :
407
- # Re-use the intermediate memory (device_grads) already allocated for maximize
408
- if maximize :
409
- torch ._foreach_add_ (device_grads , device_params , alpha = weight_decay )
410
- else :
411
- device_grads = torch ._foreach_add ( # type: ignore[assignment]
412
- device_grads , device_params , alpha = weight_decay
413
- )
395
+ device_grads = torch ._foreach_add (device_grads , device_params , alpha = weight_decay )
414
396
415
397
if device_state_steps [0 ] == 1 :
416
398
torch ._foreach_addcmul_ (device_exp_avg_sqs , device_grads , device_grads )
417
399
continue
418
400
401
+ if weight_decay != 0 and decoupled :
402
+ torch ._foreach_add_ (device_params , device_params , alpha = - lr * weight_decay )
403
+
419
404
exp_avg_sq_sqrt = torch ._foreach_sqrt (device_exp_avg_sqs )
420
- exp_avg_sq_sqrt = torch ._foreach_maximum (exp_avg_sq_sqrt , eps )
405
+ torch ._foreach_maximum_ (exp_avg_sq_sqrt , eps )
406
+ normed_grad = torch ._foreach_div (device_grads , exp_avg_sq_sqrt )
421
407
422
- if device_state_steps [0 ] == 2 :
423
- torch ._foreach_addcdiv_ (device_exp_avgs , device_grads , exp_avg_sq_sqrt )
424
- else :
425
- torch ._foreach_mul_ (device_exp_avgs , beta1 )
426
- torch ._foreach_addcdiv_ (
427
- device_exp_avgs , device_grads , exp_avg_sq_sqrt , value = 1 - beta1
428
- )
408
+ if clip_exp is not None :
409
+ clip_val = (device_state_steps [0 ] - 1 ) ** clip_exp
410
+ torch ._foreach_maximum_ (normed_grad , - clip_val )
411
+ torch ._foreach_minimum_ (normed_grad , clip_val )
429
412
413
+ torch ._foreach_lerp_ (device_exp_avgs , normed_grad , 1 - beta1 )
430
414
torch ._foreach_add_ (device_params , device_exp_avgs , alpha = - lr )
415
+
431
416
torch ._foreach_mul_ (device_exp_avg_sqs , beta2 )
432
- torch ._foreach_addcmul_ (
433
- device_exp_avg_sqs , device_grads , device_grads , value = 1 - beta2
434
- )
417
+ torch ._foreach_addcmul_ (device_exp_avg_sqs , device_grads , device_grads , value = 1 - beta2 )
435
418
436
419
437
420
#@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adopt) # FIXME internal context mgr, can't use
@@ -454,6 +437,7 @@ def adopt(
454
437
beta2 : float ,
455
438
lr : Union [float , Tensor ],
456
439
weight_decay : float ,
440
+ clip_exp : Optional [float ],
457
441
decoupled : bool ,
458
442
eps : float ,
459
443
maximize : bool ,
@@ -490,6 +474,7 @@ def adopt(
490
474
beta2 = beta2 ,
491
475
lr = lr ,
492
476
weight_decay = weight_decay ,
477
+ clip_exp = clip_exp ,
493
478
decoupled = decoupled ,
494
479
eps = eps ,
495
480
maximize = maximize ,
0 commit comments