Skip to content

Commit a130365

Browse files
committed
[wip] zero dim support for float8 training
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent b65e513 commit a130365

File tree

2 files changed

+41
-1
lines changed

2 files changed

+41
-1
lines changed

test/float8/test_base.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,43 @@ def test_inference_mode(self):
531531
with torch.inference_mode(mode=True):
532532
m(x)
533533

534+
@unittest.skip(
535+
"TODO enable this test after https://github.com/pytorch/pytorch/pull/140967 lands in CI"
536+
)
537+
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
538+
@unittest.skipIf(not is_cuda_8_9, "CUDA 8.9 not available")
539+
@pytest.mark.parametrize(
540+
"recipe_name",
541+
[
542+
Float8LinearRecipeName.ALL_TENSORWISE,
543+
# TODO(future PR): enable axiswise recipes
544+
],
545+
)
546+
def test_zero_dim(self, recipe_name):
547+
# Note: we only test M == 0 because we can assume that K == 0 and N == 0
548+
# are not important
549+
M, K, N = 0, 64, 128
550+
551+
x0_ref = torch.randn(M, K, device="cuda", dtype=torch.bfloat16).requires_grad_()
552+
x0_fp8 = copy.deepcopy(x0_ref)
553+
config = recipe_name_to_linear_config(recipe_name)
554+
555+
m_ref = nn.Sequential(nn.Linear(K, N, device="cuda", dtype=torch.bfloat16))
556+
m_fp8 = copy.deepcopy(m_ref)
557+
m_fp8 = convert_to_float8_training(m_fp8, config=config)
558+
559+
y_ref = m_ref(x0_ref)
560+
y_ref.sum().backward()
561+
562+
y_fp8 = m_fp8(x0_fp8)
563+
y_fp8.sum().backward()
564+
565+
assert torch.allclose(y_ref, y_fp8, rtol=0, atol=0)
566+
assert torch.allclose(
567+
m_ref[0].weight.grad, m_fp8[0].weight.grad, rtol=0, atol=0
568+
)
569+
assert torch.allclose(x0_ref.grad, x0_fp8.grad, rtol=0, atol=0)
570+
534571

535572
class TestScaledMM:
536573
@unittest.skipIf(

torchao/float8/float8_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,10 @@ def tensor_to_amax(
9999
axiswise_dim: Optional[int] = None,
100100
) -> torch.Tensor:
101101
if scaling_granularity is ScalingGranularity.TENSORWISE:
102-
amax = torch.max(torch.abs(x))
102+
if x.numel() > 0:
103+
amax = torch.max(torch.abs(x))
104+
else:
105+
amax = torch.tensor(EPS, device=x.device, dtype=x.dtype)
103106
else:
104107
assert scaling_granularity is ScalingGranularity.AXISWISE, "unsupported"
105108
assert axiswise_dim is not None, "unsupported"

0 commit comments

Comments
 (0)