diff --git a/keras_hub/src/models/gemma/gemma_attention.py b/keras_hub/src/models/gemma/gemma_attention.py index 4f8c414eb8..f66a4506ce 100644 --- a/keras_hub/src/models/gemma/gemma_attention.py +++ b/keras_hub/src/models/gemma/gemma_attention.py @@ -152,7 +152,7 @@ def _compute_attention( attention_mask = ops.expand_dims(attention_mask, axis=1) attention_mask = ops.cast(attention_mask, dtype="bool") # Only pass soft cap if needed as not all keras versions support. - if self.logit_soft_cap: + if self.logit_soft_cap is not None: kwargs = {"attn_logits_soft_cap": self.logit_soft_cap} else: kwargs = {} diff --git a/keras_hub/src/models/qwen_moe/qwen_moe_attention.py b/keras_hub/src/models/qwen_moe/qwen_moe_attention.py index 8c342ea905..30c4466de0 100644 --- a/keras_hub/src/models/qwen_moe/qwen_moe_attention.py +++ b/keras_hub/src/models/qwen_moe/qwen_moe_attention.py @@ -67,6 +67,7 @@ def __init__( self.rope_scaling_factor = rope_scaling_factor self.use_sliding_window_attention = use_sliding_window_attention self.sliding_window_size = sliding_window_size + self.logit_soft_cap = None def build(self, inputs_shape): # Einsum variables: diff --git a/keras_hub/src/utils/keras_utils.py b/keras_hub/src/utils/keras_utils.py index 21607ffccb..6b5a7ad55c 100644 --- a/keras_hub/src/utils/keras_utils.py +++ b/keras_hub/src/utils/keras_utils.py @@ -71,6 +71,23 @@ def fused_attention_op_available(): ) return False return True + elif ( + hasattr(keras.config, "is_flash_attention_enabled") + and keras.config.backend() == "torch" + ): + try: + from torch.backends.cuda import SDPAParams as SDPAParams + from torch.backends.cuda import ( + can_use_flash_attention as can_use_flash_attention, + ) + except ImportError: + logging.warning( + "Flash attention is not supported in your current PyTorch " + "version. Please update it by following the official guide: " + "https://pytorch.org/get-started/locally/" + ) + return False + return True else: return False