Skip to content

Conversation

@danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Sep 8, 2025

Summary

  • We have a triton kernel that does per group mxfp8 scale conversion to blocked format, for 2d-3d grouped gemms, where the groups are along the total_M dimension in (total_M, K) @ (E, K, N)
  • We now need triton kernels that does per group conversion for 2d-2d grouped gemms, where the groups are along the scaled dim: (M, total_K) @ (total_K, N).

Memory layout

LHS operand for 2d-3d MXFP8 grouped gemm

This is the existing kernel, the layout is much simpler for the 2d-3d case where the groups are along M:

Screenshot 2025-09-09 at 9 11 18 AM

LHS operand for 2d-2d MXFP8 grouped gemm

When groups are along the scaled dim being contracted, the memory layout is more complicated, as we have to represent separate standalone "row of blocks major" layouts in subtensors that are part of a larger parent tensor.

Screenshot 2025-09-09 at 8 45 48 AM

Other

  • I also removed the Mg param from the torch_to_blocked_per_group_2d function (used for 2d-3d grouped gemms), since it is not used.

Test plan

  • pytest test/prototype/moe_training/test_kernels.py -k test_mxfp8_per_group_blocked_scales_2d2d -s

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 8, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2956

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 213f19b with merge base f1acc1e (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 8, 2025
@danielvegamyhre danielvegamyhre added topic: not user facing Use this tag if you don't want this PR to show up in release notes and removed CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. labels Sep 8, 2025
@danielvegamyhre danielvegamyhre marked this pull request as draft September 8, 2025 16:36
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 8, 2025
@danielvegamyhre danielvegamyhre marked this pull request as ready for review September 9, 2025 01:54
@danielvegamyhre danielvegamyhre changed the title [WIP] [mxfp8 moe training] blocked scale conversion for LHS of 2d-2d grouped gemm [mxfp8 moe training] blocked scale conversion for LHS of 2d-2d grouped gemm Sep 9, 2025
@danielvegamyhre danielvegamyhre changed the title [mxfp8 moe training] blocked scale conversion for LHS of 2d-2d grouped gemm [mxfp8 moe training] per group scale conversion to blocked format for LHS of 2d-2d grouped gemm Sep 9, 2025
_, output_group_offsets = compute_per_group_blocked_scale_offsets_2d2d_lhs(
input_group_offsets
)
assert torch.allclose(output_group_offsets, ref_start_cols_after_padding), (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.equal?

return blocked_scales, start_row_after_padding


def torch_to_blocked_per_group_2d2d_lhs(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe soemthig like 2d_kmajor since this is only operates on 1 tensor

Copy link
Contributor Author

@danielvegamyhre danielvegamyhre Sep 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm for 2d-3d and 2d-2d, the 2d tensors are all K major aren't they? I was actually thinking we should just remove the _lhs suffix from the name, since I think the kernel will actually work for both LHS and RHS operands in the 2d2d grouped gemm, given shapes (M, total_K) and (N, total_K)

How about: torch_to_blocked_per_group_for_2d2d (added "for")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed this func and others for clarity that the difference is about grouping along M dim vs K dim

orig_offsets + group_pid - 1, mask=group_pid > 0, other=0
)
input_group_end_col = tl.load(
orig_offsets + group_pid, mask=group_pid < num_groups, other=0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit we dont need a mask for this load right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

orig_offsets + group_pid, mask=group_pid < num_groups, other=0
)
# Output scales start row we will begin writing to
output_group_start_col = tl.load(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

output_scales_group_offsets + group_pid, mask=group_pid < num_groups, other=0
)

# Calculate destination indices for each row and col in block swizzled layout.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you make this a helper jit function and reuse elsewwhere, its fine if you stack that commit on this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea, done

@danielvegamyhre
Copy link
Contributor Author

@drisspg i finished addressing your comments, ready for another look

@danielvegamyhre danielvegamyhre changed the title [mxfp8 moe training] per group scale conversion to blocked format for LHS of 2d-2d grouped gemm [mxfp8 moe training] per group scale conversion to blocked format with groups along K dim (for 2d2d grouped gemm) Sep 11, 2025
@danielvegamyhre danielvegamyhre merged commit 14ca521 into main Sep 11, 2025
18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. moe mx topic: not user facing Use this tag if you don't want this PR to show up in release notes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants