Skip to content

Commit e675ffd

Browse files
authored
mxfp8 training: add TP sharding strategy for dim1 kernel (#2436)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent c57226b commit e675ffd

File tree

6 files changed

+95
-49
lines changed

6 files changed

+95
-49
lines changed

test/float8/test_dtensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def _test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16):
183183
loss.backward()
184184

185185

186-
def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16):
186+
def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=32):
187187
tensorwise_config = Float8LinearConfig(emulate=True)
188188
_test_lowp_mlp_tensor_parallelism_base(
189189
mesh, tensorwise_config, size, compile=False, allgather_in_lowp=True
@@ -198,7 +198,7 @@ def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16):
198198
)
199199

200200

201-
def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16):
201+
def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=32):
202202
tensorwise_config = Float8LinearConfig(emulate=True)
203203
_test_lowp_mlp_tensor_parallelism_base(
204204
mesh, tensorwise_config, size, compile=True, allgather_in_lowp=True

test/float8/test_fsdp2_tp.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
)
3535
from torchao.testing.training.dtensor_utils import ToyModel
3636

37+
torch.set_float32_matmul_precision("high")
38+
3739

3840
def setup_distributed():
3941
world_size = int(os.environ.get("WORLD_SIZE", -1))
@@ -61,7 +63,7 @@ def _test_fp8_mlp_tensor_parallelism_base(
6163
enable_fsdp_float8_all_gather=True,
6264
)
6365

64-
toy_model = ToyModel().to(device)
66+
toy_model = ToyModel(size).to(device)
6567

6668
tp_model = copy.deepcopy(toy_model)
6769
tp_model = convert_to_float8_training(tp_model, config=config)
@@ -94,11 +96,11 @@ def _test_fp8_mlp_tensor_parallelism_base(
9496
# TODO(future PR): test numerics, and add more cases
9597

9698

97-
def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16):
99+
def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=32):
98100
_test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=False)
99101

100102

101-
def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16):
103+
def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=32):
102104
_test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True)
103105

104106

test/prototype/mx_formats/test_mx_dtensor.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@ def _test_dtensor_cast_to_mxfp8(mesh: DeviceMesh, size=4):
6868
)
6969

7070

71-
def _test_mxfp8_mlp_tensor_parallelism(mesh: DeviceMesh, size=16):
71+
def _test_mxfp8_mlp_tensor_parallelism(mesh: DeviceMesh, size=128):
7272
config = MXLinearConfig.from_recipe_name("mxfp8_emulated")
73-
config.block_size = 16
73+
config.block_size = 32
7474
_test_lowp_mlp_tensor_parallelism_base(
7575
mesh, config, size, compile=False, allgather_in_lowp=False
7676
)
@@ -79,11 +79,26 @@ def _test_mxfp8_mlp_tensor_parallelism(mesh: DeviceMesh, size=16):
7979
)
8080

8181

82+
def _test_mxfp8_mlp_tensor_parallelism_dim1_triton(mesh: DeviceMesh, size=128):
83+
config = MXLinearConfig.from_recipe_name("mxfp8_emulated")
84+
config.block_size = 32
85+
config.use_fp8_dim1_cast_triton_kernel = True
86+
_test_lowp_mlp_tensor_parallelism_base(
87+
mesh, config, size, compile=False, allgather_in_lowp=False
88+
)
89+
# TODO(future PR): enable compile here, currently seeing
90+
# https://www.internalfb.com/phabricator/paste/view/P1851219639
91+
# _test_lowp_mlp_tensor_parallelism_base(
92+
# mesh, config, size, compile=True, allgather_in_lowp=False
93+
# )
94+
95+
8296
if __name__ == "__main__":
8397
device_mesh = setup_distributed()
8498
tests = [
8599
_test_dtensor_cast_to_mxfp8,
86100
_test_mxfp8_mlp_tensor_parallelism,
101+
_test_mxfp8_mlp_tensor_parallelism_dim1_triton,
87102
]
88103

89104
for test in tqdm(tests, desc="Running tests"):

torchao/prototype/mx_formats/kernels.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
import numpy as np
1010
import torch
11+
from torch.distributed.tensor import Replicate, Shard
12+
from torch.distributed.tensor.experimental import register_sharding
1113
from torch.utils._triton import has_triton
1214

