Skip to content

Commit 61d49d4

Browse files
authored
test rowwise fp32
Differential Revision: D73552660 Pull Request resolved: #2431
1 parent 00417b8 commit 61d49d4

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

test/float8/test_base.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,7 @@
3434
e5m2_dtype,
3535
)
3636
from torchao.float8.float8_linear import Float8Linear
37-
from torchao.float8.float8_linear_utils import (
38-
convert_to_float8_training,
39-
)
37+
from torchao.float8.float8_linear_utils import convert_to_float8_training
4038
from torchao.float8.float8_ops import addmm_float8_unwrapped
4139
from torchao.float8.float8_scaling_utils import (
4240
get_maybe_axiswise_dim,
@@ -379,12 +377,16 @@ def test_linear_from_config_params(
379377
)
380378
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
381379
@pytest.mark.parametrize("linear_bias", [True, False])
380+
@pytest.mark.parametrize(
381+
"linear_dtype", [torch.bfloat16, torch.float16, torch.float32]
382+
)
382383
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
383384
@skip_if_rocm("ROCm enablement in progress")
384385
def test_linear_from_recipe(
385386
self,
386387
recipe_name,
387388
x_shape,
389+
linear_dtype: torch.dtype,
388390
linear_bias: bool,
389391
):
390392
if torch.cuda.get_device_capability() < (9, 0):
@@ -393,7 +395,6 @@ def test_linear_from_recipe(
393395
)
394396
pytest.skip()
395397

396-
linear_dtype = torch.bfloat16
397398
x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
398399
m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype)
399400
config = Float8LinearConfig.from_recipe_name(recipe_name)

torchao/float8/float8_ops.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,12 @@ def addmm_float8_unwrapped(
5454
a_inverse_scale = a_inverse_scale.new_ones(())
5555
b_inverse_scale = a_inverse_scale.new_ones(())
5656

57+
# work around torch._scaled_mm not having float32 output type
58+
# TODO(pytorch/pytorch#156771): remove this once torch._scaled_mm supports float32 output
59+
orig_dtype = output_dtype
60+
if orig_dtype in (torch.float16, torch.float32) and is_rowwise_scaling:
61+
output_dtype = torch.bfloat16
62+
5763
post_bias = None
5864
if output_dtype == torch.float32:
5965
# Bias is not supported by _scaled_mm when output is fp32
@@ -76,6 +82,9 @@ def addmm_float8_unwrapped(
7682
if post_bias is not None:
7783
output += post_bias
7884

85+
if orig_dtype in (torch.float16, torch.float32) and is_rowwise_scaling:
86+
output = output.to(orig_dtype)
87+
7988
return output
8089

8190

0 commit comments

Comments
 (0)