From 2d812af1bb6a91dfffffb19a2fd400c23e20bc9e Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Thu, 5 Jun 2025 11:31:35 -0700 Subject: [PATCH 1/2] make internal torchao.float8 functions private --- benchmarks/float8/bench_padding.py | 16 ++--- test/float8/test_base.py | 68 +++++++++---------- test/float8/test_compile.py | 8 +-- test/float8/test_dtensor.py | 18 ++--- test/float8/test_fsdp2/test_fsdp2.py | 4 +- torchao/dtypes/floatx/float8_layout.py | 4 +- torchao/float8/distributed_utils.py | 6 +- torchao/float8/float8_linear.py | 54 +++++++-------- torchao/float8/float8_linear_utils.py | 6 +- torchao/float8/float8_ops.py | 62 ++++++++--------- torchao/float8/float8_scaling_utils.py | 16 ++--- torchao/float8/float8_tensor.py | 4 +- torchao/float8/float8_tensor_parallel.py | 20 +++--- torchao/float8/float8_utils.py | 16 +++-- torchao/float8/fsdp_utils.py | 8 +-- torchao/float8/inference.py | 12 ++-- .../float8nocompile_linear_test.py | 4 +- .../kernels/fp8_dynamic_tensorwise_test.py | 58 ++++++++-------- 18 files changed, 196 insertions(+), 188 deletions(-) diff --git a/benchmarks/float8/bench_padding.py b/benchmarks/float8/bench_padding.py index eed8a5b542..73b65a4b7e 100644 --- a/benchmarks/float8/bench_padding.py +++ b/benchmarks/float8/bench_padding.py @@ -16,9 +16,9 @@ GemmInputRole, LinearMMConfig, ScaledMMConfig, - hp_tensor_and_scale_to_float8, + _hp_tensor_and_scale_to_float8, ) -from torchao.float8.float8_utils import pad_tensor_for_matmul +from torchao.float8.float8_utils import _pad_tensor_for_matmul # estimating TOPs for matmuls in fp32, fp16, fp8 # assuming A * B = C, with A being M * K, B being K * N, C being M * N @@ -63,14 +63,14 @@ def do_fp8_matmul(A, B, fp8_dtype, out_dtype): a_config = LinearMMConfig(a_config, a_config, a_config) b_config = LinearMMConfig(b_config, b_config, b_config) - a_fp8 = hp_tensor_and_scale_to_float8( + a_fp8 = _hp_tensor_and_scale_to_float8( A, scale_a, fp8_dtype, a_config, GemmInputRole.INPUT, ) - b_fp8 = hp_tensor_and_scale_to_float8( + b_fp8 = _hp_tensor_and_scale_to_float8( B, scale_b, fp8_dtype, @@ -84,8 +84,8 @@ def do_fp8_matmul(A, B, fp8_dtype, out_dtype): def do_fp8_pad_first_matmul(A, B, fp8_dtype, out_dtype): # Breaks with compile due to trying to pad on fp8 dtype # return do_fp8_matmul(A, B, fp8_dtype, out_dtype) - A_pad = pad_tensor_for_matmul(A, dims=1) # mem copy - B_pad = pad_tensor_for_matmul(B, dims=0) # mem copy + A_pad = _pad_tensor_for_matmul(A, dims=1) # mem copy + B_pad = _pad_tensor_for_matmul(B, dims=0) # mem copy scale_a = torch.tensor([1], device="cuda", dtype=torch.float32) scale_b = torch.tensor([1], device="cuda", dtype=torch.float32) @@ -105,8 +105,8 @@ def do_hp_matmul(A, B): def do_aligned_bf16_matmul(A, B): - A_pad = pad_tensor_for_matmul(A, dims=1) - B_pad = pad_tensor_for_matmul(B, dims=0) + A_pad = _pad_tensor_for_matmul(A, dims=1) + B_pad = _pad_tensor_for_matmul(B, dims=0) return torch.matmul(A_pad, B_pad) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 8e3efeab60..e37c80c51c 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -37,17 +37,17 @@ from torchao.float8.float8_linear_utils import ( convert_to_float8_training, ) -from torchao.float8.float8_ops import addmm_float8_unwrapped +from torchao.float8.float8_ops import _addmm_float8_unwrapped from torchao.float8.float8_scaling_utils import ( - get_maybe_axiswise_dim, - hp_tensor_to_float8_dynamic, + _get_maybe_axiswise_dim, + _hp_tensor_to_float8_dynamic, ) from torchao.float8.float8_tensor import ( Float8Tensor, GemmInputRole, LinearMMConfig, ScaledMMConfig, - hp_tensor_and_scale_to_float8, + _hp_tensor_and_scale_to_float8, ) from torchao.float8.float8_utils import ( FP8_TYPES, @@ -76,7 +76,7 @@ def test_preserves_dtype(self) -> None: for hp_dtype, lp_dtype in itertools.product(hp_dtypes, lp_dtypes): x1_hp = torch.randn(4, 4, dtype=hp_dtype) x1_s = tensor_to_scale(x1_hp, lp_dtype) - x2_lp = hp_tensor_and_scale_to_float8(x1_hp, x1_s, lp_dtype) + x2_lp = _hp_tensor_and_scale_to_float8(x1_hp, x1_s, lp_dtype) x3_hp = x2_lp.to_original_precision() assert x3_hp.dtype == hp_dtype @@ -86,7 +86,7 @@ def test_differentiable_casts(self) -> None: x = torch.randn(1).requires_grad_() grad = torch.randn(1) x_s = tensor_to_scale(x, f8_dtype) - x_f8 = hp_tensor_and_scale_to_float8(x, x_s, f8_dtype) + x_f8 = _hp_tensor_and_scale_to_float8(x, x_s, f8_dtype) x_f8_hp = x_f8.to_original_precision() x_f8_hp.backward(grad) # the gradient should be unchanged through both casts @@ -95,7 +95,7 @@ def test_differentiable_casts(self) -> None: def test_split_cat(self): a = torch.rand(16, 16, dtype=torch.bfloat16) scale = tensor_to_scale(a, e4m3_dtype) - fp8_a = hp_tensor_and_scale_to_float8(a, scale, e4m3_dtype) + fp8_a = _hp_tensor_and_scale_to_float8(a, scale, e4m3_dtype) splits = torch.split(fp8_a, 16) catted = torch.cat(splits, dim=0) @@ -104,14 +104,14 @@ def test_split_cat(self): def test_index_put(self): a = torch.rand(16, dtype=torch.bfloat16) scale_a = tensor_to_scale(a, e4m3_dtype) - fp8_a = hp_tensor_and_scale_to_float8(a, scale_a, e4m3_dtype) + fp8_a = _hp_tensor_and_scale_to_float8(a, scale_a, e4m3_dtype) index = torch.randint(0, 15, (16,), dtype=torch.long) b = torch.rand(16, 16, dtype=torch.bfloat16) scale_b = tensor_to_scale(b, e4m3_dtype) - fp8_b = hp_tensor_and_scale_to_float8(b, scale_a, e4m3_dtype) - fp8_b_bad = hp_tensor_and_scale_to_float8(b, scale_b, e4m3_dtype) + fp8_b = _hp_tensor_and_scale_to_float8(b, scale_a, e4m3_dtype) + fp8_b_bad = _hp_tensor_and_scale_to_float8(b, scale_b, e4m3_dtype) with pytest.raises(AssertionError): b[index] = fp8_a @@ -122,7 +122,7 @@ def test_index_put(self): def test_copy_(self): a = torch.rand(16, dtype=torch.bfloat16) scale_a = tensor_to_scale(a, e4m3_dtype) - fp8_a = hp_tensor_and_scale_to_float8(a, scale_a, e4m3_dtype) + fp8_a = _hp_tensor_and_scale_to_float8(a, scale_a, e4m3_dtype) b = torch.empty(16, dtype=torch.bfloat16) b.copy_(fp8_a) # Should work @@ -143,10 +143,10 @@ def test_transpose(self): a = torch.rand((16, 16), dtype=torch.bfloat16) for axiswise_dim in (None, 0, -1): scale_a = tensor_to_scale(a, e4m3_dtype) - fp8_a = hp_tensor_and_scale_to_float8( + fp8_a = _hp_tensor_and_scale_to_float8( a, scale_a, e4m3_dtype, axiswise_dim=axiswise_dim ) - fp8_b = hp_tensor_and_scale_to_float8( + fp8_b = _hp_tensor_and_scale_to_float8( a, scale_a, e4m3_dtype, axiswise_dim=axiswise_dim ) @@ -166,7 +166,7 @@ def test_axiswise_dynamic_cast( ): a = torch.randn(*shape, dtype=torch.bfloat16) linear_mm_config = LinearMMConfig() - a_fp8 = hp_tensor_to_float8_dynamic( + a_fp8 = _hp_tensor_to_float8_dynamic( a, e4m3_dtype, linear_mm_config, @@ -183,7 +183,7 @@ def test_axiswise_reshape(self): linear_mm_config = LinearMMConfig() # if we scale across dim0, we can only reshape to [3, -1] - a_fp8_d0 = hp_tensor_to_float8_dynamic( + a_fp8_d0 = _hp_tensor_to_float8_dynamic( a, e4m3_dtype, linear_mm_config, @@ -207,7 +207,7 @@ def test_axiswise_reshape(self): a_fp8_d0.reshape(-1, 7) # if we scale across dim2, we can only reshape to [-1, 7] - a_fp8_d2 = hp_tensor_to_float8_dynamic( + a_fp8_d2 = _hp_tensor_to_float8_dynamic( a, e4m3_dtype, linear_mm_config, @@ -247,23 +247,23 @@ def test_axiswise_gemm(self, a_shape, a_granularity, b_granularity): linear_mm_config = LinearMMConfig() - a_fp8 = hp_tensor_to_float8_dynamic( + a_fp8 = _hp_tensor_to_float8_dynamic( a, e4m3_dtype, linear_mm_config, gemm_input_role=GemmInputRole.INPUT, scaling_granularity=a_granularity, - axiswise_dim=get_maybe_axiswise_dim(-1, a_granularity), + axiswise_dim=_get_maybe_axiswise_dim(-1, a_granularity), ) a_fp8 = a_fp8.reshape(-1, a_shape[-1]) - b_fp8 = hp_tensor_to_float8_dynamic( + b_fp8 = _hp_tensor_to_float8_dynamic( b, e4m3_dtype, linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, scaling_granularity=b_granularity, - axiswise_dim=get_maybe_axiswise_dim(-1, b_granularity), + axiswise_dim=_get_maybe_axiswise_dim(-1, b_granularity), ) c_fp8_compute = torch.mm(a_fp8, b_fp8.t()) @@ -528,10 +528,10 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum): a_scale = tensor_to_scale(a, input_dtype).float() b_scale = tensor_to_scale(b, input_dtype).float() - a_fp8 = hp_tensor_and_scale_to_float8(a, a_scale, input_dtype) - b_fp8 = hp_tensor_and_scale_to_float8(b, b_scale, input_dtype) + a_fp8 = _hp_tensor_and_scale_to_float8(a, a_scale, input_dtype) + b_fp8 = _hp_tensor_and_scale_to_float8(b, b_scale, input_dtype) - out_scaled_mm = addmm_float8_unwrapped( + out_scaled_mm = _addmm_float8_unwrapped( a_fp8._data, a_fp8._scale, b_fp8._data, @@ -569,14 +569,14 @@ def test_different_configs_error(self): ScaledMMConfig(True, False, False, False), ScaledMMConfig(True, False, False, False), ) - a = hp_tensor_and_scale_to_float8( + a = _hp_tensor_and_scale_to_float8( x_fp32, x_scale, fp8_dtype, linear_config_a, GemmInputRole.INPUT, ) - b = hp_tensor_and_scale_to_float8( + b = _hp_tensor_and_scale_to_float8( x_fp32, x_scale, fp8_dtype, @@ -608,10 +608,10 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum): a_scale = tensor_to_scale(a, input_dtype).float() b_scale = tensor_to_scale(b, input_dtype).float() - a_fp8 = hp_tensor_and_scale_to_float8( + a_fp8 = _hp_tensor_and_scale_to_float8( a, a_scale, input_dtype, None, GemmInputRole.INPUT ) - b_fp8 = hp_tensor_and_scale_to_float8( + b_fp8 = _hp_tensor_and_scale_to_float8( b, b_scale, input_dtype, None, GemmInputRole.WEIGHT ) @@ -628,14 +628,14 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum): scaled_mm_config, scaled_mm_config, scaled_mm_config ) - a_fp8 = hp_tensor_and_scale_to_float8( + a_fp8 = _hp_tensor_and_scale_to_float8( a, a_scale, input_dtype, pad_config, GemmInputRole.INPUT, ) - b_fp8 = hp_tensor_and_scale_to_float8( + b_fp8 = _hp_tensor_and_scale_to_float8( b, b_scale, input_dtype, @@ -651,14 +651,14 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum): emulated_scaled_mm_config, emulated_scaled_mm_config, ) - a_fp8 = hp_tensor_and_scale_to_float8( + a_fp8 = _hp_tensor_and_scale_to_float8( a, a_scale, input_dtype, emulated_config, GemmInputRole.INPUT, ) - b_fp8 = hp_tensor_and_scale_to_float8( + b_fp8 = _hp_tensor_and_scale_to_float8( b, b_scale, input_dtype, @@ -813,19 +813,19 @@ def test_fp8_tensor_statistics(self): # Overflow caused by a too large scaling factor s_overflow = torch.tensor(1e9) - fp8_overflow = hp_tensor_and_scale_to_float8(x1_hp, s_overflow, lp_dtype) + fp8_overflow = _hp_tensor_and_scale_to_float8(x1_hp, s_overflow, lp_dtype) (zero_cnt, max_cnt) = fp8_tensor_statistics(fp8_overflow, lp_dtype) self.assertEqual((zero_cnt, max_cnt), (0, tensor_len)) # Underflow caused by a too small scaling factor s_underflow = torch.tensor(1e-9) - fp8_underflow = hp_tensor_and_scale_to_float8(x1_hp, s_underflow, lp_dtype) + fp8_underflow = _hp_tensor_and_scale_to_float8(x1_hp, s_underflow, lp_dtype) (zero_cnt, max_cnt) = fp8_tensor_statistics(fp8_underflow, lp_dtype) self.assertEqual((zero_cnt, max_cnt), (tensor_len, 0)) # Both overflow and underflow x2_hp = torch.cat((x1_hp * 1e9, x1_hp * 1.0, x1_hp * 1e-9), 0) - fp8_over_underflow = hp_tensor_and_scale_to_float8( + fp8_over_underflow = _hp_tensor_and_scale_to_float8( x2_hp, torch.tensor(1.0), lp_dtype ) (zero_cnt, max_cnt) = fp8_tensor_statistics(fp8_over_underflow, lp_dtype) diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index ac5d1f8d96..56cb297008 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -34,7 +34,7 @@ ) from torchao.float8.float8_linear import Float8Linear from torchao.float8.float8_scaling_utils import ( - hp_tensor_to_float8_dynamic, + _hp_tensor_to_float8_dynamic, ) from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig, ScaledMMConfig from torchao.testing.float8.test_utils import get_test_float8_linear_config @@ -221,7 +221,7 @@ def __init__(self, graph_break: bool): self.graph_break = graph_break def forward(self, x): - x_fp8 = hp_tensor_to_float8_dynamic( + x_fp8 = _hp_tensor_to_float8_dynamic( x, e4m3_dtype, LinearMMConfig(), @@ -373,7 +373,7 @@ def test_dynamic_scale_numeric_parity( float8_config.pad_inner_dim, ), ) - float8_eager = hp_tensor_to_float8_dynamic( + float8_eager = _hp_tensor_to_float8_dynamic( hp_tensor1, e4m3_dtype, linear_mm_config, @@ -381,7 +381,7 @@ def test_dynamic_scale_numeric_parity( round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2, ) torch._dynamo.reset() - float8_compile = torch.compile(hp_tensor_to_float8_dynamic)( + float8_compile = torch.compile(_hp_tensor_to_float8_dynamic)( hp_tensor2, e4m3_dtype, linear_mm_config, diff --git a/test/float8/test_dtensor.py b/test/float8/test_dtensor.py index 9db046b749..eb97df5317 100644 --- a/test/float8/test_dtensor.py +++ b/test/float8/test_dtensor.py @@ -43,12 +43,12 @@ e4m3_dtype, ) from torchao.float8.float8_linear_utils import convert_to_float8_training -from torchao.float8.float8_scaling_utils import NoopFwToFloat8BwDynamic +from torchao.float8.float8_scaling_utils import _NoopFwToFloat8BwDynamic from torchao.float8.float8_tensor import ( Float8Tensor, GemmInputRole, LinearMMConfig, - hp_tensor_and_scale_to_float8, + _hp_tensor_and_scale_to_float8, ) from torchao.float8.float8_tensor_parallel import ( Float8ColwiseParallel, @@ -92,10 +92,10 @@ def _test_scaled_mm(mesh: DeviceMesh, size=16): x_scale = tensor_to_scale(x_fp32, fp8_dtype).float() y_scale = tensor_to_scale(y_fp32, fp8_dtype).float() - x_fp8 = hp_tensor_and_scale_to_float8( + x_fp8 = _hp_tensor_and_scale_to_float8( x_fp32, x_scale, fp8_dtype, None, GemmInputRole.INPUT ) - y_fp8 = hp_tensor_and_scale_to_float8( + y_fp8 = _hp_tensor_and_scale_to_float8( y_fp32, y_scale, fp8_dtype, None, GemmInputRole.WEIGHT ) @@ -122,7 +122,7 @@ def _test_fp8_redistribute(mesh: DeviceMesh, size=16): x_scale = tensor_to_scale(x_fp32, fp8_dtype).float() - x_fp8 = hp_tensor_and_scale_to_float8(x_fp32, x_scale, fp8_dtype) + x_fp8 = _hp_tensor_and_scale_to_float8(x_fp32, x_scale, fp8_dtype) dist_x_fp8 = DTensor.from_local(x_fp8, mesh, [Shard(0)], run_check=False) out_dist = dist_x_fp8.redistribute(placements=[Replicate()]) @@ -150,7 +150,7 @@ def _test_dtensor_cast_to_fp8(mesh: DeviceMesh, size=16): dist_x_scale = tensor_to_scale(dist_x_fp32, fp8_dtype).float() assert isinstance(dist_x_scale, DTensor) - dist_x_fp8 = hp_tensor_and_scale_to_float8(dist_x_fp32, dist_x_scale, fp8_dtype) + dist_x_fp8 = _hp_tensor_and_scale_to_float8(dist_x_fp32, dist_x_scale, fp8_dtype) assert isinstance(dist_x_fp8, DTensor) @@ -169,14 +169,14 @@ def _test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16): dist_weight_scale = tensor_to_scale(dist_wight_fp32, fp8_dtype).float() dist_target = distribute_tensor(target, mesh, [Shard(0)]) - dist_x_fp8 = hp_tensor_and_scale_to_float8( + dist_x_fp8 = _hp_tensor_and_scale_to_float8( dist_x_fp32, dist_x_scale, fp8_dtype, None, GemmInputRole.INPUT, ) - dist_weight_fp8 = hp_tensor_and_scale_to_float8( + dist_weight_fp8 = _hp_tensor_and_scale_to_float8( dist_wight_fp32, dist_weight_scale, fp8_dtype, @@ -185,7 +185,7 @@ def _test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16): ) out = torch.nn.functional.linear(dist_x_fp8, dist_weight_fp8) - out = NoopFwToFloat8BwDynamic.apply(out, LinearMMConfig(), fp8_dtype) + out = _NoopFwToFloat8BwDynamic.apply(out, LinearMMConfig(), fp8_dtype) assert isinstance(out, DTensor), f"Expected DTensor, got {type(out)}" loss = torch.sum(torch.abs(out - dist_target)) loss.backward() diff --git a/test/float8/test_fsdp2/test_fsdp2.py b/test/float8/test_fsdp2/test_fsdp2.py index 6f0cfecf41..1c348d55a3 100644 --- a/test/float8/test_fsdp2/test_fsdp2.py +++ b/test/float8/test_fsdp2/test_fsdp2.py @@ -40,7 +40,7 @@ from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType from torchao.float8.float8_linear_utils import convert_to_float8_training -from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic +from torchao.float8.float8_scaling_utils import _hp_tensor_to_float8_dynamic from torchao.float8.float8_tensor import GemmInputRole from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor from torchao.testing.float8.fsdp2_utils import check_parity_bf16_mp, check_parity_no_mp @@ -337,7 +337,7 @@ def test_amax_allreduce_device_mesh(self): # rank 2 and 4 are doing nothing but waiting for the 1st stage torch.manual_seed(42 + self.rank) hp_tensor = torch.randn(768, 32, device="cuda") - hp_tensor_to_float8_dynamic( + _hp_tensor_to_float8_dynamic( hp_tensor, torch.float8_e4m3fn, Float8LinearConfig( diff --git a/torchao/dtypes/floatx/float8_layout.py b/torchao/dtypes/floatx/float8_layout.py index 799832a5ea..dfdb8f6fd3 100644 --- a/torchao/dtypes/floatx/float8_layout.py +++ b/torchao/dtypes/floatx/float8_layout.py @@ -19,8 +19,8 @@ from torchao.dtypes.utils import AQTTensorImpl, Layout, get_out_shape from torchao.float8.inference import ( Float8MMConfig, + _addmm_float8_unwrapped_inference, _is_rowwise_scaled, - addmm_float8_unwrapped_inference, preprocess_data, ) from torchao.utils import _is_float8_type, fill_defaults @@ -413,7 +413,7 @@ def _linear_fp8_act_fp8_weight_impl( inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config) # Perform the computation - return addmm_float8_unwrapped_inference( + return _addmm_float8_unwrapped_inference( inpt_data, input_scale, w_data, diff --git a/torchao/float8/distributed_utils.py b/torchao/float8/distributed_utils.py index cd1560fabd..3d6e42bd6b 100644 --- a/torchao/float8/distributed_utils.py +++ b/torchao/float8/distributed_utils.py @@ -11,7 +11,7 @@ from torchao.float8.float8_tensor import Float8Tensor -def tensor_already_casted_to_fp8(tensor: torch.Tensor) -> bool: +def _tensor_already_casted_to_fp8(tensor: torch.Tensor) -> bool: """ Check if the tensor is already casted to fp8, works if the local tensor is wrapped in DTensor. @@ -20,8 +20,8 @@ def tensor_already_casted_to_fp8(tensor: torch.Tensor) -> bool: return True elif isinstance(tensor, DTensor): # TODO: shall we stick to public API and directly use tensor.to_local() here? - return tensor_already_casted_to_fp8(tensor._local_tensor) + return _tensor_already_casted_to_fp8(tensor._local_tensor) elif isinstance(tensor, funcol.AsyncCollectiveTensor): - return tensor_already_casted_to_fp8(tensor.elem) + return _tensor_already_casted_to_fp8(tensor.elem) return False diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index c926ede40f..5e19de3da2 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -13,16 +13,16 @@ import torch.utils.checkpoint as checkpoint from torchao.float8.config import Float8LinearConfig, ScalingGranularity, ScalingType -from torchao.float8.distributed_utils import tensor_already_casted_to_fp8 +from torchao.float8.distributed_utils import _tensor_already_casted_to_fp8 from torchao.float8.float8_scaling_utils import ( - get_maybe_axiswise_dim, - hp_tensor_to_float8_dynamic, + _get_maybe_axiswise_dim, + _hp_tensor_to_float8_dynamic, ) from torchao.float8.float8_tensor import ( GemmInputRole, LinearMMConfig, ScaledMMConfig, - hp_tensor_and_scale_to_float8, + _hp_tensor_and_scale_to_float8, ) from torchao.float8.float8_utils import tensor_to_scale from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor @@ -33,7 +33,7 @@ def _get_weight_scale( scaling_type_weight: ScalingType, config: Float8LinearConfig, ) -> Optional[torch.Tensor]: - if tensor_already_casted_to_fp8(weight): + if _tensor_already_casted_to_fp8(weight): return None assert scaling_type_weight is ScalingType.DYNAMIC return tensor_to_scale(weight, config.cast_config_weight.target_dtype) @@ -45,9 +45,9 @@ def _cast_weight_to_float8_t( linear_mm_config: LinearMMConfig, weight_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if tensor_already_casted_to_fp8(weight): + if _tensor_already_casted_to_fp8(weight): return weight.t() - weight_fp8 = hp_tensor_and_scale_to_float8( + weight_fp8 = _hp_tensor_and_scale_to_float8( weight, weight_scale, config.cast_config_weight.target_dtype, @@ -58,7 +58,7 @@ def _cast_weight_to_float8_t( @torch._dynamo.allow_in_graph -class matmul_with_hp_or_float8_args(torch.autograd.Function): +class _matmul_with_hp_or_float8_args(torch.autograd.Function): """ Like torch.matmul, but with the arguments in either high precision or float8. * if the arguments are in high precision, they are cast to float8 according @@ -80,35 +80,35 @@ def forward( c = config - if tensor_already_casted_to_fp8(input_hp): + if _tensor_already_casted_to_fp8(input_hp): input_maybe_fp8 = input_hp elif c.cast_config_input.scaling_type is ScalingType.DISABLED: input_maybe_fp8 = input_hp else: - input_maybe_fp8 = hp_tensor_to_float8_dynamic( + input_maybe_fp8 = _hp_tensor_to_float8_dynamic( input_hp, c.cast_config_input.target_dtype, linear_mm_config, gemm_input_role=GemmInputRole.INPUT, scaling_granularity=c.cast_config_input.scaling_granularity, - axiswise_dim=get_maybe_axiswise_dim( + axiswise_dim=_get_maybe_axiswise_dim( -1, c.cast_config_input.scaling_granularity ), round_scales_to_power_of_2=c.round_scales_to_power_of_2, ) - if tensor_already_casted_to_fp8(weight_hp_t): + if _tensor_already_casted_to_fp8(weight_hp_t): weight_maybe_fp8_t = weight_hp_t elif c.cast_config_weight.scaling_type is ScalingType.DISABLED: weight_maybe_fp8_t = weight_hp_t else: - weight_maybe_fp8_t = hp_tensor_to_float8_dynamic( + weight_maybe_fp8_t = _hp_tensor_to_float8_dynamic( weight_hp_t, c.cast_config_weight.target_dtype, linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, scaling_granularity=c.cast_config_weight.scaling_granularity, - axiswise_dim=get_maybe_axiswise_dim( + axiswise_dim=_get_maybe_axiswise_dim( 0, c.cast_config_weight.scaling_granularity ), round_scales_to_power_of_2=c.round_scales_to_power_of_2, @@ -136,25 +136,25 @@ def backward(ctx, grad_output): # calculate grad_input # - if tensor_already_casted_to_fp8(grad_output_reshaped): + if _tensor_already_casted_to_fp8(grad_output_reshaped): # TODO(future PR): this var name is axiswise-specific, fix it grad_output_reshaped_maybe_fp8_dim0 = grad_output_reshaped elif c.cast_config_grad_output.scaling_type is ScalingType.DISABLED: grad_output_reshaped_maybe_fp8_dim0 = grad_output_reshaped else: - grad_output_reshaped_maybe_fp8_dim0 = hp_tensor_to_float8_dynamic( + grad_output_reshaped_maybe_fp8_dim0 = _hp_tensor_to_float8_dynamic( grad_output_reshaped, c.cast_config_grad_output.target_dtype, ctx.linear_mm_config, gemm_input_role=GemmInputRole.GRAD_OUTPUT, scaling_granularity=c.cast_config_grad_output.scaling_granularity, - axiswise_dim=get_maybe_axiswise_dim( + axiswise_dim=_get_maybe_axiswise_dim( -1, c.cast_config_grad_output.scaling_granularity ), round_scales_to_power_of_2=c.round_scales_to_power_of_2, ) - if tensor_already_casted_to_fp8(weight_hp_t): + if _tensor_already_casted_to_fp8(weight_hp_t): # TODO(future PR): var name is axiswise specific, fix it weight_t_maybe_fp8_dim0 = weight_hp_t elif c.cast_config_weight_for_grad_input.scaling_type is ScalingType.DISABLED: @@ -175,13 +175,13 @@ def backward(ctx, grad_output): # to be solved to have a chance to reuse max(abs(weight, dim=...)) # from the forward to get max(abs(weight)) here without reading # the entire tensor. - weight_t_maybe_fp8_dim0 = hp_tensor_to_float8_dynamic( + weight_t_maybe_fp8_dim0 = _hp_tensor_to_float8_dynamic( weight_hp_t, c.cast_config_weight_for_grad_input.target_dtype, ctx.linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, scaling_granularity=c.cast_config_weight_for_grad_input.scaling_granularity, - axiswise_dim=get_maybe_axiswise_dim( + axiswise_dim=_get_maybe_axiswise_dim( -1, c.cast_config_weight_for_grad_input.scaling_granularity ), round_scales_to_power_of_2=c.round_scales_to_power_of_2, @@ -202,7 +202,7 @@ def backward(ctx, grad_output): # calculate grad_weight # - if tensor_already_casted_to_fp8(grad_output_reshaped): + if _tensor_already_casted_to_fp8(grad_output_reshaped): # TODO(future PR): var name is axiswise specific, fix it grad_output_reshaped_maybe_fp8_dim1 = grad_output_reshaped elif ( @@ -211,31 +211,31 @@ def backward(ctx, grad_output): ): grad_output_reshaped_maybe_fp8_dim1 = grad_output_reshaped else: - grad_output_reshaped_maybe_fp8_dim1 = hp_tensor_to_float8_dynamic( + grad_output_reshaped_maybe_fp8_dim1 = _hp_tensor_to_float8_dynamic( grad_output_reshaped, c.cast_config_grad_output_for_grad_weight.target_dtype, ctx.linear_mm_config, gemm_input_role=GemmInputRole.GRAD_OUTPUT, scaling_granularity=c.cast_config_grad_output_for_grad_weight.scaling_granularity, - axiswise_dim=get_maybe_axiswise_dim( + axiswise_dim=_get_maybe_axiswise_dim( 0, c.cast_config_grad_output_for_grad_weight.scaling_granularity ), round_scales_to_power_of_2=c.round_scales_to_power_of_2, ) - if tensor_already_casted_to_fp8(input_hp_reshaped): + if _tensor_already_casted_to_fp8(input_hp_reshaped): # TODO(future PR): var name is axiswise specific, fix it input_reshaped_maybe_fp8_dim1 = input_hp_reshaped elif c.cast_config_input_for_grad_weight.scaling_type is ScalingType.DISABLED: input_reshaped_maybe_fp8_dim1 = input_hp_reshaped else: - input_reshaped_maybe_fp8_dim1 = hp_tensor_to_float8_dynamic( + input_reshaped_maybe_fp8_dim1 = _hp_tensor_to_float8_dynamic( input_hp_reshaped, c.cast_config_input_for_grad_weight.target_dtype, ctx.linear_mm_config, gemm_input_role=GemmInputRole.INPUT, scaling_granularity=c.cast_config_input_for_grad_weight.scaling_granularity, - axiswise_dim=get_maybe_axiswise_dim( + axiswise_dim=_get_maybe_axiswise_dim( 0, c.cast_config_input_for_grad_weight.scaling_granularity ), round_scales_to_power_of_2=c.round_scales_to_power_of_2, @@ -349,7 +349,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: weight_maybe_fp8_t = weight_fp8_t - output = matmul_with_hp_or_float8_args.apply( + output = _matmul_with_hp_or_float8_args.apply( input, weight_maybe_fp8_t, self.linear_mm_config, diff --git a/torchao/float8/float8_linear_utils.py b/torchao/float8/float8_linear_utils.py index 230bfd881f..4d1d464908 100644 --- a/torchao/float8/float8_linear_utils.py +++ b/torchao/float8/float8_linear_utils.py @@ -15,7 +15,7 @@ log.addHandler(logging.NullHandler()) -def swap_linear_layers( +def _swap_linear_layers( module: nn.Module, from_float_func: Callable[[nn.Linear], nn.Linear], *, @@ -81,6 +81,10 @@ def post_order_traversal( return root_module +# for BC, confirmed there are users of this util function +swap_linear_layers = _swap_linear_layers + + def convert_to_float8_training( module: nn.Module, *, diff --git a/torchao/float8/float8_ops.py b/torchao/float8/float8_ops.py index 4071d83e4f..704d305d40 100644 --- a/torchao/float8/float8_ops.py +++ b/torchao/float8/float8_ops.py @@ -8,8 +8,8 @@ import torch from torch.utils._pytree import tree_map -from torchao.float8.float8_tensor import Float8Tensor, choose_scaled_mm_config -from torchao.float8.float8_utils import is_row_major, pad_tensor_for_matmul +from torchao.float8.float8_tensor import Float8Tensor, _choose_scaled_mm_config +from torchao.float8.float8_utils import _is_row_major, _pad_tensor_for_matmul aten = torch.ops.aten c10d_functional = torch.ops.c10d_functional @@ -22,7 +22,7 @@ # Cublas defines scale to always mean a multiplicative factor for the respective matrices # For a,b going from fp8 -> fp32 we multiple by the inverse of the scale # For output going from fp32 -> fp8 we multiply by the scale -def addmm_float8_unwrapped( +def _addmm_float8_unwrapped( a_data: torch.Tensor, a_scale: torch.Tensor, b_data: torch.Tensor, @@ -112,7 +112,7 @@ def decorator(func): aten.reshape.default, ] ) -def float8_desugar_op(aten_op, args, kwargs=None): +def _float8_desugar_op(aten_op, args, kwargs=None): _assert_tensorwise_scale(aten_op, args[0]._scale) new_data = aten_op(args[0]._data, *args[1:], **kwargs) return Float8Tensor( @@ -129,7 +129,7 @@ def float8_desugar_op(aten_op, args, kwargs=None): aten.detach.default, ] ) -def float8_desugar_data_and_scale_op(aten_op, args, kwargs=None): +def _float8_desugar_data_and_scale_op(aten_op, args, kwargs=None): new_data = aten_op(args[0]._data, *args[1:], **kwargs) new_scale = aten_op(args[0]._scale, *args[1:], **kwargs) return Float8Tensor( @@ -147,7 +147,7 @@ def float8_desugar_data_and_scale_op(aten_op, args, kwargs=None): aten.transpose.int, ] ) -def float8_transpose(aten_op, args, kwargs=None): +def _float8_transpose(aten_op, args, kwargs=None): new_data = aten_op(args[0]._data, *args[1:], **kwargs) if args[0]._scale.ndim > 1: new_scale = aten_op(args[0]._scale, *args[1:], **kwargs) @@ -176,7 +176,7 @@ def float8_transpose(aten_op, args, kwargs=None): @implements([aten.view.default]) -def float8_view(aten_op, args, kwargs=None): +def _float8_view(aten_op, args, kwargs=None): t, new_shape = args[0], args[1] # if the new shape is the same as old, return an equivalent tensor @@ -194,7 +194,7 @@ def float8_view(aten_op, args, kwargs=None): if len(args[0]._scale.shape) < 2: # tensorwise scaling - return float8_desugar_op(aten_op, args, kwargs) + return _float8_desugar_op(aten_op, args, kwargs) # for now, only support reshaping to [-1, dim] or [dim, -1] axiswise_dim = t._axiswise_dim @@ -231,7 +231,7 @@ def float8_view(aten_op, args, kwargs=None): @implements([aten.split.Tensor]) -def float8_split(aten_op, args, kwargs=None): +def _float8_split(aten_op, args, kwargs=None): new_data_tensors = aten_op(args[0]._data, *args[1:], **kwargs) _assert_tensorwise_scale(aten_op, args[0]._scale) @@ -250,7 +250,7 @@ def make_float8(data): # Errors cant `cat_cuda float8 e4m3fn` @implements([aten.cat.default]) -def float8_cat(aten_op, args, kwargs=None): +def _float8_cat(aten_op, args, kwargs=None): chunked_tensors: Tuple[Float8Tensor] = args[0] orig_dtype = chunked_tensors[0]._orig_dtype @@ -287,7 +287,7 @@ def float8_cat(aten_op, args, kwargs=None): @implements([aten.sum.dim_IntList]) -def float8_cast_up_op(aten_op, args, kwargs=None): +def _float8_cast_up_op(aten_op, args, kwargs=None): """Be careful with this function, this is a "fallback" op that casts the output of the op to the original precision. And performs the op. @@ -307,12 +307,12 @@ def unwrap(x): return aten_op(*new_args, **new_kwargs) -def preprocess_addmm(a: Float8Tensor, b: Float8Tensor): +def _preprocess_addmm(a: Float8Tensor, b: Float8Tensor): a_data = a._data a_scale = a._scale b_data = b._data - scaled_mm_config = choose_scaled_mm_config( + scaled_mm_config = _choose_scaled_mm_config( a._gemm_input_role, a._linear_mm_config, b._gemm_input_role, @@ -323,12 +323,12 @@ def preprocess_addmm(a: Float8Tensor, b: Float8Tensor): assert a._data.size(1) == b._data.size(0), ( f"Inner dims must match for mm, got {a._data.size(1)} and {b._data.size(0)}" ) - a_data = pad_tensor_for_matmul(a_data, dims=1) - b_data = pad_tensor_for_matmul(b_data, dims=0) + a_data = _pad_tensor_for_matmul(a_data, dims=1) + b_data = _pad_tensor_for_matmul(b_data, dims=0) - if not is_row_major(a_data.stride()): + if not _is_row_major(a_data.stride()): a_data = a_data.contiguous() - if is_row_major(b_data.stride()): + if _is_row_major(b_data.stride()): b_data = b_data.t().contiguous().t() b_scale = b._scale @@ -349,7 +349,7 @@ def preprocess_addmm(a: Float8Tensor, b: Float8Tensor): @implements([aten.mm.default, aten.matmul.default]) -def float8_mm(aten_op, args, kwargs=None): +def _float8_mm(aten_op, args, kwargs=None): a = args[0] b = args[1] @@ -358,9 +358,9 @@ def float8_mm(aten_op, args, kwargs=None): type(a), type(b) ) ) - a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b) + a_data, a_scale, b_data, b_scale = _preprocess_addmm(a, b) output_dtype = a._orig_dtype - scaled_mm_config = choose_scaled_mm_config( + scaled_mm_config = _choose_scaled_mm_config( a._gemm_input_role, a._linear_mm_config, b._gemm_input_role, @@ -370,7 +370,7 @@ def float8_mm(aten_op, args, kwargs=None): return torch.mm(a._data.float() / a._scale, b._data.float() / b._scale).to( output_dtype ) - tensor_out = addmm_float8_unwrapped( + tensor_out = _addmm_float8_unwrapped( a_data, a_scale, b_data, @@ -384,7 +384,7 @@ def float8_mm(aten_op, args, kwargs=None): @implements([aten.addmm.default]) -def float8_addmm(aten_op, args, kwargs=None): +def _float8_addmm(aten_op, args, kwargs=None): assert ( isinstance(args[0], torch.Tensor) and isinstance(args[1], Float8Tensor) @@ -393,10 +393,10 @@ def float8_addmm(aten_op, args, kwargs=None): bias = args[0] a = args[1] b = args[2] - a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b) + a_data, a_scale, b_data, b_scale = _preprocess_addmm(a, b) output_dtype = a._orig_dtype assert bias.dtype == output_dtype, "bias dtype must match output dtype" - scaled_mm_config = choose_scaled_mm_config( + scaled_mm_config = _choose_scaled_mm_config( a._gemm_input_role, a._linear_mm_config, b._gemm_input_role, @@ -407,7 +407,7 @@ def float8_addmm(aten_op, args, kwargs=None): output_dtype ) return out + bias - tensor_out = addmm_float8_unwrapped( + tensor_out = _addmm_float8_unwrapped( a_data, a_scale, b_data, @@ -421,13 +421,13 @@ def float8_addmm(aten_op, args, kwargs=None): @implements([aten.is_same_size.default]) -def float8_is_same_size(aten_op, args, kwargs=None): +def _float8_is_same_size(aten_op, args, kwargs=None): _assert_tensorwise_scale(aten_op, args[0]._scale) return args[0].shape == args[1].shape @implements([aten._to_copy.default]) -def autocast_to_copy(aten_op, args, kwargs=None): +def _autocast_to_copy(aten_op, args, kwargs=None): """This gets called when running matmul under autocast when the input is a Float8Tensor, presenting as a fp32 tensor. @@ -456,7 +456,7 @@ def autocast_to_copy(aten_op, args, kwargs=None): _c10d_functional.all_gather_into_tensor.default, ] ) -def allgather_fp8(aten_op, args, kwargs=None): +def _allgather_fp8(aten_op, args, kwargs=None): """ override funcol with FP8 handling """ @@ -479,7 +479,7 @@ def allgather_fp8(aten_op, args, kwargs=None): @implements([c10d_functional.wait_tensor.default, _c10d_functional.wait_tensor.default]) -def wait_tensor_fp8(aten_op, args, kwargs=None): +def _wait_tensor_fp8(aten_op, args, kwargs=None): _assert_tensorwise_scale(aten_op, args[0]._scale) fp8_input = args[0] assert isinstance(fp8_input, Float8Tensor) @@ -496,7 +496,7 @@ def wait_tensor_fp8(aten_op, args, kwargs=None): @implements([aten.index_put_.default]) -def index_put_fp8(aten_op, args, kwargs=None): +def _index_put_fp8(aten_op, args, kwargs=None): fp8_self = args[0] fp8_values = args[2] assert isinstance(fp8_self, Float8Tensor) @@ -519,7 +519,7 @@ def index_put_fp8(aten_op, args, kwargs=None): @implements([aten.copy_.default]) -def copy_fp8(aten_op, args, kwargs=None): +def _copy_fp8(aten_op, args, kwargs=None): # For a copy op with Float8Tensors involved, only the following combinations are allowed: # 1. self is a high precision (hp) tensor, src is a Float8Tensor: # in this case src is upcasted and unscaled to go into the hp tensor diff --git a/torchao/float8/float8_scaling_utils.py b/torchao/float8/float8_scaling_utils.py index 31f2db6b4e..77440d5c12 100644 --- a/torchao/float8/float8_scaling_utils.py +++ b/torchao/float8/float8_scaling_utils.py @@ -13,12 +13,12 @@ import torch from torchao.float8.config import ScalingGranularity -from torchao.float8.distributed_utils import tensor_already_casted_to_fp8 +from torchao.float8.distributed_utils import _tensor_already_casted_to_fp8 from torchao.float8.float8_tensor import ( Float8Tensor, GemmInputRole, LinearMMConfig, - hp_tensor_and_scale_to_float8, + _hp_tensor_and_scale_to_float8, ) from torchao.float8.float8_utils import ( tensor_to_scale, @@ -26,7 +26,7 @@ # TODO(danielvegamyhre): refactor to accept Float8LinearConfig directly -def hp_tensor_to_float8_dynamic( +def _hp_tensor_to_float8_dynamic( hp_tensor: torch.Tensor, float8_dtype: torch.dtype, linear_mm_config: LinearMMConfig, @@ -62,7 +62,7 @@ def hp_tensor_to_float8_dynamic( axiswise_dim, round_scales_to_power_of_2, ) - return hp_tensor_and_scale_to_float8( + return _hp_tensor_and_scale_to_float8( hp_tensor, scale, float8_dtype, @@ -72,7 +72,7 @@ def hp_tensor_to_float8_dynamic( ) -def get_maybe_axiswise_dim( +def _get_maybe_axiswise_dim( axiswise_dim: int, scaling_granularity: ScalingGranularity, ) -> Optional[int]: @@ -88,7 +88,7 @@ def get_maybe_axiswise_dim( @torch._dynamo.allow_in_graph -class NoopFwToFloat8BwDynamic(torch.autograd.Function): +class _NoopFwToFloat8BwDynamic(torch.autograd.Function): """ Forward: no-op Backward: convert to float8_e5m2 with dynamic scaling @@ -107,10 +107,10 @@ def forward( @staticmethod def backward(ctx, gradY): - if tensor_already_casted_to_fp8(gradY): + if _tensor_already_casted_to_fp8(gradY): return gradY, None, None gradY_scale = tensor_to_scale(gradY, ctx.target_dtype) - fp8_tensor = hp_tensor_and_scale_to_float8( + fp8_tensor = _hp_tensor_and_scale_to_float8( gradY, gradY_scale, ctx.target_dtype, diff --git a/torchao/float8/float8_tensor.py b/torchao/float8/float8_tensor.py index 6b5177e1fe..194d568df9 100644 --- a/torchao/float8/float8_tensor.py +++ b/torchao/float8/float8_tensor.py @@ -94,7 +94,7 @@ class GemmInputRole(enum.Enum): # choose which scaled_mm_config to use based on gemm inputs -def choose_scaled_mm_config( +def _choose_scaled_mm_config( a_role: GemmInputRole, a_linear_mm_config: LinearMMConfig, b_role: GemmInputRole, @@ -209,7 +209,7 @@ def backward(ctx, g): return g, None, None -def hp_tensor_and_scale_to_float8( +def _hp_tensor_and_scale_to_float8( hp_tensor: torch.Tensor, s: torch.Tensor, float8_dtype: torch.dtype, diff --git a/torchao/float8/float8_tensor_parallel.py b/torchao/float8/float8_tensor_parallel.py index 36ae6d587e..40d1a7a6e6 100644 --- a/torchao/float8/float8_tensor_parallel.py +++ b/torchao/float8/float8_tensor_parallel.py @@ -14,10 +14,10 @@ ) from torchao.float8.config import ScalingType, e4m3_dtype -from torchao.float8.distributed_utils import tensor_already_casted_to_fp8 +from torchao.float8.distributed_utils import _tensor_already_casted_to_fp8 from torchao.float8.float8_scaling_utils import ( - NoopFwToFloat8BwDynamic, - hp_tensor_to_float8_dynamic, + _hp_tensor_to_float8_dynamic, + _NoopFwToFloat8BwDynamic, ) from torchao.float8.float8_tensor import GemmInputRole @@ -56,8 +56,8 @@ def _prepare_input_fn( input_tensor, device_mesh, input_layouts, run_check=False ) - if not tensor_already_casted_to_fp8(input_tensor): - input_tensor = hp_tensor_to_float8_dynamic( + if not _tensor_already_casted_to_fp8(input_tensor): + input_tensor = _hp_tensor_to_float8_dynamic( input_tensor, mod.config.cast_config_input.target_dtype, mod.linear_mm_config, @@ -80,7 +80,7 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me ) # DTensor(torch.Tensor) # fwd noop bwd cast to DTensor(Float8Tensor) - outputs = NoopFwToFloat8BwDynamic.apply( + outputs = _NoopFwToFloat8BwDynamic.apply( outputs, mod.linear_mm_config, mod.config.cast_config_grad_output.target_dtype, @@ -120,8 +120,8 @@ def _prepare_input_fn( input_tensor, device_mesh, input_layouts, run_check=False ) - if not tensor_already_casted_to_fp8(input_tensor): - input_tensor = hp_tensor_to_float8_dynamic( + if not _tensor_already_casted_to_fp8(input_tensor): + input_tensor = _hp_tensor_to_float8_dynamic( input_tensor, mod.config.cast_config_input.target_dtype, mod.linear_mm_config, @@ -143,7 +143,7 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me outputs = outputs.redistribute(placements=output_layouts, async_op=True) # fwd noop bwd cast to DTensor(Float8Tensor) - outputs = NoopFwToFloat8BwDynamic.apply( + outputs = _NoopFwToFloat8BwDynamic.apply( outputs, mod.linear_mm_config, mod.config.cast_config_grad_output.target_dtype, @@ -229,7 +229,7 @@ def _prepare_input_arg(self, input, mesh, input_layout, desired_layout): input, mesh, (input_layout,), run_check=False ) - dt_inp = hp_tensor_to_float8_dynamic( + dt_inp = _hp_tensor_to_float8_dynamic( dt_inp, e4m3_dtype, self.linear_mm_config, diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index 625fb29235..0663899db0 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -171,8 +171,8 @@ def fp8_tensor_statistics( return (num_zero, num_max) -def is_row_major(stride): - assert len(stride) == 2, "is_row_major only supports 2D tensors" +def _is_row_major(stride): + assert len(stride) == 2, "_is_row_major only supports 2D tensors" return stride[0] > stride[1] and stride[1] == 1 @@ -196,7 +196,7 @@ def _get_min_alignment(size: int, alignment_value: int) -> int: return (1 + ((size - 1) // alignment_value)) * alignment_value -def pad_tensor_for_matmul( +def _pad_tensor_for_matmul( tensor: torch.Tensor, dims: Union[int, Iterable[int]] ) -> torch.Tensor: """ @@ -211,11 +211,11 @@ def pad_tensor_for_matmul( Usage: ``` - >>> pad_tensor_for_matmul(torch.randn((10, 10)), dims=0).shape + >>> _pad_tensor_for_matmul(torch.randn((10, 10)), dims=0).shape torch.Size([16, 10]) - >>> pad_tensor_for_matmul(torch.randn((10, 10)), dims=1).shape + >>> _pad_tensor_for_matmul(torch.randn((10, 10)), dims=1).shape torch.Size([10, 16]) - >>> pad_tensor_for_matmul(torch.randn((10, 10)), dims=(0, 1)).shape + >>> _pad_tensor_for_matmul(torch.randn((10, 10)), dims=(0, 1)).shape torch.Size([16, 16]) ``` """ @@ -236,6 +236,10 @@ def pad_tensor_for_matmul( return torch.nn.functional.pad(tensor, (0, pad_dim2, 0, pad_dim1)) +# for BC, confirmed there are users using this util function +pad_tensor_for_matmul = _pad_tensor_for_matmul + + def _round_scale_down_to_power_of_2(scale: torch.Tensor): assert scale.dtype == torch.float32, "scale must be float32 tensor" return torch.exp2(torch.floor(torch.log2(scale))) diff --git a/torchao/float8/fsdp_utils.py b/torchao/float8/fsdp_utils.py index 7b24dc2b53..a63fb2abc3 100644 --- a/torchao/float8/fsdp_utils.py +++ b/torchao/float8/fsdp_utils.py @@ -13,13 +13,13 @@ from torch._prims_common import suggest_memory_format from torchao.float8.float8_scaling_utils import ( - hp_tensor_to_float8_dynamic, + _hp_tensor_to_float8_dynamic, ) from torchao.float8.float8_tensor import ( Float8Tensor, GemmInputRole, LinearMMConfig, - hp_tensor_and_scale_to_float8, + _hp_tensor_and_scale_to_float8, ) from torchao.float8.float8_utils import EPS @@ -217,7 +217,7 @@ def __repr__(self): def fsdp_pre_all_gather(self, mesh): if self._precomputed_scale is not None: - float8_tensor = hp_tensor_and_scale_to_float8( + float8_tensor = _hp_tensor_and_scale_to_float8( self._tensor, self._precomputed_scale, self._dtype, @@ -225,7 +225,7 @@ def fsdp_pre_all_gather(self, mesh): GemmInputRole.WEIGHT, ) else: - float8_tensor = hp_tensor_to_float8_dynamic( + float8_tensor = _hp_tensor_to_float8_dynamic( self._tensor, self._dtype, self._linear_mm_config, diff --git a/torchao/float8/inference.py b/torchao/float8/inference.py index d6e650aa6e..c581004343 100644 --- a/torchao/float8/inference.py +++ b/torchao/float8/inference.py @@ -11,7 +11,7 @@ import torch -from torchao.float8.float8_utils import is_row_major, pad_tensor_for_matmul +from torchao.float8.float8_utils import _is_row_major, _pad_tensor_for_matmul from torchao.quantization.granularity import ( PerRow, PerTensor, @@ -57,16 +57,16 @@ def preprocess_data( assert a_data.size(1) == b_data.size(0), ( f"Inner dims must match for mm, got {a_data.size(1)} and {b_data.size(0)}" ) - a_data = pad_tensor_for_matmul(a_data, dims=1) - b_data = pad_tensor_for_matmul(b_data, dims=0) - if not is_row_major(a_data.stride()): + a_data = _pad_tensor_for_matmul(a_data, dims=1) + b_data = _pad_tensor_for_matmul(b_data, dims=0) + if not _is_row_major(a_data.stride()): a_data = a_data.contiguous() - if is_row_major(b_data.stride()): + if _is_row_major(b_data.stride()): b_data = b_data.t().contiguous().t() return a_data, b_data -def addmm_float8_unwrapped_inference( +def _addmm_float8_unwrapped_inference( a_data: Tensor, a_scale: Tensor, b_data: Tensor, diff --git a/torchao/prototype/float8nocompile/float8nocompile_linear_test.py b/torchao/prototype/float8nocompile/float8nocompile_linear_test.py index 7df5ce768c..032a21829c 100644 --- a/torchao/prototype/float8nocompile/float8nocompile_linear_test.py +++ b/torchao/prototype/float8nocompile/float8nocompile_linear_test.py @@ -7,7 +7,7 @@ import torch from torchao.float8.config import Float8LinearConfig -from torchao.float8.float8_linear import matmul_with_hp_or_float8_args +from torchao.float8.float8_linear import _matmul_with_hp_or_float8_args from torchao.float8.float8_tensor import LinearMMConfig, ScaledMMConfig from torchao.prototype.float8nocompile.float8nocompile_linear import ( matmul_with_args_in_hp, @@ -72,7 +72,7 @@ def test_matmul_with_args_in_hp(input_shape: tuple[int, int]): ) # prod forward. expects transposed weight. - out_prod = matmul_with_hp_or_float8_args.apply( + out_prod = _matmul_with_hp_or_float8_args.apply( prod_input_bf16, prod_weight_bf16.t(), linear_mm_config, config ) diff --git a/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise_test.py b/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise_test.py index 2348877d5c..4e87380968 100644 --- a/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise_test.py +++ b/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise_test.py @@ -6,9 +6,9 @@ import pytest import torch -from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic +from torchao.float8.float8_scaling_utils import _hp_tensor_to_float8_dynamic from torchao.float8.float8_tensor import LinearMMConfig -from torchao.float8.float8_utils import is_row_major +from torchao.float8.float8_utils import _is_row_major from torchao.prototype.float8nocompile.kernels.fp8_dynamic_tensorwise import ( KernelAlgorithm, hp_to_fp8_col_major, @@ -37,7 +37,7 @@ def test_fp8_hp_to_fp8_row_major(input_shape: tuple[int, int], algo: KernelAlgor y_bf16 = input_bf16.clone().detach().to(device) # production implementation - x_fp8_row_major = hp_tensor_to_float8_dynamic( + x_fp8_row_major = _hp_tensor_to_float8_dynamic( x_bf16, torch.float8_e4m3fn, LinearMMConfig(), @@ -64,8 +64,8 @@ def test_fp8_hp_to_fp8_row_major(input_shape: tuple[int, int], algo: KernelAlgor assert x_fp8_row_major.stride() == y_fp8_row_major.stride() # check memory layout - assert is_row_major(x_fp8_row_major.stride()) - assert is_row_major(y_fp8_row_major.stride()) + assert _is_row_major(x_fp8_row_major.stride()) + assert _is_row_major(y_fp8_row_major.stride()) # check underlying memory layout assert ( @@ -100,7 +100,7 @@ def test_fp8_hp_to_fp8_row_major_t(input_shape: tuple[int, int], algo: KernelAlg y_bf16 = input_bf16.clone().detach().to(device) # production implementation - x_fp8_row_major = hp_tensor_to_float8_dynamic( + x_fp8_row_major = _hp_tensor_to_float8_dynamic( x_bf16, torch.float8_e4m3fn, LinearMMConfig(), @@ -128,8 +128,8 @@ def test_fp8_hp_to_fp8_row_major_t(input_shape: tuple[int, int], algo: KernelAlg assert x_fp8_row_major_t.stride() == y_fp8_row_major_t.stride() # check memory layout - assert is_row_major(x_fp8_row_major_t.stride()) - assert is_row_major(y_fp8_row_major_t.stride()) + assert _is_row_major(x_fp8_row_major_t.stride()) + assert _is_row_major(y_fp8_row_major_t.stride()) # check underlying memory layout assert ( @@ -162,7 +162,7 @@ def test_fp8_hp_to_fp8_col_major(input_shape: tuple[int, int], algo: KernelAlgor y_bf16 = input_bf16.clone().detach().to(device) # production implementation - x_fp8_row_major = hp_tensor_to_float8_dynamic( + x_fp8_row_major = _hp_tensor_to_float8_dynamic( x_bf16, torch.float8_e4m3fn, LinearMMConfig(), @@ -190,8 +190,8 @@ def test_fp8_hp_to_fp8_col_major(input_shape: tuple[int, int], algo: KernelAlgor assert x_fp8_col_major.stride() == y_fp8_col_major.stride() # check memory layout - assert not is_row_major(x_fp8_col_major.stride()) - assert not is_row_major(y_fp8_col_major.stride()) + assert not _is_row_major(x_fp8_col_major.stride()) + assert not _is_row_major(y_fp8_col_major.stride()) # check underlying memory layout assert ( @@ -224,7 +224,7 @@ def test_fp8_hp_to_fp8_col_major_t(input_shape: tuple[int, int], algo: KernelAlg y_bf16 = input_bf16.clone().detach().to(device) # production implementation - x_fp8_row_major = hp_tensor_to_float8_dynamic( + x_fp8_row_major = _hp_tensor_to_float8_dynamic( x_bf16, torch.float8_e4m3fn, LinearMMConfig(), @@ -252,8 +252,8 @@ def test_fp8_hp_to_fp8_col_major_t(input_shape: tuple[int, int], algo: KernelAlg assert x_fp8_col_major_t.stride() == y_fp8_col_major_t.stride() # check memory layout - assert not is_row_major(x_fp8_col_major_t.stride()) - assert not is_row_major(y_fp8_col_major_t.stride()) + assert not _is_row_major(x_fp8_col_major_t.stride()) + assert not _is_row_major(y_fp8_col_major_t.stride()) # check underlying memory layout assert ( @@ -288,7 +288,7 @@ def test_fp8_hp_to_fp8_row_and_col_major( y_bf16 = input_bf16.clone().detach().to(device) # production implementation - x_fp8_row_major = hp_tensor_to_float8_dynamic( + x_fp8_row_major = _hp_tensor_to_float8_dynamic( x_bf16, torch.float8_e4m3fn, LinearMMConfig(), @@ -320,10 +320,10 @@ def test_fp8_hp_to_fp8_row_and_col_major( assert x_fp8_col_major.stride() == y_fp8_col_major.stride() # check memory layout - assert is_row_major(x_fp8_row_major.stride()) - assert is_row_major(y_fp8_row_major.stride()) - assert not is_row_major(x_fp8_col_major.stride()) - assert not is_row_major(y_fp8_col_major.stride()) + assert _is_row_major(x_fp8_row_major.stride()) + assert _is_row_major(y_fp8_row_major.stride()) + assert not _is_row_major(x_fp8_col_major.stride()) + assert not _is_row_major(y_fp8_col_major.stride()) # check underlying memory layout assert ( @@ -362,7 +362,7 @@ def test_fp8_hp_to_fp8_row_major_t_and_non_t( y_bf16 = input_bf16.clone().detach().to(device) # production implementation - x_fp8_row_major = hp_tensor_to_float8_dynamic( + x_fp8_row_major = _hp_tensor_to_float8_dynamic( x_bf16, torch.float8_e4m3fn, LinearMMConfig(), @@ -394,10 +394,10 @@ def test_fp8_hp_to_fp8_row_major_t_and_non_t( assert x_fp8_row_major_t.stride() == y_fp8_row_major_t.stride() # check memory layout - assert is_row_major(x_fp8_row_major.stride()) - assert is_row_major(y_fp8_row_major.stride()) - assert is_row_major(x_fp8_row_major_t.stride()) - assert is_row_major(y_fp8_row_major_t.stride()) + assert _is_row_major(x_fp8_row_major.stride()) + assert _is_row_major(y_fp8_row_major.stride()) + assert _is_row_major(x_fp8_row_major_t.stride()) + assert _is_row_major(y_fp8_row_major_t.stride()) # check underlying memory layout assert ( @@ -436,7 +436,7 @@ def test_fp8_hp_to_fp8_col_major_t_and_non_t( y_bf16 = input_bf16.clone().detach().to(device) # production implementation - x_fp8_row_major = hp_tensor_to_float8_dynamic( + x_fp8_row_major = _hp_tensor_to_float8_dynamic( x_bf16, torch.float8_e4m3fn, LinearMMConfig(), @@ -469,10 +469,10 @@ def test_fp8_hp_to_fp8_col_major_t_and_non_t( assert x_fp8_col_major_t.stride() == y_fp8_col_major_t.stride() # check memory layout - assert not is_row_major(x_fp8_col_major.stride()) - assert not is_row_major(y_fp8_col_major.stride()) - assert not is_row_major(x_fp8_col_major_t.stride()) - assert not is_row_major(y_fp8_col_major_t.stride()) + assert not _is_row_major(x_fp8_col_major.stride()) + assert not _is_row_major(y_fp8_col_major.stride()) + assert not _is_row_major(x_fp8_col_major_t.stride()) + assert not _is_row_major(y_fp8_col_major_t.stride()) # check underlying memory layout assert ( From 655fbd7e9e6294b0939a1c84fae4536f168cb659 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Thu, 5 Jun 2025 12:25:25 -0700 Subject: [PATCH 2/2] update import --- test/quantization/test_qat.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 323802757d..74199b65ae 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -18,7 +18,7 @@ from torchao import quantize_ from torchao.float8.config import ScalingGranularity -from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic +from torchao.float8.float8_scaling_utils import _hp_tensor_to_float8_dynamic from torchao.float8.float8_tensor import LinearMMConfig from torchao.quantization.granularity import ( PerAxis, @@ -1703,7 +1703,7 @@ def test_float8_rowwise_fake_quantize(self): x = torch.randn(32, 64) axiswise_dim = 0 out = _Float8RowwiseFakeQuantize.apply(x, dtype, axiswise_dim) - out_expected = hp_tensor_to_float8_dynamic( + out_expected = _hp_tensor_to_float8_dynamic( x, dtype, LinearMMConfig(),