Replies: 2 comments
-
@lucianyao Hi, thank you for your great suggestion. It's indeed a big bottleneck, but sadly could not be done in triton. |
Beta Was this translation helpful? Give feedback.
0 replies
-
We are looking for some solutions, these PRs could help you. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi Team,
I wonder if replacing
b_aw = tl.sum(tl.where(mask[:, None], b_Aw, 0), 0)
withb_aw = b_Aw[i, :]
infwd_prepare_wy_repr_kernel
of ofwy_fast.py
could reduce complexity from O(BC^2) to O(BC), saving memory and compute effort?Furthermore, I wonder if we could replace
b_aw = b_aw + tl.sum(b_aw[:, None] * b_Aw, 0) * (tl.arange(0, BC) < i)
withb_aw = b_aw + tl.dot(b_aw, b_Aw)
to make the code more readable? The updated loop would look like:for i in range(1, BC):
It seems to maintain correctness while simplifying the logic and potentially improving GPU performance.
Any thoughts on these changes? Thanks for your great work!
Hong
Beta Was this translation helpful? Give feedback.
All reactions