Skip to content

[BE] Make internal torchao.float8 functions private #2321

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions benchmarks/float8/bench_padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
GemmInputRole,
LinearMMConfig,
ScaledMMConfig,
hp_tensor_and_scale_to_float8,
_hp_tensor_and_scale_to_float8,
)
from torchao.float8.float8_utils import pad_tensor_for_matmul
from torchao.float8.float8_utils import _pad_tensor_for_matmul

# estimating TOPs for matmuls in fp32, fp16, fp8
# assuming A * B = C, with A being M * K, B being K * N, C being M * N
Expand Down Expand Up @@ -63,14 +63,14 @@ def do_fp8_matmul(A, B, fp8_dtype, out_dtype):
a_config = LinearMMConfig(a_config, a_config, a_config)
b_config = LinearMMConfig(b_config, b_config, b_config)

a_fp8 = hp_tensor_and_scale_to_float8(
a_fp8 = _hp_tensor_and_scale_to_float8(
A,
scale_a,
fp8_dtype,
a_config,
GemmInputRole.INPUT,
)
b_fp8 = hp_tensor_and_scale_to_float8(
b_fp8 = _hp_tensor_and_scale_to_float8(
B,
scale_b,
fp8_dtype,
Expand All @@ -84,8 +84,8 @@ def do_fp8_matmul(A, B, fp8_dtype, out_dtype):
def do_fp8_pad_first_matmul(A, B, fp8_dtype, out_dtype):
# Breaks with compile due to trying to pad on fp8 dtype
# return do_fp8_matmul(A, B, fp8_dtype, out_dtype)
A_pad = pad_tensor_for_matmul(A, dims=1) # mem copy
B_pad = pad_tensor_for_matmul(B, dims=0) # mem copy
A_pad = _pad_tensor_for_matmul(A, dims=1) # mem copy
B_pad = _pad_tensor_for_matmul(B, dims=0) # mem copy

scale_a = torch.tensor([1], device="cuda", dtype=torch.float32)
scale_b = torch.tensor([1], device="cuda", dtype=torch.float32)
Expand All @@ -105,8 +105,8 @@ def do_hp_matmul(A, B):


def do_aligned_bf16_matmul(A, B):
A_pad = pad_tensor_for_matmul(A, dims=1)
B_pad = pad_tensor_for_matmul(B, dims=0)
A_pad = _pad_tensor_for_matmul(A, dims=1)
B_pad = _pad_tensor_for_matmul(B, dims=0)
return torch.matmul(A_pad, B_pad)


Expand Down
68 changes: 34 additions & 34 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,17 @@
from torchao.float8.float8_linear_utils import (
convert_to_float8_training,
)
from torchao.float8.float8_ops import addmm_float8_unwrapped
from torchao.float8.float8_ops import _addmm_float8_unwrapped
from torchao.float8.float8_scaling_utils import (
get_maybe_axiswise_dim,
hp_tensor_to_float8_dynamic,
_get_maybe_axiswise_dim,
_hp_tensor_to_float8_dynamic,
)
from torchao.float8.float8_tensor import (
Float8Tensor,
GemmInputRole,
LinearMMConfig,
ScaledMMConfig,
hp_tensor_and_scale_to_float8,
_hp_tensor_and_scale_to_float8,
)
from torchao.float8.float8_utils import (
FP8_TYPES,
Expand Down Expand Up @@ -76,7 +76,7 @@ def test_preserves_dtype(self) -> None:
for hp_dtype, lp_dtype in itertools.product(hp_dtypes, lp_dtypes):
x1_hp = torch.randn(4, 4, dtype=hp_dtype)
x1_s = tensor_to_scale(x1_hp, lp_dtype)
x2_lp = hp_tensor_and_scale_to_float8(x1_hp, x1_s, lp_dtype)
x2_lp = _hp_tensor_and_scale_to_float8(x1_hp, x1_s, lp_dtype)
x3_hp = x2_lp.to_original_precision()
assert x3_hp.dtype == hp_dtype

Expand All @@ -86,7 +86,7 @@ def test_differentiable_casts(self) -> None:
x = torch.randn(1).requires_grad_()
grad = torch.randn(1)
x_s = tensor_to_scale(x, f8_dtype)
x_f8 = hp_tensor_and_scale_to_float8(x, x_s, f8_dtype)
x_f8 = _hp_tensor_and_scale_to_float8(x, x_s, f8_dtype)
x_f8_hp = x_f8.to_original_precision()
x_f8_hp.backward(grad)
# the gradient should be unchanged through both casts
Expand All @@ -95,7 +95,7 @@ def test_differentiable_casts(self) -> None:
def test_split_cat(self):
a = torch.rand(16, 16, dtype=torch.bfloat16)
scale = tensor_to_scale(a, e4m3_dtype)
fp8_a = hp_tensor_and_scale_to_float8(a, scale, e4m3_dtype)
fp8_a = _hp_tensor_and_scale_to_float8(a, scale, e4m3_dtype)

splits = torch.split(fp8_a, 16)
catted = torch.cat(splits, dim=0)
Expand All @@ -104,14 +104,14 @@ def test_split_cat(self):
def test_index_put(self):
a = torch.rand(16, dtype=torch.bfloat16)
scale_a = tensor_to_scale(a, e4m3_dtype)
fp8_a = hp_tensor_and_scale_to_float8(a, scale_a, e4m3_dtype)
fp8_a = _hp_tensor_and_scale_to_float8(a, scale_a, e4m3_dtype)

index = torch.randint(0, 15, (16,), dtype=torch.long)

b = torch.rand(16, 16, dtype=torch.bfloat16)
scale_b = tensor_to_scale(b, e4m3_dtype)
fp8_b = hp_tensor_and_scale_to_float8(b, scale_a, e4m3_dtype)
fp8_b_bad = hp_tensor_and_scale_to_float8(b, scale_b, e4m3_dtype)
fp8_b = _hp_tensor_and_scale_to_float8(b, scale_a, e4m3_dtype)
fp8_b_bad = _hp_tensor_and_scale_to_float8(b, scale_b, e4m3_dtype)

with pytest.raises(AssertionError):
b[index] = fp8_a
Expand All @@ -122,7 +122,7 @@ def test_index_put(self):
def test_copy_(self):
a = torch.rand(16, dtype=torch.bfloat16)
scale_a = tensor_to_scale(a, e4m3_dtype)
fp8_a = hp_tensor_and_scale_to_float8(a, scale_a, e4m3_dtype)
fp8_a = _hp_tensor_and_scale_to_float8(a, scale_a, e4m3_dtype)

b = torch.empty(16, dtype=torch.bfloat16)
b.copy_(fp8_a) # Should work
Expand All @@ -143,10 +143,10 @@ def test_transpose(self):
a = torch.rand((16, 16), dtype=torch.bfloat16)
for axiswise_dim in (None, 0, -1):
scale_a = tensor_to_scale(a, e4m3_dtype)
fp8_a = hp_tensor_and_scale_to_float8(
fp8_a = _hp_tensor_and_scale_to_float8(
a, scale_a, e4m3_dtype, axiswise_dim=axiswise_dim
)
fp8_b = hp_tensor_and_scale_to_float8(
fp8_b = _hp_tensor_and_scale_to_float8(
a, scale_a, e4m3_dtype, axiswise_dim=axiswise_dim
)

Expand All @@ -166,7 +166,7 @@ def test_axiswise_dynamic_cast(
):
a = torch.randn(*shape, dtype=torch.bfloat16)
linear_mm_config = LinearMMConfig()
a_fp8 = hp_tensor_to_float8_dynamic(
a_fp8 = _hp_tensor_to_float8_dynamic(
a,
e4m3_dtype,
linear_mm_config,
Expand All @@ -183,7 +183,7 @@ def test_axiswise_reshape(self):
linear_mm_config = LinearMMConfig()

# if we scale across dim0, we can only reshape to [3, -1]
a_fp8_d0 = hp_tensor_to_float8_dynamic(
a_fp8_d0 = _hp_tensor_to_float8_dynamic(
a,
e4m3_dtype,
linear_mm_config,
Expand All @@ -207,7 +207,7 @@ def test_axiswise_reshape(self):
a_fp8_d0.reshape(-1, 7)

# if we scale across dim2, we can only reshape to [-1, 7]
a_fp8_d2 = hp_tensor_to_float8_dynamic(
a_fp8_d2 = _hp_tensor_to_float8_dynamic(
a,
e4m3_dtype,
linear_mm_config,
Expand Down Expand Up @@ -247,23 +247,23 @@ def test_axiswise_gemm(self, a_shape, a_granularity, b_granularity):

linear_mm_config = LinearMMConfig()

a_fp8 = hp_tensor_to_float8_dynamic(
a_fp8 = _hp_tensor_to_float8_dynamic(
a,
e4m3_dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
scaling_granularity=a_granularity,
axiswise_dim=get_maybe_axiswise_dim(-1, a_granularity),
axiswise_dim=_get_maybe_axiswise_dim(-1, a_granularity),
)
a_fp8 = a_fp8.reshape(-1, a_shape[-1])

b_fp8 = hp_tensor_to_float8_dynamic(
b_fp8 = _hp_tensor_to_float8_dynamic(
b,
e4m3_dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
scaling_granularity=b_granularity,
axiswise_dim=get_maybe_axiswise_dim(-1, b_granularity),
axiswise_dim=_get_maybe_axiswise_dim(-1, b_granularity),
)

c_fp8_compute = torch.mm(a_fp8, b_fp8.t())
Expand Down Expand Up @@ -528,10 +528,10 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum):
a_scale = tensor_to_scale(a, input_dtype).float()
b_scale = tensor_to_scale(b, input_dtype).float()

a_fp8 = hp_tensor_and_scale_to_float8(a, a_scale, input_dtype)
b_fp8 = hp_tensor_and_scale_to_float8(b, b_scale, input_dtype)
a_fp8 = _hp_tensor_and_scale_to_float8(a, a_scale, input_dtype)
b_fp8 = _hp_tensor_and_scale_to_float8(b, b_scale, input_dtype)

out_scaled_mm = addmm_float8_unwrapped(
out_scaled_mm = _addmm_float8_unwrapped(
a_fp8._data,
a_fp8._scale,
b_fp8._data,
Expand Down Expand Up @@ -569,14 +569,14 @@ def test_different_configs_error(self):
ScaledMMConfig(True, False, False, False),
ScaledMMConfig(True, False, False, False),
)
a = hp_tensor_and_scale_to_float8(
a = _hp_tensor_and_scale_to_float8(
x_fp32,
x_scale,
fp8_dtype,
linear_config_a,
GemmInputRole.INPUT,
)
b = hp_tensor_and_scale_to_float8(
b = _hp_tensor_and_scale_to_float8(
x_fp32,
x_scale,
fp8_dtype,
Expand Down Expand Up @@ -608,10 +608,10 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum):
a_scale = tensor_to_scale(a, input_dtype).float()
b_scale = tensor_to_scale(b, input_dtype).float()

a_fp8 = hp_tensor_and_scale_to_float8(
a_fp8 = _hp_tensor_and_scale_to_float8(
a, a_scale, input_dtype, None, GemmInputRole.INPUT
)
b_fp8 = hp_tensor_and_scale_to_float8(
b_fp8 = _hp_tensor_and_scale_to_float8(
b, b_scale, input_dtype, None, GemmInputRole.WEIGHT
)

Expand All @@ -628,14 +628,14 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum):
scaled_mm_config, scaled_mm_config, scaled_mm_config
)

a_fp8 = hp_tensor_and_scale_to_float8(
a_fp8 = _hp_tensor_and_scale_to_float8(
a,
a_scale,
input_dtype,
pad_config,
GemmInputRole.INPUT,
)
b_fp8 = hp_tensor_and_scale_to_float8(
b_fp8 = _hp_tensor_and_scale_to_float8(
b,
b_scale,
input_dtype,
Expand All @@ -651,14 +651,14 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum):
emulated_scaled_mm_config,
emulated_scaled_mm_config,
)
a_fp8 = hp_tensor_and_scale_to_float8(
a_fp8 = _hp_tensor_and_scale_to_float8(
a,
a_scale,
input_dtype,
emulated_config,
GemmInputRole.INPUT,
)
b_fp8 = hp_tensor_and_scale_to_float8(
b_fp8 = _hp_tensor_and_scale_to_float8(
b,
b_scale,
input_dtype,
Expand Down Expand Up @@ -813,19 +813,19 @@ def test_fp8_tensor_statistics(self):

# Overflow caused by a too large scaling factor
s_overflow = torch.tensor(1e9)
fp8_overflow = hp_tensor_and_scale_to_float8(x1_hp, s_overflow, lp_dtype)
fp8_overflow = _hp_tensor_and_scale_to_float8(x1_hp, s_overflow, lp_dtype)
(zero_cnt, max_cnt) = fp8_tensor_statistics(fp8_overflow, lp_dtype)
self.assertEqual((zero_cnt, max_cnt), (0, tensor_len))

# Underflow caused by a too small scaling factor
s_underflow = torch.tensor(1e-9)
fp8_underflow = hp_tensor_and_scale_to_float8(x1_hp, s_underflow, lp_dtype)
fp8_underflow = _hp_tensor_and_scale_to_float8(x1_hp, s_underflow, lp_dtype)
(zero_cnt, max_cnt) = fp8_tensor_statistics(fp8_underflow, lp_dtype)
self.assertEqual((zero_cnt, max_cnt), (tensor_len, 0))

# Both overflow and underflow
x2_hp = torch.cat((x1_hp * 1e9, x1_hp * 1.0, x1_hp * 1e-9), 0)
fp8_over_underflow = hp_tensor_and_scale_to_float8(
fp8_over_underflow = _hp_tensor_and_scale_to_float8(
x2_hp, torch.tensor(1.0), lp_dtype
)
(zero_cnt, max_cnt) = fp8_tensor_statistics(fp8_over_underflow, lp_dtype)
Expand Down
8 changes: 4 additions & 4 deletions test/float8/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
)
from torchao.float8.float8_linear import Float8Linear
from torchao.float8.float8_scaling_utils import (
hp_tensor_to_float8_dynamic,
_hp_tensor_to_float8_dynamic,
)
from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig, ScaledMMConfig
from torchao.testing.float8.test_utils import get_test_float8_linear_config
Expand Down Expand Up @@ -221,7 +221,7 @@ def __init__(self, graph_break: bool):
self.graph_break = graph_break

def forward(self, x):
x_fp8 = hp_tensor_to_float8_dynamic(
x_fp8 = _hp_tensor_to_float8_dynamic(
x,
e4m3_dtype,
LinearMMConfig(),
Expand Down Expand Up @@ -373,15 +373,15 @@ def test_dynamic_scale_numeric_parity(
float8_config.pad_inner_dim,
),
)
float8_eager = hp_tensor_to_float8_dynamic(
float8_eager = _hp_tensor_to_float8_dynamic(
hp_tensor1,
e4m3_dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2,
)
torch._dynamo.reset()
float8_compile = torch.compile(hp_tensor_to_float8_dynamic)(
float8_compile = torch.compile(_hp_tensor_to_float8_dynamic)(
hp_tensor2,
e4m3_dtype,
linear_mm_config,
Expand Down
Loading
Loading