|
14 | 14 | import torch
|
15 | 15 | import torch.nn as nn
|
16 | 16 |
|
17 |
| -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 |
| 17 | +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_89 |
18 | 18 |
|
19 | 19 | if not TORCH_VERSION_AT_LEAST_2_5:
|
20 | 20 | pytest.skip("Unsupported PyTorch version", allow_module_level=True)
|
@@ -531,6 +531,21 @@ def test_inference_mode(self):
|
531 | 531 | with torch.inference_mode(mode=True):
|
532 | 532 | m(x)
|
533 | 533 |
|
| 534 | + @unittest.skipIf(not is_sm_89(), "CUDA arch 8.9 not available") |
| 535 | + def test_quantize(self): |
| 536 | + x = torch.randn(32, 32, device="cuda") |
| 537 | + m = nn.Sequential(nn.Linear(32, 32)).cuda() |
| 538 | + m = convert_to_float8_training(m) |
| 539 | + assert isinstance(m[0], Float8Linear), "Module is not a Float8Linear" |
| 540 | + from torchao.quantization.quant_api import float8_weight_only, quantize_ |
| 541 | + |
| 542 | + quantize_(m, float8_weight_only()) |
| 543 | + assert ( |
| 544 | + m[0].weight.tensor_impl.float8_data.dtype == torch.float8_e4m3fn |
| 545 | + ), "Post quantization dtype should be torch.float8_e4m3fn" |
| 546 | + with torch.no_grad(): |
| 547 | + m(x) |
| 548 | + |
534 | 549 |
|
535 | 550 | class TestScaledMM:
|
536 | 551 | @unittest.skipIf(
|
@@ -576,7 +591,7 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum):
|
576 | 591 | if base_dtype in {torch.bfloat16, torch.float16}:
|
577 | 592 | atol, rtol = 7e-2, 7e-2
|
578 | 593 | else:
|
579 |
| - atol, rtol = 2e-3, 2e-3 |
| 594 | + atol, rtol = 3e-3, 3e-3 |
580 | 595 | torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
|
581 | 596 |
|
582 | 597 | @unittest.skipIf(not is_cuda_8_9, "CUDA not available")
|
|
0 commit comments