1315
from torchao.prototype.custom_fp_utils import (
@@ -1315,7 +1317,6 @@ def triton_to_mxfp8_dim1(
13151317
* `col_scale`: the `e8m0` values of `x_scale` used to cast `x` to mxfp8 across dim1
13161318
"""
13171319
assert x.is_contiguous(), "`x` must be contiguous"
1318-
assert x.dtype == torch.bfloat16
13191320
assert inner_block_size <= 32
13201321

13211322
# Get tensor shape
@@ -1363,6 +1364,16 @@ def triton_to_mxfp8_dim1(
13631364
col_scale.view(torch.float8_e8m0fnu),
13641365
)
13651366

1367+
@register_sharding(torch.ops.torchao.triton_to_mxfp8_dim1.default)
1368+
def custom_triton_to_mxfp8_dim1_sharding(x, inner_block_size=32):
1369+
replicate = ([Replicate(), Replicate()], [Replicate(), None])
1370+
# Note that the data is returned transposed, which is why
1371+
# we flip the sharding dim below
1372+
shard_dim0 = ([Shard(1), Shard(1)], [Shard(0), None])
1373+
shard_dim1 = ([Shard(0), Shard(0)], [Shard(1), None])
1374+
acceptable_shardings = [replicate, shard_dim0, shard_dim1]
1375+
return acceptable_shardings
1376+
13661377
def triton_to_mxfp8_dim1_reference(
13671378
x_hp: torch.Tensor, block_size
13681379
) -> Tuple[torch.Tensor, torch.Tensor]:

torchao/prototype/mx_formats/mx_linear.py

Lines changed: 49 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import torch
1414
import torch.nn.functional as F
15+
from torch.distributed._tensor import DTensor
1516

1617
from torchao.prototype.mx_formats.config import (
1718
MXGemmKernelChoice,
@@ -25,6 +26,46 @@
2526
)
2627

2728

29+
def _triton_to_mxfp8_dim1_wrapper(
30+
a, block_size, elem_dtype, hp_dtype, gemm_kernel_choice
31+
):
32+
a_data, a_scale = triton_to_mxfp8_dim1(a, block_size)
33+
if isinstance(a_data, DTensor):
34+
assert isinstance(a_scale, DTensor)
35+
a_data_local = a_data.to_local()
36+
a_scale_local = a_scale.to_local()
37+
inner = MXTensor(
38+
a_scale_local,
39+
a_data_local.t(),
40+
elem_dtype,
41+
block_size,
42+
hp_dtype,
43+
False,
44+
gemm_kernel_choice,
45+
False,
46+
)
47+
mx_tensor = DTensor.from_local(
48+
inner,
49+
a_data.device_mesh,
50+
a_data.placements,
51+
run_check=False,
52+
shape=a_data.t().size(),
53+
stride=a_data.t().stride(),
54+
)
55+
else:
56+
mx_tensor = MXTensor(
57+
a_scale,
58+
a_data.t(),
59+
elem_dtype,
60+
block_size,
61+
hp_dtype,
62+
False,
63+
gemm_kernel_choice,
64+
False,
65+
)
66+
return mx_tensor
67+
68+
2869
@torch._dynamo.allow_in_graph
2970
class mx_mm(torch.autograd.Function):
3071
# There are three gemms in a forward + backward of a Linear layer:
@@ -95,20 +136,9 @@ def backward(ctx, grad_output_hp: torch.Tensor):
95136
)
96137

97138
if use_fp8_dim1_cast_triton_kernel:
98-
weight_mx_dim1_data, weight_mx_dim1_scale = triton_to_mxfp8_dim1(
99-
weight_hp, block_size
139+
weight_mx_dim1 = _triton_to_mxfp8_dim1_wrapper(
140+
weight_hp, block_size, w_elem_dtype, weight_hp.dtype, gemm_kernel_choice
100141
)
101-
weight_mx_dim1 = MXTensor(
102-
weight_mx_dim1_scale.reshape(-1),
103-
weight_mx_dim1_data.t(),
104-
w_elem_dtype,
105-
block_size,
106-
weight_hp.dtype,
107-
False,
108-
gemm_kernel_choice,
109-
False,
110-
)
111-
112142
else:
113143
weight_hp_t_c = weight_hp.t().contiguous()
114144
weight_mx_dim1 = MXTensor.to_mx(
@@ -124,18 +154,12 @@ def backward(ctx, grad_output_hp: torch.Tensor):
124154

125155
# input_t @ grad_output = grad_weight
126156
if use_fp8_dim1_cast_triton_kernel:
127-
grad_output_mx_dim1_data, grad_output_mx_dim1_scale = triton_to_mxfp8_dim1(
128-
grad_output_hp_r, block_size
129-
)
130-
grad_output_mx_dim1 = MXTensor(
131-
grad_output_mx_dim1_scale.reshape(-1),
132-
grad_output_mx_dim1_data.t(),
133-
grad_elem_dtype,
157+
grad_output_mx_dim1 = _triton_to_mxfp8_dim1_wrapper(
158+
grad_output_hp_r,
134159
block_size,
160+
grad_elem_dtype,
135161
grad_output_hp_r.dtype,
136-
False,
137162
gemm_kernel_choice,
138-
False,
139163
)
140164
else:
141165
grad_output_mx_dim1 = MXTensor.to_mx(
@@ -146,18 +170,12 @@ def backward(ctx, grad_output_hp: torch.Tensor):
146170
)
147171

148172
if use_fp8_dim1_cast_triton_kernel:
149-
input_t_mx_dim0_tmp_data, input_t_mx_dim0_tmp_scale = triton_to_mxfp8_dim1(
150-
input_hp_r, block_size
151-
)
152-
input_t_mx_dim0_tmp = MXTensor(
153-
input_t_mx_dim0_tmp_scale.reshape(-1),
154-
input_t_mx_dim0_tmp_data.t(),
155-
in_elem_dtype,
173+
input_t_mx_dim0_tmp = _triton_to_mxfp8_dim1_wrapper(
174+
input_hp_r,
156175
block_size,
176+
in_elem_dtype,
157177
input_hp_r.dtype,
158-
False,
159178
gemm_kernel_choice,
160-
False,
161179
)
162180
input_t_mx_dim0 = input_t_mx_dim0_tmp.t()
163181
else:

torchao/testing/training/dtensor_utils.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,11 @@
3232
class FeedForward(nn.Module):
3333
"""MLP based model"""
3434

35-
def __init__(self):
35+
def __init__(self, size):
3636
super(FeedForward, self).__init__()
37-
self.w1 = nn.Linear(16, 32, bias=False)
38-
self.w2 = nn.Linear(16, 32, bias=False)
39-
self.out_proj = nn.Linear(32, 16, bias=False)
37+
self.w1 = nn.Linear(size, size * 2, bias=False)
38+
self.w2 = nn.Linear(size, size * 2, bias=False)
39+
self.out_proj = nn.Linear(size * 2, size, bias=False)
4040

4141
def forward(self, x):
4242
x = F.silu(self.w1(x)) * self.w2(x)
@@ -45,9 +45,9 @@ def forward(self, x):
4545

4646

4747
class ToyModel(nn.Module):
48-
def __init__(self):
48+
def __init__(self, size):
4949
super(ToyModel, self).__init__()
50-
self.ffn = FeedForward()
50+
self.ffn = FeedForward(size)
5151

5252
def forward(self, x):
5353
return self.ffn(x)
@@ -56,7 +56,7 @@ def forward(self, x):
5656
def _test_lowp_mlp_tensor_parallelism_base(
5757
mesh: DeviceMesh,
5858
config: Union[Float8LinearConfig, MXLinearConfig],
59-
size=16,
59+
size=32,
6060
compile: bool = False,
6161
allgather_in_lowp: bool = False,
6262
):
@@ -67,7 +67,7 @@ def _test_lowp_mlp_tensor_parallelism_base(
6767
if isinstance(config, MXLinearConfig):
6868
convert_model_func = quantize_
6969

70-
toy_model = ToyModel().to(device)
70+
toy_model = ToyModel(size).to(device)
7171
toy_model_fp8 = copy.deepcopy(toy_model)
7272
convert_model_func(toy_model_fp8, config=config)
7373

@@ -151,8 +151,8 @@ def _test_lowp_mlp_tensor_parallelism_base(
151151
sp_model = torch.compile(sp_model)
152152
sp_model2 = torch.compile(sp_model2)
153153

154-
x_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False)
155-
go_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False)
154+
x_fp32 = torch.rand(2, size * 2, size, device=device, requires_grad=False)
155+
go_fp32 = torch.rand(2, size * 2, size, device=device, requires_grad=False)
156156
x_fp32_tp_input = x_fp32.clone()
157157
go_fp32_tp = go_fp32.clone()
158158
x_fp32_sp_input = distribute_tensor(x_fp32.clone(), mesh, [Shard(0)])

0 commit comments

Comments
 (0)