Skip to content

Add nd_loop and Enable block_n tiling for all_gather_lhs_matmul #29822

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

hanzlfs
Copy link
Contributor

@hanzlfs hanzlfs commented Jun 27, 2025

  1. Enable block_n tiling
  2. use plgpu.nd_loop
    [PAIR] justinfu@google.com

Copy link
Collaborator

@justinjfu justinjfu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fixes!

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Jun 27, 2025
@justinjfu justinjfu requested a review from apaszke June 27, 2025 21:11
@@ -150,7 +152,7 @@ def k_loop(idxs, lhs_smem, rhs_smem):
# We only delay release by 1 step, so we need to wait for the
# previous copies.
plgpu.wait_smem_to_gmem(1, wait_read_only=True)
k_loop(scratch_ref.at[scratch_slot], rhs_ref)
k_loop(scratch_ref.at[scratch_slot], rhs_ref.at[:,n_tile_slice])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Add a space for formatting: rhs_ref.at[:,n_tile_slice] -> rhs_ref.at[:, n_tile_slice]

Copy link
Member

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that's the right way to do it. The gathers happen only on the M dimension and adding the N loop in the same place will perform the same gather n // block_n times. Instead, for every M chunk we gather, we should run an inner loop that steps over all the N blocks that need to be multiplied with it

@hanzlfs
Copy link
Contributor Author

hanzlfs commented Jun 30, 2025

I don't think that's the right way to do it. The gathers happen only on the M dimension and adding the N loop in the same place will perform the same gather n // block_n times. Instead, for every M chunk we gather, we should run an inner loop that steps over all the N blocks that need to be multiplied with it

Thanks I will update this part tmr. I met another issue, if I want to use a 2d mesh x, y: (2, 4), or even x, y : (1, 8), to use axis_name = (x, y) it will give me following error

    out = core_map_p.bind(*consts, jaxpr=jaxpr, mesh=mesh,
jax._src.source_info_util.JaxStackTraceBeforeTransformation: ValueError: Failed to recompute the async_copy peer id on the host``` 
what's the corresponding change would be needed? cc @justinjfu 

@apaszke
Copy link
Member

apaszke commented Jun 30, 2025

Could you please send a small PR with the change needed to reproduce the problem? I can take a look

@hanzlfs
Copy link
Contributor Author

hanzlfs commented Jun 30, 2025

Could you please send a small PR with the change needed to reproduce the problem? I can take a look
This one should do
#29849

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
kokoro:force-run pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants