1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD 3-Clause license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
1
8
from typing import Any , Optional , Tuple
2
9
3
10
import torch
6
13
7
14
from torchao .prototype .moe_training import _scaled_grouped_mm
8
15
16
+ logger : logging .Logger = logging .getLogger (__name__ )
17
+
18
+
9
19
_ops_to_preserve_subclass = {
10
20
torch .ops .aten .empty_like .default ,
11
21
torch .ops .aten .new_zeros .default ,
@@ -34,14 +44,15 @@ class ScaledGroupedMMTensor(torch.Tensor):
34
44
def __new__ (
35
45
cls ,
36
46
tensor : torch .Tensor ,
47
+ dtype : torch .dtype ,
37
48
):
38
49
return torch .Tensor ._make_wrapper_subclass (
39
50
cls ,
40
51
tensor .size (),
41
52
strides = tensor .stride (),
42
53
storage_offset = tensor .storage_offset (),
43
54
memory_format = suggest_memory_format (tensor ),
44
- dtype = tensor . dtype ,
55
+ dtype = dtype ,
45
56
layout = tensor .layout ,
46
57
device = tensor .device ,
47
58
pin_memory = tensor .is_pinned (),
@@ -51,11 +62,14 @@ def __new__(
51
62
def __init__ (
52
63
self ,
53
64
tensor : torch .Tensor ,
65
+ dtype : torch .dtype ,
54
66
):
55
67
self ._data = tensor
68
+ self ._dtype = dtype
56
69
57
70
@classmethod
58
71
def __torch_function__ (cls , func , types , args , kwargs = {}):
72
+ logger .info (f"{ func .__name__ } , args: { args } , kwargs: { kwargs } " )
59
73
# override the grouped mm op to use the differentiable _scaled_grouped_mm
60
74
if func .__name__ == cls .grouped_mm_func_name :
61
75
# Use torchao scaled grouped mm with dynamic quant for
@@ -84,10 +98,19 @@ def __torch_function__(cls, func, types, args, kwargs={}):
84
98
def __torch_dispatch__ (cls , func , types , args , kwargs = {}):
85
99
# detach is special case
86
100
if func == torch .ops .aten .detach .default :
87
- return ScaledGroupedMMTensor (args [0 ]._data )
101
+ return ScaledGroupedMMTensor (args [0 ]._data , args [ 0 ]. _dtype )
88
102
89
103
# unwrap args and kwargs
90
- unwrap = lambda tensor : tensor ._data
104
+ dtype : Optional [torch .dtype ] = None
105
+
106
+ def unwrap (t ):
107
+ nonlocal dtype
108
+ if dtype is None :
109
+ dtype = t ._dtype
110
+ else :
111
+ assert t ._dtype == dtype
112
+ return t ._data
113
+
91
114
args , kwargs = pytree .tree_map_only (
92
115
ScaledGroupedMMTensor , unwrap , (args , kwargs or {})
93
116
)
@@ -102,12 +125,27 @@ def __torch_dispatch__(cls, func, types, args, kwargs={}):
102
125
# wrap outputs back into ScaledGroupedMMTensor for ops that do preserve subclass
103
126
return pytree .tree_map_only (
104
127
torch .Tensor ,
105
- lambda x : ScaledGroupedMMTensor (x ),
128
+ lambda x : ScaledGroupedMMTensor (x , dtype ),
106
129
out ,
107
130
)
108
131
132
+ def __repr__ (self ):
133
+ return f"ScaledGroupedMMTensor(data={ self ._data } , dtype={ self ._dtype } )"
134
+
135
+ def __tensor_flatten__ (self ):
136
+ return ["_data" ], {"_dtype" : self ._dtype }
137
+
138
+ @staticmethod
139
+ def __tensor_unflatten__ (inner_tensors , flatten_spec , outer_size , outer_stride ):
140
+ return ScaledGroupedMMTensor (
141
+ inner_tensors ["_data" ],
142
+ flatten_spec ["_dtype" ],
143
+ )
144
+
109
145
def fsdp_pre_all_gather (self , mesh ):
110
- return (self ._data ,), ()
146
+ all_gather_inputs = (self ._data ,)
147
+ all_gather_metadata = ()
148
+ return all_gather_inputs , all_gather_metadata
111
149
112
150
def fsdp_post_all_gather (
113
151
self ,
@@ -118,6 +156,6 @@ def fsdp_post_all_gather(
118
156
out : Optional [torch .Tensor ] = None ,
119
157
):
120
158
(data ,) = all_gather_outputs
121
- return ScaledGroupedMMTensor (
122
- data ,
123
- ), ( data ,)
159
+ output = ScaledGroupedMMTensor (data , param_dtype )
160
+ inner_tensors = ( data ,)
161
+ return output , inner_tensors
0 commit comments