Skip to content

make dtensor shared test util more generic #2416

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 24, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/float8/bench_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
get_name_to_shapes_iter,
)

from torchao.testing.float8.roofline_utils import get_specs
from torchao.testing.training.roofline_utils import get_specs


def benchmark_fn_in_sec(f, *args, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/float8/float8_roofline.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
)
from torchao.prototype.mx_formats import MXLinearConfig
from torchao.quantization import quantize_
from torchao.testing.float8.roofline_utils import (
from torchao.testing.training.roofline_utils import (
get_float8_mem_sympy,
get_gemm_time_sympy,
)
Expand Down
2 changes: 1 addition & 1 deletion test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
fp8_tensor_statistics,
tensor_to_scale,
)
from torchao.testing.float8.test_utils import get_test_float8_linear_config
from torchao.testing.training.test_utils import get_test_float8_linear_config
from torchao.utils import is_MI300, is_ROCM

random.seed(0)
Expand Down
2 changes: 1 addition & 1 deletion test/float8/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
hp_tensor_to_float8_dynamic,
)
from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig, ScaledMMConfig
from torchao.testing.float8.test_utils import get_test_float8_linear_config
from torchao.testing.training.test_utils import get_test_float8_linear_config


def _test_compile_base(
Expand Down
160 changes: 23 additions & 137 deletions test/float8/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
TODO(future): make this run in CI
"""

import copy
import os

import pytest
Expand All @@ -23,12 +22,6 @@

from torch.distributed._tensor import DTensor, Replicate, Shard, distribute_tensor
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.tensor.parallel import (
ColwiseParallel,
PrepareModuleInput,
RowwiseParallel,
parallelize_module,
)
from torch.testing._internal.distributed._tensor.common_dtensor import (
ModelArgs,
Transformer,
Expand All @@ -50,14 +43,11 @@
LinearMMConfig,
hp_tensor_and_scale_to_float8,
)
from torchao.float8.float8_tensor_parallel import (
Float8ColwiseParallel,
Float8RowwiseParallel,
PrepareFloat8ModuleInput,
)
from torchao.float8.float8_utils import tensor_to_scale
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor
from torchao.testing.float8.dtensor_utils import ToyModel
from torchao.testing.training.dtensor_utils import (
_test_lowp_mlp_tensor_parallelism_base,
)

torch.set_float32_matmul_precision("high")

Expand Down Expand Up @@ -193,140 +183,36 @@ def _test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16):
loss.backward()


def _test_fp8_mlp_tensor_parallelism_base(
mesh: DeviceMesh, size=16, compile: bool = False, rowwise: bool = False
):
device = mesh.device_type

if rowwise:
config = Float8LinearConfig.from_recipe_name(Float8LinearRecipeName.ROWWISE)
# hack around config being frozen
# TODO(future PR): we should make this nicer at the config level
object.__setattr__(config, "emulate", True)
else:
config = Float8LinearConfig(emulate=True)

toy_model = ToyModel().to(device)
toy_model_fp8 = convert_to_float8_training(toy_model, config=config)

tp_model = copy.deepcopy(toy_model)
tp_model = convert_to_float8_training(tp_model, config=config)
sp_model = copy.deepcopy(toy_model)
sp_model = convert_to_float8_training(sp_model, config=config)

# For tensorwise scaling, enable float8 all_gather.
# For rowwise scaling, keep high precision all_gather. Motivation for
# not doing float8 all-gather for rowwise: tensors need to be scaled both ways,
# so for float8 all-gather we'd need to send two float8 copies per tensor,
# which is similar # bytes over the wire than just doing bfloat16 all-gather.
if rowwise:
colwise_parallel_cls = ColwiseParallel
rowwise_parallel_cls = RowwiseParallel
prepare_input_cls = PrepareModuleInput
else:
colwise_parallel_cls = Float8ColwiseParallel
rowwise_parallel_cls = Float8RowwiseParallel
prepare_input_cls = PrepareFloat8ModuleInput

# vanilla TP
tp_model = parallelize_module(
tp_model,
mesh,
{
"ffn.w1": colwise_parallel_cls(),
"ffn.w2": colwise_parallel_cls(),
"ffn.out_proj": rowwise_parallel_cls(),
},
def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16):
tensorwise_config = Float8LinearConfig(emulate=True)
_test_lowp_mlp_tensor_parallelism_base(
mesh, tensorwise_config, size, compile=False, allgather_in_lowp=True
)

# "sequence parallel" mlp computation
sp_model = parallelize_module(
sp_model,
mesh,
{
"ffn": prepare_input_cls(
input_layouts=Shard(1), desired_input_layouts=Replicate()
),
"ffn.w1": colwise_parallel_cls(),
"ffn.w2": colwise_parallel_cls(),
"ffn.out_proj": rowwise_parallel_cls(
output_layouts=Shard(1), use_local_output=False
),
},
rowwise_config = Float8LinearConfig.from_recipe_name(Float8LinearRecipeName.ROWWISE)
# hack around config being frozen
# TODO(future PR): we should make this nicer at the config level
object.__setattr__(rowwise_config, "emulate", True)
_test_lowp_mlp_tensor_parallelism_base(
mesh, rowwise_config, size, compile=False, allgather_in_lowp=False
)

# prepare_input_cls with specific submodule fqn
sp_model2 = copy.deepcopy(toy_model)
sp_model2 = convert_to_float8_training(sp_model2, config=config)

if rowwise:
prepare_input = prepare_input_cls(
input_layouts=Shard(1),
desired_input_layouts=Replicate(),
)
else:
prepare_input = prepare_input_cls(
input_layouts=Shard(1),
desired_input_layouts=Replicate(),
fwd_config_submodule_fqn="w2",
)

sp_model2 = parallelize_module(
sp_model2,
mesh,
{
"ffn": prepare_input,
"ffn.w1": colwise_parallel_cls(),
"ffn.w2": colwise_parallel_cls(),
"ffn.out_proj": rowwise_parallel_cls(
output_layouts=Shard(1), use_local_output=False
),
},
)

if compile:
tp_model = torch.compile(tp_model)
sp_model = torch.compile(sp_model)
sp_model2 = torch.compile(sp_model2)

x_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False)
x_fp32_tp_input = x_fp32.clone()
x_fp32_sp_input = distribute_tensor(x_fp32.clone(), mesh, [Shard(0)])

tp_out = tp_model(x_fp32_tp_input)
tp_out.sum().backward()
sp_out = sp_model(x_fp32_sp_input)
sp_out.sum().backward()
global_out = toy_model_fp8(x_fp32)
global_out.sum().backward()
torch.testing.assert_close(tp_out, global_out)
torch.testing.assert_close(sp_out.full_tensor(), global_out)
torch.testing.assert_close(tp_model.ffn.w1.weight.grad, sp_model.ffn.w1.weight.grad)
torch.testing.assert_close(
tp_model.ffn.out_proj.weight.grad, sp_model.ffn.out_proj.weight.grad
def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16):
tensorwise_config = Float8LinearConfig(emulate=True)
_test_lowp_mlp_tensor_parallelism_base(
mesh, tensorwise_config, size, compile=True, allgather_in_lowp=True
)

sp_out2 = sp_model2(x_fp32_sp_input)
sp_out2.sum().backward()
torch.testing.assert_close(sp_out2.full_tensor(), global_out)
torch.testing.assert_close(
tp_model.ffn.w1.weight.grad, sp_model2.ffn.w1.weight.grad
)
torch.testing.assert_close(
tp_model.ffn.out_proj.weight.grad, sp_model2.ffn.out_proj.weight.grad
rowwise_config = Float8LinearConfig.from_recipe_name(Float8LinearRecipeName.ROWWISE)
# hack around config being frozen
# TODO(future PR): we should make this nicer at the config level
object.__setattr__(rowwise_config, "emulate", True)
_test_lowp_mlp_tensor_parallelism_base(
mesh, rowwise_config, size, compile=True, allgather_in_lowp=False
)


def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16):
_test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=False, rowwise=False)
_test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=False, rowwise=True)


def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16):
_test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True, rowwise=False)
_test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True, rowwise=True)


def _test_distribute_fsdp_tensor_subclass(tp_mesh: DeviceMesh):
torch.manual_seed(42)
model = Transformer(ModelArgs(dropout_p=0.0, weight_tying=False)).cuda()
Expand Down
5 changes: 4 additions & 1 deletion test/float8/test_fsdp2/test_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@
from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic
from torchao.float8.float8_tensor import GemmInputRole
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor
from torchao.testing.float8.fsdp2_utils import check_parity_bf16_mp, check_parity_no_mp
from torchao.testing.training.fsdp2_utils import (
check_parity_bf16_mp,
check_parity_no_mp,
)

if not is_sm_at_least_89():
pytest.skip("Unsupported CUDA device capability version", allow_module_level=True)
Expand Down
2 changes: 1 addition & 1 deletion test/float8/test_fsdp2_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
Float8ColwiseParallel,
Float8RowwiseParallel,
)
from torchao.testing.float8.dtensor_utils import ToyModel
from torchao.testing.training.dtensor_utils import ToyModel


def setup_distributed():
Expand Down
2 changes: 1 addition & 1 deletion test/float8/test_numerics_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
convert_to_float8_training,
)
from torchao.float8.float8_utils import IS_ROCM, compute_error
from torchao.testing.float8.test_utils import get_test_float8_linear_config
from torchao.testing.training.test_utils import get_test_float8_linear_config

torch.manual_seed(0)

Expand Down
30 changes: 0 additions & 30 deletions torchao/testing/float8/dtensor_utils.py

This file was deleted.

Loading
Loading