- 
                Notifications
    You must be signed in to change notification settings 
- Fork 356
[mxfp8 moe training] per group scale conversion to blocked format with groups along K dim (for 2d2d grouped gemm) #2956
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| 🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2956
 Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 213f19b with merge base f1acc1e ( This comment was automatically generated by Dr. CI and updates every 15 minutes. | 
df502e0    to
    25f4c23      
    Compare
  
    | _, output_group_offsets = compute_per_group_blocked_scale_offsets_2d2d_lhs( | ||
| input_group_offsets | ||
| ) | ||
| assert torch.allclose(output_group_offsets, ref_start_cols_after_padding), ( | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.equal?
| return blocked_scales, start_row_after_padding | ||
|  | ||
|  | ||
| def torch_to_blocked_per_group_2d2d_lhs( | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe soemthig like 2d_kmajor since this is only operates on 1 tensor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm for 2d-3d and 2d-2d, the 2d tensors are all K major aren't they? I was actually thinking we should just remove the _lhs suffix from the name, since I think the kernel will actually work for both LHS and RHS operands in the 2d2d grouped gemm, given shapes (M, total_K) and (N, total_K)
How about: torch_to_blocked_per_group_for_2d2d (added "for")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renamed this func and others for clarity that the difference is about grouping along M dim vs K dim
| orig_offsets + group_pid - 1, mask=group_pid > 0, other=0 | ||
| ) | ||
| input_group_end_col = tl.load( | ||
| orig_offsets + group_pid, mask=group_pid < num_groups, other=0 | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit we dont need a mask for this load right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated
| orig_offsets + group_pid, mask=group_pid < num_groups, other=0 | ||
| ) | ||
| # Output scales start row we will begin writing to | ||
| output_group_start_col = tl.load( | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
| output_scales_group_offsets + group_pid, mask=group_pid < num_groups, other=0 | ||
| ) | ||
|  | ||
| # Calculate destination indices for each row and col in block swizzled layout. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you make this a helper jit function and reuse elsewwhere, its fine if you stack that commit on this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good idea, done
25f4c23    to
    213f19b      
    Compare
  
    | @drisspg i finished addressing your comments, ready for another look | 
Summary
total_Mdimension in(total_M, K) @ (E, K, N)(M, total_K) @ (total_K, N).Memory layout
LHS operand for 2d-3d MXFP8 grouped gemm
This is the existing kernel, the layout is much simpler for the 2d-3d case where the groups are along M:
LHS operand for 2d-2d MXFP8 grouped gemm
When groups are along the scaled dim being contracted, the memory layout is more complicated, as we have to represent separate standalone "row of blocks major" layouts in subtensors that are part of a larger parent tensor.
Other
Mgparam from thetorch_to_blocked_per_group_2dfunction (used for 2d-3d grouped gemms), since it is not used.Test plan
pytest test/prototype/moe_training/test_kernels.py -k test_mxfp8_per_group_blocked_scales_2d2d -s