File tree Expand file tree Collapse file tree 2 files changed +1
-8
lines changed
torchao/prototype/moe_training Expand file tree Collapse file tree 2 files changed +1
-8
lines changed Original file line number Diff line number Diff line change @@ -40,16 +40,8 @@ def _scaled_grouped_mm(
40
40
offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor.
41
41
out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported.
42
42
"""
43
- < << << << HEAD
44
- < << << << HEAD
45
- # logger.info("Using scaled_grouped_mm")
46
- == == == =
47
- #logger.info("Using scaled_grouped_mm")
48
- >> >> >> > 6 ca070de (handle out != None )
49
- == == == =
50
43
# TODO: Remove once prototype is more mature. This is currently very useful for development and debugging.
51
44
logger .info ("Using scaled_grouped_mm" )
52
- >> >> >> > 2 f3bb137 (add tp support for fp8 moe training )
53
45
return _Float8GroupedMM .apply (
54
46
A ,
55
47
B_t ,
Original file line number Diff line number Diff line change @@ -183,6 +183,7 @@ def fsdp_post_all_gather(
183
183
f"{ out_data .dtype } { param_dtype } "
184
184
)
185
185
out_data .copy_ (data )
186
+
186
187
return
187
188
188
189
# For training step 0, out=None, so we need to return a new ScaledGroupedMMTensor.
You can’t perform that action at this time.
0 commit comments