Skip to content

Commit 3aa9361

Browse files
authored
[optim] Fix bug when default dtype is BF16 (#2286)
* handle error when default dtype is BF16 * skip FP8 optim on unsupported GPUs
1 parent bc68b11 commit 3aa9361

File tree

3 files changed

+30
-11
lines changed

3 files changed

+30
-11
lines changed

test/test_low_bit_optim.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
from torchao.optim.subclass_fp8 import OptimStateFp8
3838
from torchao.testing.utils import skip_if_rocm
3939
from torchao.utils import (
40-
TORCH_VERSION_AT_LEAST_2_4,
4140
TORCH_VERSION_AT_LEAST_2_5,
4241
TORCH_VERSION_AT_LEAST_2_7,
4342
get_available_devices,
@@ -128,8 +127,6 @@ class TestOptim(TestCase):
128127
@skip_if_rocm("ROCm enablement in progress")
129128
def test_optim_smoke(self, optim_name, dtype, device):
130129
if optim_name.endswith("Fp8") and device == "cuda":
131-
if not TORCH_VERSION_AT_LEAST_2_4:
132-
pytest.skip("FP8 CUDA requires PyTorch >= 2.4")
133130
if torch.cuda.get_device_capability() < (8, 9):
134131
pytest.skip("FP8 CUDA requires compute capability >= 8.9")
135132

@@ -166,6 +163,30 @@ def test_optim_smoke(self, optim_name, dtype, device):
166163
for p1, p2 in zip(model.parameters(), model2.parameters()):
167164
torch.testing.assert_close(p2, p1)
168165

166+
@parametrize("optim_name", ["Adam8bit", "Adam4bit", "AdamFp8"])
167+
@parametrize("device", _DEVICES)
168+
def test_optim_default_dtype_bf16(self, optim_name, device):
169+
if optim_name.endswith("Fp8") and device == "cuda":
170+
if torch.cuda.get_device_capability() < (8, 9):
171+
pytest.skip("FP8 CUDA requires compute capability >= 8.9")
172+
173+
old_dtype = torch.get_default_dtype()
174+
torch.set_default_dtype(torch.bfloat16)
175+
176+
try:
177+
model = nn.Sequential(nn.Linear(32, 256), nn.ReLU(), nn.Linear(256, 32))
178+
model.to(device=device)
179+
optimizer = getattr(optim, optim_name)(model.parameters())
180+
181+
x = torch.randn(4, 32, device=device)
182+
loss = model(x).sum()
183+
loss.backward()
184+
optimizer.step()
185+
optimizer.zero_grad()
186+
187+
finally:
188+
torch.set_default_dtype(old_dtype)
189+
169190
# aten.slice is required for dcp.load() when world size changes i.e. re-sharding
170191
# however, it's cumbersome to test it directly, since we would need to run distributed
171192
# test 2 times with different world size, and persist checkpoint across the 2 runs.
@@ -178,8 +199,6 @@ def test_subclass_slice(self, subclass, shape, device):
178199
if subclass == OptimStateFp8:
179200
if device == "cpu" and len(shape) > 1 and not TORCH_VERSION_AT_LEAST_2_5:
180201
pytest.skip("fill_cpu not implemented for Float8_e4m3fn for torch<2.5")
181-
if device == "cuda" and not TORCH_VERSION_AT_LEAST_2_4:
182-
pytest.skip("FP8 CUDA requires PyTorch >= 2.4")
183202
if device == "cuda" and torch.cuda.get_device_capability() < (8, 9):
184203
pytest.skip("FP8 CUDA requires compute capability >= 8.9")
185204

torchao/optim/subclass_4bit.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def __init__(self, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool, sha
6969
assert codes.dtype is torch.uint8
7070
assert codes.ndim == 1 # flattened buffer
7171
assert scale.ndim == 1
72+
assert qmap.dtype is torch.float32
7273
self.codes = codes
7374
self.scale = scale
7475
self.qmap = qmap
@@ -101,9 +102,8 @@ def zeros(cls, shape, signed: bool = True, block_size: int = 128, device=None):
101102

102103
codes = torch.zeros(n_elems // 2, dtype=torch.uint8, device=device)
103104
scale = torch.zeros(n_elems // block_size, device=device)
104-
qmap = torch.tensor(
105-
get_qmap_signed() if signed else get_qmap_unsigned(), device=device
106-
)
105+
qmap_list = get_qmap_signed() if signed else get_qmap_unsigned()
106+
qmap = torch.tensor(qmap_list, dtype=torch.float32, device=device)
107107
return cls(codes, scale, qmap, signed, shape)
108108

109109
def __repr__(self):

torchao/optim/subclass_8bit.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def __init__(self, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool):
6262
"""
6363
assert codes.dtype is torch.uint8
6464
assert scale.ndim == 1
65+
assert qmap.dtype is torch.float32
6566
self.codes = codes
6667
self.scale = scale
6768
self.qmap = qmap
@@ -89,9 +90,8 @@ def dequantize(self, output_dtype=None):
8990
def zeros(cls, shape, signed: bool = True, block_size: int = 256, device=None):
9091
codes = torch.zeros(shape, dtype=torch.uint8, device=device)
9192
scale = torch.zeros(codes.numel() // block_size, device=device)
92-
qmap = torch.tensor(
93-
get_qmap_signed() if signed else get_qmap_unsigned(), device=device
94-
)
93+
qmap_list = get_qmap_signed() if signed else get_qmap_unsigned()
94+
qmap = torch.tensor(qmap_list, dtype=torch.float32, device=device)
9595
return cls(codes, scale, qmap, signed)
9696

9797
def __repr__(self):

0 commit comments

Comments
 (0)