Skip to content

Commit 2e5001e

Browse files
committed
[wip] zero dim support for float8 training
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent a558f7e commit 2e5001e

File tree

2 files changed

+41
-1
lines changed

2 files changed

+41
-1
lines changed

test/float8/test_base.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,43 @@ def test_quantize(self):
552552
with torch.no_grad():
553553
m(x)
554554

555+
@unittest.skip(
556+
"TODO enable this test after https://github.com/pytorch/pytorch/pull/140967 lands in CI"
557+
)
558+
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
559+
@unittest.skipIf(not is_sm_at_least_89, "CUDA 8.9 not available")
560+
@pytest.mark.parametrize(
561+
"recipe_name",
562+
[
563+
Float8LinearRecipeName.ALL_TENSORWISE,
564+
# TODO(future PR): enable axiswise recipes
565+
],
566+
)
567+
def test_zero_dim(self, recipe_name):
568+
# Note: we only test M == 0 because we can assume that K == 0 and N == 0
569+
# are not important
570+
M, K, N = 0, 64, 128
571+
572+
x0_ref = torch.randn(M, K, device="cuda", dtype=torch.bfloat16).requires_grad_()
573+
x0_fp8 = copy.deepcopy(x0_ref)
574+
config = recipe_name_to_linear_config(recipe_name)
575+
576+
m_ref = nn.Sequential(nn.Linear(K, N, device="cuda", dtype=torch.bfloat16))
577+
m_fp8 = copy.deepcopy(m_ref)
578+
m_fp8 = convert_to_float8_training(m_fp8, config=config)
579+
580+
y_ref = m_ref(x0_ref)
581+
y_ref.sum().backward()
582+
583+
y_fp8 = m_fp8(x0_fp8)
584+
y_fp8.sum().backward()
585+
586+
assert torch.allclose(y_ref, y_fp8, rtol=0, atol=0)
587+
assert torch.allclose(
588+
m_ref[0].weight.grad, m_fp8[0].weight.grad, rtol=0, atol=0
589+
)
590+
assert torch.allclose(x0_ref.grad, x0_fp8.grad, rtol=0, atol=0)
591+
555592

556593
class TestScaledMM:
557594
@unittest.skipIf(

torchao/float8/float8_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,10 @@ def tensor_to_amax(
9999
axiswise_dim: Optional[int] = None,
100100
) -> torch.Tensor:
101101
if scaling_granularity is ScalingGranularity.TENSORWISE:
102-
amax = torch.max(torch.abs(x))
102+
if x.numel() > 0:
103+
amax = torch.max(torch.abs(x))
104+
else:
105+
amax = torch.tensor(EPS, device=x.device, dtype=x.dtype)
103106
else:
104107
assert scaling_granularity is ScalingGranularity.AXISWISE, "unsupported"
105108
assert axiswise_dim is not None, "unsupported"

0 commit comments

Comments
 (0)