@@ -41,7 +41,7 @@ class ScaledGroupedMMTensor(torch.Tensor):
41
41
differentiable _scaled_grouped_mm autograd function.
42
42
"""
43
43
44
- grouped_mm_func_name = "_grouped_mm"
44
+ grouped_mm_func_names = { "_grouped_mm" , "_grouped_mm.default" }
45
45
offs_arg_name = "offs"
46
46
47
47
@staticmethod
@@ -74,7 +74,7 @@ def __init__(
74
74
@classmethod
75
75
def __torch_function__ (cls , func , types , args , kwargs = {}):
76
76
# override the grouped mm op to use the differentiable _scaled_grouped_mm
77
- if func .__name__ == cls .grouped_mm_func_name :
77
+ if func .__name__ in cls .grouped_mm_func_names :
78
78
# Use torchao scaled grouped mm with dynamic quant for
79
79
# "2d x 3d with offsets" case (used for routed experts).
80
80
# Otherwise, fall back to regular grouped mm.
@@ -86,7 +86,9 @@ def __torch_function__(cls, func, types, args, kwargs={}):
86
86
A_is_2d = A .dim () == 2
87
87
B_is_3d = B .dim () == 3
88
88
has_offs = kwargs .get (cls .offs_arg_name ) is not None
89
- if A_is_2d and B_is_3d and has_offs :
89
+ logger .info (f"A.shape={ A .shape } , B.shape={ B .shape } , has_offs={ has_offs } " )
90
+
91
+ if A_is_2d and B_is_3d :
90
92
return _scaled_grouped_mm (
91
93
* args ,
92
94
** kwargs ,
@@ -133,7 +135,7 @@ def unwrap(t):
133
135
)
134
136
135
137
def __repr__ (self ):
136
- return f"ScaledGroupedMMTensor(data={ self ._data } , dtype={ self ._dtype } )"
138
+ return f"ScaledGroupedMMTensor(data.dtype ={ self ._data . dtype } , self. dtype={ self ._dtype } )"
137
139
138
140
def __tensor_flatten__ (self ):
139
141
return ["_data" ], {"_dtype" : self ._dtype }
@@ -171,7 +173,6 @@ def fsdp_post_all_gather(
171
173
logger .debug (f"fsdp_post_all_gather: data.dtype={ data .dtype } , param_dtype: { param_dtype } " )
172
174
173
175
if out is not None :
174
- #with _unsafe_preserve_version_counter(out):
175
176
with torch .no_grad ():
176
177
out .copy_ (data )
177
178
return
0 commit comments