Skip to content

enable tensor parallelism for MXLinear #2434

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Jun 24, 2025
96 changes: 96 additions & 0 deletions test/prototype/mx_formats/test_mx_dtensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
"""
Test numerics of manually defined float16 TP vs mxfp8 TP of toy models

Note: for now, this does not run in CI.
TODO(future): make this run in CI
"""

import os

import pytest
import torch

from torchao.utils import TORCH_VERSION_AT_LEAST_2_7

if not TORCH_VERSION_AT_LEAST_2_7:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)

from torch.distributed._tensor import DTensor, Shard, distribute_tensor
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from tqdm import tqdm

from torchao.prototype.mx_formats import MXLinearConfig
from torchao.prototype.mx_formats.mx_tensor import MXTensor
from torchao.testing.training.dtensor_utils import (
_test_lowp_mlp_tensor_parallelism_base,
)

torch.set_float32_matmul_precision("high")


def setup_distributed():
world_size = int(os.environ.get("WORLD_SIZE", -1))
device_mesh = init_device_mesh("cuda", (world_size,))
# seed must be the same in all processes
torch.manual_seed(1)
local_rank = torch.distributed.get_rank()
torch.cuda.set_device(local_rank)
return device_mesh


def _test_dtensor_cast_to_mxfp8(mesh: DeviceMesh, size=4):
device = mesh.device_type

