Skip to content

Commit 7c8301a

Browse files
authored
Make betas and weight_decay Adam(W) hyperparameters configurable (#1282)
This change exposes the `betas` and `weight_decay` hyperparameters for Adam(W) to the user by adding them to the `Optimizer` dataclass. The default behavior will be unchanged since I used the previously hard-coded values as the default.
1 parent b422463 commit 7c8301a

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

torchtitan/components/optimizer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,10 @@ def build_optimizers(
267267
)
268268
name = job_config.optimizer.name
269269
lr = job_config.optimizer.lr
270+
beta1 = job_config.optimizer.beta1
271+
beta2 = job_config.optimizer.beta2
270272
eps = job_config.optimizer.eps
273+
weight_decay = job_config.optimizer.weight_decay
271274

272275
optim_implementation = job_config.optimizer.implementation
273276
assert optim_implementation in ["fused", "foreach", "for-loop"]
@@ -277,9 +280,9 @@ def build_optimizers(
277280

278281
optimizer_kwargs = {
279282
"lr": lr,
283+
"betas": (beta1, beta2),
280284
"eps": eps,
281-
"betas": (0.9, 0.95),
282-
"weight_decay": 0.1,
285+
"weight_decay": weight_decay,
283286
"fused": fused,
284287
"foreach": foreach,
285288
}

torchtitan/config_manager.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,16 @@ class Optimizer:
126126
lr: float = 8e-4
127127
"""Learning rate to use"""
128128

129+
beta1: float = 0.9
130+
beta2: float = 0.95
131+
"""Exponential moving average hyperparameters to use"""
132+
129133
eps: float = 1e-8
130134
"""Epsilon value to use"""
131135

136+
weight_decay: float = 0.1
137+
"""Weight decay to use"""
138+
132139
implementation: Literal["for-loop", "foreach", "fused"] = "fused"
133140
"""
134141
Specify which optimizer implementation to use:

0 commit comments

Comments
 (0)