Skip to content

Commit 4ee2ee1

Browse files
authored
[optim] Fix low-bit optim when used with FSDP2+CPUOffload (#2195)
1 parent 8b96bcd commit 4ee2ee1

File tree

2 files changed

+23
-7
lines changed

2 files changed

+23
-7
lines changed

test/test_low_bit_optim.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@
1111
import pytest
1212
import torch
1313
from torch import nn
14-
from torch.distributed._composable.fsdp import fully_shard
14+
from torch.distributed._composable.fsdp import (
15+
fully_shard,
16+
CPUOffloadPolicy,
17+
OffloadPolicy,
18+
)
1519
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
1620
from torch.testing._internal.common_fsdp import FSDPTest
1721
from torch.testing._internal.common_utils import (
@@ -427,16 +431,21 @@ def world_size(self) -> int:
427431
@skip_if_lt_x_gpu(_FSDP_WORLD_SIZE)
428432
@skip_if_rocm("ROCm enablement in progress")
429433
def test_fsdp2(self):
430-
optim_classes = [optim.AdamW8bit, optim.AdamW4bit]
434+
# we do this to avoid all combinations
435+
args_list = [
436+
(optim.AdamW8bit, OffloadPolicy),
437+
(optim.AdamW4bit, OffloadPolicy),
438+
(optim.AdamW8bit, CPUOffloadPolicy),
439+
]
431440
if torch.cuda.get_device_capability() >= (8, 9):
432-
optim_classes.append(optim.AdamWFp8)
441+
args_list.append((optim.AdamWFp8, OffloadPolicy))
433442

434443
self.run_subtests(
435-
{"optim_cls": optim_classes},
444+
{"args": args_list},
436445
self._test_fsdp2,
437446
)
438447

439-
def _test_fsdp2(self, optim_cls):
448+
def _test_fsdp2(self, args):
440449
import torch.distributed as dist
441450
import torch.distributed.checkpoint as dcp
442451
import torch.utils._pytree as pytree
@@ -447,6 +456,8 @@ def _test_fsdp2(self, optim_cls):
447456
TransformerBlock,
448457
)
449458

459+
optim_cls, offload_policy = args
460+
450461
batch_size = 3
451462
vocab_size = 1024
452463
seq_len = 64
@@ -466,8 +477,8 @@ def _test_fsdp2(self, optim_cls):
466477
fsdp_model = copy.deepcopy(base_model)
467478
for m in fsdp_model.modules():
468479
if isinstance(m, TransformerBlock):
469-
fully_shard(m)
470-
fully_shard(fsdp_model)
480+
fully_shard(m, offload_policy=offload_policy)
481+
fully_shard(fsdp_model, offload_policy=offload_policy)
471482
fsdp_optim = optim_cls(fsdp_model.parameters(), lr=1e-2)
472483

473484
torch.manual_seed(42 + self.rank + 1)

torchao/optim/adam.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,11 @@ def _new_buffer(self, p: Tensor, signed: bool):
8383
stride=p.stride(),
8484
)
8585

86+
# when there is CPU offload, p.device is cpu, but device_mesh.device_type is cuda.
87+
# DTensor.from_local() will move local_tensor to device_mesh.device_type.
88+
# hence, we need to manually move it back to CPU.
89+
# https://github.com/pytorch/pytorch/blob/bc4cf1c1/torch/distributed/tensor/_api.py#L410-L415
90+
out = out.to(p.device)
8691
return out
8792

8893
@torch.no_grad()

0 commit comments

Comments
 (0)