Skip to content

Commit e7c0cae

Browse files
authored
Redefine lr_scheduler behavior (#1284)
## Context The current warmup-stable-decay lr_scheduler behavior is not intuitive. For example, in `debug_model.toml`, the configurattion is: ``` [lr_scheduler] warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps decay_type = "linear" lr_min = 0.0 ``` So we are expected warmup_steps =2, decay steps=8, total steps=10. And we got the curret learning rate (blue line) There are 2 issues: 1. The max learning rate is not reaching 1 (expected max lr = 1 since we are calculating ratio here). 2. Intuitively, the user would expect to see learning rate increase by 0.5 (suppose max_lr = 1) each step during warm up stage, and decrease by 1/8 each step during decay stage. But in blue line, the lr is increasing by 1/3, and decreasing by 1/9, which is counter-intuitive. Thus we propose a standard lr_scheduler behavior, which aligns with user's intuitive and the meaning of the parameter names. (the red line) ![learning_rate_schedule](https://github.com/user-attachments/assets/1c2be9e0-6043-4310-b09f-9b06a024abf9) ## Standard lr_scheduler behavior - Warm up stage: LR increase by 1/{warmup_steps} - Stable stage: Length of stable stage = total_train_step + 1 - warmup_stage - decay_stage. We manually add one step to stable stage (so we have a fake step 11), which is preventing if decay is enabled the step 10 learning rate drops to 0. - Decay stage: LR decrease by 1/{decay_steps}
1 parent 31a5411 commit e7c0cae

File tree

2 files changed

+296
-3
lines changed

2 files changed

+296
-3
lines changed

tests/unit_tests/test_lr_scheduler.py

Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
from unittest.mock import MagicMock
9+
10+
import torch
11+
from torch.optim import Adam
12+
13+
from torchtitan.components.lr_scheduler import build_lr_schedulers
14+
from torchtitan.components.optimizer import OptimizersContainer
15+
16+
17+
class TestLRScheduler(unittest.TestCase):
18+
def setUp(self):
19+
# Create a simple model with parameters
20+
self.model = torch.nn.Linear(10, 10)
21+
# Create an optimizer
22+
self.optimizer = Adam(self.model.parameters(), lr=0.1)
23+
# Create an optimizer container
24+
self.optimizer_container = MagicMock(spec=OptimizersContainer)
25+
self.optimizer_container.__iter__.return_value = iter([self.optimizer])
26+
self.optimizer_container.__len__.return_value = 1
27+
28+
def create_job_config(
29+
self,
30+
training_steps=10,
31+
warmup_steps=None,
32+
decay_ratio=None,
33+
decay_type=None,
34+
lr_min=None,
35+
):
36+
# Create a job config with the specified parameters
37+
from torchtitan.config_manager import ConfigManager
38+
39+
args = [
40+
"--training.steps",
41+
str(training_steps),
42+
]
43+
44+
args += (
45+
["--lr_scheduler.warmup_steps", str(warmup_steps)]
46+
if warmup_steps is not None
47+
else []
48+
)
49+
args += (
50+
["--lr_scheduler.decay_ratio", str(decay_ratio)]
51+
if decay_ratio is not None
52+
else []
53+
)
54+
args += (
55+
["--lr_scheduler.decay_type", decay_type] if decay_type is not None else []
56+
)
57+
args += ["--lr_scheduler.lr_min", str(lr_min)] if lr_min is not None else []
58+
59+
config_manager = ConfigManager()
60+
# Create base config with parameters passed directly
61+
config = config_manager.parse_args(args)
62+
63+
return config
64+
65+
def test_linear_warmup_decay(self):
66+
"""Test the linear warmup followed by linear decay schedule."""
67+
# Create a job config with 10 steps, 2 warmup steps, and linear decay
68+
config = self.create_job_config(
69+
training_steps=10,
70+
warmup_steps=2,
71+
decay_ratio=None, # Use default decay: start decay immediately
72+
decay_type=None,
73+
lr_min=None,
74+
)
75+
76+
# Build the lr scheduler
77+
lr_scheduler = build_lr_schedulers(self.optimizer_container, config)
78+
79+
# Expected adjustment factors for each step
80+
expected_factors = [
81+
0.5, # Step 0: 50% of max LR (warmup)
82+
1.0, # Step 1: 100% of max LR (warmup complete)
83+
1.0, # Step 2: We maunally added step of stable phase, to prevent LR from dropping to 0 at last step
84+
7.0 / 8.0, # Step 3: 7/8 of max LR
85+
6.0 / 8.0, # Step 4: 3/4 of max LR
86+
5.0 / 8.0, # Step 5: 5/8 of max LR
87+
4.0 / 8.0, # Step 6: 1/2 of max LR
88+
3.0 / 8.0, # Step 7: 3/8 of max LR
89+
2.0 / 8.0, # Step 8: 1/4 of max LR
90+
1.0 / 8.0, # Step 9: 1/8 of max LR
91+
]
92+
93+
# Check the learning rate at each step
94+
for i, factor in enumerate(expected_factors):
95+
# The LambdaLR multiplies the base lr by the factor
96+
expected_lr = 0.1 * factor
97+
self.assertAlmostEqual(
98+
self.optimizer.param_groups[0]["lr"],
99+
expected_lr,
100+
places=6,
101+
msg=f"Step {i}: Expected LR {expected_lr}, got {self.optimizer.param_groups[0]['lr']}",
102+
)
103+
lr_scheduler.step()
104+
105+
def test_warmup_stable_decay(self):
106+
"""Test warmup followed by stable phase and then decay."""
107+
# Create a job config with 10 steps, 2 warmup steps, 3 stable steps, and 5 decay steps
108+
config = self.create_job_config(
109+
training_steps=10,
110+
warmup_steps=2,
111+
decay_ratio=0.5, # 50% of steps for decay
112+
decay_type="linear",
113+
lr_min=0.0,
114+
)
115+
116+
# Build the lr scheduler
117+
lr_scheduler = build_lr_schedulers(self.optimizer_container, config)
118+
119+
# Expected adjustment factors for each step
120+
expected_factors = [
121+
0.5, # Step 0: 50% of max LR (warmup)
122+
1.0, # Step 1: 100% of max LR (warmup complete)
123+
1.0, # Step 2: Stable phase
124+
1.0, # Step 3: Stable phase
125+
1.0, # Step 4: Stable phase
126+
1.0, # Step 5: We maunally added step of stable phase, to prevent LR from dropping to 0 at last step
127+
0.8, # Step 6: Linear decay starts (80% of max LR)
128+
0.6, # Step 7: 60% of max LR
129+
0.4, # Step 8: 40% of max LR
130+
0.2, # Step 9: 20% of max LR
131+
]
132+
133+
# Check the learning rate at each step
134+
for i, factor in enumerate(expected_factors):
135+
expected_lr = 0.1 * factor
136+
self.assertAlmostEqual(
137+
self.optimizer.param_groups[0]["lr"],
138+
expected_lr,
139+
places=6,
140+
msg=f"Step {i}: Expected LR {expected_lr}, got {self.optimizer.param_groups[0]['lr']}",
141+
)
142+
lr_scheduler.step()
143+
144+
def test_min_lr(self):
145+
"""Test that the learning rate doesn't go below the minimum."""
146+
# Create a job config with a minimum learning rate
147+
config = self.create_job_config(
148+
training_steps=10,
149+
warmup_steps=2,
150+
decay_ratio=None,
151+
decay_type="linear",
152+
lr_min=0.2, # 20% of base LR as minimum
153+
)
154+
155+
# Build the lr scheduler
156+
lr_scheduler = build_lr_schedulers(self.optimizer_container, config)
157+
158+
# Step through all steps
159+
for _ in range(10):
160+
lr_scheduler.step()
161+
162+
# After all steps, LR should be at minimum (0.1 * 0.2 = 0.02)
163+
self.assertAlmostEqual(self.optimizer.param_groups[0]["lr"], 0.02, places=6)
164+
165+
def test_warmup_exceeds_training(self):
166+
"""Test when warmup steps exceed training steps."""
167+
# Create a job config where warmup steps > training steps
168+
config = self.create_job_config(
169+
training_steps=5,
170+
warmup_steps=10, # More than training steps
171+
decay_ratio=None,
172+
decay_type="linear",
173+
lr_min=0.0,
174+
)
175+
176+
# Build the lr scheduler - should adjust warmup steps
177+
lr_scheduler = build_lr_schedulers(self.optimizer_container, config)
178+
179+
# Expected adjustment factors for each step
180+
expected_factors = [
181+
0.2, # Step 0: 50% of max LR (warmup)
182+
0.4, # Step 1: 100% of max LR (warmup complete)
183+
0.6, # Step 2: Stable phase
184+
0.8, # Step 3: Stable phase
185+
1.0, # Step 4: Stable phase
186+
]
187+
188+
# Check the learning rate at each step
189+
for i, factor in enumerate(expected_factors):
190+
expected_lr = 0.1 * factor
191+
self.assertAlmostEqual(
192+
self.optimizer.param_groups[0]["lr"],
193+
expected_lr,
194+
places=6,
195+
msg=f"Step {i}: Expected LR {expected_lr}, got {self.optimizer.param_groups[0]['lr']}",
196+
)
197+
lr_scheduler.step()
198+
199+
def test_warmup_stable_only(self):
200+
"""Test warmup followed by stable phase only, with no decay phase."""
201+
# Create a job config with 10 steps, 2 warmup steps, and no decay phase
202+
config = self.create_job_config(
203+
training_steps=10,
204+
warmup_steps=2,
205+
decay_ratio=0.0, # 0% of steps for decay (no decay)
206+
decay_type="linear",
207+
lr_min=0.0,
208+
)
209+
210+
# Build the lr scheduler
211+
lr_scheduler = build_lr_schedulers(self.optimizer_container, config)
212+
213+
# Expected adjustment factors for each step
214+
expected_factors = [
215+
0.5, # Step 0: 50% of max LR (warmup)
216+
1.0, # Step 1: 100% of max LR (warmup complete)
217+
1.0, # Step 2: We maunally added step of stable phase, to prevent LR from dropping to 0 at last step
218+
1.0, # Step 3: Stable phase
219+
1.0, # Step 4: Stable phase
220+
1.0, # Step 5: Stable phase
221+
1.0, # Step 6: Stable phase
222+
1.0, # Step 7: Stable phase
223+
1.0, # Step 8: Stable phase
224+
1.0, # Step 9: Stable phase
225+
]
226+
227+
# Check the learning rate at each step
228+
for i, factor in enumerate(expected_factors):
229+
expected_lr = 0.1 * factor
230+
self.assertAlmostEqual(
231+
self.optimizer.param_groups[0]["lr"],
232+
expected_lr,
233+
places=6,
234+
msg=f"Step {i}: Expected LR {expected_lr}, got {self.optimizer.param_groups[0]['lr']}",
235+
)
236+
lr_scheduler.step()
237+
238+
def test_warmup_plus_decay_exceeds_training(self):
239+
"""Test when warmup + decay steps exceed training steps."""
240+
# Create a job config where warmup + decay steps > training steps
241+
# Expected behaviro: warmup steps = 5, decay steps = 5
242+
config = self.create_job_config(
243+
training_steps=10,
244+
warmup_steps=5,
245+
decay_ratio=0.8, # 80% of steps for decay (8 steps)
246+
decay_type="linear",
247+
lr_min=0.0,
248+
)
249+
250+
# Build the lr scheduler - should adjust warmup steps
251+
lr_scheduler = build_lr_schedulers(self.optimizer_container, config)
252+
253+
# Expected adjustment factors for each step
254+
expected_factors = [
255+
0.2, # Step 0: 50% of max LR (warmup)
256+
0.4, # Step 1: 100% of max LR (warmup complete)
257+
0.6, # Step 2: Stable phase
258+
0.8, # Step 3: Stable phase
259+
1.0, # Step 4: Stable phase
260+
1.0, # Step 5: We maunally added step of stable phase, to prevent LR from dropping to 0 at last step
261+
0.8, # Step 6: Linear decay starts (80% of max LR)
262+
0.6, # Step 7: 60% of max LR
263+
0.4, # Step 8: 40% of max LR
264+
0.2, # Step 9: 20% of max LR
265+
]
266+
267+
# Check the learning rate at each step
268+
for i, factor in enumerate(expected_factors):
269+
expected_lr = 0.1 * factor
270+
self.assertAlmostEqual(
271+
self.optimizer.param_groups[0]["lr"],
272+
expected_lr,
273+
places=6,
274+
msg=f"Step {i}: Expected LR {expected_lr}, got {self.optimizer.param_groups[0]['lr']}",
275+
)
276+
lr_scheduler.step()
277+
278+
279+
if __name__ == "__main__":
280+
unittest.main()

torchtitan/components/lr_scheduler.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,14 @@ def build_lr_schedulers(
102102
"""
103103
training_steps = job_config.training.steps
104104
warmup_steps = int(job_config.lr_scheduler.warmup_steps)
105+
106+
if warmup_steps > training_steps:
107+
logger.warning(
108+
f"Warmup steps ({warmup_steps}) exceed total training steps ({training_steps}). "
109+
f"Adjusting warmup steps to {training_steps}."
110+
)
111+
warmup_steps = training_steps
112+
105113
if job_config.lr_scheduler.decay_ratio is not None:
106114
decay_steps = round(training_steps * job_config.lr_scheduler.decay_ratio)
107115
if warmup_steps + decay_steps > training_steps:
@@ -113,7 +121,8 @@ def build_lr_schedulers(
113121
decay_steps = training_steps - warmup_steps
114122
else:
115123
decay_steps = training_steps - warmup_steps
116-
stable_steps = training_steps - warmup_steps - decay_steps
124+
# Add a vitual last step to prevent the learning rate from dropping to 0
125+
stable_steps = training_steps + 1 - warmup_steps - decay_steps
117126
lr_decay_type = job_config.lr_scheduler.decay_type
118127
lr_min = job_config.lr_scheduler.lr_min
119128

@@ -146,13 +155,17 @@ def linear_warmup_stable_decay(
146155
# linear warmup
147156
# 0-indexed step, hence + 1 adjustments
148157
current_step += 1
149-
curr_adjustment = float(current_step / (warmup_steps + 1))
158+
assert (
159+
warmup_steps != 0
160+
), "warmup_steps must not be zero to reach this branch"
161+
curr_adjustment = float(current_step / warmup_steps)
150162
elif current_step < warmup_stable_steps:
151163
curr_adjustment = 1.0
152164
else:
153165
# 0-indexed step, hence + 1 adjustments
154166
current_step += 1
155-
progress = float(current_step - warmup_stable_steps) / (decay_steps + 1)
167+
assert decay_steps != 0, "decay_steps must not be zero to reach this branch"
168+
progress = float(current_step - warmup_stable_steps) / decay_steps
156169

157170
if lr_decay_type == "linear":
158171
curr_adjustment = 1 - progress

0 commit comments

Comments
 (0)