1
+ ### Copyright 2023 [Dawn Of Eve]
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Dict , Any , Optional
15
+ from dataclasses import dataclass
16
+
17
+ import torch
18
+
19
+ from .adam import Adam , AdamConfig
20
+
21
+ __all__ = ['Adabelief' , 'AdabeliefConfig' ]
22
+
23
+ @dataclass
24
+ class AdabeliefConfig (AdamConfig ):
25
+ lr : float = 3E-4
26
+ nesterov : bool = True
27
+
28
+ class Adabelief (Adam ):
29
+ def __init__ (self , params , config : AdabeliefConfig = AdabeliefConfig ()):
30
+ super ().__init__ (params , config )
31
+ self .config = config
32
+
33
+ @Adam .amsgrad
34
+ def adaptivity (self ,
35
+ state ,
36
+ grad ):
37
+
38
+ step = state ['step' ]
39
+ v = state ['adaptivity' ]
40
+ m = state ['momentum' ]
41
+ beta_2 = self .config .beta_2
42
+ bias_correction = self .config .bias_correction
43
+
44
+ v .mul_ (beta_2 ).addcmul_ (grad - m , grad - m , value = (1 - beta_2 ))
45
+
46
+ if bias_correction :
47
+ v_hat = v .div (1 - beta_2 ** (step + 1 ))
48
+ else :
49
+ v_hat = v
50
+
51
+ state ['adaptivity' ] = v
52
+ return torch .sqrt (v_hat + self .config .eps )
0 commit comments