Skip to content

Commit 8f73e84

Browse files
authored
fix bug with float8 training + FSDP2 + TP (#1327)
Summary: The combination of float8 training + FSDP2 + TP recently broke, fixing: 1. add a test case so we have this covered in CI. 2. fix the test case, by ensuring we check for `Float8Tensor` properly when it is wrapped in `DTensor`. Note 1: most of the code in `distributed_utils.py` was dead code from before we switched to DTensor, so I deleted it in this PR. Note 2: we already have extensive testing for FSDP2 and TP/SP in separate files. I chose to create a new file for testing those two features together to keep complexity and test runtime manageable. Note 3: we really should make these distributed test cases run in CI, right now it's still local testing only Note 4: there are a couple of future follow-ups which would be interesting: - in FSDP2 with float8-all-gather, perhaps we should return DTensor(Float8Tensor) instead of Float8Tensor, to stay consistent with how FSDP2 wraps weights without float8-all-gather - in DTensor, it would be nice if `isinstance(t, Float8Tensor)` returned True if `t` is a DTensor wrapping a Float8Tensor - food for thought for composability. Having this would enable us to simplify some of the float8 modeling code. Test Plan: ``` // tests added in this PR ./test/float8/test_dtensor.sh // all tests ./test/float8/test_everything.sh // torchtitan command fails before this PR and passes after with-proxy CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --float8.enable_float8_linear --float8.enable_fsdp_float8_all_gather --training.tensor_parallel_degree 2 ``` Reviewers: Subscribers: Tasks: Tags:
1 parent 7489c7d commit 8f73e84

File tree

8 files changed

+177
-144
lines changed

8 files changed

+177
-144
lines changed

test/float8/test_dtensor.py

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515

1616
import pytest
1717
import torch
18-
import torch.nn as nn
19-
import torch.nn.functional as F
2018

2119
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
2220

@@ -49,6 +47,7 @@
4947
)
5048
from torchao.float8.float8_utils import e4m3_dtype, tensor_to_scale
5149
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor
50+
from torchao.testing.float8.dtensor_utils import ToyModel
5251

5352

5453
def setup_distributed():
@@ -59,28 +58,6 @@ def setup_distributed():
5958
return device_mesh
6059

6160

62-
class FeedForward(nn.Module):
63-
"""MLP based model"""
64-
65-
def __init__(self):
66-
super(FeedForward, self).__init__()
67-
self.w1 = nn.Linear(16, 32, bias=False)
68-
self.w2 = nn.Linear(16, 32, bias=False)
69-
self.out_proj = nn.Linear(32, 16, bias=False)
70-
71-
def forward(self, x):
72-
return self.out_proj(F.silu(self.w1(x)) * self.w2(x))
73-
74-
75-
class ToyModel(nn.Module):
76-
def __init__(self):
77-
super(ToyModel, self).__init__()
78-
self.ffn = FeedForward()
79-
80-
def forward(self, x):
81-
return self.ffn(x)
82-
83-
8461
def _test_scaled_mm(mesh: DeviceMesh, size=16):
8562
device = mesh.device_type
8663
fp8_dtype = e4m3_dtype

test/float8/test_dtensor.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,8 @@ if python -c 'import torch;print(torch.cuda.is_available())' | grep -q "False";
88
exit
99
fi
1010

11+
# integration tests for TP/SP
1112
NCCL_DEBUG=WARN torchrun --nproc_per_node 2 test/float8/test_dtensor.py
13+
14+
# integration smoke tests for FSDP2 + TP
15+
NCCL_DEBUG=WARN torchrun --nproc_per_node 4 test/float8/test_fsdp2_tp.py

