Skip to content

Commit f673fb9

Browse files
fix dtype bug
1 parent b51d773 commit f673fb9

File tree

2 files changed

+11
-0
lines changed

2 files changed

+11
-0
lines changed

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,11 @@ def _scaled_grouped_mm(
4040
offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor.
4141
out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported.
4242
"""
43+
<<<<<<< HEAD
4344
logger.info("Using scaled_grouped_mm")
45+
=======
46+
logger.info("Using differentiable _scaled_grouped_mm")
47+
>>>>>>> eb2dd3e0 (fix dtype bug)
4448
return _Float8GroupedMM.apply(
4549
A,
4650
B_t,

torchao/prototype/moe_training/tensor.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import logging
78
from typing import Any, Optional, Tuple
89

910
import torch
@@ -18,6 +19,10 @@
1819

1920
logger: logging.Logger = logging.getLogger(__name__)
2021

22+
<<<<<<< HEAD
23+
=======
24+
25+
>>>>>>> eb2dd3e0 (fix dtype bug)
2126
_ops_to_preserve_subclass = {
2227
torch.ops.aten.empty_like.default,
2328
torch.ops.aten.new_zeros.default,
@@ -96,6 +101,7 @@ def __torch_function__(cls, func, types, args, kwargs={}):
96101

97102
@classmethod
98103
def __torch_dispatch__(cls, func, types, args, kwargs={}):
104+
logger.debug(f"{func.__name__}, args={args}, kwargs={kwargs}")
99105
# detach is special case
100106
if func == torch.ops.aten.detach.default:
101107
return ScaledGroupedMMTensor(args[0]._data, args[0]._dtype)
@@ -135,6 +141,7 @@ def __repr__(self):
135141
def __tensor_flatten__(self):
136142
return ["_data"], {"_dtype": self._dtype}
137143

144+
138145
@staticmethod
139146
def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
140147
return ScaledGroupedMMTensor(

0 commit comments

Comments
 (0)