4
4
# This source code is licensed under the BSD 3-Clause license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ import logging
7
8
from typing import Any , Optional , Tuple
8
9
9
10
import torch
12
13
13
14
from torchao .prototype .moe_training import _scaled_grouped_mm
14
15
16
+ logger : logging .Logger = logging .getLogger (__name__ )
17
+
18
+
15
19
_ops_to_preserve_subclass = {
16
20
torch .ops .aten .empty_like .default ,
17
21
torch .ops .aten .new_zeros .default ,
@@ -40,14 +44,18 @@ class ScaledGroupedMMTensor(torch.Tensor):
40
44
def __new__ (
41
45
cls ,
42
46
tensor : torch .Tensor ,
47
+ dtype : torch .dtype ,
43
48
):
49
+ logger .info (
50
+ f"__new__ tensor={ tensor } , tensor.dtype={ tensor .dtype } , dtype={ dtype } "
51
+ )
44
52
return torch .Tensor ._make_wrapper_subclass (
45
53
cls ,
46
54
tensor .size (),
47
55
strides = tensor .stride (),
48
56
storage_offset = tensor .storage_offset (),
49
57
memory_format = suggest_memory_format (tensor ),
50
- dtype = tensor . dtype ,
58
+ dtype = dtype ,
51
59
layout = tensor .layout ,
52
60
device = tensor .device ,
53
61
pin_memory = tensor .is_pinned (),
@@ -57,8 +65,10 @@ def __new__(
57
65
def __init__ (
58
66
self ,
59
67
tensor : torch .Tensor ,
68
+ dtype : torch .dtype ,
60
69
):
61
70
self ._data = tensor
71
+ self ._dtype = dtype
62
72
63
73
@classmethod
64
74
def __torch_function__ (cls , func , types , args , kwargs = {}):
@@ -87,12 +97,22 @@ def __torch_function__(cls, func, types, args, kwargs={}):
87
97
88
98
@classmethod
89
99
def __torch_dispatch__ (cls , func , types , args , kwargs = {}):
100
+ logger .debug (f"{ func .__name__ } , args={ args } , kwargs={ kwargs } " )
90
101
# detach is special case
91
102
if func == torch .ops .aten .detach .default :
92
- return ScaledGroupedMMTensor (args [0 ]._data )
103
+ return ScaledGroupedMMTensor (args [0 ]._data , args [ 0 ]. _dtype )
93
104
94
105
# unwrap args and kwargs
95
- unwrap = lambda tensor : tensor ._data
106
+ dtype : Optional [torch .dtype ] = None
107
+
108
+ def unwrap (t ):
109
+ nonlocal dtype
110
+ if dtype is None :
111
+ dtype = t ._dtype
112
+ else :
113
+ assert t ._dtype == dtype
114
+ return t ._data
115
+
96
116
args , kwargs = pytree .tree_map_only (
97
117
ScaledGroupedMMTensor , unwrap , (args , kwargs or {})
98
118
)
@@ -107,13 +127,27 @@ def __torch_dispatch__(cls, func, types, args, kwargs={}):
107
127
# wrap outputs back into ScaledGroupedMMTensor for ops that do preserve subclass
108
128
return pytree .tree_map_only (
109
129
torch .Tensor ,
110
- lambda x : ScaledGroupedMMTensor (x ),
130
+ lambda x : ScaledGroupedMMTensor (x , dtype ),
111
131
out ,
112
132
)
113
133
134
+ def __repr__ (self ):
135
+ return f"ScaledGroupedMMTensor(data={ self ._data } , dtype={ self ._dtype } )"
136
+
137
+ def __tensor_flatten__ (self ):
138
+ return ["_data" ], {"dtype" : self ._dtype }
139
+
140
+ @staticmethod
141
+ def __tensor_unflatten__ (inner_tensors , flatten_spec , outer_size , outer_stride ):
142
+ return ScaledGroupedMMTensor (
143
+ inner_tensors ["_data" ],
144
+ flatten_spec ["dtype" ],
145
+ )
146
+
114
147
def fsdp_pre_all_gather (self , mesh ):
115
- metadata = ()
116
- return (self ._data ,), metadata
148
+ all_gather_inputs = (self ._data ,)
149
+ all_gather_metadata = ()
150
+ return all_gather_inputs , all_gather_metadata
117
151
118
152
def fsdp_post_all_gather (
119
153
self ,
@@ -124,6 +158,6 @@ def fsdp_post_all_gather(
124
158
out : Optional [torch .Tensor ] = None ,
125
159
):
126
160
(data ,) = all_gather_outputs
127
- return ScaledGroupedMMTensor (
128
- data ,
129
- ), ( data ,)
161
+ out = ScaledGroupedMMTensor (data , param_dtype )
162
+ inner_tensors = ( data ,)
163
+ return out , inner_tensors
0 commit comments