-
Notifications
You must be signed in to change notification settings - Fork 286
support flash-attn at torch backend #2257
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
pass-lin
wants to merge
12
commits into
keras-team:master
Choose a base branch
from
pass-lin:master
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
bcc0f22
support flash-attn at torch backend
pass-lin faf8ffb
fix
pass-lin 6bba5ae
fix
pass-lin 0f960b8
fix
pass-lin b4dcc7f
fix conflit
pass-lin 72f4260
fix conflit
pass-lin 6ce366d
fix conflit
pass-lin 16c4541
fix conflit
pass-lin 78f2c06
fix conflit
pass-lin 52336ac
fix conflit
pass-lin edbee6f
format
pass-lin 5c7f11f
Merge branch 'keras-team:master' into master
pass-lin File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this looks good! Can you please enable this
https://github.com/keras-team/keras-hub/blob/master/keras_hub/src/models/gemma/gemma_causal_lm_test.py#L101
in PyTorch backend and make sure the tests pass in the supported GPU - ( this may not be supported on T4-which our CI tests use, so a demo colab showing the tests passing on a supported GPU would be great)
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are models that reference the fused_attention_op_available() function.
Here are the test results of A100.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pass-lin the test has not been enabled on Pytorch backend. Can you please refer to the above comment on enabling it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know if you have tested it on a100. At present, the gemma and gemma3 test code flash attn fails. This is true for both jax and torch.
I propose, can you design tests on models like qwen and llama that are more suitable for flash-attn?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pctablet505 - have you tested this? can you please take a look?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure about it, I'll have to look into it
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pctablet505 @divyashreepathihalli
I can make sure this test is wrong, because it is testing gemma2, and gemm2 does not support flash-attn.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pass-lin
I just verified that Gemma2 and Gemma3 can't support Flash_attention on A100 GPU.
Gemma3 can use flash attention on TPU or GPUs with cuda compute capability >=9.0 that is H series or latter. For example H100
#21333