Skip to content

Commit 97e55e8

Browse files
fix dtype bug
1 parent fb0122e commit 97e55e8

File tree

3 files changed

+47
-12
lines changed

3 files changed

+47
-12
lines changed

torchao/prototype/moe_training/conversion_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def _swap_params(
8484
f"Does not support a root nn.Parameter with children: {module}"
8585
)
8686
if not isinstance(module.data, ScaledGroupedMMTensor):
87-
new_data = ScaledGroupedMMTensor(module.data)
87+
new_data = ScaledGroupedMMTensor(module.data, module.data.dtype)
8888
return nn.Parameter(new_data, requires_grad=module.requires_grad)
8989
return module
9090

@@ -110,7 +110,8 @@ def post_order_traversal(
110110
for param_name, param in module.named_parameters(recurse=False):
111111
if not isinstance(param.data, ScaledGroupedMMTensor):
112112
new_param = nn.Parameter(
113-
ScaledGroupedMMTensor(param), requires_grad=param.requires_grad
113+
ScaledGroupedMMTensor(param.data, param.data.dtype),
114+
requires_grad=param.requires_grad,
114115
)
115116
setattr(module, param_name, new_param)
116117
logger.info(

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def _scaled_grouped_mm(
4040
offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor.
4141
out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported.
4242
"""
43-
logger.debug("Using differentiable _scaled_grouped_mm")
43+
logger.info("Using differentiable _scaled_grouped_mm")
4444
return _Float8GroupedMM.apply(
4545
A,
4646
B_t,

torchao/prototype/moe_training/tensor.py

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import logging
78
from typing import Any, Optional, Tuple
89

910
import torch
@@ -12,6 +13,9 @@
1213

1314
from torchao.prototype.moe_training import _scaled_grouped_mm
1415

16+
logger: logging.Logger = logging.getLogger(__name__)
17+
18+
1519
_ops_to_preserve_subclass = {
1620
torch.ops.aten.empty_like.default,
1721
torch.ops.aten.new_zeros.default,
@@ -40,14 +44,18 @@ class ScaledGroupedMMTensor(torch.Tensor):
4044
def __new__(
4145
cls,
4246
tensor: torch.Tensor,
47+
dtype: torch.dtype,
4348
):
49+
logger.info(
50+
f"__new__ tensor={tensor}, tensor.dtype={tensor.dtype}, dtype={dtype}"
51+
)
4452
return torch.Tensor._make_wrapper_subclass(
4553
cls,
4654
tensor.size(),
4755
strides=tensor.stride(),
4856
storage_offset=tensor.storage_offset(),
4957
memory_format=suggest_memory_format(tensor),
50-
dtype=tensor.dtype,
58+
dtype=dtype,
5159
layout=tensor.layout,
5260
device=tensor.device,
5361
pin_memory=tensor.is_pinned(),
@@ -57,8 +65,10 @@ def __new__(
5765
def __init__(
5866
self,
5967
tensor: torch.Tensor,
68+
dtype: torch.dtype,
6069
):
6170
self._data = tensor
71+
self._dtype = dtype
6272

6373
@classmethod
6474
def __torch_function__(cls, func, types, args, kwargs={}):
@@ -87,12 +97,22 @@ def __torch_function__(cls, func, types, args, kwargs={}):
8797

8898
@classmethod
8999
def __torch_dispatch__(cls, func, types, args, kwargs={}):
100+
logger.debug(f"{func.__name__}, args={args}, kwargs={kwargs}")
90101
# detach is special case
91102
if func == torch.ops.aten.detach.default:
92-
return ScaledGroupedMMTensor(args[0]._data)
103+
return ScaledGroupedMMTensor(args[0]._data, args[0]._dtype)
93104

94105
# unwrap args and kwargs
95-
unwrap = lambda tensor: tensor._data
106+
dtype: Optional[torch.dtype] = None
107+
108+
def unwrap(t):
109+
nonlocal dtype
110+
if dtype is None:
111+
dtype = t._dtype
112+
else:
113+
assert t._dtype == dtype
114+
return t._data
115+
96116
args, kwargs = pytree.tree_map_only(
97117
ScaledGroupedMMTensor, unwrap, (args, kwargs or {})
98118
)
@@ -107,13 +127,27 @@ def __torch_dispatch__(cls, func, types, args, kwargs={}):
107127
# wrap outputs back into ScaledGroupedMMTensor for ops that do preserve subclass
108128
return pytree.tree_map_only(
109129
torch.Tensor,
110-
lambda x: ScaledGroupedMMTensor(x),
130+
lambda x: ScaledGroupedMMTensor(x, dtype),
111131
out,
112132
)
113133

134+
def __repr__(self):
135+
return f"ScaledGroupedMMTensor(data={self._data}, dtype={self._dtype})"
136+
137+
def __tensor_flatten__(self):
138+
return ["_data"], {"dtype": self._dtype}
139+
140+
@staticmethod
141+
def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
142+
return ScaledGroupedMMTensor(
143+
inner_tensors["_data"],
144+
flatten_spec["dtype"],
145+
)
146+
114147
def fsdp_pre_all_gather(self, mesh):
115-
metadata = ()
116-
return (self._data,), metadata
148+
all_gather_inputs = (self._data,)
149+
all_gather_metadata = ()
150+
return all_gather_inputs, all_gather_metadata
117151

118152
def fsdp_post_all_gather(
119153
self,
@@ -124,6 +158,6 @@ def fsdp_post_all_gather(
124158
out: Optional[torch.Tensor] = None,
125159
):
126160
(data,) = all_gather_outputs
127-
return ScaledGroupedMMTensor(
128-
data,
129-
), (data,)
161+
out = ScaledGroupedMMTensor(data, param_dtype)
162+
inner_tensors = (data,)
163+
return out, inner_tensors

0 commit comments

Comments
 (0)