Skip to content

Commit 29be4b2

Browse files
fix dtype bug
1 parent fb0122e commit 29be4b2

File tree

3 files changed

+44
-12
lines changed

3 files changed

+44
-12
lines changed

torchao/prototype/moe_training/conversion_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def _swap_params(
8484
f"Does not support a root nn.Parameter with children: {module}"
8585
)
8686
if not isinstance(module.data, ScaledGroupedMMTensor):
87-
new_data = ScaledGroupedMMTensor(module.data)
87+
new_data = ScaledGroupedMMTensor(module.data, module.data.dtype)
8888
return nn.Parameter(new_data, requires_grad=module.requires_grad)
8989
return module
9090

@@ -110,7 +110,8 @@ def post_order_traversal(
110110
for param_name, param in module.named_parameters(recurse=False):
111111
if not isinstance(param.data, ScaledGroupedMMTensor):
112112
new_param = nn.Parameter(
113-
ScaledGroupedMMTensor(param), requires_grad=param.requires_grad
113+
ScaledGroupedMMTensor(param.data, param.data.dtype),
114+
requires_grad=param.requires_grad,
114115
)
115116
setattr(module, param_name, new_param)
116117
logger.info(

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def _scaled_grouped_mm(
4040
offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor.
4141
out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported.
4242
"""
43-
logger.debug("Using differentiable _scaled_grouped_mm")
43+
logger.info("Using differentiable _scaled_grouped_mm")
4444
return _Float8GroupedMM.apply(
4545
A,
4646
B_t,

torchao/prototype/moe_training/tensor.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import logging
78
from typing import Any, Optional, Tuple
89

910
import torch
@@ -12,6 +13,9 @@
1213

1314
from torchao.prototype.moe_training import _scaled_grouped_mm
1415

16+
logger: logging.Logger = logging.getLogger(__name__)
17+
18+
1519
_ops_to_preserve_subclass = {
1620
torch.ops.aten.empty_like.default,
1721
torch.ops.aten.new_zeros.default,
@@ -40,14 +44,15 @@ class ScaledGroupedMMTensor(torch.Tensor):
4044
def __new__(
4145
cls,
4246
tensor: torch.Tensor,
47+
dtype: torch.dtype,
4348
):
4449
return torch.Tensor._make_wrapper_subclass(
4550
cls,
4651
tensor.size(),
4752
strides=tensor.stride(),
4853
storage_offset=tensor.storage_offset(),
4954
memory_format=suggest_memory_format(tensor),
50-
dtype=tensor.dtype,
55+
dtype=dtype,
5156
layout=tensor.layout,
5257
device=tensor.device,
5358
pin_memory=tensor.is_pinned(),
@@ -57,8 +62,10 @@ def __new__(
5762
def __init__(
5863
self,
5964
tensor: torch.Tensor,
65+
dtype: torch.dtype,
6066
):
6167
self._data = tensor
68+
self._dtype = dtype
6269

6370
@classmethod
6471
def __torch_function__(cls, func, types, args, kwargs={}):
@@ -87,12 +94,22 @@ def __torch_function__(cls, func, types, args, kwargs={}):
8794

8895
@classmethod
8996
def __torch_dispatch__(cls, func, types, args, kwargs={}):
97+
logger.debug(f"{func.__name__}, args={args}, kwargs={kwargs}")
9098
# detach is special case
9199
if func == torch.ops.aten.detach.default:
92-
return ScaledGroupedMMTensor(args[0]._data)
100+
return ScaledGroupedMMTensor(args[0]._data, args[0]._dtype)
93101

94102
# unwrap args and kwargs
95-
unwrap = lambda tensor: tensor._data
103+
dtype: Optional[torch.dtype] = None
104+
105+
def unwrap(t):
106+
nonlocal dtype
107+
if dtype is None:
108+
dtype = t._dtype
109+
else:
110+
assert t._dtype == dtype
111+
return t._data
112+
96113
args, kwargs = pytree.tree_map_only(
97114
ScaledGroupedMMTensor, unwrap, (args, kwargs or {})
98115
)
@@ -107,13 +124,27 @@ def __torch_dispatch__(cls, func, types, args, kwargs={}):
107124
# wrap outputs back into ScaledGroupedMMTensor for ops that do preserve subclass
108125
return pytree.tree_map_only(
109126
torch.Tensor,
110-
lambda x: ScaledGroupedMMTensor(x),
127+
lambda x: ScaledGroupedMMTensor(x, dtype),
111128
out,
112129
)
113130

131+
def __repr__(self):
132+
return f"ScaledGroupedMMTensor(data={self._data}, dtype={self._dtype})"
133+
134+
def __tensor_flatten__(self):
135+
return ["_data"], {"dtype": self._dtype}
136+
137+
@staticmethod
138+
def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
139+
return ScaledGroupedMMTensor(
140+
inner_tensors["_data"],
141+
flatten_spec["dtype"],
142+
)
143+
114144
def fsdp_pre_all_gather(self, mesh):
115-
metadata = ()
116-
return (self._data,), metadata
145+
all_gather_inputs = (self._data,)
146+
all_gather_metadata = ()
147+
return all_gather_inputs, all_gather_metadata
117148

118149
def fsdp_post_all_gather(
119150
self,
@@ -124,6 +155,6 @@ def fsdp_post_all_gather(
124155
out: Optional[torch.Tensor] = None,
125156
):
126157
(data,) = all_gather_outputs
127-
return ScaledGroupedMMTensor(
128-
data,
129-
), (data,)
158+
out = ScaledGroupedMMTensor(data, param_dtype)
159+
inner_tensors = (data,)
160+
return out, inner_tensors

0 commit comments

Comments
 (0)