Skip to content

Commit dbb4d2a

Browse files
handle out != None
1 parent bb9626e commit dbb4d2a

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,11 @@ def _scaled_grouped_mm(
4040
offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor.
4141
out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported.
4242
"""
43+
<<<<<<< HEAD
4344
# logger.info("Using scaled_grouped_mm")
45+
=======
46+
#logger.info("Using scaled_grouped_mm")
47+
>>>>>>> 6ca070de (handle out != None)
4448
return _Float8GroupedMM.apply(
4549
A,
4650
B_t,

0 commit comments

Comments
 (0)