From 5c23c6b112ccc3d24de51c4f37421602a6d0959d Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 20 Jun 2025 07:10:13 -0700 Subject: [PATCH 1/3] Update [ghstack-poisoned] --- test/float8/test_dtensor.py | 2 ++ test/float8/test_fsdp2_tp.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/test/float8/test_dtensor.py b/test/float8/test_dtensor.py index 9db046b749..a9ccb35b79 100644 --- a/test/float8/test_dtensor.py +++ b/test/float8/test_dtensor.py @@ -67,6 +67,8 @@ def setup_distributed(): device_mesh = init_device_mesh("cuda", (world_size,)) # seed must be the same in all processes torch.manual_seed(1) + local_rank = torch.distributed.get_rank() + torch.cuda.set_device(local_rank) return device_mesh diff --git a/test/float8/test_fsdp2_tp.py b/test/float8/test_fsdp2_tp.py index fa3d30410b..f04b791273 100644 --- a/test/float8/test_fsdp2_tp.py +++ b/test/float8/test_fsdp2_tp.py @@ -46,6 +46,8 @@ def setup_distributed(): ) # seed must be the same in all processes torch.manual_seed(1) + local_rank = torch.distributed.get_rank() + torch.cuda.set_device(local_rank) return device_mesh From ad2ce6213a6045b99448e4b1e3cd57d87f43cde0 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 20 Jun 2025 07:30:21 -0700 Subject: [PATCH 2/3] Update [ghstack-poisoned] --- benchmarks/float8/bench_matmul.py | 2 +- benchmarks/float8/float8_roofline.py | 2 +- test/float8/test_base.py | 2 +- test/float8/test_compile.py | 2 +- test/float8/test_dtensor.py | 2 +- test/float8/test_fsdp2/test_fsdp2.py | 5 ++++- test/float8/test_fsdp2_tp.py | 2 +- test/float8/test_numerics_integration.py | 2 +- torchao/testing/{float8 => training}/__init__.py | 0 torchao/testing/{float8 => training}/dtensor_utils.py | 0 torchao/testing/{float8 => training}/fsdp2_utils.py | 0 torchao/testing/{float8 => training}/roofline_utils.py | 0 torchao/testing/{float8 => training}/test_utils.py | 0 13 files changed, 11 insertions(+), 8 deletions(-) rename torchao/testing/{float8 => training}/__init__.py (100%) rename torchao/testing/{float8 => training}/dtensor_utils.py (100%) rename torchao/testing/{float8 => training}/fsdp2_utils.py (100%) rename torchao/testing/{float8 => training}/roofline_utils.py (100%) rename torchao/testing/{float8 => training}/test_utils.py (100%) diff --git a/benchmarks/float8/bench_matmul.py b/benchmarks/float8/bench_matmul.py index e3f19d8f49..cf844fa51b 100644 --- a/benchmarks/float8/bench_matmul.py +++ b/benchmarks/float8/bench_matmul.py @@ -16,7 +16,7 @@ get_name_to_shapes_iter, ) -from torchao.testing.float8.roofline_utils import get_specs +from torchao.testing.training.roofline_utils import get_specs def benchmark_fn_in_sec(f, *args, **kwargs): diff --git a/benchmarks/float8/float8_roofline.py b/benchmarks/float8/float8_roofline.py index f9374f835e..5a8419cde8 100644 --- a/benchmarks/float8/float8_roofline.py +++ b/benchmarks/float8/float8_roofline.py @@ -63,7 +63,7 @@ ) from torchao.prototype.mx_formats import MXLinearConfig from torchao.quantization import quantize_ -from torchao.testing.float8.roofline_utils import ( +from torchao.testing.training.roofline_utils import ( get_float8_mem_sympy, get_gemm_time_sympy, ) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 8e3efeab60..15099dc2c1 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -55,7 +55,7 @@ fp8_tensor_statistics, tensor_to_scale, ) -from torchao.testing.float8.test_utils import get_test_float8_linear_config +from torchao.testing.training.test_utils import get_test_float8_linear_config from torchao.utils import is_MI300, is_ROCM random.seed(0) diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index ac5d1f8d96..aaf9d3d3f5 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -37,7 +37,7 @@ hp_tensor_to_float8_dynamic, ) from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig, ScaledMMConfig -from torchao.testing.float8.test_utils import get_test_float8_linear_config +from torchao.testing.training.test_utils import get_test_float8_linear_config def _test_compile_base( diff --git a/test/float8/test_dtensor.py b/test/float8/test_dtensor.py index a9ccb35b79..e7220bff9f 100644 --- a/test/float8/test_dtensor.py +++ b/test/float8/test_dtensor.py @@ -57,7 +57,7 @@ ) from torchao.float8.float8_utils import tensor_to_scale from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor -from torchao.testing.float8.dtensor_utils import ToyModel +from torchao.testing.training.dtensor_utils import ToyModel torch.set_float32_matmul_precision("high") diff --git a/test/float8/test_fsdp2/test_fsdp2.py b/test/float8/test_fsdp2/test_fsdp2.py index 6f0cfecf41..b4c7f9fd15 100644 --- a/test/float8/test_fsdp2/test_fsdp2.py +++ b/test/float8/test_fsdp2/test_fsdp2.py @@ -43,7 +43,10 @@ from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic from torchao.float8.float8_tensor import GemmInputRole from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor -from torchao.testing.float8.fsdp2_utils import check_parity_bf16_mp, check_parity_no_mp +from torchao.testing.training.fsdp2_utils import ( + check_parity_bf16_mp, + check_parity_no_mp, +) if not is_sm_at_least_89(): pytest.skip("Unsupported CUDA device capability version", allow_module_level=True) diff --git a/test/float8/test_fsdp2_tp.py b/test/float8/test_fsdp2_tp.py index f04b791273..93c7735149 100644 --- a/test/float8/test_fsdp2_tp.py +++ b/test/float8/test_fsdp2_tp.py @@ -32,7 +32,7 @@ Float8ColwiseParallel, Float8RowwiseParallel, ) -from torchao.testing.float8.dtensor_utils import ToyModel +from torchao.testing.training.dtensor_utils import ToyModel def setup_distributed(): diff --git a/test/float8/test_numerics_integration.py b/test/float8/test_numerics_integration.py index f25c876189..db02444109 100644 --- a/test/float8/test_numerics_integration.py +++ b/test/float8/test_numerics_integration.py @@ -33,7 +33,7 @@ convert_to_float8_training, ) from torchao.float8.float8_utils import IS_ROCM, compute_error -from torchao.testing.float8.test_utils import get_test_float8_linear_config +from torchao.testing.training.test_utils import get_test_float8_linear_config torch.manual_seed(0) diff --git a/torchao/testing/float8/__init__.py b/torchao/testing/training/__init__.py similarity index 100% rename from torchao/testing/float8/__init__.py rename to torchao/testing/training/__init__.py diff --git a/torchao/testing/float8/dtensor_utils.py b/torchao/testing/training/dtensor_utils.py similarity index 100% rename from torchao/testing/float8/dtensor_utils.py rename to torchao/testing/training/dtensor_utils.py diff --git a/torchao/testing/float8/fsdp2_utils.py b/torchao/testing/training/fsdp2_utils.py similarity index 100% rename from torchao/testing/float8/fsdp2_utils.py rename to torchao/testing/training/fsdp2_utils.py diff --git a/torchao/testing/float8/roofline_utils.py b/torchao/testing/training/roofline_utils.py similarity index 100% rename from torchao/testing/float8/roofline_utils.py rename to torchao/testing/training/roofline_utils.py diff --git a/torchao/testing/float8/test_utils.py b/torchao/testing/training/test_utils.py similarity index 100% rename from torchao/testing/float8/test_utils.py rename to torchao/testing/training/test_utils.py From 5eb2066824cadea7028939993c0d1ab35ae09c7a Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 20 Jun 2025 07:52:19 -0700 Subject: [PATCH 3/3] Update [ghstack-poisoned] --- test/float8/test_dtensor.py | 160 ++++------------------ torchao/testing/training/dtensor_utils.py | 138 +++++++++++++++++++ 2 files changed, 161 insertions(+), 137 deletions(-) diff --git a/test/float8/test_dtensor.py b/test/float8/test_dtensor.py index e7220bff9f..5509eb1cc2 100644 --- a/test/float8/test_dtensor.py +++ b/test/float8/test_dtensor.py @@ -10,7 +10,6 @@ TODO(future): make this run in CI """ -import copy import os import pytest @@ -23,12 +22,6 @@ from torch.distributed._tensor import DTensor, Replicate, Shard, distribute_tensor from torch.distributed.device_mesh import DeviceMesh, init_device_mesh -from torch.distributed.tensor.parallel import ( - ColwiseParallel, - PrepareModuleInput, - RowwiseParallel, - parallelize_module, -) from torch.testing._internal.distributed._tensor.common_dtensor import ( ModelArgs, Transformer, @@ -50,14 +43,11 @@ LinearMMConfig, hp_tensor_and_scale_to_float8, ) -from torchao.float8.float8_tensor_parallel import ( - Float8ColwiseParallel, - Float8RowwiseParallel, - PrepareFloat8ModuleInput, -) from torchao.float8.float8_utils import tensor_to_scale from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor -from torchao.testing.training.dtensor_utils import ToyModel +from torchao.testing.training.dtensor_utils import ( + _test_lowp_mlp_tensor_parallelism_base, +) torch.set_float32_matmul_precision("high") @@ -193,140 +183,36 @@ def _test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16): loss.backward() -def _test_fp8_mlp_tensor_parallelism_base( - mesh: DeviceMesh, size=16, compile: bool = False, rowwise: bool = False -): - device = mesh.device_type - - if rowwise: - config = Float8LinearConfig.from_recipe_name(Float8LinearRecipeName.ROWWISE) - # hack around config being frozen - # TODO(future PR): we should make this nicer at the config level - object.__setattr__(config, "emulate", True) - else: - config = Float8LinearConfig(emulate=True) - - toy_model = ToyModel().to(device) - toy_model_fp8 = convert_to_float8_training(toy_model, config=config) - - tp_model = copy.deepcopy(toy_model) - tp_model = convert_to_float8_training(tp_model, config=config) - sp_model = copy.deepcopy(toy_model) - sp_model = convert_to_float8_training(sp_model, config=config) - - # For tensorwise scaling, enable float8 all_gather. - # For rowwise scaling, keep high precision all_gather. Motivation for - # not doing float8 all-gather for rowwise: tensors need to be scaled both ways, - # so for float8 all-gather we'd need to send two float8 copies per tensor, - # which is similar # bytes over the wire than just doing bfloat16 all-gather. - if rowwise: - colwise_parallel_cls = ColwiseParallel - rowwise_parallel_cls = RowwiseParallel - prepare_input_cls = PrepareModuleInput - else: - colwise_parallel_cls = Float8ColwiseParallel - rowwise_parallel_cls = Float8RowwiseParallel - prepare_input_cls = PrepareFloat8ModuleInput - - # vanilla TP - tp_model = parallelize_module( - tp_model, - mesh, - { - "ffn.w1": colwise_parallel_cls(), - "ffn.w2": colwise_parallel_cls(), - "ffn.out_proj": rowwise_parallel_cls(), - }, +def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16): + tensorwise_config = Float8LinearConfig(emulate=True) + _test_lowp_mlp_tensor_parallelism_base( + mesh, tensorwise_config, size, compile=False, allgather_in_lowp=True ) - # "sequence parallel" mlp computation - sp_model = parallelize_module( - sp_model, - mesh, - { - "ffn": prepare_input_cls( - input_layouts=Shard(1), desired_input_layouts=Replicate() - ), - "ffn.w1": colwise_parallel_cls(), - "ffn.w2": colwise_parallel_cls(), - "ffn.out_proj": rowwise_parallel_cls( - output_layouts=Shard(1), use_local_output=False - ), - }, + rowwise_config = Float8LinearConfig.from_recipe_name(Float8LinearRecipeName.ROWWISE) + # hack around config being frozen + # TODO(future PR): we should make this nicer at the config level + object.__setattr__(rowwise_config, "emulate", True) + _test_lowp_mlp_tensor_parallelism_base( + mesh, rowwise_config, size, compile=False, allgather_in_lowp=False ) - # prepare_input_cls with specific submodule fqn - sp_model2 = copy.deepcopy(toy_model) - sp_model2 = convert_to_float8_training(sp_model2, config=config) - if rowwise: - prepare_input = prepare_input_cls( - input_layouts=Shard(1), - desired_input_layouts=Replicate(), - ) - else: - prepare_input = prepare_input_cls( - input_layouts=Shard(1), - desired_input_layouts=Replicate(), - fwd_config_submodule_fqn="w2", - ) - - sp_model2 = parallelize_module( - sp_model2, - mesh, - { - "ffn": prepare_input, - "ffn.w1": colwise_parallel_cls(), - "ffn.w2": colwise_parallel_cls(), - "ffn.out_proj": rowwise_parallel_cls( - output_layouts=Shard(1), use_local_output=False - ), - }, - ) - - if compile: - tp_model = torch.compile(tp_model) - sp_model = torch.compile(sp_model) - sp_model2 = torch.compile(sp_model2) - - x_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False) - x_fp32_tp_input = x_fp32.clone() - x_fp32_sp_input = distribute_tensor(x_fp32.clone(), mesh, [Shard(0)]) - - tp_out = tp_model(x_fp32_tp_input) - tp_out.sum().backward() - sp_out = sp_model(x_fp32_sp_input) - sp_out.sum().backward() - global_out = toy_model_fp8(x_fp32) - global_out.sum().backward() - torch.testing.assert_close(tp_out, global_out) - torch.testing.assert_close(sp_out.full_tensor(), global_out) - torch.testing.assert_close(tp_model.ffn.w1.weight.grad, sp_model.ffn.w1.weight.grad) - torch.testing.assert_close( - tp_model.ffn.out_proj.weight.grad, sp_model.ffn.out_proj.weight.grad +def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16): + tensorwise_config = Float8LinearConfig(emulate=True) + _test_lowp_mlp_tensor_parallelism_base( + mesh, tensorwise_config, size, compile=True, allgather_in_lowp=True ) - sp_out2 = sp_model2(x_fp32_sp_input) - sp_out2.sum().backward() - torch.testing.assert_close(sp_out2.full_tensor(), global_out) - torch.testing.assert_close( - tp_model.ffn.w1.weight.grad, sp_model2.ffn.w1.weight.grad - ) - torch.testing.assert_close( - tp_model.ffn.out_proj.weight.grad, sp_model2.ffn.out_proj.weight.grad + rowwise_config = Float8LinearConfig.from_recipe_name(Float8LinearRecipeName.ROWWISE) + # hack around config being frozen + # TODO(future PR): we should make this nicer at the config level + object.__setattr__(rowwise_config, "emulate", True) + _test_lowp_mlp_tensor_parallelism_base( + mesh, rowwise_config, size, compile=True, allgather_in_lowp=False ) -def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16): - _test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=False, rowwise=False) - _test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=False, rowwise=True) - - -def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16): - _test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True, rowwise=False) - _test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True, rowwise=True) - - def _test_distribute_fsdp_tensor_subclass(tp_mesh: DeviceMesh): torch.manual_seed(42) model = Transformer(ModelArgs(dropout_p=0.0, weight_tying=False)).cuda() diff --git a/torchao/testing/training/dtensor_utils.py b/torchao/testing/training/dtensor_utils.py index 84e4095263..7ac0360363 100644 --- a/torchao/testing/training/dtensor_utils.py +++ b/torchao/testing/training/dtensor_utils.py @@ -3,9 +3,27 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +import copy +import torch import torch.nn as nn import torch.nn.functional as F +from torch.distributed._tensor import Replicate, Shard, distribute_tensor +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + PrepareModuleInput, + RowwiseParallel, + parallelize_module, +) + +from torchao.float8 import Float8LinearConfig +from torchao.float8.float8_linear_utils import convert_to_float8_training +from torchao.float8.float8_tensor_parallel import ( + Float8ColwiseParallel, + Float8RowwiseParallel, + PrepareFloat8ModuleInput, +) class FeedForward(nn.Module): @@ -28,3 +46,123 @@ def __init__(self): def forward(self, x): return self.ffn(x) + + +def _test_lowp_mlp_tensor_parallelism_base( + mesh: DeviceMesh, + config: Float8LinearConfig, + size=16, + compile: bool = False, + allgather_in_lowp: bool = False, +): + device = mesh.device_type + + toy_model = ToyModel().to(device) + toy_model_fp8 = convert_to_float8_training(toy_model, config=config) + + tp_model = copy.deepcopy(toy_model) + tp_model = convert_to_float8_training(tp_model, config=config) + sp_model = copy.deepcopy(toy_model) + sp_model = convert_to_float8_training(sp_model, config=config) + + # For tensorwise scaling, enable float8 all_gather. + # For rowwise scaling, keep high precision all_gather. Motivation for + # not doing float8 all-gather for rowwise: tensors need to be scaled both ways, + # so for float8 all-gather we'd need to send two float8 copies per tensor, + # which is similar # bytes over the wire than just doing bfloat16 all-gather. + if not allgather_in_lowp: + colwise_parallel_cls = ColwiseParallel + rowwise_parallel_cls = RowwiseParallel + prepare_input_cls = PrepareModuleInput + else: + colwise_parallel_cls = Float8ColwiseParallel + rowwise_parallel_cls = Float8RowwiseParallel + prepare_input_cls = PrepareFloat8ModuleInput + + # vanilla TP + tp_model = parallelize_module( + tp_model, + mesh, + { + "ffn.w1": colwise_parallel_cls(), + "ffn.w2": colwise_parallel_cls(), + "ffn.out_proj": rowwise_parallel_cls(), + }, + ) + + # "sequence parallel" mlp computation + sp_model = parallelize_module( + sp_model, + mesh, + { + "ffn": prepare_input_cls( + input_layouts=Shard(1), desired_input_layouts=Replicate() + ), + "ffn.w1": colwise_parallel_cls(), + "ffn.w2": colwise_parallel_cls(), + "ffn.out_proj": rowwise_parallel_cls( + output_layouts=Shard(1), use_local_output=False + ), + }, + ) + + # prepare_input_cls with specific submodule fqn + sp_model2 = copy.deepcopy(toy_model) + sp_model2 = convert_to_float8_training(sp_model2, config=config) + + if not allgather_in_lowp: + prepare_input = prepare_input_cls( + input_layouts=Shard(1), + desired_input_layouts=Replicate(), + ) + else: + prepare_input = prepare_input_cls( + input_layouts=Shard(1), + desired_input_layouts=Replicate(), + fwd_config_submodule_fqn="w2", + ) + + sp_model2 = parallelize_module( + sp_model2, + mesh, + { + "ffn": prepare_input, + "ffn.w1": colwise_parallel_cls(), + "ffn.w2": colwise_parallel_cls(), + "ffn.out_proj": rowwise_parallel_cls( + output_layouts=Shard(1), use_local_output=False + ), + }, + ) + + if compile: + tp_model = torch.compile(tp_model) + sp_model = torch.compile(sp_model) + sp_model2 = torch.compile(sp_model2) + + x_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False) + x_fp32_tp_input = x_fp32.clone() + x_fp32_sp_input = distribute_tensor(x_fp32.clone(), mesh, [Shard(0)]) + + tp_out = tp_model(x_fp32_tp_input) + tp_out.sum().backward() + sp_out = sp_model(x_fp32_sp_input) + sp_out.sum().backward() + global_out = toy_model_fp8(x_fp32) + global_out.sum().backward() + torch.testing.assert_close(tp_out, global_out) + torch.testing.assert_close(sp_out.full_tensor(), global_out) + torch.testing.assert_close(tp_model.ffn.w1.weight.grad, sp_model.ffn.w1.weight.grad) + torch.testing.assert_close( + tp_model.ffn.out_proj.weight.grad, sp_model.ffn.out_proj.weight.grad + ) + + sp_out2 = sp_model2(x_fp32_sp_input) + sp_out2.sum().backward() + torch.testing.assert_close(sp_out2.full_tensor(), global_out) + torch.testing.assert_close( + tp_model.ffn.w1.weight.grad, sp_model2.ffn.w1.weight.grad + ) + torch.testing.assert_close( + tp_model.ffn.out_proj.weight.grad, sp_model2.ffn.out_proj.weight.grad + )