-
Notifications
You must be signed in to change notification settings - Fork 303
enable tensor parallelism for MXLinear #2434
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 20 commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
5c23c6b
Update
vkuzo ad2ce62
Update
vkuzo 5eb2066
Update
vkuzo 6e3df57
Update
vkuzo 75e6fe7
Update
vkuzo 8bf42da
Update
vkuzo c0080cd
Update
vkuzo c6fc48b
Update
vkuzo 4cc1531
Update
vkuzo 42083e2
Update
vkuzo 9d171ad
Update
vkuzo 09c1c58
Update
vkuzo e511e7b
Update
vkuzo 3562a5e
Update
vkuzo 7a0fd96
Update
vkuzo 2d1545f
Update
vkuzo 20b7db2
Update
vkuzo 7788412
Update
vkuzo aabeb61
Update
vkuzo 28f32b9
Update
vkuzo 1001602
Update
vkuzo File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD 3-Clause license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
""" | ||
Test numerics of manually defined float16 TP vs mxfp8 TP of toy models | ||
|
||
Note: for now, this does not run in CI. | ||
TODO(future): make this run in CI | ||
""" | ||
|
||
import os | ||
|
||
import pytest | ||
import torch | ||
|
||
from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 | ||
|
||
if not TORCH_VERSION_AT_LEAST_2_7: | ||
pytest.skip("Unsupported PyTorch version", allow_module_level=True) | ||
|
||
from torch.distributed._tensor import DTensor, Shard, distribute_tensor | ||
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh | ||
from tqdm import tqdm | ||
|
||
from torchao.prototype.mx_formats import MXLinearConfig | ||
from torchao.prototype.mx_formats.mx_tensor import MXTensor | ||
from torchao.testing.training.dtensor_utils import ( | ||
_test_lowp_mlp_tensor_parallelism_base, | ||
) | ||
|
||
torch.set_float32_matmul_precision("high") | ||
|
||
|
||
def setup_distributed(): | ||
world_size = int(os.environ.get("WORLD_SIZE", -1)) | ||
device_mesh = init_device_mesh("cuda", (world_size,)) | ||
# seed must be the same in all processes | ||
torch.manual_seed(1) | ||
local_rank = torch.distributed.get_rank() | ||
torch.cuda.set_device(local_rank) | ||
return device_mesh | ||
|
||
|
||
def _test_dtensor_cast_to_mxfp8(mesh: DeviceMesh, size=4): | ||
device = mesh.device_type | ||
|
||
x_fp32 = torch.rand(size, size, device=device) | ||
x_fp8 = MXTensor.to_mx(x_fp32, torch.float8_e4m3fn, block_size=size // 2) | ||
|
||
dist_x_fp32 = distribute_tensor(x_fp32, mesh, [Shard(0)]) | ||
dist_x_fp8 = MXTensor.to_mx(dist_x_fp32, torch.float8_e4m3fn, block_size=size // 2) | ||
assert isinstance(dist_x_fp8, DTensor) | ||
|
||
# Verify that the result of to_mx with DTensor matches the slice of the | ||
# result of to_mx without DTensor. This will fail on numeric op mismatches. | ||
local_rank = torch.distributed.get_rank() | ||
world_size = torch.distributed.get_world_size() | ||
assert size % world_size == 0, "unsupported" | ||
x_fp8_fp32 = x_fp8.to_dtype(torch.float32) | ||
rows_per_slice = size // world_size | ||
slice_start = local_rank * rows_per_slice | ||
slice_end = (local_rank + 1) * rows_per_slice | ||
x_fp8_fp32_slice = x_fp8_fp32[slice_start:slice_end] | ||
torch.testing.assert_close( | ||
x_fp8_fp32_slice, dist_x_fp8.to_local().to_dtype(torch.float32), atol=0, rtol=0 | ||
) | ||
|
||
|
||
def _test_mxfp8_mlp_tensor_parallelism(mesh: DeviceMesh, size=16): | ||
config = MXLinearConfig.from_recipe_name("mxfp8_emulated") | ||
config.block_size = 16 | ||
_test_lowp_mlp_tensor_parallelism_base( | ||
mesh, config, size, compile=False, allgather_in_lowp=False | ||
) | ||
_test_lowp_mlp_tensor_parallelism_base( | ||
mesh, config, size, compile=True, allgather_in_lowp=False | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
device_mesh = setup_distributed() | ||
tests = [ | ||
_test_dtensor_cast_to_mxfp8, | ||
_test_mxfp8_mlp_tensor_parallelism, | ||
] | ||
|
||
for test in tqdm(tests, desc="Running tests"): | ||
try: | ||
test(device_mesh) | ||
except Exception as e: | ||
print(f"Test {test.__name__} failed with error: {e}") | ||
raise e | ||
|
||
torch.distributed.destroy_process_group() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD 3-Clause license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
#!/bin/bash | ||
|
||
# terminate script on first error | ||
set -e | ||
|
||
if python -c 'import torch;print(torch.cuda.is_available())' | grep -q "False"; then | ||
echo "Skipping test_dtensor.sh because no CUDA devices are available." | ||
exit | ||
fi | ||
|
||
# integration tests for TP/SP | ||
NCCL_DEBUG=WARN torchrun --nproc_per_node 2 test/prototype/mx_formats/test_mx_dtensor.py |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,10 +21,10 @@ | |
from typing import Callable, Dict, Union | ||
|
||
import torch | ||
from torch.distributed._tensor import DTensor | ||
|
||
from torchao.prototype.mx_formats.config import MXGemmKernelChoice | ||
from torchao.prototype.mx_formats.constants import ( | ||
BF16_EXP_BIAS, | ||
BLOCK_SIZE_DEFAULT, | ||
DTYPE_FP6_E2M3, | ||
DTYPE_FP6_E3M2, | ||
|
@@ -61,7 +61,6 @@ | |
|
||
# TODO(later): read from somewhere else? | ||
SBITS, EBITS_F32, MBITS_F32 = 1, 8, 23 | ||
EBITS_BF16, MBITS_BF16 = 8, 7 | ||
EBITS_F4_E2M1, MBITS_F4_E2M1 = 2, 1 | ||
EBITS_F6_E2M3, MBITS_F6_E2M3 = 2, 3 | ||
EBITS_F6_E3M2, MBITS_F6_E3M2 = 3, 2 | ||
|
@@ -136,9 +135,7 @@ def _to_mx_rceil( | |
) | ||
|
||
# scale and saturated cast the data elements to max of target dtype | ||
data_lp = torch.clamp( | ||
data_hp * descale_fp.unsqueeze(1), min=-1 * max_pos, max=max_pos | ||
) | ||
data_lp = torch.clamp(data_hp * descale_fp, min=-1 * max_pos, max=max_pos) | ||
return exponent, data_lp | ||
|
||
|
||
|
@@ -159,24 +156,33 @@ def to_mx( | |
torch.float, | ||
), f"{data_hp.dtype} is not supported yet" | ||
# TODO(future PR): consider supporting padding | ||
assert data_hp.numel() % block_size == 0, "unsupported" | ||
assert data_hp.shape[-1] % block_size == 0, ( | ||
f"the last dimension of shape {data_hp.shape} must be divisible by block_size {block_size}" | ||
) | ||
assert data_hp.is_contiguous(), "unsupported" | ||
assert elem_dtype in SUPPORTED_ELEM_DTYPES, "unsupported" | ||
|
||
# calculate the scale in e8m0 format | ||
|
||
orig_shape = data_hp.shape | ||
data_hp = data_hp.reshape(-1, block_size) | ||
data_hp = data_hp.reshape( | ||
*orig_shape[:-1], orig_shape[-1] // block_size, block_size | ||
) | ||
|
||
# find max value of the data | ||
# Note: this only implements the `minimally supported` version of | ||
# https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf | ||
# section 6.3. | ||
max_abs = torch.amax(torch.abs(data_hp), 1) | ||
|
||
# Add an epsilon to prevent the log2 function call for returning -inf | ||
# where the values are zero. | ||
eps = F32_MIN_NORMAL * (max_abs == 0).type(max_abs.dtype) | ||
max_abs = torch.amax(torch.abs(data_hp), -1).unsqueeze(-1) | ||
|
||
# We cast to float32 here because | ||
# in the `max_abs_int32 = max_abs.view(hp_int_dtype)` line below, | ||
# if tensor parallel is enabled then the resulting shape is 2x larger | ||
# than it should be under some conditions, likely because of a bug in | ||
# the `view` op with DTensor and target dtype int16. I reproduce in | ||
# torchtitan but not in a unit test, so not enough info to file a good | ||
# issue in pytorch/pytorch. For now, work around. In the future we should | ||
# debug and fix this properly. | ||
data_hp = data_hp.to(torch.float32) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. performance testing showed that with compile on, having this in float32 does not regress performance |
||
max_abs = max_abs.to(torch.float32) | ||
|
||
# Set X to be the largest power-of-two less than or equal to | ||
# max_abs(v), divided by the largest power of two representable | ||
|
@@ -207,17 +213,11 @@ def to_mx( | |
if scaling_mode == ScaleCalculationMode.RCEIL: | ||
scale_e8m0_biased, data_lp = _to_mx_rceil(data_hp, max_abs, max_pos) | ||
else: | ||
if data_hp.dtype is torch.float32: | ||
hp_int_dtype = torch.int32 | ||
hp_mbits = MBITS_F32 | ||
hp_ebits = EBITS_F32 | ||
hp_exp_bias = F32_EXP_BIAS | ||
else: | ||
assert data_hp.dtype is torch.bfloat16 | ||
hp_int_dtype = torch.int16 | ||
hp_mbits = MBITS_BF16 | ||
hp_ebits = EBITS_BF16 | ||
hp_exp_bias = BF16_EXP_BIAS | ||
assert data_hp.dtype is torch.float32 | ||
hp_int_dtype = torch.int32 | ||
hp_mbits = MBITS_F32 | ||
hp_ebits = EBITS_F32 | ||
hp_exp_bias = F32_EXP_BIAS | ||
|
||
# rounding before calculating the largest power of 2 | ||
# X = 2^(floor(log2(rounding(max_abs(v)))-max_exp)) | ||
|
@@ -233,8 +233,12 @@ def to_mx( | |
) | ||
|
||
# Calculate the scale for different modes | ||
max_abs_int32 = (max_abs + eps).view(hp_int_dtype) | ||
extracted_pow2 = ((max_abs_int32 >> hp_mbits) & 0b11111111) - hp_exp_bias | ||
max_abs_int32 = max_abs.view(hp_int_dtype) | ||
# For now, use `torch.bitwise_right_shift` instead of `>>` to support DTensor | ||
# See https://github.com/pytorch/pytorch/issues/156533. | ||
extracted_pow2 = ( | ||
(torch.bitwise_right_shift(max_abs_int32, hp_mbits)) & 0b11111111 | ||
) - hp_exp_bias | ||
|
||
if scaling_mode in (ScaleCalculationMode.FLOOR, ScaleCalculationMode.EVEN): | ||
scale_e8m0_unbiased = extracted_pow2 - target_max_pow2 | ||
|
@@ -266,9 +270,11 @@ def to_mx( | |
) | ||
|
||
# For now, calculate the scale in floating point. | ||
scale_fp32 = (scale_e8m0_biased.to(torch.int32) << MBITS_F32).view( | ||
torch.float32 | ||
) | ||
# For now, use `torch.bitwise_left_shift` instead of `<<` to support DTensor | ||
# See https://github.com/pytorch/pytorch/issues/156533. | ||
scale_fp32 = ( | ||
torch.bitwise_left_shift(scale_e8m0_biased.to(torch.int32), MBITS_F32) | ||
).view(torch.float32) | ||
|
||
# Today, 2**-127 returns 0 in compile+inductor+triton because it is in the | ||
# float32 denormal range. For now, manually adjust the fp scale. This is | ||
|
@@ -280,7 +286,7 @@ def to_mx( | |
scale_fp32 = torch.clamp(scale_fp32, min=F32_MIN_NORMAL) | ||
|
||
# scale and saturated cast the data elements to max of target dtype | ||
data_lp = data_hp / scale_fp32.unsqueeze(1) | ||
data_lp = data_hp / scale_fp32 | ||
|
||
if ( | ||
elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2) | ||
|
@@ -506,7 +512,6 @@ def __new__( | |
assert scale_e8m0_bits.dtype == torch.float8_e8m0fnu, ( | ||
f"scale_e8m0_bits.dtype must be `torch.float8_e8m0fnu`, got {scale_e8m0_bits.dtype}" | ||
) | ||
assert len(scale_e8m0_bits.shape) == 1, "unsupported" | ||
assert data_bits.dtype in ( | ||
torch.float8_e4m3fn, | ||
torch.float8_e5m2, | ||
|
@@ -597,6 +602,28 @@ def to_mx( | |
scale_e8m0_biased, data_lp = to_mx( | ||
data_hp, elem_dtype, block_size, scaling_mode, pack_fp6 | ||
) | ||
if isinstance(scale_e8m0_biased, DTensor): | ||
assert isinstance(data_lp, DTensor), "unsupported" | ||
local_scale_e8m0_biased = scale_e8m0_biased.to_local() | ||
local_data_lp = data_lp.to_local() | ||
inner_mx_tensor = MXTensor( | ||
local_scale_e8m0_biased, | ||
local_data_lp, | ||
elem_dtype, | ||
block_size, | ||
data_hp.dtype, | ||
use_fp4_custom_triton_dequant_kernel, | ||
gemm_kernel_choice, | ||
pack_fp6, | ||
) | ||
return DTensor.from_local( | ||
inner_mx_tensor, | ||
data_lp.device_mesh, | ||
data_lp.placements, | ||
run_check=False, | ||
shape=data_lp.size(), | ||
stride=data_lp.stride(), | ||
) | ||
return MXTensor( | ||
scale_e8m0_biased, | ||
data_lp, | ||
|
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this was broken before, caught by enforcing that inner dim is divisible by block size