diff --git a/benchmarks/float8/bench_matmul.py b/benchmarks/float8/bench_matmul.py index e3f19d8f49..cf844fa51b 100644 --- a/benchmarks/float8/bench_matmul.py +++ b/benchmarks/float8/bench_matmul.py @@ -16,7 +16,7 @@ get_name_to_shapes_iter, ) -from torchao.testing.float8.roofline_utils import get_specs +from torchao.testing.training.roofline_utils import get_specs def benchmark_fn_in_sec(f, *args, **kwargs): diff --git a/benchmarks/float8/float8_roofline.py b/benchmarks/float8/float8_roofline.py index f9374f835e..5a8419cde8 100644 --- a/benchmarks/float8/float8_roofline.py +++ b/benchmarks/float8/float8_roofline.py @@ -63,7 +63,7 @@ ) from torchao.prototype.mx_formats import MXLinearConfig from torchao.quantization import quantize_ -from torchao.testing.float8.roofline_utils import ( +from torchao.testing.training.roofline_utils import ( get_float8_mem_sympy, get_gemm_time_sympy, ) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 8e3efeab60..15099dc2c1 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -55,7 +55,7 @@ fp8_tensor_statistics, tensor_to_scale, ) -from torchao.testing.float8.test_utils import get_test_float8_linear_config +from torchao.testing.training.test_utils import get_test_float8_linear_config from torchao.utils import is_MI300, is_ROCM random.seed(0) diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index ac5d1f8d96..aaf9d3d3f5 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -37,7 +37,7 @@ hp_tensor_to_float8_dynamic, ) from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig, ScaledMMConfig -from torchao.testing.float8.test_utils import get_test_float8_linear_config +from torchao.testing.training.test_utils import get_test_float8_linear_config def _test_compile_base( diff --git a/test/float8/test_dtensor.py b/test/float8/test_dtensor.py index a9ccb35b79..e7220bff9f 100644 --- a/test/float8/test_dtensor.py +++ b/test/float8/test_dtensor.py @@ -57,7 +57,7 @@ ) from torchao.float8.float8_utils import tensor_to_scale from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor -from torchao.testing.float8.dtensor_utils import ToyModel +from torchao.testing.training.dtensor_utils import ToyModel torch.set_float32_matmul_precision("high") diff --git a/test/float8/test_fsdp2/test_fsdp2.py b/test/float8/test_fsdp2/test_fsdp2.py index 6f0cfecf41..b4c7f9fd15 100644 --- a/test/float8/test_fsdp2/test_fsdp2.py +++ b/test/float8/test_fsdp2/test_fsdp2.py @@ -43,7 +43,10 @@ from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic from torchao.float8.float8_tensor import GemmInputRole from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor -from torchao.testing.float8.fsdp2_utils import check_parity_bf16_mp, check_parity_no_mp +from torchao.testing.training.fsdp2_utils import ( + check_parity_bf16_mp, + check_parity_no_mp, +) if not is_sm_at_least_89(): pytest.skip("Unsupported CUDA device capability version", allow_module_level=True) diff --git a/test/float8/test_fsdp2_tp.py b/test/float8/test_fsdp2_tp.py index f04b791273..93c7735149 100644 --- a/test/float8/test_fsdp2_tp.py +++ b/test/float8/test_fsdp2_tp.py @@ -32,7 +32,7 @@ Float8ColwiseParallel, Float8RowwiseParallel, ) -from torchao.testing.float8.dtensor_utils import ToyModel +from torchao.testing.training.dtensor_utils import ToyModel def setup_distributed(): diff --git a/test/float8/test_numerics_integration.py b/test/float8/test_numerics_integration.py index f25c876189..db02444109 100644 --- a/test/float8/test_numerics_integration.py +++ b/test/float8/test_numerics_integration.py @@ -33,7 +33,7 @@ convert_to_float8_training, ) from torchao.float8.float8_utils import IS_ROCM, compute_error -from torchao.testing.float8.test_utils import get_test_float8_linear_config +from torchao.testing.training.test_utils import get_test_float8_linear_config torch.manual_seed(0) diff --git a/torchao/testing/float8/__init__.py b/torchao/testing/training/__init__.py similarity index 100% rename from torchao/testing/float8/__init__.py rename to torchao/testing/training/__init__.py diff --git a/torchao/testing/float8/dtensor_utils.py b/torchao/testing/training/dtensor_utils.py similarity index 100% rename from torchao/testing/float8/dtensor_utils.py rename to torchao/testing/training/dtensor_utils.py diff --git a/torchao/testing/float8/fsdp2_utils.py b/torchao/testing/training/fsdp2_utils.py similarity index 100% rename from torchao/testing/float8/fsdp2_utils.py rename to torchao/testing/training/fsdp2_utils.py diff --git a/torchao/testing/float8/roofline_utils.py b/torchao/testing/training/roofline_utils.py similarity index 100% rename from torchao/testing/float8/roofline_utils.py rename to torchao/testing/training/roofline_utils.py diff --git a/torchao/testing/float8/test_utils.py b/torchao/testing/training/test_utils.py similarity index 100% rename from torchao/testing/float8/test_utils.py rename to torchao/testing/training/test_utils.py