test/float8/test_fsdp2_tp.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
"""
7+
Test numerics of manually defined float16 TP vs float8 TP of toy models
8+
9+
Note: for now, this does not run in CI.
10+
TODO(future): make this run in CI
11+
"""
12+
13+
import copy
14+
import os
15+
16+
import pytest
17+
import torch
18+
19+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
20+
21+
if not TORCH_VERSION_AT_LEAST_2_5:
22+
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
23+
24+
from torch.distributed._composable.fsdp import fully_shard
25+
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
26+
from torch.distributed.tensor.parallel import parallelize_module
27+
from tqdm import tqdm
28+
29+
from torchao.float8 import Float8LinearConfig
30+
from torchao.float8.float8_linear_utils import convert_to_float8_training
31+
from torchao.float8.float8_tensor_parallel import (
32+
Float8ColwiseParallel,
33+
Float8RowwiseParallel,
34+
)
35+
from torchao.testing.float8.dtensor_utils import ToyModel
36+
37+
38+
def setup_distributed():
39+
world_size = int(os.environ.get("WORLD_SIZE", -1))
40+
41+
# https://pytorch.org/tutorials/recipes/distributed_device_mesh.html
42+
device_mesh = init_device_mesh(
43+
"cuda",
44+
(world_size // 2, 2),
45+
mesh_dim_names=("dp", "tp"),
46+
)
47+
# seed must be the same in all processes
48+
torch.manual_seed(1)
49+
return device_mesh
50+
51+
52+
def _test_fp8_mlp_tensor_parallelism_base(
53+
mesh: DeviceMesh, size=16, compile: bool = False
54+
):
55+
device = mesh.device_type
56+
57+
config = Float8LinearConfig(
58+
emulate=True,
59+
enable_fsdp_float8_all_gather=True,
60+
)
61+
62+
toy_model = ToyModel().to(device)
63+
64+
tp_model = copy.deepcopy(toy_model)
65+
tp_model = convert_to_float8_training(tp_model, config=config)
66+
67+
# apply TP
68+
tp_model = parallelize_module(
69+
tp_model,
70+
mesh["tp"],
71+
{
72+
"ffn.w1": Float8ColwiseParallel(),
73+
"ffn.w2": Float8ColwiseParallel(),
74+
"ffn.out_proj": Float8RowwiseParallel(),
75+
},
76+
)
77+
78+
if compile:
79+
tp_model = torch.compile(tp_model)
80+
81+
# apply FSDP
82+
fsdp_config = {"mesh": mesh["dp"]}
83+
tp_model = fully_shard(tp_model, **fsdp_config)
84+
85+
x_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False)
86+
x_fp32_tp_input = x_fp32.clone()
87+
88+
tp_out = tp_model(x_fp32_tp_input)
89+
tp_out.sum().backward()
90+
torch.cuda.synchronize()
91+
92+
# TODO(future PR): test numerics, and add more cases
93+
94+
95+
def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16):
96+
_test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=False)
97+
98+
99+
def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16):
100+
_test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True)
101+
102+
103+
if __name__ == "__main__":
104+
# float8 only works on CUDA H100 so we only test cuda and we follow
105+
# other test files to not use TestCase but instead just add the test
106+
# cases in the main func.
107+
device_mesh = setup_distributed()
108+
109+
tests = [
110+
_test_fp8_mlp_tensor_parallelism_eager,
111+
_test_fp8_mlp_tensor_parallelism_compile,
112+
]
113+
114+
for test in tqdm(tests, desc="Running tests"):
115+
try:
116+
test(device_mesh)
117+
except Exception as e:
118+
print(f"Test {test.__name__} failed with error: {e}")
119+
raise e
120+
121+
torch.distributed.destroy_process_group()

torchao/float8/distributed_utils.py

Lines changed: 16 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -3,110 +3,25 @@
33
#
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
6-
from typing import Any
76

87
import torch
9-
from fairscale.nn.model_parallel.initialize import get_model_parallel_group
8+
import torch.distributed._functional_collectives as funcol
9+
from torch.distributed._tensor import DTensor
1010

