From 5c23c6b112ccc3d24de51c4f37421602a6d0959d Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 20 Jun 2025 07:10:13 -0700 Subject: [PATCH 1/2] Update [ghstack-poisoned] --- test/float8/test_dtensor.py | 2 ++ test/float8/test_fsdp2_tp.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/test/float8/test_dtensor.py b/test/float8/test_dtensor.py index 9db046b749..a9ccb35b79 100644 --- a/test/float8/test_dtensor.py +++ b/test/float8/test_dtensor.py @@ -67,6 +67,8 @@ def setup_distributed(): device_mesh = init_device_mesh("cuda", (world_size,)) # seed must be the same in all processes torch.manual_seed(1) + local_rank = torch.distributed.get_rank() + torch.cuda.set_device(local_rank) return device_mesh diff --git a/test/float8/test_fsdp2_tp.py b/test/float8/test_fsdp2_tp.py index fa3d30410b..f04b791273 100644 --- a/test/float8/test_fsdp2_tp.py +++ b/test/float8/test_fsdp2_tp.py @@ -46,6 +46,8 @@ def setup_distributed(): ) # seed must be the same in all processes torch.manual_seed(1) + local_rank = torch.distributed.get_rank() + torch.cuda.set_device(local_rank) return device_mesh From ad2ce6213a6045b99448e4b1e3cd57d87f43cde0 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 20 Jun 2025 07:30:21 -0700 Subject: [PATCH 2/2] Update [ghstack-poisoned] --- benchmarks/float8/bench_matmul.py | 2 +- benchmarks/float8/float8_roofline.py | 2 +- test/float8/test_base.py | 2 +- test/float8/test_compile.py | 2 +- test/float8/test_dtensor.py | 2 +- test/float8/test_fsdp2/test_fsdp2.py | 5 ++++- test/float8/test_fsdp2_tp.py | 2 +- test/float8/test_numerics_integration.py | 2 +- torchao/testing/{float8 => training}/__init__.py | 0 torchao/testing/{float8 => training}/dtensor_utils.py | 0 torchao/testing/{float8 => training}/fsdp2_utils.py | 0 torchao/testing/{float8 => training}/roofline_utils.py | 0 torchao/testing/{float8 => training}/test_utils.py | 0 13 files changed, 11 insertions(+), 8 deletions(-) rename torchao/testing/{float8 => training}/__init__.py (100%) rename torchao/testing/{float8 => training}/dtensor_utils.py (100%) rename torchao/testing/{float8 => training}/fsdp2_utils.py (100%) rename torchao/testing/{float8 => training}/roofline_utils.py (100%) rename torchao/testing/{float8 => training}/test_utils.py (100%) diff --git a/benchmarks/float8/bench_matmul.py b/benchmarks/float8/bench_matmul.py index e3f19d8f49..cf844fa51b 100644 --- a/benchmarks/float8/bench_matmul.py +++ b/benchmarks/float8/bench_matmul.py @@ -16,7 +16,7 @@ get_name_to_shapes_iter, ) -from torchao.testing.float8.roofline_utils import get_specs +from torchao.testing.training.roofline_utils import get_specs def benchmark_fn_in_sec(f, *args, **kwargs): diff --git a/benchmarks/float8/float8_roofline.py b/benchmarks/float8/float8_roofline.py index f9374f835e..5a8419cde8 100644 --- a/benchmarks/float8/float8_roofline.py +++ b/benchmarks/float8/float8_roofline.py @@ -63,7 +63,7 @@ ) from torchao.prototype.mx_formats import MXLinearConfig from torchao.quantization import quantize_ -from torchao.testing.float8.roofline_utils import ( +from torchao.testing.training.roofline_utils import ( get_float8_mem_sympy, get_gemm_time_sympy, ) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 8e3efeab60..15099dc2c1 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -55,7 +55,7 @@ fp8_tensor_statistics, tensor_to_scale, ) -from torchao.testing.float8.test_utils import get_test_float8_linear_config +from torchao.testing.training.test_utils import get_test_float8_linear_config from torchao.utils import is_MI300, is_ROCM random.seed(0) diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index ac5d1f8d96..aaf9d3d3f5 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -37,7 +37,7 @@ hp_tensor_to_float8_dynamic, ) from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig, ScaledMMConfig -from torchao.testing.float8.test_utils import get_test_float8_linear_config +from torchao.testing.training.test_utils import get_test_float8_linear_config def _test_compile_base( diff --git a/test/float8/test_dtensor.py b/test/float8/test_dtensor.py index a9ccb35b79..e7220bff9f 100644 --- a/test/float8/test_dtensor.py +++ b/test/float8/test_dtensor.py @@ -57,7 +57,7 @@ ) from torchao.float8.float8_utils import tensor_to_scale from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor -from torchao.testing.float8.dtensor_utils import ToyModel +from torchao.testing.training.dtensor_utils import ToyModel torch.set_float32_matmul_precision("high") diff --git a/test/float8/test_fsdp2/test_fsdp2.py b/test/float8/test_fsdp2/test_fsdp2.py index 6f0cfecf41..b4c7f9fd15 100644 --- a/test/float8/test_fsdp2/test_fsdp2.py +++ b/test/float8/test_fsdp2/test_fsdp2.py @@ -43,7 +43,10 @@ from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic from torchao.float8.float8_tensor import GemmInputRole from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor -from torchao.testing.float8.fsdp2_utils import check_parity_bf16_mp, check_parity_no_mp +from torchao.testing.training.fsdp2_utils import ( + check_parity_bf16_mp, + check_parity_no_mp, +) if not is_sm_at_least_89(): pytest.skip("Unsupported CUDA device capability version", allow_module_level=True) diff --git a/test/float8/test_fsdp2_tp.py b/test/float8/test_fsdp2_tp.py index f04b791273..93c7735149 100644 --- a/test/float8/test_fsdp2_tp.py +++ b/test/float8/test_fsdp2_tp.py @@ -32,7 +32,7 @@ Float8ColwiseParallel, Float8RowwiseParallel, ) -from torchao.testing.float8.dtensor_utils import ToyModel +from torchao.testing.training.dtensor_utils import ToyModel def setup_distributed(): diff --git a/test/float8/test_numerics_integration.py b/test/float8/test_numerics_integration.py index f25c876189..db02444109 100644 --- a/test/float8/test_numerics_integration.py +++ b/test/float8/test_numerics_integration.py @@ -33,7 +33,7 @@ convert_to_float8_training, ) from torchao.float8.float8_utils import IS_ROCM, compute_error -from torchao.testing.float8.test_utils import get_test_float8_linear_config +from torchao.testing.training.test_utils import get_test_float8_linear_config torch.manual_seed(0) diff --git a/torchao/testing/float8/__init__.py b/torchao/testing/training/__init__.py similarity index 100% rename from torchao/testing/float8/__init__.py rename to torchao/testing/training/__init__.py diff --git a/torchao/testing/float8/dtensor_utils.py b/torchao/testing/training/dtensor_utils.py similarity index 100% rename from torchao/testing/float8/dtensor_utils.py rename to torchao/testing/training/dtensor_utils.py diff --git a/torchao/testing/float8/fsdp2_utils.py b/torchao/testing/training/fsdp2_utils.py similarity index 100% rename from torchao/testing/float8/fsdp2_utils.py rename to torchao/testing/training/fsdp2_utils.py diff --git a/torchao/testing/float8/roofline_utils.py b/torchao/testing/training/roofline_utils.py similarity index 100% rename from torchao/testing/float8/roofline_utils.py rename to torchao/testing/training/roofline_utils.py diff --git a/torchao/testing/float8/test_utils.py b/torchao/testing/training/test_utils.py similarity index 100% rename from torchao/testing/float8/test_utils.py rename to torchao/testing/training/test_utils.py