1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD 3-Clause license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
1
8
from typing import Any , Optional , Tuple
2
9
3
10
import torch
6
13
7
14
from torchao .prototype .moe_training import _scaled_grouped_mm
8
15
16
+ logger : logging .Logger = logging .getLogger (__name__ )
17
+
18
+
9
19
_ops_to_preserve_subclass = {
10
20
torch .ops .aten .empty_like .default ,
11
21
torch .ops .aten .new_zeros .default ,
@@ -34,14 +44,15 @@ class ScaledGroupedMMTensor(torch.Tensor):
34
44
def __new__ (
35
45
cls ,
36
46
tensor : torch .Tensor ,
47
+ dtype : torch .dtype ,
37
48
):
38
49
return torch .Tensor ._make_wrapper_subclass (
39
50
cls ,
40
51
tensor .size (),
41
52
strides = tensor .stride (),
42
53
storage_offset = tensor .storage_offset (),
43
54
memory_format = suggest_memory_format (tensor ),
44
- dtype = tensor . dtype ,
55
+ dtype = dtype ,
45
56
layout = tensor .layout ,
46
57
device = tensor .device ,
47
58
pin_memory = tensor .is_pinned (),
@@ -51,8 +62,10 @@ def __new__(
51
62
def __init__ (
52
63
self ,
53
64
tensor : torch .Tensor ,
65
+ dtype : torch .dtype ,
54
66
):
55
67
self ._data = tensor
68
+ self ._dtype = dtype
56
69
57
70
@classmethod
58
71
def __torch_function__ (cls , func , types , args , kwargs = {}):
@@ -84,10 +97,19 @@ def __torch_function__(cls, func, types, args, kwargs={}):
84
97
def __torch_dispatch__ (cls , func , types , args , kwargs = {}):
85
98
# detach is special case
86
99
if func == torch .ops .aten .detach .default :
87
- return ScaledGroupedMMTensor (args [0 ]._data )
100
+ return ScaledGroupedMMTensor (args [0 ]._data , args [ 0 ]. _dtype )
88
101
89
102
# unwrap args and kwargs
90
- unwrap = lambda tensor : tensor ._data
103
+ dtype : Optional [torch .dtype ] = None
104
+
105
+ def unwrap (t ):
106
+ nonlocal dtype
107
+ if dtype is None :
108
+ dtype = t ._dtype
109
+ else :
110
+ assert t ._dtype == dtype
111
+ return t ._data
112
+
91
113
args , kwargs = pytree .tree_map_only (
92
114
ScaledGroupedMMTensor , unwrap , (args , kwargs or {})
93
115
)
@@ -102,12 +124,27 @@ def __torch_dispatch__(cls, func, types, args, kwargs={}):
102
124
# wrap outputs back into ScaledGroupedMMTensor for ops that do preserve subclass
103
125
return pytree .tree_map_only (
104
126
torch .Tensor ,
105
- lambda x : ScaledGroupedMMTensor (x ),
127
+ lambda x : ScaledGroupedMMTensor (x , dtype ),
106
128
out ,
107
129
)
108
130
131
+ def __repr__ (self ):
132
+ return f"ScaledGroupedMMTensor(data={ self ._data } , dtype={ self ._dtype } )"
133
+
134
+ def __tensor_flatten__ (self ):
135
+ return ["_data" ], {"dtype" : self ._dtype }
136
+
137
+ @staticmethod
138
+ def __tensor_unflatten__ (inner_tensors , flatten_spec , outer_size , outer_stride ):
139
+ return ScaledGroupedMMTensor (
140
+ inner_tensors ["_data" ],
141
+ flatten_spec ["dtype" ],
142
+ )
143
+
109
144
def fsdp_pre_all_gather (self , mesh ):
110
- return (self ._data ,), ()
145
+ all_gather_inputs = (self ._data ,)
146
+ all_gather_metadata = ()
147
+ return all_gather_inputs , all_gather_metadata
111
148
112
149
def fsdp_post_all_gather (
113
150
self ,
@@ -118,6 +155,6 @@ def fsdp_post_all_gather(
118
155
out : Optional [torch .Tensor ] = None ,
119
156
):
120
157
(data ,) = all_gather_outputs
121
- return ScaledGroupedMMTensor (
122
- data ,
123
- ), ( data ,)
158
+ out = ScaledGroupedMMTensor (data , param_dtype )
159
+ inner_tensors = ( data ,)
160
+ return out , inner_tensors
0 commit comments