We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent bb9626e commit dbb4d2aCopy full SHA for dbb4d2a
torchao/prototype/moe_training/scaled_grouped_mm.py
@@ -40,7 +40,11 @@ def _scaled_grouped_mm(
40
offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor.
41
out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported.
42
"""
43
+<<<<<<< HEAD
44
# logger.info("Using scaled_grouped_mm")
45
+=======
46
+ #logger.info("Using scaled_grouped_mm")
47
+>>>>>>> 6ca070de (handle out != None)
48
return _Float8GroupedMM.apply(
49
A,
50
B_t,
0 commit comments