Skip to content

Commit 44e6a2e

Browse files
fix float8 moe training dtype bug
1 parent 994a4ba commit 44e6a2e

File tree

3 files changed

+60
-11
lines changed

3 files changed

+60
-11
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
torchrun --nproc_per_node=2 --local-ranks-filter=0 -m pytest test/prototype/moe_training/test_fsdp.py

torchao/prototype/moe_training/conversion_utils.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import logging
17
from typing import Callable, Optional
28

39
from torch import nn
@@ -8,6 +14,8 @@
814
register_quantize_module_handler,
915
)
1016

17+
logger: logging.Logger = logging.getLogger(__name__)
18+
1119

1220
class MoETrainingConfig(AOBaseConfig):
1321
"""
@@ -76,7 +84,7 @@ def _swap_params(
7684
f"Does not support a root nn.Parameter with children: {module}"
7785
)
7886
if not isinstance(module.data, ScaledGroupedMMTensor):
79-
new_data = ScaledGroupedMMTensor(module.data)
87+
new_data = ScaledGroupedMMTensor(module.data, module.data.dtype)
8088
return nn.Parameter(new_data, requires_grad=module.requires_grad)
8189
return module
8290

@@ -102,10 +110,13 @@ def post_order_traversal(
102110
for param_name, param in module.named_parameters(recurse=False):
103111
if not isinstance(param.data, ScaledGroupedMMTensor):
104112
new_param = nn.Parameter(
105-
ScaledGroupedMMTensor(param), requires_grad=param.requires_grad
113+
ScaledGroupedMMTensor(param.data, param.data.dtype),
114+
requires_grad=param.requires_grad,
106115
)
107116
setattr(module, param_name, new_param)
108-
print(f"Swapped {cur_fqn}.{param_name} to ScaledGroupedMMTensor")
117+
logger.info(
118+
f"Swapped {cur_fqn}.{param_name} to ScaledGroupedMMTensor"
119+
)
109120

110121
post_order_traversal(root_module)
111122
return root_module

torchao/prototype/moe_training/tensor.py

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
18
from typing import Any, Optional, Tuple
29

310
import torch
@@ -6,6 +13,9 @@
613

714
from torchao.prototype.moe_training import _scaled_grouped_mm
815

16+
logger: logging.Logger = logging.getLogger(__name__)
17+
18+
919
_ops_to_preserve_subclass = {
1020
torch.ops.aten.empty_like.default,
1121
torch.ops.aten.new_zeros.default,
@@ -34,14 +44,15 @@ class ScaledGroupedMMTensor(torch.Tensor):
3444
def __new__(
3545
cls,
3646
tensor: torch.Tensor,
47+
dtype: torch.dtype,
3748
):
3849
return torch.Tensor._make_wrapper_subclass(
3950
cls,
4051
tensor.size(),
4152
strides=tensor.stride(),
4253
storage_offset=tensor.storage_offset(),
4354
memory_format=suggest_memory_format(tensor),
44-
dtype=tensor.dtype,
55+
dtype=dtype,
4556
layout=tensor.layout,
4657
device=tensor.device,
4758
pin_memory=tensor.is_pinned(),
@@ -51,8 +62,10 @@ def __new__(
5162
def __init__(
5263
self,
5364
tensor: torch.Tensor,
65+
dtype: torch.dtype,
5466
):
5567
self._data = tensor
68+
self._dtype = dtype
5669

5770
@classmethod
5871
def __torch_function__(cls, func, types, args, kwargs={}):
@@ -84,10 +97,19 @@ def __torch_function__(cls, func, types, args, kwargs={}):
8497
def __torch_dispatch__(cls, func, types, args, kwargs={}):
8598
# detach is special case
8699
if func == torch.ops.aten.detach.default:
87-
return ScaledGroupedMMTensor(args[0]._data)
100+
return ScaledGroupedMMTensor(args[0]._data, args[0]._dtype)
88101

89102
# unwrap args and kwargs
90-
unwrap = lambda tensor: tensor._data
103+
dtype: Optional[torch.dtype] = None
104+
105+
def unwrap(t):
106+
nonlocal dtype
107+
if dtype is None:
108+
dtype = t._dtype
109+
else:
110+
assert t._dtype == dtype
111+
return t._data
112+
91113
args, kwargs = pytree.tree_map_only(
92114
ScaledGroupedMMTensor, unwrap, (args, kwargs or {})
93115
)
@@ -102,12 +124,27 @@ def __torch_dispatch__(cls, func, types, args, kwargs={}):
102124
# wrap outputs back into ScaledGroupedMMTensor for ops that do preserve subclass
103125
return pytree.tree_map_only(
104126
torch.Tensor,
105-
lambda x: ScaledGroupedMMTensor(x),
127+
lambda x: ScaledGroupedMMTensor(x, dtype),
106128
out,
107129
)
108130

131+
def __repr__(self):
132+
return f"ScaledGroupedMMTensor(data={self._data}, dtype={self._dtype})"
133+
134+
def __tensor_flatten__(self):
135+
return ["_data"], {"dtype": self._dtype}
136+
137+
@staticmethod
138+
def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
139+
return ScaledGroupedMMTensor(
140+
inner_tensors["_data"],
141+
flatten_spec["dtype"],
142+
)
143+
109144
def fsdp_pre_all_gather(self, mesh):
110-
return (self._data,), ()
145+
all_gather_inputs = (self._data,)
146+
all_gather_metadata = ()
147+
return all_gather_inputs, all_gather_metadata
111148

112149
def fsdp_post_all_gather(
113150
self,
@@ -118,6 +155,6 @@ def fsdp_post_all_gather(
118155
out: Optional[torch.Tensor] = None,
119156
):
120157
(data,) = all_gather_outputs
121-
return ScaledGroupedMMTensor(
122-
data,
123-
), (data,)
158+
out = ScaledGroupedMMTensor(data, param_dtype)
159+
inner_tensors = (data,)
160+
return out, inner_tensors

0 commit comments

Comments
 (0)