1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD 3-Clause license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
1
8
from typing import Any , Optional , Tuple
2
9
3
10
import torch
6
13
7
14
from torchao .prototype .moe_training import _scaled_grouped_mm
8
15
16
+ logger : logging .Logger = logging .getLogger (__name__ )
17
+
9
18
_ops_to_preserve_subclass = {
10
19
torch .ops .aten .empty_like .default ,
11
20
torch .ops .aten .new_zeros .default ,
@@ -27,7 +36,7 @@ class ScaledGroupedMMTensor(torch.Tensor):
27
36
differentiable _scaled_grouped_mm autograd function.
28
37
"""
29
38
30
- grouped_mm_func_name = "_grouped_mm"
39
+ grouped_mm_func_names = { "_grouped_mm" , "_grouped_mm.default" }
31
40
offs_arg_name = "offs"
32
41
33
42
@staticmethod
@@ -57,7 +66,7 @@ def __init__(
57
66
@classmethod
58
67
def __torch_function__ (cls , func , types , args , kwargs = {}):
59
68
# override the grouped mm op to use the differentiable _scaled_grouped_mm
60
- if func .__name__ == cls .grouped_mm_func_name :
69
+ if func .__name__ in cls .grouped_mm_func_names :
61
70
# Use torchao scaled grouped mm with dynamic quant for
62
71
# "2d x 3d with offsets" case (used for routed experts).
63
72
# Otherwise, fall back to regular grouped mm.
@@ -69,7 +78,9 @@ def __torch_function__(cls, func, types, args, kwargs={}):
69
78
A_is_2d = A .dim () == 2
70
79
B_is_3d = B .dim () == 3
71
80
has_offs = kwargs .get (cls .offs_arg_name ) is not None
72
- if A_is_2d and B_is_3d and has_offs :
81
+ logger .info (f"A.shape={ A .shape } , B.shape={ B .shape } , has_offs={ has_offs } " )
82
+
83
+ if A_is_2d and B_is_3d :
73
84
return _scaled_grouped_mm (
74
85
* args ,
75
86
** kwargs ,
@@ -107,7 +118,8 @@ def __torch_dispatch__(cls, func, types, args, kwargs={}):
107
118
)
108
119
109
120
def fsdp_pre_all_gather (self , mesh ):
110
- return (self ._data ,), ()
121
+ metadata = ()
122
+ return (self ._data ,), metadata
111
123
112
124
def fsdp_post_all_gather (
113
125
self ,
0 commit comments