Skip to content

Commit 4bb52b3

Browse files
make internal torchao.float8 functions private
1 parent d963a88 commit 4bb52b3

18 files changed

+196
-188
lines changed

benchmarks/float8/bench_padding.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
GemmInputRole,
1717
LinearMMConfig,
1818
ScaledMMConfig,
19-
hp_tensor_and_scale_to_float8,
19+
_hp_tensor_and_scale_to_float8,
2020
)
21-
from torchao.float8.float8_utils import pad_tensor_for_matmul
21+
from torchao.float8.float8_utils import _pad_tensor_for_matmul
2222

2323
# estimating TOPs for matmuls in fp32, fp16, fp8
2424
# assuming A * B = C, with A being M * K, B being K * N, C being M * N
@@ -63,14 +63,14 @@ def do_fp8_matmul(A, B, fp8_dtype, out_dtype):
6363
a_config = LinearMMConfig(a_config, a_config, a_config)
6464
b_config = LinearMMConfig(b_config, b_config, b_config)
6565

66-
a_fp8 = hp_tensor_and_scale_to_float8(
66+
a_fp8 = _hp_tensor_and_scale_to_float8(
6767
A,
6868
scale_a,
6969
fp8_dtype,
7070
a_config,
7171
GemmInputRole.INPUT,
7272
)
73-
b_fp8 = hp_tensor_and_scale_to_float8(
73+
b_fp8 = _hp_tensor_and_scale_to_float8(
7474
B,
7575
scale_b,
7676
fp8_dtype,
@@ -84,8 +84,8 @@ def do_fp8_matmul(A, B, fp8_dtype, out_dtype):
8484
def do_fp8_pad_first_matmul(A, B, fp8_dtype, out_dtype):
8585
# Breaks with compile due to trying to pad on fp8 dtype
8686
# return do_fp8_matmul(A, B, fp8_dtype, out_dtype)
87-
A_pad = pad_tensor_for_matmul(A, dims=1) # mem copy
88-
B_pad = pad_tensor_for_matmul(B, dims=0) # mem copy
87+
A_pad = _pad_tensor_for_matmul(A, dims=1) # mem copy
88+
B_pad = _pad_tensor_for_matmul(B, dims=0) # mem copy
8989

9090
scale_a = torch.tensor([1], device="cuda", dtype=torch.float32)
9191
scale_b = torch.tensor([1], device="cuda", dtype=torch.float32)
@@ -105,8 +105,8 @@ def do_hp_matmul(A, B):
105105

106106

107107
def do_aligned_bf16_matmul(A, B):
108-
A_pad = pad_tensor_for_matmul(A, dims=1)
109-
B_pad = pad_tensor_for_matmul(B, dims=0)
108+
A_pad = _pad_tensor_for_matmul(A, dims=1)
109+
B_pad = _pad_tensor_for_matmul(B, dims=0)
110110
return torch.matmul(A_pad, B_pad)
111111

112112

test/float8/test_base.py

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,17 @@
3737
from torchao.float8.float8_linear_utils import (
3838
convert_to_float8_training,
3939
)
40-
from torchao.float8.float8_ops import addmm_float8_unwrapped
40+
from torchao.float8.float8_ops import _addmm_float8_unwrapped
4141
from torchao.float8.float8_scaling_utils import (
42-
get_maybe_axiswise_dim,
43-
hp_tensor_to_float8_dynamic,
42+
_get_maybe_axiswise_dim,
43+
_hp_tensor_to_float8_dynamic,
4444
)
4545
from torchao.float8.float8_tensor import (
4646
Float8Tensor,
4747
GemmInputRole,
4848
LinearMMConfig,
4949
ScaledMMConfig,
50-
hp_tensor_and_scale_to_float8,
50+
_hp_tensor_and_scale_to_float8,
5151
)
5252
from torchao.float8.float8_utils import (
5353
FP8_TYPES,
@@ -76,7 +76,7 @@ def test_preserves_dtype(self) -> None:
7676
for hp_dtype, lp_dtype in itertools.product(hp_dtypes, lp_dtypes):
7777
x1_hp = torch.randn(4, 4, dtype=hp_dtype)
7878
x1_s = tensor_to_scale(x1_hp, lp_dtype)
79-
x2_lp = hp_tensor_and_scale_to_float8(x1_hp, x1_s, lp_dtype)
79+
x2_lp = _hp_tensor_and_scale_to_float8(x1_hp, x1_s, lp_dtype)
8080
x3_hp = x2_lp.to_original_precision()
8181
assert x3_hp.dtype == hp_dtype
8282

@@ -86,7 +86,7 @@ def test_differentiable_casts(self) -> None:
8686
x = torch.randn(1).requires_grad_()
8787
grad = torch.randn(1)
8888
x_s = tensor_to_scale(x, f8_dtype)
89-
x_f8 = hp_tensor_and_scale_to_float8(x, x_s, f8_dtype)
89+
x_f8 = _hp_tensor_and_scale_to_float8(x, x_s, f8_dtype)
9090
x_f8_hp = x_f8.to_original_precision()
9191
x_f8_hp.backward(grad)
9292
# the gradient should be unchanged through both casts
@@ -95,7 +95,7 @@ def test_differentiable_casts(self) -> None:
9595
def test_split_cat(self):
9696
a = torch.rand(16, 16, dtype=torch.bfloat16)
9797
scale = tensor_to_scale(a, e4m3_dtype)
98-
fp8_a = hp_tensor_and_scale_to_float8(a, scale, e4m3_dtype)
98+
fp8_a = _hp_tensor_and_scale_to_float8(a, scale, e4m3_dtype)
9999

100100
splits = torch.split(fp8_a, 16)
101101
catted = torch.cat(splits, dim=0)
@@ -104,14 +104,14 @@ def test_split_cat(self):
104104
def test_index_put(self):
105105
a = torch.rand(16, dtype=torch.bfloat16)
106106
scale_a = tensor_to_scale(a, e4m3_dtype)
107-
fp8_a = hp_tensor_and_scale_to_float8(a, scale_a, e4m3_dtype)
107+
fp8_a = _hp_tensor_and_scale_to_float8(a, scale_a, e4m3_dtype)
108108

109109
index = torch.randint(0, 15, (16,), dtype=torch.long)
110110

111111
b = torch.rand(16, 16, dtype=torch.bfloat16)
112112
scale_b = tensor_to_scale(b, e4m3_dtype)
113-
fp8_b = hp_tensor_and_scale_to_float8(b, scale_a, e4m3_dtype)
114-
fp8_b_bad = hp_tensor_and_scale_to_float8(b, scale_b, e4m3_dtype)
113+
fp8_b = _hp_tensor_and_scale_to_float8(b, scale_a, e4m3_dtype)
114+
fp8_b_bad = _hp_tensor_and_scale_to_float8(b, scale_b, e4m3_dtype)
115115

116116
with pytest.raises(AssertionError):
117117
b[index] = fp8_a
@@ -122,7 +122,7 @@ def test_index_put(self):
122122
def test_copy_(self):
123123
a = torch.rand(16, dtype=torch.bfloat16)
124124
scale_a = tensor_to_scale(a, e4m3_dtype)
125-
fp8_a = hp_tensor_and_scale_to_float8(a, scale_a, e4m3_dtype)
125+
fp8_a = _hp_tensor_and_scale_to_float8(a, scale_a, e4m3_dtype)
126126

127127
b = torch.empty(16, dtype=torch.bfloat16)
128128
b.copy_(fp8_a) # Should work
@@ -143,10 +143,10 @@ def test_transpose(self):
143143
a = torch.rand((16, 16), dtype=torch.bfloat16)
144144
for axiswise_dim in (None, 0, -1):
145145
scale_a = tensor_to_scale(a, e4m3_dtype)
146-
fp8_a = hp_tensor_and_scale_to_float8(
146+
fp8_a = _hp_tensor_and_scale_to_float8(
147147
a, scale_a, e4m3_dtype, axiswise_dim=axiswise_dim
148148
)
149-
fp8_b = hp_tensor_and_scale_to_float8(
149+
fp8_b = _hp_tensor_and_scale_to_float8(
150150
a, scale_a, e4m3_dtype, axiswise_dim=axiswise_dim
151151
)
152152

@@ -166,7 +166,7 @@ def test_axiswise_dynamic_cast(
166166
):
167167
a = torch.randn(*shape, dtype=torch.bfloat16)
168168
linear_mm_config = LinearMMConfig()
169-
a_fp8 = hp_tensor_to_float8_dynamic(
169+
a_fp8 = _hp_tensor_to_float8_dynamic(
170170
a,
171171
e4m3_dtype,
172172
linear_mm_config,
@@ -183,7 +183,7 @@ def test_axiswise_reshape(self):
183183
linear_mm_config = LinearMMConfig()
184184

185185
# if we scale across dim0, we can only reshape to [3, -1]
186-
a_fp8_d0 = hp_tensor_to_float8_dynamic(
186+
a_fp8_d0 = _hp_tensor_to_float8_dynamic(
187187
a,
188188
e4m3_dtype,
189189
linear_mm_config,
@@ -207,7 +207,7 @@ def test_axiswise_reshape(self):
207207
a_fp8_d0.reshape(-1, 7)
208208

209209
# if we scale across dim2, we can only reshape to [-1, 7]
210-
a_fp8_d2 = hp_tensor_to_float8_dynamic(
210+
a_fp8_d2 = _hp_tensor_to_float8_dynamic(
211211
a,
212212
e4m3_dtype,
213213
linear_mm_config,
@@ -247,23 +247,23 @@ def test_axiswise_gemm(self, a_shape, a_granularity, b_granularity):
247247

248248
linear_mm_config = LinearMMConfig()
249249

250-
a_fp8 = hp_tensor_to_float8_dynamic(
250+
a_fp8 = _hp_tensor_to_float8_dynamic(
251251
a,
252252
e4m3_dtype,
253253
linear_mm_config,
254254
gemm_input_role=GemmInputRole.INPUT,
255255
scaling_granularity=a_granularity,
256-
axiswise_dim=get_maybe_axiswise_dim(-1, a_granularity),
256+
axiswise_dim=_get_maybe_axiswise_dim(-1, a_granularity),
257257
)
258258
a_fp8 = a_fp8.reshape(-1, a_shape[-1])
259259

260-
b_fp8 = hp_tensor_to_float8_dynamic(
260+
b_fp8 = _hp_tensor_to_float8_dynamic(
261261
b,
262262
e4m3_dtype,
263263
linear_mm_config,
264264
gemm_input_role=GemmInputRole.WEIGHT,
265265
scaling_granularity=b_granularity,
266-
axiswise_dim=get_maybe_axiswise_dim(-1, b_granularity),
266+
axiswise_dim=_get_maybe_axiswise_dim(-1, b_granularity),
267267
)
268268

269269
c_fp8_compute = torch.mm(a_fp8, b_fp8.t())
@@ -528,10 +528,10 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum):
528528
a_scale = tensor_to_scale(a, input_dtype).float()
529529
b_scale = tensor_to_scale(b, input_dtype).float()
530530

531-
a_fp8 = hp_tensor_and_scale_to_float8(a, a_scale, input_dtype)
532-
b_fp8 = hp_tensor_and_scale_to_float8(b, b_scale, input_dtype)
531+
a_fp8 = _hp_tensor_and_scale_to_float8(a, a_scale, input_dtype)
532+
b_fp8 = _hp_tensor_and_scale_to_float8(b, b_scale, input_dtype)
533533

534-
out_scaled_mm = addmm_float8_unwrapped(
534+
out_scaled_mm = _addmm_float8_unwrapped(
535535
a_fp8._data,
536536
a_fp8._scale,
537537
b_fp8._data,
@@ -569,14 +569,14 @@ def test_different_configs_error(self):
569569
ScaledMMConfig(True, False, False, False),
570570
ScaledMMConfig(True, False, False, False),
571571
)
572-
a = hp_tensor_and_scale_to_float8(
572+
a = _hp_tensor_and_scale_to_float8(
573573
x_fp32,
574574
x_scale,
575575
fp8_dtype,
576576
linear_config_a,
577577
GemmInputRole.INPUT,
578578
)
579-
b = hp_tensor_and_scale_to_float8(
579+
b = _hp_tensor_and_scale_to_float8(
580580
x_fp32,
581581
x_scale,
582582
fp8_dtype,
@@ -608,10 +608,10 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum):
608608
a_scale = tensor_to_scale(a, input_dtype).float()
609609
b_scale = tensor_to_scale(b, input_dtype).float()
610610

611-
a_fp8 = hp_tensor_and_scale_to_float8(
611+
a_fp8 = _hp_tensor_and_scale_to_float8(
612612
a, a_scale, input_dtype, None, GemmInputRole.INPUT
613613
)
614-
b_fp8 = hp_tensor_and_scale_to_float8(
614+
b_fp8 = _hp_tensor_and_scale_to_float8(
615615
b, b_scale, input_dtype, None, GemmInputRole.WEIGHT
616616
)
617617

@@ -628,14 +628,14 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum):
628628
scaled_mm_config, scaled_mm_config, scaled_mm_config
629629
)
630630

631-
a_fp8 = hp_tensor_and_scale_to_float8(
631+
a_fp8 = _hp_tensor_and_scale_to_float8(
632632
a,
633633
a_scale,
634634
input_dtype,
635635
pad_config,
636636
GemmInputRole.INPUT,
637637
)
638-
b_fp8 = hp_tensor_and_scale_to_float8(
638+
b_fp8 = _hp_tensor_and_scale_to_float8(
639639
b,
640640
b_scale,
641641
input_dtype,
@@ -651,14 +651,14 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum):
651651
emulated_scaled_mm_config,
652652
emulated_scaled_mm_config,
653653
)
654-
a_fp8 = hp_tensor_and_scale_to_float8(
654+
a_fp8 = _hp_tensor_and_scale_to_float8(
655655
a,
656656
a_scale,
657657
input_dtype,
658658
emulated_config,
659659
GemmInputRole.INPUT,
660660
)
661-
b_fp8 = hp_tensor_and_scale_to_float8(
661+
b_fp8 = _hp_tensor_and_scale_to_float8(
662662
b,
663663
b_scale,
664664
input_dtype,
@@ -813,19 +813,19 @@ def test_fp8_tensor_statistics(self):
813813

814814
# Overflow caused by a too large scaling factor
815815
s_overflow = torch.tensor(1e9)
816-
fp8_overflow = hp_tensor_and_scale_to_float8(x1_hp, s_overflow, lp_dtype)
816+
fp8_overflow = _hp_tensor_and_scale_to_float8(x1_hp, s_overflow, lp_dtype)
817817
(zero_cnt, max_cnt) = fp8_tensor_statistics(fp8_overflow, lp_dtype)
818818
self.assertEqual((zero_cnt, max_cnt), (0, tensor_len))
819819

820820
# Underflow caused by a too small scaling factor
821821
s_underflow = torch.tensor(1e-9)
822-
fp8_underflow = hp_tensor_and_scale_to_float8(x1_hp, s_underflow, lp_dtype)
822+
fp8_underflow = _hp_tensor_and_scale_to_float8(x1_hp, s_underflow, lp_dtype)
823823
(zero_cnt, max_cnt) = fp8_tensor_statistics(fp8_underflow, lp_dtype)
824824
self.assertEqual((zero_cnt, max_cnt), (tensor_len, 0))
825825

826826
# Both overflow and underflow
827827
x2_hp = torch.cat((x1_hp * 1e9, x1_hp * 1.0, x1_hp * 1e-9), 0)
828-
fp8_over_underflow = hp_tensor_and_scale_to_float8(
828+
fp8_over_underflow = _hp_tensor_and_scale_to_float8(
829829
x2_hp, torch.tensor(1.0), lp_dtype
830830
)
831831
(zero_cnt, max_cnt) = fp8_tensor_statistics(fp8_over_underflow, lp_dtype)

test/float8/test_compile.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
)
3535
from torchao.float8.float8_linear import Float8Linear
3636
from torchao.float8.float8_scaling_utils import (
37-
hp_tensor_to_float8_dynamic,
37+
_hp_tensor_to_float8_dynamic,
3838
)
3939
from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig, ScaledMMConfig
4040
from torchao.testing.float8.test_utils import get_test_float8_linear_config
@@ -221,7 +221,7 @@ def __init__(self, graph_break: bool):
221221
self.graph_break = graph_break
222222

223223
def forward(self, x):
224-
x_fp8 = hp_tensor_to_float8_dynamic(
224+
x_fp8 = _hp_tensor_to_float8_dynamic(
225225
x,
226226
e4m3_dtype,
227227
LinearMMConfig(),
@@ -373,15 +373,15 @@ def test_dynamic_scale_numeric_parity(
373373
float8_config.pad_inner_dim,
374374
),
375375
)
376-
float8_eager = hp_tensor_to_float8_dynamic(
376+
float8_eager = _hp_tensor_to_float8_dynamic(
377377
hp_tensor1,
378378
e4m3_dtype,
379379
linear_mm_config,
380380
gemm_input_role=GemmInputRole.WEIGHT,
381381
round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2,
382382
)
383383
torch._dynamo.reset()
384-
float8_compile = torch.compile(hp_tensor_to_float8_dynamic)(
384+
float8_compile = torch.compile(_hp_tensor_to_float8_dynamic)(
385385
hp_tensor2,
386386
e4m3_dtype,
387387
linear_mm_config,

0 commit comments

Comments
 (0)