11-
# from float8_tensor import Float8Tensor
1211
from torchao.float8.float8_tensor import Float8Tensor
1312

14-
# additional differentiable distributed primitives for SP which are not in
15-
# the Fairscale codebase
1613

17-
18-
def _gather_along_first_dim(input_: torch.Tensor):
19-
# same as https://github.com/facebookresearch/fairscale/blob/main/fairscale/nn/model_parallel/mappings.py#L67,
20-
# but gather along first dim instead of last dim
21-
group = get_model_parallel_group()
22-
23-
# Bypass the function if we are using only 1 GPU.
24-
if torch.distributed.get_world_size(group=group) == 1:
25-
return input_
26-
27-
# Size and dimension.
28-
first_dim = 0
29-
rank = torch.distributed.get_rank(group=group)
30-
world_size = torch.distributed.get_world_size(group=group)
31-
32-
# If the input is a float8 tensor, we need to do the transformation on the
33-
# inner tensor and then return a new wrapper.
34-
def _transform(t):
35-
# tensors must be contiguous for all_gather to work
36-
input_contig = t.contiguous()
37-
38-
tensor_list = [torch.empty_like(input_contig) for _ in range(world_size)]
39-
tensor_list[rank] = input_contig
40-
torch.distributed.all_gather(tensor_list, input_contig, group=group)
41-
42-
# Note: torch.cat already creates a contiguous tensor.
43-
output = torch.cat(tensor_list, dim=first_dim).contiguous()
44-
return output
45-
46-
if isinstance(input_, Float8Tensor):
47-
new_data = input_._data
48-
new_data = new_data.view(torch.int8)
49-
new_data = _transform(new_data)
50-
new_data = new_data.view(input_._data.dtype)
51-
output = Float8Tensor(new_data, input_._scale, input_._orig_dtype)
52-
else:
53-
output = _transform(input_)
54-
55-
return output
56-
57-
58-
def _reduce_scatter(ctx: Any, input_: torch.Tensor):
59-
group = get_model_parallel_group()
60-
world_size = torch.distributed.get_world_size(group)
61-
62-
assert input_.shape[0] % world_size == 0
63-
output_shape = (input_.shape[0] // world_size, *input_.shape[1:])
64-
output = torch.empty(*output_shape, device=input_.device, dtype=input_.dtype)
65-
66-
torch.distributed.reduce_scatter_tensor(output, input_, group=group)
67-
return output
68-
69-
70-
def _split_along_first_dim(input_: torch.Tensor):
71-
# this is needed for testing
72-
73-
# like fairscale.nn.model_parallel.mappings._split, but
74-
# along the first dim instead of last dim
75-
76-
group = get_model_parallel_group()
77-
local_rank = torch.distributed.get_rank(group)
78-
world_size = torch.distributed.get_world_size(group)
79-
80-
assert input_.shape[0] % world_size == 0
81-
input_list = torch.split(input_, input_.shape[0] // world_size)
82-
return input_list[local_rank]
83-
84-
85-
class _AllGatherFloat8FwReduceScatterBw(torch.autograd.Function):
86-
@staticmethod
87-
def forward(ctx, input_):
88-
return _gather_along_first_dim(input_)
89-
90-
@staticmethod
91-
def backward(ctx, grad_output):
92-
return _reduce_scatter(ctx, grad_output)
93-
94-
95-
class _ReduceScatterFwAllGatherFloat8Bw(torch.autograd.Function):
96-
@staticmethod
97-
def forward(ctx, input_):
98-
return _reduce_scatter(ctx, input_)
99-
100-
@staticmethod
101-
def backward(ctx, grad_output):
102-
return _gather_along_first_dim(grad_output)
103-
104-
105-
class _AllGatherFwSplitBw(torch.autograd.Function):
106-
@staticmethod
107-
def forward(ctx, input_):
108-
return _gather_along_first_dim(input_)
109-
110-
@staticmethod
111-
def backward(ctx, grad_output):
112-
return _split_along_first_dim(grad_output)
14+
def tensor_already_casted_to_fp8(tensor: torch.Tensor) -> bool:
15+
"""
16+
Check if the tensor is already casted to fp8, works if the local
17+
tensor is wrapped in DTensor.
18+
"""
19+
if isinstance(tensor, Float8Tensor):
20+
return True
21+
elif isinstance(tensor, DTensor):
22+
# TODO: shall we stick to public API and directly use tensor.to_local() here?
23+
return tensor_already_casted_to_fp8(tensor._local_tensor)
24+
elif isinstance(tensor, funcol.AsyncCollectiveTensor):
25+
return tensor_already_casted_to_fp8(tensor.elem)
26+
27+
return False

torchao/float8/float8_linear.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import torch.utils.checkpoint as checkpoint
1414

1515
from torchao.float8.config import Float8LinearConfig, ScalingGranularity, ScalingType
16+
from torchao.float8.distributed_utils import tensor_already_casted_to_fp8
1617
from torchao.float8.float8_scaling_utils import (
1718
NoopFwToFloat8E5M2BwDelayed,
1819
NoopFwToFloat8E5M2BwDynamic,
@@ -469,7 +470,7 @@ def cast_input_to_float8(
469470
return input_fp8
470471

471472
def get_weight_scale(self, weight: torch.Tensor) -> Optional[torch.Tensor]:
472-
if isinstance(weight, Float8Tensor):
473+
if tensor_already_casted_to_fp8(weight):
473474
return None
474475
if self.scaling_type_weight is ScalingType.DELAYED:
475476
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
@@ -497,7 +498,7 @@ def cast_weight_to_float8_t(
497498
is_amax_initialized: bool,
498499
weight_scale: Optional[torch.Tensor] = None,
499500
) -> torch.Tensor:
500-
if isinstance(weight, Float8Tensor):
501+
if tensor_already_casted_to_fp8(weight):
501502
return weight.t()
502503
weight_fp8 = hp_tensor_and_scale_to_float8(
503504
weight,

torchao/float8/float8_scaling_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313
import torch
1414

1515
from torchao.float8.config import ScalingGranularity
16+
from torchao.float8.distributed_utils import tensor_already_casted_to_fp8
1617
from torchao.float8.float8_tensor import (
1718
Float8Tensor,
1819
GemmInputRole,
1920
LinearMMConfig,
2021
hp_tensor_and_scale_to_float8,
21-
tensor_already_casted_to_fp8,
2222
)
2323
from torchao.float8.float8_utils import (
2424
amax_history_to_scale,

torchao/float8/float8_tensor.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from typing import Dict, NamedTuple, Optional
88

99
import torch
10-
import torch.distributed._functional_collectives as funcol
1110
from torch.distributed._tensor import DTensor
1211

1312
from torchao.float8.float8_utils import (
@@ -121,21 +120,6 @@ def choose_scaled_mm_config(
121120
raise AssertionError(f"unexpected a_role {a_role} and b_role {b_role}")
122121

123122

124-
def tensor_already_casted_to_fp8(tensor: torch.Tensor) -> bool:
125-
"""
126-
Check if the tensor is already casted to fp8
127-
"""
128-
if isinstance(tensor, Float8Tensor):
129-
return True
130-
elif isinstance(tensor, DTensor):
131-
# TODO: shall we stick to public API and directly use tensor.to_local() here?
132-
return tensor_already_casted_to_fp8(tensor._local_tensor)
133-
elif isinstance(tensor, funcol.AsyncCollectiveTensor):
134-
return tensor_already_casted_to_fp8(tensor.elem)
135-
136-
return False
137-
138-
139123
@torch._dynamo.allow_in_graph
140124
class _ToFloat8ConstrFunc(torch.autograd.Function):
141125
"""

0 commit comments

Comments
 (0)