Skip to content

Commit efd993f

Browse files
make offs optional for scaled grouped mm
1 parent 2898903 commit efd993f

File tree

2 files changed

+22
-8
lines changed

2 files changed

+22
-8
lines changed

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import logging
78
from typing import Optional
89

910
import torch
@@ -18,11 +19,13 @@
1819
_is_column_major,
1920
)
2021

22+
logger: logging.Logger = logging.getLogger(__name__)
23+
2124

2225
def _scaled_grouped_mm(
2326
A: torch.Tensor,
2427
B_t: torch.Tensor,
25-
offs: torch.Tensor,
28+
offs: Optional[torch.Tensor] = None,
2629
out_dtype: Optional[torch.dtype] = torch.bfloat16,
2730
) -> torch.Tensor:
2831
"""
@@ -38,6 +41,7 @@ def _scaled_grouped_mm(
3841
out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported.
3942
use_triton_for_per_group_scales (bool): Whether to use custom triton kernels to compute per-group scales. Default is True.
4043
"""
44+
logger.info("Using differentiable _scaled_grouped_mm")
4145
return _Float8GroupedMM.apply(
4246
A,
4347
B_t,
@@ -54,9 +58,8 @@ def forward(
5458
ctx,
5559
A: torch.Tensor,
5660
B_t: torch.Tensor,
57-
offs: torch.Tensor,
61+
offs: Optional[torch.Tensor] = None,
5862
out_dtype: Optional[torch.dtype] = torch.bfloat16,
59-
use_triton_for_per_group_scales: bool = True,
6063
) -> torch.Tensor:
6164
# torchao _scaled_grouped_mm only supports A=2D, B=3D.
6265
assert A.ndim == 2, "A must be 2D"
@@ -76,7 +79,6 @@ def forward(
7679
assert B_t.dtype == torch.float32 or B_t.dtype == torch.bfloat16, (
7780
"B must be float32 or bfloat16"
7881
)
79-
assert offs.dtype == torch.int32, "offs must be int32"
8082

8183
# Assert A and B dims are compatible for a scaled grouped GEMM.
8284
assert A.size(-1) == B_t.size(-2), (

torchao/prototype/moe_training/tensor.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
18
from typing import Any, Optional, Tuple
29

310
import torch
@@ -6,6 +13,8 @@
613

714
from torchao.prototype.moe_training import _scaled_grouped_mm
815

16+
logger: logging.Logger = logging.getLogger(__name__)
17+
918
_ops_to_preserve_subclass = {
1019
torch.ops.aten.empty_like.default,
1120
torch.ops.aten.new_zeros.default,
@@ -27,7 +36,7 @@ class ScaledGroupedMMTensor(torch.Tensor):
2736
differentiable _scaled_grouped_mm autograd function.
2837
"""
2938

30-
grouped_mm_func_name = "_grouped_mm"
39+
grouped_mm_func_names = {"_grouped_mm", "_grouped_mm.default"}
3140
offs_arg_name = "offs"
3241

3342
@staticmethod
@@ -57,7 +66,7 @@ def __init__(
5766
@classmethod
5867
def __torch_function__(cls, func, types, args, kwargs={}):
5968
# override the grouped mm op to use the differentiable _scaled_grouped_mm
60-
if func.__name__ == cls.grouped_mm_func_name:
69+
if func.__name__ in cls.grouped_mm_func_names:
6170
# Use torchao scaled grouped mm with dynamic quant for
6271
# "2d x 3d with offsets" case (used for routed experts).
6372
# Otherwise, fall back to regular grouped mm.
@@ -69,7 +78,9 @@ def __torch_function__(cls, func, types, args, kwargs={}):
6978
A_is_2d = A.dim() == 2
7079
B_is_3d = B.dim() == 3
7180
has_offs = kwargs.get(cls.offs_arg_name) is not None
72-
if A_is_2d and B_is_3d and has_offs:
81+
logger.info(f"A.shape={A.shape}, B.shape={B.shape}, has_offs={has_offs}")
82+
83+
if A_is_2d and B_is_3d:
7384
return _scaled_grouped_mm(
7485
*args,
7586
**kwargs,
@@ -107,7 +118,8 @@ def __torch_dispatch__(cls, func, types, args, kwargs={}):
107118
)
108119

109120
def fsdp_pre_all_gather(self, mesh):
110-
return (self._data,), ()
121+
metadata = ()
122+
return (self._data,), metadata
111123

112124
def fsdp_post_all_gather(
113125
self,

0 commit comments

Comments
 (0)