x_fp32 = torch.rand(size, size, device=device)
x_fp8 = MXTensor.to_mx(x_fp32, torch.float8_e4m3fn, block_size=size // 2)

dist_x_fp32 = distribute_tensor(x_fp32, mesh, [Shard(0)])
dist_x_fp8 = MXTensor.to_mx(dist_x_fp32, torch.float8_e4m3fn, block_size=size // 2)
assert isinstance(dist_x_fp8, DTensor)

# Verify that the result of to_mx with DTensor matches the slice of the
# result of to_mx without DTensor. This will fail on numeric op mismatches.
local_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
assert size % world_size == 0, "unsupported"
x_fp8_fp32 = x_fp8.to_dtype(torch.float32)
rows_per_slice = size // world_size
slice_start = local_rank * rows_per_slice
slice_end = (local_rank + 1) * rows_per_slice
x_fp8_fp32_slice = x_fp8_fp32[slice_start:slice_end]
torch.testing.assert_close(
x_fp8_fp32_slice, dist_x_fp8.to_local().to_dtype(torch.float32), atol=0, rtol=0
)


def _test_mxfp8_mlp_tensor_parallelism(mesh: DeviceMesh, size=16):
config = MXLinearConfig.from_recipe_name("mxfp8_emulated")
config.block_size = 16
_test_lowp_mlp_tensor_parallelism_base(
mesh, config, size, compile=False, allgather_in_lowp=False
)
_test_lowp_mlp_tensor_parallelism_base(
mesh, config, size, compile=True, allgather_in_lowp=False
)


if __name__ == "__main__":
device_mesh = setup_distributed()
tests = [
_test_dtensor_cast_to_mxfp8,
_test_mxfp8_mlp_tensor_parallelism,
]

for test in tqdm(tests, desc="Running tests"):
try:
test(device_mesh)
except Exception as e:
print(f"Test {test.__name__} failed with error: {e}")
raise e

torch.distributed.destroy_process_group()
17 changes: 17 additions & 0 deletions test/prototype/mx_formats/test_mx_dtensor.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
#!/bin/bash

# terminate script on first error
set -e

if python -c 'import torch;print(torch.cuda.is_available())' | grep -q "False"; then
echo "Skipping test_dtensor.sh because no CUDA devices are available."
exit
fi

# integration tests for TP/SP
NCCL_DEBUG=WARN torchrun --nproc_per_node 2 test/prototype/mx_formats/test_mx_dtensor.py
4 changes: 2 additions & 2 deletions test/prototype/mx_formats/test_mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,8 @@ def test_linear_eager_emulated_vs_real_gemm(recipe_name, mkn):
# TODO(future): enable compile support
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_activation_checkpointing():
input_shape = (2, 4)
grad_shape = (2, 8)
input_shape = (16, 4)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was broken before, caught by enforcing that inner dim is divisible by block size

grad_shape = (16, 8)
elem_dtype = torch.float8_e4m3fn

m = nn.Sequential(
Expand Down
2 changes: 1 addition & 1 deletion test/prototype/mx_formats/test_mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def assert_sqnr_gt_threshold(orig, new, threshold):
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
def test_hello_world(elem_dtype):
data = torch.randn(4, 4, device="cuda", dtype=torch.bfloat16)
data = torch.randn(8, 8, device="cuda", dtype=torch.bfloat16)
block_size = 4
_test_mx(data, elem_dtype, block_size)

Expand Down
11 changes: 5 additions & 6 deletions torchao/prototype/mx_formats/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,7 +1056,7 @@ def pack_uint6(uint8_data: torch.Tensor) -> torch.Tensor:

# effective mx block size since we're packing 2 fp4 into 1 uint8
packed_mx_block_size = 3 * mx_block_size // 4
packed_shape = [uint8_data.shape[0], packed_mx_block_size]
packed_shape = [*uint8_data.shape[:-1], packed_mx_block_size]
n_mx_blocks = uint8_data.numel() // mx_block_size

grid = lambda meta: (triton.cdiv(n_mx_blocks, meta["BLOCK_SIZE_IN"]),)
Expand Down Expand Up @@ -1102,15 +1102,12 @@ def _triton_calculate_scale(x, axis):
bf16_mbits = 7
bf16_exp_bias = 127
fp32_mbits = 23
# We use a small epsilon to avoid division by zero
epsilon = 1e-10

# Find the maximum absolute value for each row
max_abs = tl.max(x, axis=axis)

# Calculate the e8m0 scale by extracting the exponent (floor)
# TODO(future PR): support other exponent extraction types (ceil, RNE)
max_abs = max_abs + epsilon
max_abs = max_abs.to(tl.bfloat16)
max_abs_int16 = max_abs.to(tl.int16, bitcast=True)
extracted_pow2 = ((max_abs_int16 >> bf16_mbits) & 0b11111111) - bf16_exp_bias
Expand Down Expand Up @@ -1340,7 +1337,9 @@ def triton_to_mxfp8_dim1(

# Create scale tensors
col_scale = torch.empty(
(n_cols * n_rows // inner_block_size, 1), dtype=torch.uint8, device=x.device
(n_cols, n_rows // inner_block_size, 1),
dtype=torch.uint8,
device=x.device,
)

# Calculate grid dimensions based on tile size
Expand Down Expand Up @@ -1377,7 +1376,7 @@ def triton_to_mxfp8_dim1_reference(
scale_e8m0_dim1, x_hp_d1_normalized = to_mx(
x_hp_d1, torch.float8_e4m3fn, block_size
)
scale_e8m0_dim1 = scale_e8m0_dim1.unsqueeze(1).view(torch.float8_e8m0fnu)
scale_e8m0_dim1 = scale_e8m0_dim1.view(torch.float8_e8m0fnu)
return (
x_hp_d1_normalized.t(),
scale_e8m0_dim1,
Expand Down
91 changes: 59 additions & 32 deletions torchao/prototype/mx_formats/mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
from typing import Callable, Dict, Union

import torch
from torch.distributed._tensor import DTensor

from torchao.prototype.mx_formats.config import MXGemmKernelChoice
from torchao.prototype.mx_formats.constants import (
BF16_EXP_BIAS,
BLOCK_SIZE_DEFAULT,
DTYPE_FP6_E2M3,
DTYPE_FP6_E3M2,
Expand Down Expand Up @@ -61,7 +61,6 @@

# TODO(later): read from somewhere else?
SBITS, EBITS_F32, MBITS_F32 = 1, 8, 23
EBITS_BF16, MBITS_BF16 = 8, 7
EBITS_F4_E2M1, MBITS_F4_E2M1 = 2, 1
EBITS_F6_E2M3, MBITS_F6_E2M3 = 2, 3
EBITS_F6_E3M2, MBITS_F6_E3M2 = 3, 2
Expand Down Expand Up @@ -136,9 +135,7 @@ def _to_mx_rceil(
)

# scale and saturated cast the data elements to max of target dtype
data_lp = torch.clamp(
data_hp * descale_fp.unsqueeze(1), min=-1 * max_pos, max=max_pos
)
data_lp = torch.clamp(data_hp * descale_fp, min=-1 * max_pos, max=max_pos)
return exponent, data_lp


Expand All @@ -159,24 +156,33 @@ def to_mx(
torch.float,
), f"{data_hp.dtype} is not supported yet"
# TODO(future PR): consider supporting padding
assert data_hp.numel() % block_size == 0, "unsupported"
assert data_hp.shape[-1] % block_size == 0, (
f"the last dimension of shape {data_hp.shape} must be divisible by block_size {block_size}"
)
assert data_hp.is_contiguous(), "unsupported"
assert elem_dtype in SUPPORTED_ELEM_DTYPES, "unsupported"

# calculate the scale in e8m0 format

orig_shape = data_hp.shape
data_hp = data_hp.reshape(-1, block_size)
data_hp = data_hp.reshape(
*orig_shape[:-1], orig_shape[-1] // block_size, block_size
)

# find max value of the data
# Note: this only implements the `minimally supported` version of
# https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
# section 6.3.
max_abs = torch.amax(torch.abs(data_hp), 1)

# Add an epsilon to prevent the log2 function call for returning -inf
# where the values are zero.
eps = F32_MIN_NORMAL * (max_abs == 0).type(max_abs.dtype)
max_abs = torch.amax(torch.abs(data_hp), -1).unsqueeze(-1)

# We cast to float32 here because
# in the `max_abs_int32 = max_abs.view(hp_int_dtype)` line below,
# if tensor parallel is enabled then the resulting shape is 2x larger
# than it should be under some conditions, likely because of a bug in
# the `view` op with DTensor and target dtype int16. I reproduce in
# torchtitan but not in a unit test, so not enough info to file a good
# issue in pytorch/pytorch. For now, work around. In the future we should
# debug and fix this properly.
data_hp = data_hp.to(torch.float32)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

performance testing showed that with compile on, having this in float32 does not regress performance

max_abs = max_abs.to(torch.float32)

# Set X to be the largest power-of-two less than or equal to
# max_abs(v), divided by the largest power of two representable
Expand Down Expand Up @@ -207,17 +213,11 @@ def to_mx(
if scaling_mode == ScaleCalculationMode.RCEIL:
scale_e8m0_biased, data_lp = _to_mx_rceil(data_hp, max_abs, max_pos)
else:
if data_hp.dtype is torch.float32:
hp_int_dtype = torch.int32
hp_mbits = MBITS_F32
hp_ebits = EBITS_F32
hp_exp_bias = F32_EXP_BIAS
else:
assert data_hp.dtype is torch.bfloat16
hp_int_dtype = torch.int16
hp_mbits = MBITS_BF16
hp_ebits = EBITS_BF16
hp_exp_bias = BF16_EXP_BIAS
assert data_hp.dtype is torch.float32
hp_int_dtype = torch.int32
hp_mbits = MBITS_F32
hp_ebits = EBITS_F32
hp_exp_bias = F32_EXP_BIAS

# rounding before calculating the largest power of 2
# X = 2^(floor(log2(rounding(max_abs(v)))-max_exp))
Expand All @@ -233,8 +233,12 @@ def to_mx(
)

# Calculate the scale for different modes
max_abs_int32 = (max_abs + eps).view(hp_int_dtype)
extracted_pow2 = ((max_abs_int32 >> hp_mbits) & 0b11111111) - hp_exp_bias
max_abs_int32 = max_abs.view(hp_int_dtype)
# For now, use `torch.bitwise_right_shift` instead of `>>` to support DTensor
# See https://github.com/pytorch/pytorch/issues/156533.
extracted_pow2 = (
(torch.bitwise_right_shift(max_abs_int32, hp_mbits)) & 0b11111111
) - hp_exp_bias

if scaling_mode in (ScaleCalculationMode.FLOOR, ScaleCalculationMode.EVEN):
scale_e8m0_unbiased = extracted_pow2 - target_max_pow2
Expand Down Expand Up @@ -266,9 +270,11 @@ def to_mx(
)

# For now, calculate the scale in floating point.
scale_fp32 = (scale_e8m0_biased.to(torch.int32) << MBITS_F32).view(
torch.float32
)
# For now, use `torch.bitwise_left_shift` instead of `<<` to support DTensor
# See https://github.com/pytorch/pytorch/issues/156533.
scale_fp32 = (
torch.bitwise_left_shift(scale_e8m0_biased.to(torch.int32), MBITS_F32)
).view(torch.float32)

# Today, 2**-127 returns 0 in compile+inductor+triton because it is in the
# float32 denormal range. For now, manually adjust the fp scale. This is
Expand All @@ -280,7 +286,7 @@ def to_mx(
scale_fp32 = torch.clamp(scale_fp32, min=F32_MIN_NORMAL)

# scale and saturated cast the data elements to max of target dtype
data_lp = data_hp / scale_fp32.unsqueeze(1)
data_lp = data_hp / scale_fp32

if (
elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2)
Expand Down Expand Up @@ -506,7 +512,6 @@ def __new__(
assert scale_e8m0_bits.dtype == torch.float8_e8m0fnu, (
f"scale_e8m0_bits.dtype must be `torch.float8_e8m0fnu`, got {scale_e8m0_bits.dtype}"
)
assert len(scale_e8m0_bits.shape) == 1, "unsupported"
assert data_bits.dtype in (
torch.float8_e4m3fn,
torch.float8_e5m2,
Expand Down Expand Up @@ -597,6 +602,28 @@ def to_mx(
scale_e8m0_biased, data_lp = to_mx(
data_hp, elem_dtype, block_size, scaling_mode, pack_fp6
)
if isinstance(scale_e8m0_biased, DTensor):
assert isinstance(data_lp, DTensor), "unsupported"
local_scale_e8m0_biased = scale_e8m0_biased.to_local()
local_data_lp = data_lp.to_local()
inner_mx_tensor = MXTensor(
local_scale_e8m0_biased,
local_data_lp,
elem_dtype,
block_size,
data_hp.dtype,
use_fp4_custom_triton_dequant_kernel,
gemm_kernel_choice,
pack_fp6,
)
return DTensor.from_local(
inner_mx_tensor,
data_lp.device_mesh,
data_lp.placements,
run_check=False,
shape=data_lp.size(),
stride=data_lp.stride(),
)
return MXTensor(
scale_e8m0_biased,
data_lp,
Expand Down
Loading
Loading