Skip to content

Commit 5cae4d0

Browse files
committed
Add Float8RowwiseTensor
Summary: Splits out the float8 rowwise quantized path (both act and weight) of AQT to Float8RowwiseTensor Next: could potentially incorporate the per tensor activation path there as well Next: we can split the per tensor weight path to another Tensor as well, so we can deprecate AQT path for float8 Test Plan: python test/dtypes/test_affine_quantized_float.py python test/quantization/quantize_/test_float8_rowwise_tensor.py Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2463, branch: jerryzh168/stack/9
1 parent 84bae22 commit 5cae4d0

File tree

7 files changed

+809
-109
lines changed

7 files changed

+809
-109
lines changed

test/dtypes/test_affine_quantized_float.py

Lines changed: 136 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from torchao.float8.float8_utils import compute_error
3030
from torchao.quantization import (
3131
Float8DynamicActivationFloat8WeightConfig,
32+
Float8RowwiseTensor,
3233
float8_dynamic_activation_float8_weight,
3334
float8_weight_only,
3435
quantize_,
@@ -324,19 +325,15 @@ def test_mm_float8dq_per_row(
324325

325326
quant_weight = test_linear.weight
326327

327-
self.assertTrue(hasattr(quant_weight, "original_weight_tensor"))
328-
weight_impl = quant_weight.original_weight_tensor.tensor_impl
329-
330-
self.assertTrue(hasattr(weight_impl, "float8_data"))
331-
self.assertTrue(hasattr(weight_impl, "scale"))
332-
self.assertFalse(weight_impl.transposed)
328+
self.assertTrue(hasattr(quant_weight, "float8_data"))
329+
self.assertTrue(hasattr(quant_weight, "scale"))
333330

334331
# Verify scale shape for row-wise quantization
335332
expected_scale_shape = (out_features, 1)
336-
actual_scale_shape = weight_impl.scale.shape
333+
actual_scale_shape = quant_weight.scale.shape
337334
self.assertEqual(actual_scale_shape, expected_scale_shape)
338335

339-
self.assertEqual(weight_impl.float8_data.shape, (out_features, in_features))
336+
self.assertEqual(quant_weight.float8_data.shape, (out_features, in_features))
340337

341338
input_tensor = torch.randn(*input_shape, device=device, dtype=dtype)
342339

@@ -357,7 +354,7 @@ def test_mm_float8dq_per_row(
357354
@common_utils.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
358355
@common_utils.parametrize("output_dtype", [torch.float32, torch.bfloat16])
359356
@common_utils.parametrize("block_size", [None, (1, 32), (2, 16), (4, 8)])
360-
def test_dequantize_affine_float8(self, float8_dtype, output_dtype, block_size):
357+
def test__dequantize_affine_float8(self, float8_dtype, output_dtype, block_size):
361358
"""Test _dequantize_affine_float8 with various configurations"""
362359

363360
device = "cuda"
@@ -387,7 +384,7 @@ def test_dequantize_affine_float8(self, float8_dtype, output_dtype, block_size):
387384
@unittest.skipIf(
388385
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
389386
)
390-
def test_dequantize_affine_float8_scale_broadcasting(self):
387+
def test__dequantize_affine_float8_scale_broadcasting(self):
391388
"""Test that scale broadcasting works correctly for block-wise quantization"""
392389
device = "cuda"
393390
# Create input tensor with known block structure
@@ -419,11 +416,11 @@ def test_dequantize_affine_float8_scale_broadcasting(self):
419416
@unittest.skipIf(
420417
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
421418
)
422-
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
423-
def test_float8_tensor_slicing_basic(self, granularity):
419+
def test_float8_tensor_slicing_basic_per_tensor(self):
424420
"""Test basic slicing operations on Float8 tensors"""
425421
device = "cuda"
426422
dtype = torch.bfloat16
423+
granularity = PerTensor()
427424

428425
# Create and quantize a model
429426
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
@@ -450,6 +447,41 @@ def test_float8_tensor_slicing_basic(self, granularity):
450447
self.assertTrue(isinstance(sliced_1, Float8AQTTensorImpl))
451448
self.assertTrue(isinstance(sliced_both, Float8AQTTensorImpl))
452449

450+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
451+
@unittest.skipIf(
452+
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
453+
)
454+
def test_float8_tensor_slicing_basic_per_row(self):
455+
"""Test basic slicing operations on Float8 tensors"""
456+
device = "cuda"
457+
dtype = torch.bfloat16
458+
granularity = PerRow()
459+
460+
# Create and quantize a model
461+
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
462+
quantize_(
463+
model, Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
464+
)
465+
466+
weight = model.weight
467+
468+
# Test dimension 0 slicing (rows)
469+
sliced_0 = weight[10:20]
470+
self.assertEqual(sliced_0.shape, (10, 64))
471+
472+
# Test dimension 1 slicing (columns)
473+
sliced_1 = weight[:, 20:40]
474+
self.assertEqual(sliced_1.shape, (32, 20))
475+
476+
# Test combined slicing
477+
sliced_both = weight[5:15, 10:30]
478+
self.assertEqual(sliced_both.shape, (10, 20))
479+
480+
# Verify the sliced tensors are still Float8 tensors
481+
self.assertTrue(isinstance(sliced_0, Float8RowwiseTensor))
482+
self.assertTrue(isinstance(sliced_1, Float8RowwiseTensor))
483+
self.assertTrue(isinstance(sliced_both, Float8RowwiseTensor))
484+
453485
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
454486
@unittest.skipIf(
455487
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
@@ -497,27 +529,26 @@ def test_float8_tensor_slicing_per_row(self):
497529
)
498530

499531
original_weight = model.weight # Shape: (32, 64)
500-
original_impl = original_weight.original_weight_tensor.tensor_impl
501-
original_scale = original_impl.scale # Shape: (32, 1)
532+
original_scale = model.weight.scale # Shape: (32, 1)
502533

503534
# Test row slicing (dimension 0)
504535
sliced_rows = original_weight[10:20] # Shape: (10, 64)
505-
sliced_impl = sliced_rows.original_weight_tensor.tensor_impl
536+
sliced_scale = sliced_rows.scale
506537

507538
# Scale should be sliced to match the rows
508539
expected_scale_shape = (10, 1)
509-
self.assertEqual(sliced_impl.scale.shape, expected_scale_shape)
540+
self.assertEqual(sliced_scale.shape, expected_scale_shape)
510541

511542
# Verify the scale values are correct (should be subset of original)
512-
self.assertTrue(torch.equal(sliced_impl.scale, original_scale[10:20]))
543+
self.assertTrue(torch.equal(sliced_scale, original_scale[10:20]))
513544

514545
# Test column slicing (dimension 1) - scale should not change for per-row
515546
sliced_cols = original_weight[:, 20:40] # Shape: (32, 20)
516-
sliced_cols_impl = sliced_cols.original_weight_tensor.tensor_impl
547+
sliced_cols_scale = sliced_cols.scale
517548

518549
# Scale shape should remain the same since we're not changing rows
519-
self.assertEqual(sliced_cols_impl.scale.shape, (32, 1))
520-
self.assertTrue(torch.equal(sliced_cols_impl.scale, original_scale))
550+
self.assertEqual(sliced_cols_scale.shape, (32, 1))
551+
self.assertTrue(torch.equal(sliced_cols_scale, original_scale))
521552

522553
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
523554
@unittest.skipIf(
@@ -552,15 +583,15 @@ def test_float8_tensor_slicing_edge_cases(self):
552583
@unittest.skipIf(
553584
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
554585
)
555-
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
556586
@unittest.skipIf(
557587
is_sm_version(8, 9),
558588
"TODO: AssertionError: tensor(-2.1562, device='cuda:0', dtype=torch.bfloat16) not greater than 15",
559589
)
560-
def test_float8_tensor_slicing_functional_correctness(self, granularity):
590+
def test_float8_tensor_slicing_functional_correctness_per_tensor(self):
561591
"""Test that sliced tensors produce correct results in computations"""
562592
device = "cuda"
563593
dtype = torch.bfloat16
594+
granularity = PerTensor()
564595

565596
# Create reference and quantized models with dimensions that are multiples of 16
566597
ref_model = (
@@ -630,6 +661,89 @@ def test_float8_tensor_slicing_functional_correctness(self, granularity):
630661
error = compute_error(ref_output, quant_output)
631662
self.assertGreater(error, 15, f"Quantization SQNR too low: {error}")
632663

664+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
665+
@unittest.skipIf(
666+
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
667+
)
668+
@unittest.skipIf(
669+
is_sm_version(8, 9),
670+
"TODO: AssertionError: tensor(-2.1562, device='cuda:0', dtype=torch.bfloat16) not greater than 15",
671+
)
672+
def test_float8_tensor_slicing_functional_correctness_per_row(self):
673+
"""Test that sliced tensors produce correct results in computations"""
674+
device = "cuda"
675+
dtype = torch.bfloat16
676+
granularity = PerRow()
677+
678+
# Create reference and quantized models with dimensions that are multiples of 16
679+
ref_model = (
680+
torch.nn.Linear(64, 48, bias=False).to(device).to(dtype)
681+
) # 48 is divisible by 16
682+
quant_model = copy.deepcopy(ref_model)
683+
quantize_(
684+
quant_model,
685+
Float8DynamicActivationFloat8WeightConfig(granularity=granularity),
686+
)
687+
688+
# Create input with batch size that works well with slicing
689+
input_tensor = torch.randn(8, 64, device=device, dtype=dtype)
690+
691+
ref_weight_slice = ref_model.weight[0:16, 0:32]
692+
quant_weight_slice = quant_model.weight[0:16, 0:32]
693+
694+
# Verify that the sliced weights maintain Float8 properties
695+
self.assertTrue(hasattr(quant_weight_slice, "float8_data"))
696+
self.assertTrue(hasattr(quant_weight_slice, "scale"))
697+
sliced_impl = quant_weight_slice
698+
self.assertTrue(isinstance(sliced_impl, Float8RowwiseTensor))
699+
700+
# Verify sliced weight shapes
701+
self.assertEqual(sliced_impl.float8_data.shape, (16, 32))
702+
703+
# Get original quantized weight implementation for scale comparison
704+
original_quant_impl = quant_model.weight
705+
706+
# Verify scale properties based on granularity
707+
if isinstance(granularity, PerTensor):
708+
# Per-tensor: scale should be identical to original (scalar)
709+
self.assertEqual(sliced_impl.scale.numel(), 1)
710+
self.assertTrue(torch.equal(sliced_impl.scale, original_quant_impl.scale))
711+
else: # PerRow
712+
# Per-row: scale should be sliced to match the selected rows (0:16)
713+
expected_scale_shape = (16, 1)
714+
self.assertEqual(sliced_impl.scale.shape, expected_scale_shape)
715+
# Verify the scale values are the correct slice from the original
716+
self.assertTrue(
717+
torch.equal(sliced_impl.scale, original_quant_impl.scale[0:16])
718+
)
719+
720+
# Verify that sliced quantized data matches the correct slice from original
721+
original_float8_data_slice = quant_model.weight.float8_data[0:16, 0:32]
722+
self.assertTrue(
723+
torch.equal(sliced_impl.float8_data, original_float8_data_slice)
724+
)
725+
726+
# Verify that sliced weights can be converted back to float with correct values
727+
sliced_float_weight = quant_weight_slice.to(dtype)
728+
self.assertEqual(sliced_float_weight.shape, (16, 32))
729+
self.assertEqual(sliced_float_weight.dtype, dtype)
730+
731+
input_slice = input_tensor[:, 0:32] # (8, 32) to match sliced weight
732+
733+
# Compute with sliced weights
734+
with torch.no_grad():
735+
ref_output = torch.nn.functional.linear(input_slice, ref_weight_slice)
736+
quant_output = torch.nn.functional.linear(input_slice, quant_weight_slice)
737+
738+
# Verify shapes
739+
expected_shape = (8, 16) # batch_size x out_features_sliced
740+
self.assertEqual(ref_output.shape, expected_shape)
741+
self.assertEqual(quant_output.shape, expected_shape)
742+
743+
# Verify reasonable quantization error
744+
error = compute_error(ref_output, quant_output)
745+
self.assertGreater(error, 15, f"Quantization SQNR too low: {error}")
746+
633747
def test_preprocess_scale_3d_reshape(self):
634748
"""Test that preprocess_scale correctly handles 3D scale tensors"""
635749
device = "cpu" # Use CPU for basic functionality test
@@ -675,46 +789,6 @@ def test_preprocess_scale_3d_reshape(self):
675789
expected_shape = (8, 1) # Flattened (2*2*2, 1)
676790
self.assertEqual(result.shape, expected_shape)
677791

678-
@common_utils.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
679-
@common_utils.parametrize("hp_dtype", [torch.float32, torch.bfloat16])
680-
def test_quantize_dequantize_fp8_inductor(self, float8_dtype, hp_dtype):
681-
quantize_affine_float8 = torch.ops.torchao.quantize_affine_float8
682-
dequantize_affine_float8 = torch.ops.torchao.dequantize_affine_float8
683-
input = torch.randn(10, 10)
684-
with torch.no_grad():
685-
torch._dynamo.reset()
686-
expected_scale = torch.tensor(2.0)
687-
expected_quantized = quantize_affine_float8(
688-
input,
689-
expected_scale,
690-
float8_dtype=float8_dtype,
691-
)
692-
expected_dequantized = dequantize_affine_float8(
693-
expected_quantized,
694-
expected_scale,
695-
output_dtype=hp_dtype,
696-
)
697-
test_q, (code_q,) = torch._inductor.utils.run_and_get_code(
698-
torch.compile(quantize_affine_float8),
699-
input,
700-
expected_scale,
701-
float8_dtype=float8_dtype,
702-
)
703-
torch.testing.FileCheck().check(
704-
"torch.ops.torchao.quantize_affine_float8.default"
705-
).run(code_q)
706-
test_dq, (code_dq,) = torch._inductor.utils.run_and_get_code(
707-
torch.compile(dequantize_affine_float8),
708-
test_q,
709-
expected_scale,
710-
hp_dtype,
711-
)
712-
torch.testing.FileCheck().check(
713-
"torch.ops.torchao.dequantize_affine_float8.default"
714-
).run(code_dq)
715-
torch.testing.assert_close(expected_quantized, test_q)
716-
torch.testing.assert_close(expected_dequantized, test_dq)
717-
718792

719793
common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)
720794

0 commit comments

Comments
 (0)