Skip to content

Commit ed6ec9c

Browse files
authored
Promote Low Bit Optim out of prototype (#1864)
* Remove prototype profiler * Promote low bit optim out of prototype * change name * update * move test file * push
1 parent f64d5a1 commit ed6ec9c

File tree

15 files changed

+1614
-51
lines changed

15 files changed

+1614
-51
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,13 @@ swap_linear_with_semi_sparse_linear(model, {"seq.0": SemiSparseLinear})
115115
ADAM takes 2x as much memory as the model params so we can quantize the optimizer state to either 8 or 4 bit effectively reducing the optimizer VRAM requirements by 2x or 4x respectively over an fp16 baseline
116116

117117
```python
118-
from torchao.prototype.low_bit_optim import AdamW8bit, AdamW4bit, AdamWFp8
118+
from torchao.optim import AdamW8bit, AdamW4bit, AdamWFp8
119119
optim = AdamW8bit(model.parameters()) # replace with Adam4bit and AdamFp8 for the 4 / fp8 versions
120120
```
121121

122-
In practice, we are a tiny bit slower than expertly written kernels but the implementations for these optimizers were written in a **few hundred lines of PyTorch code** and compiled so please use them or copy-paste them for your quantized optimizers. Benchmarks [here](https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim)
122+
In practice, we are a tiny bit slower than expertly written kernels but the implementations for these optimizers were written in a **few hundred lines of PyTorch code** and compiled so please use them or copy-paste them for your quantized optimizers. Benchmarks [here](https://github.com/pytorch/ao/tree/main/torchao/optim)
123123

124-
We also have support for [single GPU CPU offloading](https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim#optimizer-cpu-offload) where both the gradients (same size as weights) and the optimizers will be efficiently sent to the CPU. This alone can **reduce your VRAM requirements by 60%**
124+
We also have support for [single GPU CPU offloading](https://github.com/pytorch/ao/tree/main/torchao/optim#optimizer-cpu-offload) where both the gradients (same size as weights) and the optimizers will be efficiently sent to the CPU. This alone can **reduce your VRAM requirements by 60%**
125125

126126
```python
127127
optim = CPUOffloadOptimizer(model.parameters(), torch.optim.AdamW, fused=True)

benchmarks/benchmark_low_bit_adam.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from torchvision.transforms import v2
3535
from tqdm import tqdm
3636

37-
from torchao.prototype import low_bit_optim
37+
from torchao import optim
3838
from torchao.utils import get_available_devices
3939

4040
_DEVICE = get_available_devices()[-1]
@@ -43,9 +43,9 @@
4343
OPTIM_MAP = dict(
4444
AdamW=partial(torch.optim.AdamW, fused=True),
4545
AdamW8bitBnb=bnb.optim.AdamW8bit,
46-
AdamW8bitAo=low_bit_optim.AdamW8bit,
47-
AdamWFp8Ao=low_bit_optim.AdamWFp8,
48-
AdamW4bitAo=low_bit_optim.AdamW4bit,
46+
AdamW8bitAo=optim.AdamW8bit,
47+
AdamWFp8Ao=optim.AdamWFp8,
48+
AdamW4bitAo=optim.AdamW4bit,
4949
)
5050

5151
try:
@@ -249,12 +249,10 @@ def evaluate_model(model, args):
249249
optim_cls = OPTIM_MAP[args.optim]
250250

251251
if args.optim_cpu_offload == "ao":
252-
optim_cls = partial(
253-
low_bit_optim.CPUOffloadOptimizer, optimizer_class=optim_cls
254-
)
252+
optim_cls = partial(optim.CPUOffloadOptimizer, optimizer_class=optim_cls)
255253
elif args.optim_cpu_offload == "ao_offload_grads":
256254
optim_cls = partial(
257-
low_bit_optim.CPUOffloadOptimizer,
255+
optim.CPUOffloadOptimizer,
258256
optimizer_class=optim_cls,
259257
offload_gradients=True,
260258
)

benchmarks/quantized_training/pretrain_llama2.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,13 @@
2222
from torch.utils.checkpoint import checkpoint
2323
from tqdm import tqdm
2424

25-
from torchao import quantize_
25+
from torchao import optim, quantize_
2626
from torchao._models.llama.model import (
2727
ModelArgs,
2828
RMSNorm,
2929
Transformer,
3030
transformer_configs,
3131
)
32-
from torchao.prototype import low_bit_optim
3332
from torchao.prototype.quantized_training import (
3433
bitnet_training,
3534
int8_mixed_precision_training,
@@ -190,10 +189,10 @@ def insert_rmsnorm(module: torch.nn.Module):
190189
print(f"No. of buffers: {sum(p.numel() for p in model.buffers()):,}")
191190
torch.cuda.reset_peak_memory_stats() # don't count memory occupied by unquantized weights
192191

193-
# only use optimizers from torchao.prototype.low_bit_optim to support quantized training
192+
# only use optimizers from torchao.optim to support quantized training
194193
if args.optim == "AdamW":
195194
args.optim = "_AdamW"
196-
optim = getattr(low_bit_optim, args.optim)(
195+
optimizer = getattr(optim, args.optim)(
197196
model.parameters(),
198197
lr=args.lr,
199198
weight_decay=args.weight_decay,
@@ -228,15 +227,15 @@ def insert_rmsnorm(module: torch.nn.Module):
228227
if step % args.log_interval == 0:
229228
log_dict = dict(
230229
loss=loss.item(),
231-
lr=optim.param_groups[0]["lr"],
230+
lr=optimizer.param_groups[0]["lr"],
232231
max_memory_allocated=torch.cuda.max_memory_allocated() / 1e9,
233232
max_memory_reserved=torch.cuda.max_memory_reserved() / 1e9,
234233
)
235234
run.log(log_dict, step=step)
236235
pbar.set_postfix(loss=log_dict["loss"])
237236

238-
optim.step()
239-
optim.zero_grad()
237+
optimizer.step()
238+
optimizer.zero_grad()
240239

241240
step += 1
242241
pbar.update()

test/prototype/test_low_bit_optim.py renamed to test/test_low_bit_optim.py

Lines changed: 23 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,15 @@
1717
)
1818

1919
from packaging.version import Version
20-
from torchao.prototype import low_bit_optim
21-
from torchao.prototype.low_bit_optim.quant_utils import (
20+
from torchao import optim
21+
from torchao.optim.quant_utils import (
2222
_fp32_to_bf16_sr,
2323
quantize_4bit_with_qmap,
2424
quantize_8bit_with_qmap,
2525
)
26-
from torchao.prototype.low_bit_optim.subclass_4bit import OptimState4bit
27-
from torchao.prototype.low_bit_optim.subclass_8bit import OptimState8bit
28-
from torchao.prototype.low_bit_optim.subclass_fp8 import OptimStateFp8
26+
from torchao.optim.subclass_4bit import OptimState4bit
27+
from torchao.optim.subclass_8bit import OptimState8bit
28+
from torchao.optim.subclass_fp8 import OptimStateFp8
2929
from torchao.testing.utils import skip_if_rocm
3030
from torchao.utils import (
3131
TORCH_VERSION_AT_LEAST_2_4,
@@ -125,29 +125,29 @@ def test_optim_smoke(self, optim_name, dtype, device):
125125

126126
model = nn.Sequential(nn.Linear(32, 256), nn.ReLU(), nn.Linear(256, 32))
127127
model.to(device=device, dtype=dtype)
128-
optim = getattr(low_bit_optim, optim_name)(model.parameters())
128+
optimizer = getattr(optim, optim_name)(model.parameters())
129129

130130
x = torch.randn(4, 32, device=device, dtype=dtype)
131131
loss = model(x).sum()
132132
loss.backward()
133-
optim.step()
134-
optim.zero_grad()
133+
optimizer.step()
134+
optimizer.zero_grad()
135135

136136
# test serialization. also test the case CUDA optim loads CPU state dict
137137
with tempfile.NamedTemporaryFile() as f:
138-
torch.save(optim.state_dict(), f.name)
138+
torch.save(optimizer.state_dict(), f.name)
139139
state_dict = torch.load(f.name, map_location="cpu")
140140

141141
model2 = copy.deepcopy(model)
142-
optim2 = getattr(low_bit_optim, optim_name)(model2.parameters())
142+
optim2 = getattr(optim, optim_name)(model2.parameters())
143143
optim2.load_state_dict(state_dict)
144144

145145
for _ in range(2):
146146
x = torch.randn(4, 32, device=device, dtype=dtype)
147147

148148
model(x).sum().backward()
149-
optim.step()
150-
optim.zero_grad()
149+
optimizer.step()
150+
optimizer.zero_grad()
151151

152152
model2(x).sum().backward()
153153
optim2.step()
@@ -201,9 +201,7 @@ def test_optim_8bit_correctness(self, optim_name):
201201
block_size = 256 if Version(bnb.__version__) >= Version("0.44.0") else 2048
202202

203203
optim1 = getattr(bnb.optim, optim_name)(model1.parameters())
204-
optim2 = getattr(low_bit_optim, optim_name)(
205-
model2.parameters(), block_size=block_size
206-
)
204+
optim2 = getattr(optim, optim_name)(model2.parameters(), block_size=block_size)
207205

208206
for _ in range(2):
209207
x = torch.randn(4, 32, device=device)
@@ -240,7 +238,7 @@ def test_optim_4bit_correctness(self, optim_name):
240238
optim1 = lpmm.optim.AdamW(model1.parameters())
241239
else:
242240
raise ValueError(f"Unsupported {optim_name} optimizer for lpmm")
243-
optim2 = getattr(low_bit_optim, optim_name)(model2.parameters())
241+
optim2 = getattr(optim, optim_name)(model2.parameters())
244242

245243
for _ in range(2):
246244
x = torch.randn(4, 32, device=device)
@@ -286,7 +284,7 @@ def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
286284
model2 = copy.deepcopy(model1)
287285

288286
optim1 = torch.optim.AdamW(model1.parameters())
289-
optim2 = low_bit_optim.CPUOffloadOptimizer(
287+
optim2 = optim.CPUOffloadOptimizer(
290288
model2.parameters(),
291289
torch.optim.AdamW,
292290
offload_gradients=offload_grad,
@@ -335,9 +333,7 @@ def test_optim_cpu_offload_save_load(self):
335333
nn.Linear(32, 1024, bias=True), nn.ReLU(), nn.Linear(1024, 128, bias=True)
336334
)
337335
model1.to(device)
338-
optim1 = low_bit_optim.CPUOffloadOptimizer(
339-
model1.parameters(), torch.optim.AdamW
340-
)
336+
optim1 = optim.CPUOffloadOptimizer(model1.parameters(), torch.optim.AdamW)
341337

342338
for _ in range(2):
343339
x = torch.randn(4, 32, device=device)
@@ -352,9 +348,7 @@ def test_optim_cpu_offload_save_load(self):
352348

353349
# resume training
354350
model2 = copy.deepcopy(model1)
355-
optim2 = low_bit_optim.CPUOffloadOptimizer(
356-
model2.parameters(), torch.optim.AdamW
357-
)
351+
optim2 = optim.CPUOffloadOptimizer(model2.parameters(), torch.optim.AdamW)
358352
optim2.load_state_dict(state_dict)
359353

360354
for _ in range(2):
@@ -381,7 +375,7 @@ def test_optim_bf16_stochastic_round_correctness(self):
381375
# small LR so that weight update is small
382376
# when bf16_stochastic_round=False, the test will fail after 1 iteration
383377
optim1 = torch.optim.AdamW(model1.parameters(), lr=1e-5)
384-
optim2 = low_bit_optim._AdamW(
378+
optim2 = optim._AdamW(
385379
model2.parameters(),
386380
lr=1e-5,
387381
bf16_stochastic_round=True,
@@ -424,9 +418,9 @@ def world_size(self) -> int:
424418
@skip_if_lt_x_gpu(_FSDP_WORLD_SIZE)
425419
@skip_if_rocm("ROCm enablement in progress")
426420
def test_fsdp2(self):
427-
optim_classes = [low_bit_optim.AdamW8bit, low_bit_optim.AdamW4bit]
421+
optim_classes = [optim.AdamW8bit, optim.AdamW4bit]
428422
if torch.cuda.get_device_capability() >= (8, 9):
429-
optim_classes.append(low_bit_optim.AdamWFp8)
423+
optim_classes.append(optim.AdamWFp8)
430424

431425
self.run_subtests(
432426
{"optim_cls": optim_classes},
@@ -545,13 +539,13 @@ def test_uneven_shard(self):
545539

546540
# currently all of our low-bit Adam/AdamW share the same implementation.
547541
# thus, we only need to test for 1 optimizer class.
548-
optim = low_bit_optim.AdamW8bit(model.parameters())
542+
optimizer = optim.AdamW8bit(model.parameters())
549543

550544
for _ in range(2):
551545
inputs = torch.randn(2, in_dim, device="cuda")
552546
model(inputs).sum().backward()
553-
optim.step()
554-
optim.zero_grad()
547+
optimizer.step()
548+
optimizer.zero_grad()
555549

556550

557551
instantiate_parametrized_tests(TestQuantize)

torchao/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,12 @@
4848
quantize_,
4949
)
5050

51-
from . import dtypes, testing
51+
from . import dtypes, optim, testing
5252

5353
__all__ = [
5454
"dtypes",
5555
"autoquant",
56+
"optim",
5657
"quantize_",
5758
"testing",
5859
"ops",

0 commit comments

Comments
 (0)