Skip to content

Commit ab79bce

Browse files
make offs optional for scaled grouped mm
1 parent 02f061c commit ab79bce

File tree

2 files changed

+8
-10
lines changed

2 files changed

+8
-10
lines changed

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,13 @@
1919
_is_column_major,
2020
)
2121

22-
2322
logger: logging.Logger = logging.getLogger(__name__)
2423

2524

2625
def _scaled_grouped_mm(
2726
A: torch.Tensor,
2827
B_t: torch.Tensor,
29-
offs: torch.Tensor,
28+
offs: Optional[torch.Tensor] = None,
3029
out_dtype: Optional[torch.dtype] = torch.bfloat16,
3130
) -> torch.Tensor:
3231
"""
@@ -58,9 +57,8 @@ def forward(
5857
ctx,
5958
A: torch.Tensor,
6059
B_t: torch.Tensor,
61-
offs: torch.Tensor,
60+
offs: Optional[torch.Tensor] = None,
6261
out_dtype: Optional[torch.dtype] = torch.bfloat16,
63-
use_triton_for_per_group_scales: bool = True,
6462
) -> torch.Tensor:
6563
# torchao _scaled_grouped_mm only supports A=2D, B=3D.
6664
assert A.ndim == 2, "A must be 2D"
@@ -80,7 +78,6 @@ def forward(
8078
assert B_t.dtype == torch.float32 or B_t.dtype == torch.bfloat16, (
8179
"B must be float32 or bfloat16"
8280
)
83-
assert offs.dtype == torch.int32, "offs must be int32"
8481

8582
# Assert A and B dims are compatible for a scaled grouped GEMM.
8683
assert A.size(-1) == B_t.size(-2), (

torchao/prototype/moe_training/tensor.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class ScaledGroupedMMTensor(torch.Tensor):
4141
differentiable _scaled_grouped_mm autograd function.
4242
"""
4343

44-
grouped_mm_func_name = "_grouped_mm"
44+
grouped_mm_func_names = {"_grouped_mm", "_grouped_mm.default"}
4545
offs_arg_name = "offs"
4646

4747
@staticmethod
@@ -74,7 +74,7 @@ def __init__(
7474
@classmethod
7575
def __torch_function__(cls, func, types, args, kwargs={}):
7676
# override the grouped mm op to use the differentiable _scaled_grouped_mm
77-
if func.__name__ == cls.grouped_mm_func_name:
77+
if func.__name__ in cls.grouped_mm_func_names:
7878
# Use torchao scaled grouped mm with dynamic quant for
7979
# "2d x 3d with offsets" case (used for routed experts).
8080
# Otherwise, fall back to regular grouped mm.
@@ -86,7 +86,9 @@ def __torch_function__(cls, func, types, args, kwargs={}):
8686
A_is_2d = A.dim() == 2
8787
B_is_3d = B.dim() == 3
8888
has_offs = kwargs.get(cls.offs_arg_name) is not None
89-
if A_is_2d and B_is_3d and has_offs:
89+
logger.info(f"A.shape={A.shape}, B.shape={B.shape}, has_offs={has_offs}")
90+
91+
if A_is_2d and B_is_3d:
9092
return _scaled_grouped_mm(
9193
*args,
9294
**kwargs,
@@ -133,7 +135,7 @@ def unwrap(t):
133135
)
134136

135137
def __repr__(self):
136-
return f"ScaledGroupedMMTensor(data={self._data}, dtype={self._dtype})"
138+
return f"ScaledGroupedMMTensor(data.dtype={self._data.dtype}, self.dtype={self._dtype})"
137139

138140
def __tensor_flatten__(self):
139141
return ["_data"], {"_dtype": self._dtype}
@@ -171,7 +173,6 @@ def fsdp_post_all_gather(
171173
logger.debug(f"fsdp_post_all_gather: data.dtype={data.dtype}, param_dtype: {param_dtype}")
172174

173175
if out is not None:
174-
#with _unsafe_preserve_version_counter(out):
175176
with torch.no_grad():
176177
out.copy_(data)
177178
return

0 commit comments

Comments
 (0)