37
37
from torchao .optim .subclass_fp8 import OptimStateFp8
38
38
from torchao .testing .utils import skip_if_rocm
39
39
from torchao .utils import (
40
- TORCH_VERSION_AT_LEAST_2_4 ,
41
40
TORCH_VERSION_AT_LEAST_2_5 ,
42
41
TORCH_VERSION_AT_LEAST_2_7 ,
43
42
get_available_devices ,
@@ -128,8 +127,6 @@ class TestOptim(TestCase):
128
127
@skip_if_rocm ("ROCm enablement in progress" )
129
128
def test_optim_smoke (self , optim_name , dtype , device ):
130
129
if optim_name .endswith ("Fp8" ) and device == "cuda" :
131
- if not TORCH_VERSION_AT_LEAST_2_4 :
132
- pytest .skip ("FP8 CUDA requires PyTorch >= 2.4" )
133
130
if torch .cuda .get_device_capability () < (8 , 9 ):
134
131
pytest .skip ("FP8 CUDA requires compute capability >= 8.9" )
135
132
@@ -166,6 +163,30 @@ def test_optim_smoke(self, optim_name, dtype, device):
166
163
for p1 , p2 in zip (model .parameters (), model2 .parameters ()):
167
164
torch .testing .assert_close (p2 , p1 )
168
165
166
+ @parametrize ("optim_name" , ["Adam8bit" , "Adam4bit" , "AdamFp8" ])
167
+ @parametrize ("device" , _DEVICES )
168
+ def test_optim_default_dtype_bf16 (self , optim_name , device ):
169
+ if optim_name .endswith ("Fp8" ) and device == "cuda" :
170
+ if torch .cuda .get_device_capability () < (8 , 9 ):
171
+ pytest .skip ("FP8 CUDA requires compute capability >= 8.9" )
172
+
173
+ old_dtype = torch .get_default_dtype ()
174
+ torch .set_default_dtype (torch .bfloat16 )
175
+
176
+ try :
177
+ model = nn .Sequential (nn .Linear (32 , 256 ), nn .ReLU (), nn .Linear (256 , 32 ))
178
+ model .to (device = device )
179
+ optimizer = getattr (optim , optim_name )(model .parameters ())
180
+
181
+ x = torch .randn (4 , 32 , device = device )
182
+ loss = model (x ).sum ()
183
+ loss .backward ()
184
+ optimizer .step ()
185
+ optimizer .zero_grad ()
186
+
187
+ finally :
188
+ torch .set_default_dtype (old_dtype )
189
+
169
190
# aten.slice is required for dcp.load() when world size changes i.e. re-sharding
170
191
# however, it's cumbersome to test it directly, since we would need to run distributed
171
192
# test 2 times with different world size, and persist checkpoint across the 2 runs.
@@ -178,8 +199,6 @@ def test_subclass_slice(self, subclass, shape, device):
178
199
if subclass == OptimStateFp8 :
179
200
if device == "cpu" and len (shape ) > 1 and not TORCH_VERSION_AT_LEAST_2_5 :
180
201
pytest .skip ("fill_cpu not implemented for Float8_e4m3fn for torch<2.5" )
181
- if device == "cuda" and not TORCH_VERSION_AT_LEAST_2_4 :
182
- pytest .skip ("FP8 CUDA requires PyTorch >= 2.4" )
183
202
if device == "cuda" and torch .cuda .get_device_capability () < (8 , 9 ):
184
203
pytest .skip ("FP8 CUDA requires compute capability >= 8.9" )
185
204
0 commit comments