File tree Expand file tree Collapse file tree 2 files changed +0
-8
lines changed
torchao/prototype/moe_training Expand file tree Collapse file tree 2 files changed +0
-8
lines changed Original file line number Diff line number Diff line change @@ -40,11 +40,7 @@ def _scaled_grouped_mm(
40
40
offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor.
41
41
out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported.
42
42
"""
43
- < << << << HEAD
44
43
logger .info ("Using scaled_grouped_mm" )
45
- == == == =
46
- logger .info ("Using differentiable _scaled_grouped_mm" )
47
- >> >> >> > eb2dd3e0 (fix dtype bug )
48
44
return _Float8GroupedMM .apply (
49
45
A ,
50
46
B_t ,
Original file line number Diff line number Diff line change 19
19
20
20
logger : logging .Logger = logging .getLogger (__name__ )
21
21
22
- < << << << HEAD
23
- == == == =
24
-
25
- >> >> >> > eb2dd3e0 (fix dtype bug )
26
22
_ops_to_preserve_subclass = {
27
23
torch .ops .aten .empty_like .default ,
28
24
torch .ops .aten .new_zeros .default ,
You can’t perform that action at this time.
0 commit comments