25
25
from torch ._inductor .test_case import TestCase as InductorTestCase
26
26
from torch .testing ._internal import common_utils
27
27
28
+ from torchao .dtypes import FbgemmFp8Tensor
28
29
from torchao .dtypes .floatx .float8_layout import Float8AQTTensorImpl , preprocess_scale
29
30
from torchao .float8 .float8_utils import compute_error
30
31
from torchao .quantization import (
@@ -324,19 +325,15 @@ def test_mm_float8dq_per_row(
324
325
325
326
quant_weight = test_linear .weight
326
327
327
- self .assertTrue (hasattr (quant_weight , "original_weight_tensor" ))
328
- weight_impl = quant_weight .original_weight_tensor .tensor_impl
329
-
330
- self .assertTrue (hasattr (weight_impl , "float8_data" ))
331
- self .assertTrue (hasattr (weight_impl , "scale" ))
332
- self .assertFalse (weight_impl .transposed )
328
+ self .assertTrue (hasattr (quant_weight , "float8_data" ))
329
+ self .assertTrue (hasattr (quant_weight , "scale" ))
333
330
334
331
# Verify scale shape for row-wise quantization
335
332
expected_scale_shape = (out_features , 1 )
336
- actual_scale_shape = weight_impl .scale .shape
333
+ actual_scale_shape = quant_weight .scale .shape
337
334
self .assertEqual (actual_scale_shape , expected_scale_shape )
338
335
339
- self .assertEqual (weight_impl .float8_data .shape , (out_features , in_features ))
336
+ self .assertEqual (quant_weight .float8_data .shape , (out_features , in_features ))
340
337
341
338
input_tensor = torch .randn (* input_shape , device = device , dtype = dtype )
342
339
@@ -419,11 +416,11 @@ def test_dequantize_affine_float8_scale_broadcasting(self):
419
416
@unittest .skipIf (
420
417
not is_sm_at_least_89 (), "Requires GPU with compute capability >= 8.9"
421
418
)
422
- @common_utils .parametrize ("granularity" , [PerTensor (), PerRow ()])
423
- def test_float8_tensor_slicing_basic (self , granularity ):
419
+ def test_float8_tensor_slicing_basic_per_tensor (self ):
424
420
"""Test basic slicing operations on Float8 tensors"""
425
421
device = "cuda"
426
422
dtype = torch .bfloat16
423
+ granularity = PerTensor ()
427
424
428
425
# Create and quantize a model
429
426
model = torch .nn .Linear (64 , 32 , bias = False ).to (device ).to (dtype )
@@ -450,6 +447,41 @@ def test_float8_tensor_slicing_basic(self, granularity):
450
447
self .assertTrue (isinstance (sliced_1 , Float8AQTTensorImpl ))
451
448
self .assertTrue (isinstance (sliced_both , Float8AQTTensorImpl ))
452
449
450
+ @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
451
+ @unittest .skipIf (
452
+ not is_sm_at_least_89 (), "Requires GPU with compute capability >= 8.9"
453
+ )
454
+ def test_float8_tensor_slicing_basic_per_row (self ):
455
+ """Test basic slicing operations on Float8 tensors"""
456
+ device = "cuda"
457
+ dtype = torch .bfloat16
458
+ granularity = PerRow ()
459
+
460
+ # Create and quantize a model
461
+ model = torch .nn .Linear (64 , 32 , bias = False ).to (device ).to (dtype )
462
+ quantize_ (
463
+ model , Float8DynamicActivationFloat8WeightConfig (granularity = granularity )
464
+ )
465
+
466
+ weight = model .weight
467
+
468
+ # Test dimension 0 slicing (rows)
469
+ sliced_0 = weight [10 :20 ]
470
+ self .assertEqual (sliced_0 .shape , (10 , 64 ))
471
+
472
+ # Test dimension 1 slicing (columns)
473
+ sliced_1 = weight [:, 20 :40 ]
474
+ self .assertEqual (sliced_1 .shape , (32 , 20 ))
475
+
476
+ # Test combined slicing
477
+ sliced_both = weight [5 :15 , 10 :30 ]
478
+ self .assertEqual (sliced_both .shape , (10 , 20 ))
479
+
480
+ # Verify the sliced tensors are still Float8 tensors
481
+ self .assertTrue (isinstance (sliced_0 , FbgemmFp8Tensor ))
482
+ self .assertTrue (isinstance (sliced_1 , FbgemmFp8Tensor ))
483
+ self .assertTrue (isinstance (sliced_both , FbgemmFp8Tensor ))
484
+
453
485
@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
454
486
@unittest .skipIf (
455
487
not is_sm_at_least_89 (), "Requires GPU with compute capability >= 8.9"
@@ -497,27 +529,26 @@ def test_float8_tensor_slicing_per_row(self):
497
529
)
498
530
499
531
original_weight = model .weight # Shape: (32, 64)
500
- original_impl = original_weight .original_weight_tensor .tensor_impl
501
- original_scale = original_impl .scale # Shape: (32, 1)
532
+ original_scale = model .weight .scale # Shape: (32, 1)
502
533
503
534
# Test row slicing (dimension 0)
504
535
sliced_rows = original_weight [10 :20 ] # Shape: (10, 64)
505
- sliced_impl = sliced_rows .original_weight_tensor . tensor_impl
536
+ sliced_scale = sliced_rows .scale
506
537
507
538
# Scale should be sliced to match the rows
508
539
expected_scale_shape = (10 , 1 )
509
- self .assertEqual (sliced_impl . scale .shape , expected_scale_shape )
540
+ self .assertEqual (sliced_scale .shape , expected_scale_shape )
510
541
511
542
# Verify the scale values are correct (should be subset of original)
512
- self .assertTrue (torch .equal (sliced_impl . scale , original_scale [10 :20 ]))
543
+ self .assertTrue (torch .equal (sliced_scale , original_scale [10 :20 ]))
513
544
514
545
# Test column slicing (dimension 1) - scale should not change for per-row
515
546
sliced_cols = original_weight [:, 20 :40 ] # Shape: (32, 20)
516
- sliced_cols_impl = sliced_cols .original_weight_tensor . tensor_impl
547
+ sliced_cols_scale = sliced_cols .scale
517
548
518
549
# Scale shape should remain the same since we're not changing rows
519
- self .assertEqual (sliced_cols_impl . scale .shape , (32 , 1 ))
520
- self .assertTrue (torch .equal (sliced_cols_impl . scale , original_scale ))
550
+ self .assertEqual (sliced_cols_scale .shape , (32 , 1 ))
551
+ self .assertTrue (torch .equal (sliced_cols_scale , original_scale ))
521
552
522
553
@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
523
554
@unittest .skipIf (
@@ -552,15 +583,15 @@ def test_float8_tensor_slicing_edge_cases(self):
552
583
@unittest .skipIf (
553
584
not is_sm_at_least_89 (), "Requires GPU with compute capability >= 8.9"
554
585
)
555
- @common_utils .parametrize ("granularity" , [PerTensor (), PerRow ()])
556
586
@unittest .skipIf (
557
587
is_sm_version (8 , 9 ),
558
588
"TODO: AssertionError: tensor(-2.1562, device='cuda:0', dtype=torch.bfloat16) not greater than 15" ,
559
589
)
560
- def test_float8_tensor_slicing_functional_correctness (self , granularity ):
590
+ def test_float8_tensor_slicing_functional_correctness_per_tensor (self ):
561
591
"""Test that sliced tensors produce correct results in computations"""
562
592
device = "cuda"
563
593
dtype = torch .bfloat16
594
+ granularity = PerTensor ()
564
595
565
596
# Create reference and quantized models with dimensions that are multiples of 16
566
597
ref_model = (
@@ -630,6 +661,89 @@ def test_float8_tensor_slicing_functional_correctness(self, granularity):
630
661
error = compute_error (ref_output , quant_output )
631
662
self .assertGreater (error , 15 , f"Quantization SQNR too low: { error } " )
632
663
664
+ @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
665
+ @unittest .skipIf (
666
+ not is_sm_at_least_89 (), "Requires GPU with compute capability >= 8.9"
667
+ )
668
+ @unittest .skipIf (
669
+ is_sm_version (8 , 9 ),
670
+ "TODO: AssertionError: tensor(-2.1562, device='cuda:0', dtype=torch.bfloat16) not greater than 15" ,
671
+ )
672
+ def test_float8_tensor_slicing_functional_correctness_per_row (self ):
673
+ """Test that sliced tensors produce correct results in computations"""
674
+ device = "cuda"
675
+ dtype = torch .bfloat16
676
+ granularity = PerRow ()
677
+
678
+ # Create reference and quantized models with dimensions that are multiples of 16
679
+ ref_model = (
680
+ torch .nn .Linear (64 , 48 , bias = False ).to (device ).to (dtype )
681
+ ) # 48 is divisible by 16
682
+ quant_model = copy .deepcopy (ref_model )
683
+ quantize_ (
684
+ quant_model ,
685
+ Float8DynamicActivationFloat8WeightConfig (granularity = granularity ),
686
+ )
687
+
688
+ # Create input with batch size that works well with slicing
689
+ input_tensor = torch .randn (8 , 64 , device = device , dtype = dtype )
690
+
691
+ ref_weight_slice = ref_model .weight [0 :16 , 0 :32 ]
692
+ quant_weight_slice = quant_model .weight [0 :16 , 0 :32 ]
693
+
694
+ # Verify that the sliced weights maintain Float8 properties
695
+ self .assertTrue (hasattr (quant_weight_slice , "float8_data" ))
696
+ self .assertTrue (hasattr (quant_weight_slice , "scale" ))
697
+ sliced_impl = quant_weight_slice
698
+ self .assertTrue (isinstance (sliced_impl , FbgemmFp8Tensor ))
699
+
700
+ # Verify sliced weight shapes
701
+ self .assertEqual (sliced_impl .float8_data .shape , (16 , 32 ))
702
+
703
+ # Get original quantized weight implementation for scale comparison
704
+ original_quant_impl = quant_model .weight
705
+
706
+ # Verify scale properties based on granularity
707
+ if isinstance (granularity , PerTensor ):
708
+ # Per-tensor: scale should be identical to original (scalar)
709
+ self .assertEqual (sliced_impl .scale .numel (), 1 )
710
+ self .assertTrue (torch .equal (sliced_impl .scale , original_quant_impl .scale ))
711
+ else : # PerRow
712
+ # Per-row: scale should be sliced to match the selected rows (0:16)
713
+ expected_scale_shape = (16 , 1 )
714
+ self .assertEqual (sliced_impl .scale .shape , expected_scale_shape )
715
+ # Verify the scale values are the correct slice from the original
716
+ self .assertTrue (
717
+ torch .equal (sliced_impl .scale , original_quant_impl .scale [0 :16 ])
718
+ )
719
+
720
+ # Verify that sliced quantized data matches the correct slice from original
721
+ original_float8_data_slice = quant_model .weight .float8_data [0 :16 , 0 :32 ]
722
+ self .assertTrue (
723
+ torch .equal (sliced_impl .float8_data , original_float8_data_slice )
724
+ )
725
+
726
+ # Verify that sliced weights can be converted back to float with correct values
727
+ sliced_float_weight = quant_weight_slice .to (dtype )
728
+ self .assertEqual (sliced_float_weight .shape , (16 , 32 ))
729
+ self .assertEqual (sliced_float_weight .dtype , dtype )
730
+
731
+ input_slice = input_tensor [:, 0 :32 ] # (8, 32) to match sliced weight
732
+
733
+ # Compute with sliced weights
734
+ with torch .no_grad ():
735
+ ref_output = torch .nn .functional .linear (input_slice , ref_weight_slice )
736
+ quant_output = torch .nn .functional .linear (input_slice , quant_weight_slice )
737
+
738
+ # Verify shapes
739
+ expected_shape = (8 , 16 ) # batch_size x out_features_sliced
740
+ self .assertEqual (ref_output .shape , expected_shape )
741
+ self .assertEqual (quant_output .shape , expected_shape )
742
+
743
+ # Verify reasonable quantization error
744
+ error = compute_error (ref_output , quant_output )
745
+ self .assertGreater (error , 15 , f"Quantization SQNR too low: { error } " )
746
+
633
747
def test_preprocess_scale_3d_reshape (self ):
634
748
"""Test that preprocess_scale correctly handles 3D scale tensors"""
635
749
device = "cpu" # Use CPU for basic functionality test
0 commit comments