37
37
from torchao .float8 .float8_linear_utils import (
38
38
convert_to_float8_training ,
39
39
)
40
- from torchao .float8 .float8_ops import addmm_float8_unwrapped
40
+ from torchao .float8 .float8_ops import _addmm_float8_unwrapped
41
41
from torchao .float8 .float8_scaling_utils import (
42
- get_maybe_axiswise_dim ,
43
- hp_tensor_to_float8_dynamic ,
42
+ _get_maybe_axiswise_dim ,
43
+ _hp_tensor_to_float8_dynamic ,
44
44
)
45
45
from torchao .float8 .float8_tensor import (
46
46
Float8Tensor ,
47
47
GemmInputRole ,
48
48
LinearMMConfig ,
49
49
ScaledMMConfig ,
50
- hp_tensor_and_scale_to_float8 ,
50
+ _hp_tensor_and_scale_to_float8 ,
51
51
)
52
52
from torchao .float8 .float8_utils import (
53
53
FP8_TYPES ,
@@ -76,7 +76,7 @@ def test_preserves_dtype(self) -> None:
76
76
for hp_dtype , lp_dtype in itertools .product (hp_dtypes , lp_dtypes ):
77
77
x1_hp = torch .randn (4 , 4 , dtype = hp_dtype )
78
78
x1_s = tensor_to_scale (x1_hp , lp_dtype )
79
- x2_lp = hp_tensor_and_scale_to_float8 (x1_hp , x1_s , lp_dtype )
79
+ x2_lp = _hp_tensor_and_scale_to_float8 (x1_hp , x1_s , lp_dtype )
80
80
x3_hp = x2_lp .to_original_precision ()
81
81
assert x3_hp .dtype == hp_dtype
82
82
@@ -86,7 +86,7 @@ def test_differentiable_casts(self) -> None:
86
86
x = torch .randn (1 ).requires_grad_ ()
87
87
grad = torch .randn (1 )
88
88
x_s = tensor_to_scale (x , f8_dtype )
89
- x_f8 = hp_tensor_and_scale_to_float8 (x , x_s , f8_dtype )
89
+ x_f8 = _hp_tensor_and_scale_to_float8 (x , x_s , f8_dtype )
90
90
x_f8_hp = x_f8 .to_original_precision ()
91
91
x_f8_hp .backward (grad )
92
92
# the gradient should be unchanged through both casts
@@ -95,7 +95,7 @@ def test_differentiable_casts(self) -> None:
95
95
def test_split_cat (self ):
96
96
a = torch .rand (16 , 16 , dtype = torch .bfloat16 )
97
97
scale = tensor_to_scale (a , e4m3_dtype )
98
- fp8_a = hp_tensor_and_scale_to_float8 (a , scale , e4m3_dtype )
98
+ fp8_a = _hp_tensor_and_scale_to_float8 (a , scale , e4m3_dtype )
99
99
100
100
splits = torch .split (fp8_a , 16 )
101
101
catted = torch .cat (splits , dim = 0 )
@@ -104,14 +104,14 @@ def test_split_cat(self):
104
104
def test_index_put (self ):
105
105
a = torch .rand (16 , dtype = torch .bfloat16 )
106
106
scale_a = tensor_to_scale (a , e4m3_dtype )
107
- fp8_a = hp_tensor_and_scale_to_float8 (a , scale_a , e4m3_dtype )
107
+ fp8_a = _hp_tensor_and_scale_to_float8 (a , scale_a , e4m3_dtype )
108
108
109
109
index = torch .randint (0 , 15 , (16 ,), dtype = torch .long )
110
110
111
111
b = torch .rand (16 , 16 , dtype = torch .bfloat16 )
112
112
scale_b = tensor_to_scale (b , e4m3_dtype )
113
- fp8_b = hp_tensor_and_scale_to_float8 (b , scale_a , e4m3_dtype )
114
- fp8_b_bad = hp_tensor_and_scale_to_float8 (b , scale_b , e4m3_dtype )
113
+ fp8_b = _hp_tensor_and_scale_to_float8 (b , scale_a , e4m3_dtype )
114
+ fp8_b_bad = _hp_tensor_and_scale_to_float8 (b , scale_b , e4m3_dtype )
115
115
116
116
with pytest .raises (AssertionError ):
117
117
b [index ] = fp8_a
@@ -122,7 +122,7 @@ def test_index_put(self):
122
122
def test_copy_ (self ):
123
123
a = torch .rand (16 , dtype = torch .bfloat16 )
124
124
scale_a = tensor_to_scale (a , e4m3_dtype )
125
- fp8_a = hp_tensor_and_scale_to_float8 (a , scale_a , e4m3_dtype )
125
+ fp8_a = _hp_tensor_and_scale_to_float8 (a , scale_a , e4m3_dtype )
126
126
127
127
b = torch .empty (16 , dtype = torch .bfloat16 )
128
128
b .copy_ (fp8_a ) # Should work
@@ -143,10 +143,10 @@ def test_transpose(self):
143
143
a = torch .rand ((16 , 16 ), dtype = torch .bfloat16 )
144
144
for axiswise_dim in (None , 0 , - 1 ):
145
145
scale_a = tensor_to_scale (a , e4m3_dtype )
146
- fp8_a = hp_tensor_and_scale_to_float8 (
146
+ fp8_a = _hp_tensor_and_scale_to_float8 (
147
147
a , scale_a , e4m3_dtype , axiswise_dim = axiswise_dim
148
148
)
149
- fp8_b = hp_tensor_and_scale_to_float8 (
149
+ fp8_b = _hp_tensor_and_scale_to_float8 (
150
150
a , scale_a , e4m3_dtype , axiswise_dim = axiswise_dim
151
151
)
152
152
@@ -166,7 +166,7 @@ def test_axiswise_dynamic_cast(
166
166
):
167
167
a = torch .randn (* shape , dtype = torch .bfloat16 )
168
168
linear_mm_config = LinearMMConfig ()
169
- a_fp8 = hp_tensor_to_float8_dynamic (
169
+ a_fp8 = _hp_tensor_to_float8_dynamic (
170
170
a ,
171
171
e4m3_dtype ,
172
172
linear_mm_config ,
@@ -183,7 +183,7 @@ def test_axiswise_reshape(self):
183
183
linear_mm_config = LinearMMConfig ()
184
184
185
185
# if we scale across dim0, we can only reshape to [3, -1]
186
- a_fp8_d0 = hp_tensor_to_float8_dynamic (
186
+ a_fp8_d0 = _hp_tensor_to_float8_dynamic (
187
187
a ,
188
188
e4m3_dtype ,
189
189
linear_mm_config ,
@@ -207,7 +207,7 @@ def test_axiswise_reshape(self):
207
207
a_fp8_d0 .reshape (- 1 , 7 )
208
208
209
209
# if we scale across dim2, we can only reshape to [-1, 7]
210
- a_fp8_d2 = hp_tensor_to_float8_dynamic (
210
+ a_fp8_d2 = _hp_tensor_to_float8_dynamic (
211
211
a ,
212
212
e4m3_dtype ,
213
213
linear_mm_config ,
@@ -247,23 +247,23 @@ def test_axiswise_gemm(self, a_shape, a_granularity, b_granularity):
247
247
248
248
linear_mm_config = LinearMMConfig ()
249
249
250
- a_fp8 = hp_tensor_to_float8_dynamic (
250
+ a_fp8 = _hp_tensor_to_float8_dynamic (
251
251
a ,
252
252
e4m3_dtype ,
253
253
linear_mm_config ,
254
254
gemm_input_role = GemmInputRole .INPUT ,
255
255
scaling_granularity = a_granularity ,
256
- axiswise_dim = get_maybe_axiswise_dim (- 1 , a_granularity ),
256
+ axiswise_dim = _get_maybe_axiswise_dim (- 1 , a_granularity ),
257
257
)
258
258
a_fp8 = a_fp8 .reshape (- 1 , a_shape [- 1 ])
259
259
260
- b_fp8 = hp_tensor_to_float8_dynamic (
260
+ b_fp8 = _hp_tensor_to_float8_dynamic (
261
261
b ,
262
262
e4m3_dtype ,
263
263
linear_mm_config ,
264
264
gemm_input_role = GemmInputRole .WEIGHT ,
265
265
scaling_granularity = b_granularity ,
266
- axiswise_dim = get_maybe_axiswise_dim (- 1 , b_granularity ),
266
+ axiswise_dim = _get_maybe_axiswise_dim (- 1 , b_granularity ),
267
267
)
268
268
269
269
c_fp8_compute = torch .mm (a_fp8 , b_fp8 .t ())
@@ -528,10 +528,10 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum):
528
528
a_scale = tensor_to_scale (a , input_dtype ).float ()
529
529
b_scale = tensor_to_scale (b , input_dtype ).float ()
530
530
531
- a_fp8 = hp_tensor_and_scale_to_float8 (a , a_scale , input_dtype )
532
- b_fp8 = hp_tensor_and_scale_to_float8 (b , b_scale , input_dtype )
531
+ a_fp8 = _hp_tensor_and_scale_to_float8 (a , a_scale , input_dtype )
532
+ b_fp8 = _hp_tensor_and_scale_to_float8 (b , b_scale , input_dtype )
533
533
534
- out_scaled_mm = addmm_float8_unwrapped (
534
+ out_scaled_mm = _addmm_float8_unwrapped (
535
535
a_fp8 ._data ,
536
536
a_fp8 ._scale ,
537
537
b_fp8 ._data ,
@@ -569,14 +569,14 @@ def test_different_configs_error(self):
569
569
ScaledMMConfig (True , False , False , False ),
570
570
ScaledMMConfig (True , False , False , False ),
571
571
)
572
- a = hp_tensor_and_scale_to_float8 (
572
+ a = _hp_tensor_and_scale_to_float8 (
573
573
x_fp32 ,
574
574
x_scale ,
575
575
fp8_dtype ,
576
576
linear_config_a ,
577
577
GemmInputRole .INPUT ,
578
578
)
579
- b = hp_tensor_and_scale_to_float8 (
579
+ b = _hp_tensor_and_scale_to_float8 (
580
580
x_fp32 ,
581
581
x_scale ,
582
582
fp8_dtype ,
@@ -608,10 +608,10 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum):
608
608
a_scale = tensor_to_scale (a , input_dtype ).float ()
609
609
b_scale = tensor_to_scale (b , input_dtype ).float ()
610
610
611
- a_fp8 = hp_tensor_and_scale_to_float8 (
611
+ a_fp8 = _hp_tensor_and_scale_to_float8 (
612
612
a , a_scale , input_dtype , None , GemmInputRole .INPUT
613
613
)
614
- b_fp8 = hp_tensor_and_scale_to_float8 (
614
+ b_fp8 = _hp_tensor_and_scale_to_float8 (
615
615
b , b_scale , input_dtype , None , GemmInputRole .WEIGHT
616
616
)
617
617
@@ -628,14 +628,14 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum):
628
628
scaled_mm_config , scaled_mm_config , scaled_mm_config
629
629
)
630
630
631
- a_fp8 = hp_tensor_and_scale_to_float8 (
631
+ a_fp8 = _hp_tensor_and_scale_to_float8 (
632
632
a ,
633
633
a_scale ,
634
634
input_dtype ,
635
635
pad_config ,
636
636
GemmInputRole .INPUT ,
637
637
)
638
- b_fp8 = hp_tensor_and_scale_to_float8 (
638
+ b_fp8 = _hp_tensor_and_scale_to_float8 (
639
639
b ,
640
640
b_scale ,
641
641
input_dtype ,
@@ -651,14 +651,14 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum):
651
651
emulated_scaled_mm_config ,
652
652
emulated_scaled_mm_config ,
653
653
)
654
- a_fp8 = hp_tensor_and_scale_to_float8 (
654
+ a_fp8 = _hp_tensor_and_scale_to_float8 (
655
655
a ,
656
656
a_scale ,
657
657
input_dtype ,
658
658
emulated_config ,
659
659
GemmInputRole .INPUT ,
660
660
)
661
- b_fp8 = hp_tensor_and_scale_to_float8 (
661
+ b_fp8 = _hp_tensor_and_scale_to_float8 (
662
662
b ,
663
663
b_scale ,
664
664
input_dtype ,
@@ -813,19 +813,19 @@ def test_fp8_tensor_statistics(self):
813
813
814
814
# Overflow caused by a too large scaling factor
815
815
s_overflow = torch .tensor (1e9 )
816
- fp8_overflow = hp_tensor_and_scale_to_float8 (x1_hp , s_overflow , lp_dtype )
816
+ fp8_overflow = _hp_tensor_and_scale_to_float8 (x1_hp , s_overflow , lp_dtype )
817
817
(zero_cnt , max_cnt ) = fp8_tensor_statistics (fp8_overflow , lp_dtype )
818
818
self .assertEqual ((zero_cnt , max_cnt ), (0 , tensor_len ))
819
819
820
820
# Underflow caused by a too small scaling factor
821
821
s_underflow = torch .tensor (1e-9 )
822
- fp8_underflow = hp_tensor_and_scale_to_float8 (x1_hp , s_underflow , lp_dtype )
822
+ fp8_underflow = _hp_tensor_and_scale_to_float8 (x1_hp , s_underflow , lp_dtype )
823
823
(zero_cnt , max_cnt ) = fp8_tensor_statistics (fp8_underflow , lp_dtype )
824
824
self .assertEqual ((zero_cnt , max_cnt ), (tensor_len , 0 ))
825
825
826
826
# Both overflow and underflow
827
827
x2_hp = torch .cat ((x1_hp * 1e9 , x1_hp * 1.0 , x1_hp * 1e-9 ), 0 )
828
- fp8_over_underflow = hp_tensor_and_scale_to_float8 (
828
+ fp8_over_underflow = _hp_tensor_and_scale_to_float8 (
829
829
x2_hp , torch .tensor (1.0 ), lp_dtype
830
830
)
831
831
(zero_cnt , max_cnt ) = fp8_tensor_statistics (fp8_over_underflow , lp_dtype )
0 commit comments