Skip to content

Commit ac14d92

Browse files
[float8 moe training] fix bug affecting mixed precision training (#2451)
fix float8 moe training dtype bug
1 parent 3a5819e commit ac14d92

File tree

3 files changed

+61
-11
lines changed

3 files changed

+61
-11
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
torchrun --nproc_per_node=2 --local-ranks-filter=0 -m pytest test/prototype/moe_training/test_fsdp.py

torchao/prototype/moe_training/conversion_utils.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import logging
17
from typing import Callable, Optional
28

39
from torch import nn
@@ -8,6 +14,8 @@
814
register_quantize_module_handler,
915
)
1016

17+
logger: logging.Logger = logging.getLogger(__name__)
18+
1119

1220
class MoETrainingConfig(AOBaseConfig):
1321
"""
@@ -76,7 +84,7 @@ def _swap_params(
7684
f"Does not support a root nn.Parameter with children: {module}"
7785
)
7886
if not isinstance(module.data, ScaledGroupedMMTensor):
79-
new_data = ScaledGroupedMMTensor(module.data)
87+
new_data = ScaledGroupedMMTensor(module.data, module.data.dtype)
8088
return nn.Parameter(new_data, requires_grad=module.requires_grad)
8189
return module
8290

@@ -102,10 +110,13 @@ def post_order_traversal(
102110
for param_name, param in module.named_parameters(recurse=False):
103111
if not isinstance(param.data, ScaledGroupedMMTensor):
104112
new_param = nn.Parameter(
105-
ScaledGroupedMMTensor(param), requires_grad=param.requires_grad
113+
ScaledGroupedMMTensor(param.data, param.data.dtype),
114+
requires_grad=param.requires_grad,
106115
)
107116
setattr(module, param_name, new_param)
108-
print(f"Swapped {cur_fqn}.{param_name} to ScaledGroupedMMTensor")
117+
logger.info(
118+
f"Swapped {cur_fqn}.{param_name} to ScaledGroupedMMTensor"
119+
)
109120

110121
post_order_traversal(root_module)
111122
return root_module

torchao/prototype/moe_training/tensor.py

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
18
from typing import Any, Optional, Tuple
29

310
import torch
@@ -6,6 +13,9 @@
613

714
from torchao.prototype.moe_training import _scaled_grouped_mm
815

16+
logger: logging.Logger = logging.getLogger(__name__)
17+
18+
919
_ops_to_preserve_subclass = {
1020
torch.ops.aten.empty_like.default,
1121
torch.ops.aten.new_zeros.default,
@@ -34,14 +44,15 @@ class ScaledGroupedMMTensor(torch.Tensor):
3444
def __new__(
3545
cls,
3646
tensor: torch.Tensor,
47+
dtype: torch.dtype,
3748
):
3849
return torch.Tensor._make_wrapper_subclass(
3950
cls,
4051
tensor.size(),
4152
strides=tensor.stride(),
4253
storage_offset=tensor.storage_offset(),
4354
memory_format=suggest_memory_format(tensor),
44-
dtype=tensor.dtype,
55+
dtype=dtype,
4556
layout=tensor.layout,
4657
device=tensor.device,
4758
pin_memory=tensor.is_pinned(),
@@ -51,11 +62,14 @@ def __new__(
5162
def __init__(
5263
self,
5364
tensor: torch.Tensor,
65+
dtype: torch.dtype,
5466
):
5567
self._data = tensor
68+
self._dtype = dtype
5669

5770
@classmethod
5871
def __torch_function__(cls, func, types, args, kwargs={}):
72+
logger.info(f"{func.__name__}, args: {args}, kwargs: {kwargs}")
5973
# override the grouped mm op to use the differentiable _scaled_grouped_mm
6074
if func.__name__ == cls.grouped_mm_func_name:
6175
# Use torchao scaled grouped mm with dynamic quant for
@@ -84,10 +98,19 @@ def __torch_function__(cls, func, types, args, kwargs={}):
8498
def __torch_dispatch__(cls, func, types, args, kwargs={}):
8599
# detach is special case
86100
if func == torch.ops.aten.detach.default:
87-
return ScaledGroupedMMTensor(args[0]._data)
101+
return ScaledGroupedMMTensor(args[0]._data, args[0]._dtype)
88102

89103
# unwrap args and kwargs
90-
unwrap = lambda tensor: tensor._data
104+
dtype: Optional[torch.dtype] = None
105+
106+
def unwrap(t):
107+
nonlocal dtype
108+
if dtype is None:
109+
dtype = t._dtype
110+
else:
111+
assert t._dtype == dtype
112+
return t._data
113+
91114
args, kwargs = pytree.tree_map_only(
92115
ScaledGroupedMMTensor, unwrap, (args, kwargs or {})
93116
)
@@ -102,12 +125,27 @@ def __torch_dispatch__(cls, func, types, args, kwargs={}):
102125
# wrap outputs back into ScaledGroupedMMTensor for ops that do preserve subclass
103126
return pytree.tree_map_only(
104127
torch.Tensor,
105-
lambda x: ScaledGroupedMMTensor(x),
128+
lambda x: ScaledGroupedMMTensor(x, dtype),
106129
out,
107130
)
108131

132+
def __repr__(self):
133+
return f"ScaledGroupedMMTensor(data={self._data}, dtype={self._dtype})"
134+
135+
def __tensor_flatten__(self):
136+
return ["_data"], {"_dtype": self._dtype}
137+
138+
@staticmethod
139+
def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
140+
return ScaledGroupedMMTensor(
141+
inner_tensors["_data"],
142+
flatten_spec["_dtype"],
143+
)
144+
109145
def fsdp_pre_all_gather(self, mesh):
110-
return (self._data,), ()
146+
all_gather_inputs = (self._data,)
147+
all_gather_metadata = ()
148+
return all_gather_inputs, all_gather_metadata
111149

112150
def fsdp_post_all_gather(
113151
self,
@@ -118,6 +156,6 @@ def fsdp_post_all_gather(
118156
out: Optional[torch.Tensor] = None,
119157
):
120158
(data,) = all_gather_outputs
121-
return ScaledGroupedMMTensor(
122-
data,
123-
), (data,)
159+
output = ScaledGroupedMMTensor(data, param_dtype)
160+
inner_tensors = (data,)
161+
return output, inner_tensors

0 commit comments

Comments
 (0)