Skip to content

Commit cda9685

Browse files
authored
Merge pull request #35 from bhavnicksm/main
Add AdaBelief; Update README
2 parents fa58a43 + c80dc6e commit cda9685

File tree

2 files changed

+67
-13
lines changed

2 files changed

+67
-13
lines changed

README.md

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -55,20 +55,22 @@ optimizer.step()
5555

5656
# Supported Optimisers
5757

58-
| Optimiser | Paper |
59-
|:---------: |:-----: |
60-
| **SGD** | https://paperswithcode.com/method/sgd |
61-
| **Momentum** | https://paperswithcode.com/method/sgd-with-momentum |
62-
| **NAG** | https://jlmelville.github.io/mize/nesterov.html |
58+
| Optimiser | Paper |
59+
|:---------: |:-----: |
60+
| **SGD** | https://paperswithcode.com/method/sgd |
61+
| **Momentum** | https://paperswithcode.com/method/sgd-with-momentum |
62+
| **NAG** | https://jlmelville.github.io/mize/nesterov.html |
6363
| **Adagrad** | https://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf |
64-
| **RMSProp** | https://paperswithcode.com/method/rmsprop |
65-
| **Adam** | https://arxiv.org/abs/1412.6980v9 |
66-
| **Adamax** | https://arxiv.org/abs/1412.6980v9 |
67-
| **AdamW** | https://arxiv.org/abs/1711.05101v3 |
68-
| **Adadelta** | https://arxiv.org/abs/1212.5701v1 |
69-
| **AMSGrad** | https://arxiv.org/abs/1904.09237v1 |
70-
| **RAdam** | https://arxiv.org/abs/1908.03265v4 |
71-
| **Lion** | https://arxiv.org/abs/2302.06675 |
64+
| **RMSProp** | https://paperswithcode.com/method/rmsprop |
65+
| **Adam** | https://arxiv.org/abs/1412.6980v9 |
66+
| **Adamax** | https://arxiv.org/abs/1412.6980v9 |
67+
| **AdamW** | https://arxiv.org/abs/1711.05101v3 |
68+
| **Adadelta** | https://arxiv.org/abs/1212.5701v1 |
69+
| **AMSGrad** | https://arxiv.org/abs/1904.09237v1 |
70+
| **RAdam** | https://arxiv.org/abs/1908.03265v4 |
71+
| **Lion** | https://arxiv.org/abs/2302.06675 |
72+
| **AdaBelief**| https://arxiv.org/pdf/2010.07468v5.pdf |
73+
| **NAdam** | http://cs229.stanford.edu/proj2015/054_report.pdf |
7274

7375
# Acknowledgements
7476

src/nadir/adabelief.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
### Copyright 2023 [Dawn Of Eve]
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import Dict, Any, Optional
15+
from dataclasses import dataclass
16+
17+
import torch
18+
19+
from .adam import Adam, AdamConfig
20+
21+
__all__ = ['Adabelief', 'AdabeliefConfig']
22+
23+
@dataclass
24+
class AdabeliefConfig(AdamConfig):
25+
lr : float = 3E-4
26+
nesterov : bool = True
27+
28+
class Adabelief(Adam):
29+
def __init__ (self, params, config : AdabeliefConfig = AdabeliefConfig()):
30+
super().__init__(params, config)
31+
self.config = config
32+
33+
@Adam.amsgrad
34+
def adaptivity(self,
35+
state,
36+
grad):
37+
38+
step = state['step']
39+
v = state['adaptivity']
40+
m = state['momentum']
41+
beta_2 = self.config.beta_2
42+
bias_correction = self.config.bias_correction
43+
44+
v.mul_(beta_2).addcmul_(grad - m, grad - m, value = (1 - beta_2))
45+
46+
if bias_correction:
47+
v_hat = v.div(1 - beta_2**(step + 1))
48+
else:
49+
v_hat = v
50+
51+
state['adaptivity'] = v
52+
return torch.sqrt(v_hat + self.config.eps)

0 commit comments

Comments
 (0)