Skip to content

Commit f2e4d5f

Browse files
committed
[wip] zero dim support for float8 training
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent b65e513 commit f2e4d5f

File tree

2 files changed

+41
-1
lines changed

2 files changed

+41
-1
lines changed

test/float8/test_base.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,8 +531,45 @@ def test_inference_mode(self):
531531
with torch.inference_mode(mode=True):
532532
m(x)
533533

534+
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
535+
@unittest.skipIf(not is_cuda_8_9, "CUDA 8.9 not available")
536+
@pytest.mark.parametrize(
537+
"recipe_name",
538+
[
539+
Float8LinearRecipeName.ALL_TENSORWISE,
540+
# TODO(future PR): enable axiswise recipes
541+
],
542+
)
543+
def test_zero_dim(self, recipe_name):
544+
# Note: we only test M == 0 because we can assume that K == 0 and N == 0
545+
# are not important
546+
M, K, N = 0, 64, 128
547+
548+
x0_ref = torch.randn(M, K, device="cuda", dtype=torch.bfloat16).requires_grad_()
549+
x0_fp8 = copy.deepcopy(x0_ref)
550+
x1 = torch.randn(32, K, device="cuda", dtype=torch.bfloat16)
551+
config = recipe_name_to_linear_config(recipe_name)
552+
553+
m_ref = nn.Sequential(nn.Linear(K, N, device="cuda", dtype=torch.bfloat16))
554+
m_fp8 = copy.deepcopy(m_ref)
555+
m_fp8 = convert_to_float8_training(m_fp8, config=config)
556+
557+
y_ref = m_ref(x0_ref)
558+
y_ref.sum().backward()
559+
560+
y_fp8 = m_fp8(x0_fp8)
561+
y_fp8.sum().backward()
562+
563+
assert torch.allclose(
564+
y_ref, y_fp8, rtol=0, atol=0)
565+
assert torch.allclose(
566+
m_ref[0].weight.grad, m_fp8[0].weight.grad, rtol=0, atol=0)
567+
assert torch.allclose(
568+
x0_ref.grad, x0_fp8.grad, rtol=0, atol=0)
569+
534570

535571
class TestScaledMM:
572+
536573
@unittest.skipIf(
537574
not is_cuda_8_9,
538575
"CUDA not available",

torchao/float8/float8_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,10 @@ def tensor_to_amax(
9999
axiswise_dim: Optional[int] = None,
100100
) -> torch.Tensor:
101101
if scaling_granularity is ScalingGranularity.TENSORWISE:
102-
amax = torch.max(torch.abs(x))
102+
if x.numel() > 0:
103+
amax = torch.max(torch.abs(x))
104+
else:
105+
amax = torch.tensor(EPS, device=x.device, dtype=x.dtype)
103106
else:
104107
assert scaling_granularity is ScalingGranularity.AXISWISE, "unsupported"
105108
assert axiswise_dim is not None, "unsupported"

0 commit comments

Comments
 (0)