From bcc0f2236a89e3b0234ffd322bebd868b9666a5a Mon Sep 17 00:00:00 2001 From: pass_lin <935499957@qq.com> Date: Sat, 17 May 2025 15:27:39 +0800 Subject: [PATCH 01/11] support flash-attn at torch backend --- keras_hub/src/utils/keras_utils.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/keras_hub/src/utils/keras_utils.py b/keras_hub/src/utils/keras_utils.py index 21607ffccb..a800d473ed 100644 --- a/keras_hub/src/utils/keras_utils.py +++ b/keras_hub/src/utils/keras_utils.py @@ -71,6 +71,21 @@ def fused_attention_op_available(): ) return False return True + elif ( + hasattr(keras.config, "is_flash_attention_enabled") + and keras.config.backend() == "torch" + ): + try: + from torch.backends.cuda import SDPAParams + from torch.backends.cuda import can_use_flash_attention + except ImportError: + logging.warning( + "Flash attention is not supported in your current PyTorch " + "version. Please update it by following the official guide: " + "https://pytorch.org/get-started/locally/" + ) + return False + return True else: return False From faf8ffbdf1e00397bb808c4cc919c6ee1378f83f Mon Sep 17 00:00:00 2001 From: pass_lin <935499957@qq.com> Date: Sat, 17 May 2025 16:42:10 +0800 Subject: [PATCH 02/11] fix --- keras_hub/src/models/gemma/gemma_attention.py | 3 +++ keras_hub/src/models/mixtral/mixtral_attention.py | 3 ++- keras_hub/src/models/qwen_moe/qwen_moe_attention.py | 1 + keras_hub/src/utils/keras_utils.py | 6 ++++-- 4 files changed, 10 insertions(+), 3 deletions(-) diff --git a/keras_hub/src/models/gemma/gemma_attention.py b/keras_hub/src/models/gemma/gemma_attention.py index 4f8c414eb8..78c2beb767 100644 --- a/keras_hub/src/models/gemma/gemma_attention.py +++ b/keras_hub/src/models/gemma/gemma_attention.py @@ -109,14 +109,17 @@ def _apply_rope(self, x, start_index): return x def _use_fused_attention_op(self): + return True if not fused_attention_op_available(): return False if self.dropout > 0.0: return False + if running_on_gpu(): # GPU never supports softcap in the fused op. if self.logit_soft_cap is not None: return False + return gpu_supports_fused_attention_op() elif running_on_tpu(): # TPU supports softcap with on keras >= 3.10. diff --git a/keras_hub/src/models/mixtral/mixtral_attention.py b/keras_hub/src/models/mixtral/mixtral_attention.py index 080be18047..c034696931 100644 --- a/keras_hub/src/models/mixtral/mixtral_attention.py +++ b/keras_hub/src/models/mixtral/mixtral_attention.py @@ -40,6 +40,7 @@ def __init__( ) self._rope_scaling_factor = rope_scaling_factor + self.logit_soft_cap = None def build(self, inputs_shape): # Einsum variables: @@ -195,7 +196,7 @@ def _masked_softmax(self, attention_scores, attention_mask=None): def _use_fused_attention_op(self): if not fused_attention_op_available(): return False - if self.dropout > 0.0: + if self._dropout > 0.0: return False if running_on_gpu(): # GPU never supports softcap in the fused op. diff --git a/keras_hub/src/models/qwen_moe/qwen_moe_attention.py b/keras_hub/src/models/qwen_moe/qwen_moe_attention.py index 1f270d032d..f55c00cfad 100644 --- a/keras_hub/src/models/qwen_moe/qwen_moe_attention.py +++ b/keras_hub/src/models/qwen_moe/qwen_moe_attention.py @@ -67,6 +67,7 @@ def __init__( self.rope_scaling_factor = rope_scaling_factor self.use_sliding_window_attention = use_sliding_window_attention self.sliding_window_size = sliding_window_size + self.logit_soft_cap = None def build(self, inputs_shape): # Einsum variables: diff --git a/keras_hub/src/utils/keras_utils.py b/keras_hub/src/utils/keras_utils.py index a800d473ed..6b5a7ad55c 100644 --- a/keras_hub/src/utils/keras_utils.py +++ b/keras_hub/src/utils/keras_utils.py @@ -76,8 +76,10 @@ def fused_attention_op_available(): and keras.config.backend() == "torch" ): try: - from torch.backends.cuda import SDPAParams - from torch.backends.cuda import can_use_flash_attention + from torch.backends.cuda import SDPAParams as SDPAParams + from torch.backends.cuda import ( + can_use_flash_attention as can_use_flash_attention, + ) except ImportError: logging.warning( "Flash attention is not supported in your current PyTorch " From 6bba5aec56f3e699a152fa7b5d04b88a59e1978a Mon Sep 17 00:00:00 2001 From: pass_lin <935499957@qq.com> Date: Sat, 17 May 2025 16:58:53 +0800 Subject: [PATCH 03/11] fix --- keras_hub/src/models/gemma/gemma_attention.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/keras_hub/src/models/gemma/gemma_attention.py b/keras_hub/src/models/gemma/gemma_attention.py index 78c2beb767..4f8c414eb8 100644 --- a/keras_hub/src/models/gemma/gemma_attention.py +++ b/keras_hub/src/models/gemma/gemma_attention.py @@ -109,17 +109,14 @@ def _apply_rope(self, x, start_index): return x def _use_fused_attention_op(self): - return True if not fused_attention_op_available(): return False if self.dropout > 0.0: return False - if running_on_gpu(): # GPU never supports softcap in the fused op. if self.logit_soft_cap is not None: return False - return gpu_supports_fused_attention_op() elif running_on_tpu(): # TPU supports softcap with on keras >= 3.10. From 0f960b8af68158e1cf1a906ec5ca4bde66fbbd23 Mon Sep 17 00:00:00 2001 From: pass_lin <935499957@qq.com> Date: Sat, 17 May 2025 17:15:16 +0800 Subject: [PATCH 04/11] fix --- keras_hub/src/models/gemma/gemma_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_hub/src/models/gemma/gemma_attention.py b/keras_hub/src/models/gemma/gemma_attention.py index 4f8c414eb8..f66a4506ce 100644 --- a/keras_hub/src/models/gemma/gemma_attention.py +++ b/keras_hub/src/models/gemma/gemma_attention.py @@ -152,7 +152,7 @@ def _compute_attention( attention_mask = ops.expand_dims(attention_mask, axis=1) attention_mask = ops.cast(attention_mask, dtype="bool") # Only pass soft cap if needed as not all keras versions support. - if self.logit_soft_cap: + if self.logit_soft_cap is not None: kwargs = {"attn_logits_soft_cap": self.logit_soft_cap} else: kwargs = {} From b4dcc7fcf72d471b56a033d5b315827e722a94e9 Mon Sep 17 00:00:00 2001 From: pass_lin <935499957@qq.com> Date: Tue, 20 May 2025 21:59:15 +0800 Subject: [PATCH 05/11] fix conflit --- keras_hub/src/models/mixtral/mixtral_attention.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/keras_hub/src/models/mixtral/mixtral_attention.py b/keras_hub/src/models/mixtral/mixtral_attention.py index c034696931..4419726a20 100644 --- a/keras_hub/src/models/mixtral/mixtral_attention.py +++ b/keras_hub/src/models/mixtral/mixtral_attention.py @@ -39,7 +39,7 @@ def __init__( clone_initializer(kernel_initializer) ) - self._rope_scaling_factor = rope_scaling_factor + self.rope_scaling_factor = rope_scaling_factor self.logit_soft_cap = None def build(self, inputs_shape): @@ -114,7 +114,7 @@ def build(self, inputs_shape): self.rotary_embedding_layer = RotaryEmbedding( max_wavelength=self._rope_max_wavelength, - scaling_factor=self._rope_scaling_factor, + scaling_factor=self.rope_scaling_factor, dtype=self.dtype_policy, ) @@ -253,7 +253,7 @@ def get_config(self): "num_query_heads": self._num_query_heads, "num_key_value_heads": self._num_key_value_heads, "rope_max_wavelength": self._rope_max_wavelength, - "rope_scaling_factor": self._rope_scaling_factor, + "rope_scaling_factor": self.rope_scaling_factor, "kernel_initializer": keras.initializers.serialize( self._kernel_initializer ), From 72f42605a09f4f85ceba0581c270366852ba97a7 Mon Sep 17 00:00:00 2001 From: pass_lin <935499957@qq.com> Date: Tue, 20 May 2025 22:01:32 +0800 Subject: [PATCH 06/11] fix conflit --- keras_hub/src/models/mixtral/mixtral_attention.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/keras_hub/src/models/mixtral/mixtral_attention.py b/keras_hub/src/models/mixtral/mixtral_attention.py index 4419726a20..d13c35e376 100644 --- a/keras_hub/src/models/mixtral/mixtral_attention.py +++ b/keras_hub/src/models/mixtral/mixtral_attention.py @@ -38,9 +38,8 @@ def __init__( self._kernel_initializer = keras.initializers.get( clone_initializer(kernel_initializer) ) - - self.rope_scaling_factor = rope_scaling_factor self.logit_soft_cap = None + self.rope_scaling_factor = rope_scaling_factor def build(self, inputs_shape): # Einsum variables: From 6ce366dff70c5490c81e92ad73dd92ab00661372 Mon Sep 17 00:00:00 2001 From: pass_lin <935499957@qq.com> Date: Tue, 20 May 2025 22:05:53 +0800 Subject: [PATCH 07/11] fix conflit --- keras_hub/src/models/mixtral/mixtral_attention.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/keras_hub/src/models/mixtral/mixtral_attention.py b/keras_hub/src/models/mixtral/mixtral_attention.py index d13c35e376..edad120b7d 100644 --- a/keras_hub/src/models/mixtral/mixtral_attention.py +++ b/keras_hub/src/models/mixtral/mixtral_attention.py @@ -198,9 +198,6 @@ def _use_fused_attention_op(self): if self._dropout > 0.0: return False if running_on_gpu(): - # GPU never supports softcap in the fused op. - if self.logit_soft_cap is not None: - return False return gpu_supports_fused_attention_op() elif running_on_tpu(): # TPU supports softcap with on keras >= 3.10. @@ -215,18 +212,12 @@ def _compute_attention(self, query, key, value, attention_mask=None): attention_mask = ops.expand_dims(attention_mask, axis=1) attention_mask = ops.cast(attention_mask, dtype="bool") - if self.logit_soft_cap: - kwargs = {"attn_logits_soft_cap": self.logit_soft_cap} - else: - kwargs = {} - attention_output = ops.dot_product_attention( query, key, value, mask=attention_mask, scale=self._inv_norm_factor, - **kwargs, ) return attention_output From 16c45418193365670240707de567b6a0dd0eef2f Mon Sep 17 00:00:00 2001 From: pass_lin <935499957@qq.com> Date: Tue, 20 May 2025 22:06:33 +0800 Subject: [PATCH 08/11] fix conflit --- keras_hub/src/models/mixtral/mixtral_attention.py | 1 - 1 file changed, 1 deletion(-) diff --git a/keras_hub/src/models/mixtral/mixtral_attention.py b/keras_hub/src/models/mixtral/mixtral_attention.py index edad120b7d..071b87572a 100644 --- a/keras_hub/src/models/mixtral/mixtral_attention.py +++ b/keras_hub/src/models/mixtral/mixtral_attention.py @@ -38,7 +38,6 @@ def __init__( self._kernel_initializer = keras.initializers.get( clone_initializer(kernel_initializer) ) - self.logit_soft_cap = None self.rope_scaling_factor = rope_scaling_factor def build(self, inputs_shape): From 78f2c069bc365eb42cdbfea1571af97b260d0a09 Mon Sep 17 00:00:00 2001 From: pass_lin <935499957@qq.com> Date: Tue, 20 May 2025 22:07:31 +0800 Subject: [PATCH 09/11] fix conflit --- keras_hub/src/models/mixtral/mixtral_attention.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/keras_hub/src/models/mixtral/mixtral_attention.py b/keras_hub/src/models/mixtral/mixtral_attention.py index 071b87572a..81e7a2f5d3 100644 --- a/keras_hub/src/models/mixtral/mixtral_attention.py +++ b/keras_hub/src/models/mixtral/mixtral_attention.py @@ -27,17 +27,18 @@ def __init__( **kwargs, ): super().__init__(**kwargs) - self._num_query_heads = num_query_heads - self._num_key_value_heads = num_key_value_heads - self._sliding_window = sliding_window - self._dropout = dropout + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + self.sliding_window = sliding_window + self.dropout = dropout - self._num_key_value_groups = num_query_heads // num_key_value_heads - self._rope_max_wavelength = rope_max_wavelength + self.num_key_value_groups = num_query_heads // num_key_value_heads + self.rope_max_wavelength = rope_max_wavelength self._kernel_initializer = keras.initializers.get( clone_initializer(kernel_initializer) ) + self.rope_scaling_factor = rope_scaling_factor def build(self, inputs_shape): From 52336ac209681905242c82e2b4a2bb783d1815ba Mon Sep 17 00:00:00 2001 From: pass_lin <935499957@qq.com> Date: Tue, 20 May 2025 22:08:32 +0800 Subject: [PATCH 10/11] fix conflit --- .../src/models/mixtral/mixtral_attention.py | 50 +++++++++---------- 1 file changed, 24 insertions(+), 26 deletions(-) diff --git a/keras_hub/src/models/mixtral/mixtral_attention.py b/keras_hub/src/models/mixtral/mixtral_attention.py index 81e7a2f5d3..07fbcefba7 100644 --- a/keras_hub/src/models/mixtral/mixtral_attention.py +++ b/keras_hub/src/models/mixtral/mixtral_attention.py @@ -51,12 +51,12 @@ def build(self, inputs_shape): # v = num key/value heads # h = head dim self._hidden_dim = inputs_shape[-1] - self._head_dim = self._hidden_dim // self._num_query_heads + self._head_dim = self._hidden_dim // self.num_query_heads self._inv_norm_factor = 1.0 / math.sqrt(self._head_dim) self.query_dense = keras.layers.EinsumDense( equation="bqm,muh->bquh", - output_shape=(None, self._num_query_heads, self._head_dim), + output_shape=(None, self.num_query_heads, self._head_dim), kernel_initializer=self._kernel_initializer, dtype=self.dtype_policy, name="query", @@ -67,7 +67,7 @@ def build(self, inputs_shape): equation="bkm,mvh->bkvh", output_shape=( None, - self._num_key_value_heads, + self.num_key_value_heads, self._head_dim, ), kernel_initializer=self._kernel_initializer, @@ -80,7 +80,7 @@ def build(self, inputs_shape): equation="bkm,mvh->bkvh", output_shape=( None, - self._num_key_value_heads, + self.num_key_value_heads, self._head_dim, ), kernel_initializer=self._kernel_initializer, @@ -89,30 +89,30 @@ def build(self, inputs_shape): ) self.value_dense.build(inputs_shape) - self._softmax = keras.layers.Softmax( + self.softmax = keras.layers.Softmax( axis=-1, dtype="float32", name="attention_softmax", ) - self._dropout_layer = keras.layers.Dropout( - rate=self._dropout, + self.dropout_layer = keras.layers.Dropout( + rate=self.dropout, dtype=self.dtype_policy, ) - self._output_dense = keras.layers.EinsumDense( + self.output_dense = keras.layers.EinsumDense( equation="bquh,uhm->bqm", output_shape=(None, self._hidden_dim), kernel_initializer=self._kernel_initializer, dtype=self.dtype_policy, name="attention_output", ) - self._output_dense.build( - (None, None, self._num_query_heads, self._head_dim) + self.output_dense.build( + (None, None, self.num_query_heads, self._head_dim) ) self.rotary_embedding_layer = RotaryEmbedding( - max_wavelength=self._rope_max_wavelength, + max_wavelength=self.rope_max_wavelength, scaling_factor=self.rope_scaling_factor, dtype=self.dtype_policy, ) @@ -168,18 +168,18 @@ def _compute_key_value(x): # [batch_shape, seq_len, num_key_value_heads, head_dim] # -> [batch_shape, seq_len, num_heads, head_dim] - key = ops.repeat(key, repeats=self._num_key_value_groups, axis=2) - value = ops.repeat(value, repeats=self._num_key_value_groups, axis=2) + key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2) + value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2) attention_output = self._compute_attention( query, key, value, attention_mask ) - attention_output = self._dropout_layer( + attention_output = self.dropout_layer( attention_output, training=training ) - attention_output = self._output_dense(attention_output) + attention_output = self.output_dense(attention_output) if cache is not None: return attention_output, cache @@ -187,15 +187,13 @@ def _compute_key_value(x): def _masked_softmax(self, attention_scores, attention_mask=None): if attention_mask is not None: - return self._softmax( - attention_scores, attention_mask[:, None, :, :] - ) - return self._softmax(attention_scores) + return self.softmax(attention_scores, attention_mask[:, None, :, :]) + return self.softmax(attention_scores) def _use_fused_attention_op(self): if not fused_attention_op_available(): return False - if self._dropout > 0.0: + if self.dropout > 0.0: return False if running_on_gpu(): return gpu_supports_fused_attention_op() @@ -240,15 +238,15 @@ def get_config(self): config = super().get_config() config.update( { - "num_query_heads": self._num_query_heads, - "num_key_value_heads": self._num_key_value_heads, - "rope_max_wavelength": self._rope_max_wavelength, + "num_query_heads": self.num_query_heads, + "num_key_value_heads": self.num_key_value_heads, + "rope_max_wavelength": self.rope_max_wavelength, "rope_scaling_factor": self.rope_scaling_factor, "kernel_initializer": keras.initializers.serialize( self._kernel_initializer ), - "sliding_window": self._sliding_window, - "dropout": self._dropout, + "sliding_window": self.sliding_window, + "dropout": self.dropout, } ) - return config + return config \ No newline at end of file From edbee6fc6f4a143a6c20ff47c2e769a56e541917 Mon Sep 17 00:00:00 2001 From: pass_lin <935499957@qq.com> Date: Tue, 20 May 2025 22:12:19 +0800 Subject: [PATCH 11/11] format --- keras_hub/src/models/mixtral/mixtral_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_hub/src/models/mixtral/mixtral_attention.py b/keras_hub/src/models/mixtral/mixtral_attention.py index 07fbcefba7..0cae75a21c 100644 --- a/keras_hub/src/models/mixtral/mixtral_attention.py +++ b/keras_hub/src/models/mixtral/mixtral_attention.py @@ -249,4 +249,4 @@ def get_config(self): "dropout": self.dropout, } ) - return config \ No newline at end of file + return config