Skip to content

Commit 3737967

Browse files
committed
[wip] sharding strategy for dim1 kernel
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 3b7f85d ghstack-comment-id: 3001601112 Pull Request resolved: #2436
1 parent 483e3e4 commit 3737967

File tree

5 files changed

+26
-20
lines changed

5 files changed

+26
-20
lines changed

test/float8/test_dtensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def _test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16):
183183
loss.backward()
184184

185185

186-
def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16):
186+
def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=32):
187187
tensorwise_config = Float8LinearConfig(emulate=True)
188188
_test_lowp_mlp_tensor_parallelism_base(
189189
mesh, tensorwise_config, size, compile=False, allgather_in_lowp=True
@@ -198,7 +198,7 @@ def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16):
198198
)
199199

200200

201-
def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16):
201+
def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=32):
202202
tensorwise_config = Float8LinearConfig(emulate=True)
203203
_test_lowp_mlp_tensor_parallelism_base(
204204
mesh, tensorwise_config, size, compile=True, allgather_in_lowp=True

test/float8/test_fsdp2_tp.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def _test_fp8_mlp_tensor_parallelism_base(
6161
enable_fsdp_float8_all_gather=True,
6262
)
6363

64-
toy_model = ToyModel().to(device)
64+
toy_model = ToyModel(size).to(device)
6565

6666
tp_model = copy.deepcopy(toy_model)
6767
tp_model = convert_to_float8_training(tp_model, config=config)
@@ -94,11 +94,11 @@ def _test_fp8_mlp_tensor_parallelism_base(
9494
# TODO(future PR): test numerics, and add more cases
9595

9696

97-
def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16):
97+
def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=32):
9898
_test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=False)
9999

100100

101-
def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16):
101+
def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=32):
102102
_test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True)
103103

104104

test/prototype/mx_formats/test_mx_dtensor.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,21 +68,22 @@ def _test_dtensor_cast_to_mxfp8(mesh: DeviceMesh, size=4):
6868
)
6969

7070

71-
def _test_mxfp8_mlp_tensor_parallelism(mesh: DeviceMesh, size=16):
71+
def _test_mxfp8_mlp_tensor_parallelism(mesh: DeviceMesh, size=128):
7272
config = MXLinearConfig.from_recipe_name("mxfp8_emulated")
73-
config.block_size = 16
73+
config.block_size = 32
74+
config.use_fp8_dim1_cast_triton_kernel = True
7475
_test_lowp_mlp_tensor_parallelism_base(
7576
mesh, config, size, compile=False, allgather_in_lowp=False
7677
)
77-
_test_lowp_mlp_tensor_parallelism_base(
78-
mesh, config, size, compile=True, allgather_in_lowp=False
79-
)
78+
# _test_lowp_mlp_tensor_parallelism_base(
79+
# mesh, config, size, compile=True, allgather_in_lowp=False
80+
# )
8081

8182

8283
if __name__ == "__main__":
8384
device_mesh = setup_distributed()
8485
tests = [
85-
_test_dtensor_cast_to_mxfp8,
86+
# _test_dtensor_cast_to_mxfp8,
8687
_test_mxfp8_mlp_tensor_parallelism,
8788
]
8889

torchao/prototype/mx_formats/kernels.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1315,7 +1315,8 @@ def triton_to_mxfp8_dim1(
13151315
* `col_scale`: the `e8m0` values of `x_scale` used to cast `x` to mxfp8 across dim1
13161316
"""
13171317
assert x.is_contiguous(), "`x` must be contiguous"
1318-
assert x.dtype == torch.bfloat16
1318+
# TODO(before land): maybe gate by FakeTensor below?
1319+
# assert x.dtype == torch.bfloat16
13191320
assert inner_block_size <= 32
13201321

13211322
# Get tensor shape
@@ -1362,6 +1363,10 @@ def triton_to_mxfp8_dim1(
13621363
output_col_major.t(),
13631364
col_scale.view(torch.float8_e8m0fnu),
13641365
)
1366+
1367+
print('ASDFASDFASDF')
1368+
from torchao import triton_to_mxfp8_dim1
1369+
print(triton_to_mxfp8_dim1)
13651370

13661371
def triton_to_mxfp8_dim1_reference(
13671372
x_hp: torch.Tensor, block_size

torchao/testing/training/dtensor_utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,11 @@
3232
class FeedForward(nn.Module):
3333
"""MLP based model"""
3434

35-
def __init__(self):
35+
def __init__(self, size):
3636
super(FeedForward, self).__init__()
37-
self.w1 = nn.Linear(16, 32, bias=False)
38-
self.w2 = nn.Linear(16, 32, bias=False)
39-
self.out_proj = nn.Linear(32, 16, bias=False)
37+
self.w1 = nn.Linear(size, size * 2, bias=False)
38+
self.w2 = nn.Linear(size, size * 2, bias=False)
39+
self.out_proj = nn.Linear(size * 2, size, bias=False)
4040

4141
def forward(self, x):
4242
x = F.silu(self.w1(x)) * self.w2(x)
@@ -45,9 +45,9 @@ def forward(self, x):
4545

4646

4747
class ToyModel(nn.Module):
48-
def __init__(self):
48+
def __init__(self, size):
4949
super(ToyModel, self).__init__()
50-
self.ffn = FeedForward()
50+
self.ffn = FeedForward(size)
5151

5252
def forward(self, x):
5353
return self.ffn(x)
@@ -56,7 +56,7 @@ def forward(self, x):
5656
def _test_lowp_mlp_tensor_parallelism_base(
5757
mesh: DeviceMesh,
5858
config: Union[Float8LinearConfig, MXLinearConfig],
59-
size=16,
59+
size=32,
6060
compile: bool = False,
6161
allgather_in_lowp: bool = False,
6262
):
@@ -67,7 +67,7 @@ def _test_lowp_mlp_tensor_parallelism_base(
6767
if isinstance(config, MXLinearConfig):
6868
convert_model_func = quantize_
6969

70-
toy_model = ToyModel().to(device)
70+
toy_model = ToyModel(size).to(device)
7171
toy_model_fp8 = copy.deepcopy(toy_model)
7272
convert_model_func(toy_model_fp8, config=config)
7373

0 commit comments

Comments
 (0)