Skip to content

Commit b6e08aa

Browse files
work
1 parent f673fb9 commit b6e08aa

File tree

2 files changed

+0
-8
lines changed

2 files changed

+0
-8
lines changed

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,7 @@ def _scaled_grouped_mm(
4040
offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor.
4141
out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported.
4242
"""
43-
<<<<<<< HEAD
4443
logger.info("Using scaled_grouped_mm")
45-
=======
46-
logger.info("Using differentiable _scaled_grouped_mm")
47-
>>>>>>> eb2dd3e0 (fix dtype bug)
4844
return _Float8GroupedMM.apply(
4945
A,
5046
B_t,

torchao/prototype/moe_training/tensor.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,6 @@
1919

2020
logger: logging.Logger = logging.getLogger(__name__)
2121

22-
<<<<<<< HEAD
23-
=======
24-
25-
>>>>>>> eb2dd3e0 (fix dtype bug)
2622
_ops_to_preserve_subclass = {
2723
torch.ops.aten.empty_like.default,
2824
torch.ops.aten.new_zeros.default,

0 commit comments

Comments
 (0)