4
4
# This source code is licensed under the BSD 3-Clause license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ import logging
7
8
from typing import Any , Optional , Tuple
8
9
9
10
import torch
12
13
13
14
from torchao .prototype .moe_training import _scaled_grouped_mm
14
15
16
+ logger : logging .Logger = logging .getLogger (__name__ )
17
+
18
+
15
19
_ops_to_preserve_subclass = {
16
20
torch .ops .aten .empty_like .default ,
17
21
torch .ops .aten .new_zeros .default ,
@@ -40,14 +44,15 @@ class ScaledGroupedMMTensor(torch.Tensor):
40
44
def __new__ (
41
45
cls ,
42
46
tensor : torch .Tensor ,
47
+ dtype : torch .dtype ,
43
48
):
44
49
return torch .Tensor ._make_wrapper_subclass (
45
50
cls ,
46
51
tensor .size (),
47
52
strides = tensor .stride (),
48
53
storage_offset = tensor .storage_offset (),
49
54
memory_format = suggest_memory_format (tensor ),
50
- dtype = tensor . dtype ,
55
+ dtype = dtype ,
51
56
layout = tensor .layout ,
52
57
device = tensor .device ,
53
58
pin_memory = tensor .is_pinned (),
@@ -57,8 +62,10 @@ def __new__(
57
62
def __init__ (
58
63
self ,
59
64
tensor : torch .Tensor ,
65
+ dtype : torch .dtype ,
60
66
):
61
67
self ._data = tensor
68
+ self ._dtype = dtype
62
69
63
70
@classmethod
64
71
def __torch_function__ (cls , func , types , args , kwargs = {}):
@@ -87,12 +94,22 @@ def __torch_function__(cls, func, types, args, kwargs={}):
87
94
88
95
@classmethod
89
96
def __torch_dispatch__ (cls , func , types , args , kwargs = {}):
97
+ logger .debug (f"{ func .__name__ } , args={ args } , kwargs={ kwargs } " )
90
98
# detach is special case
91
99
if func == torch .ops .aten .detach .default :
92
- return ScaledGroupedMMTensor (args [0 ]._data )
100
+ return ScaledGroupedMMTensor (args [0 ]._data , args [ 0 ]. _dtype )
93
101
94
102
# unwrap args and kwargs
95
- unwrap = lambda tensor : tensor ._data
103
+ dtype : Optional [torch .dtype ] = None
104
+
105
+ def unwrap (t ):
106
+ nonlocal dtype
107
+ if dtype is None :
108
+ dtype = t ._dtype
109
+ else :
110
+ assert t ._dtype == dtype
111
+ return t ._data
112
+
96
113
args , kwargs = pytree .tree_map_only (
97
114
ScaledGroupedMMTensor , unwrap , (args , kwargs or {})
98
115
)
@@ -107,13 +124,27 @@ def __torch_dispatch__(cls, func, types, args, kwargs={}):
107
124
# wrap outputs back into ScaledGroupedMMTensor for ops that do preserve subclass
108
125
return pytree .tree_map_only (
109
126
torch .Tensor ,
110
- lambda x : ScaledGroupedMMTensor (x ),
127
+ lambda x : ScaledGroupedMMTensor (x , dtype ),
111
128
out ,
112
129
)
113
130
131
+ def __repr__ (self ):
132
+ return f"ScaledGroupedMMTensor(data={ self ._data } , dtype={ self ._dtype } )"
133
+
134
+ def __tensor_flatten__ (self ):
135
+ return ["_data" ], {"dtype" : self ._dtype }
136
+
137
+ @staticmethod
138
+ def __tensor_unflatten__ (inner_tensors , flatten_spec , outer_size , outer_stride ):
139
+ return ScaledGroupedMMTensor (
140
+ inner_tensors ["_data" ],
141
+ flatten_spec ["dtype" ],
142
+ )
143
+
114
144
def fsdp_pre_all_gather (self , mesh ):
115
- metadata = ()
116
- return (self ._data ,), metadata
145
+ all_gather_inputs = (self ._data ,)
146
+ all_gather_metadata = ()
147
+ return all_gather_inputs , all_gather_metadata
117
148
118
149
def fsdp_post_all_gather (
119
150
self ,
@@ -124,6 +155,6 @@ def fsdp_post_all_gather(
124
155
out : Optional [torch .Tensor ] = None ,
125
156
):
126
157
(data ,) = all_gather_outputs
127
- return ScaledGroupedMMTensor (
128
- data ,
129
- ), ( data ,)
158
+ out = ScaledGroupedMMTensor (data , param_dtype )
159
+ inner_tensors = ( data ,)
160
+ return out , inner_tensors
0 commit comments