Skip to content

support flash-attn at torch backend #2257

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion keras_hub/src/models/gemma/gemma_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def _compute_attention(
attention_mask = ops.expand_dims(attention_mask, axis=1)
attention_mask = ops.cast(attention_mask, dtype="bool")
# Only pass soft cap if needed as not all keras versions support.
if self.logit_soft_cap:
if self.logit_soft_cap is not None:
kwargs = {"attn_logits_soft_cap": self.logit_soft_cap}
else:
kwargs = {}
Expand Down
1 change: 1 addition & 0 deletions keras_hub/src/models/qwen_moe/qwen_moe_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(
self.rope_scaling_factor = rope_scaling_factor
self.use_sliding_window_attention = use_sliding_window_attention
self.sliding_window_size = sliding_window_size
self.logit_soft_cap = None

def build(self, inputs_shape):
# Einsum variables:
Expand Down
17 changes: 17 additions & 0 deletions keras_hub/src/utils/keras_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,23 @@ def fused_attention_op_available():
)
return False
return True
elif (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks good! Can you please enable this
https://github.com/keras-team/keras-hub/blob/master/keras_hub/src/models/gemma/gemma_causal_lm_test.py#L101
in PyTorch backend and make sure the tests pass in the supported GPU - ( this may not be supported on T4-which our CI tests use, so a demo colab showing the tests passing on a supported GPU would be great)

Copy link
Contributor Author

@pass-lin pass-lin May 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks good! Can you please enable this https://github.com/keras-team/keras-hub/blob/master/keras_hub/src/models/gemma/gemma_causal_lm_test.py#L101 in PyTorch backend and make sure the tests pass in the supported GPU - ( this may not be supported on T4-which our CI tests use, so a demo colab showing the tests passing on a supported GPU would be great)

image
These are models that reference the fused_attention_op_available() function.
Here are the test results of A100.
image

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pass-lin the test has not been enabled on Pytorch backend. Can you please refer to the above comment on enabling it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pass-lin the test has not been enabled on Pytorch backend. Can you please refer to the above comment on enabling it.

I don't know if you have tested it on a100. At present, the gemma and gemma3 test code flash attn fails. This is true for both jax and torch.
I propose, can you design tests on models like qwen and llama that are more suitable for flash-attn?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pctablet505 - have you tested this? can you please take a look?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about it, I'll have to look into it

Copy link
Contributor Author

@pass-lin pass-lin May 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pctablet505 - have you tested this? can you please take a look?

@pctablet505 @divyashreepathihalli
I can make sure this test is wrong, because it is testing gemma2, and gemm2 does not support flash-attn.

Copy link
Collaborator

@pctablet505 pctablet505 May 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pass-lin
I just verified that Gemma2 and Gemma3 can't support Flash_attention on A100 GPU.
Gemma3 can use flash attention on TPU or GPUs with cuda compute capability >=9.0 that is H series or latter. For example H100

#21333

hasattr(keras.config, "is_flash_attention_enabled")
and keras.config.backend() == "torch"
):
try:
from torch.backends.cuda import SDPAParams as SDPAParams
from torch.backends.cuda import (
can_use_flash_attention as can_use_flash_attention,
)
except ImportError:
logging.warning(
"Flash attention is not supported in your current PyTorch "
"version. Please update it by following the official guide: "
"https://pytorch.org/get-started/locally/"
)
return False
return True
else:
return False

Expand Down
Loading