From e52837875ca03b2e35d828494b36a6cd754ddcc9 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra Date: Mon, 19 May 2025 21:46:36 +0530 Subject: [PATCH 1/7] qwen3 moe init --- .../qwen3_moe/qwen3_causal_lm_preprocessor.py | 17 + .../models/qwen3_moe/qwen3_moe_attention.py | 376 +++++++++++ .../models/qwen3_moe/qwen3_moe_backbone.py | 364 +++++++++++ .../src/models/qwen3_moe/qwen3_moe_decoder.py | 598 ++++++++++++++++++ .../models/qwen3_moe/qwen3_moe_layernorm.py | 32 + .../models/qwen3_moe/qwen3_moe_tokenizer.py | 46 ++ .../utils/transformers/convert_qwen3_moe.py | 220 +++++++ .../convert_qwen3_moe_checkpoints.py | 162 +++++ 8 files changed, 1815 insertions(+) create mode 100644 keras_hub/src/models/qwen3_moe/qwen3_causal_lm_preprocessor.py create mode 100644 keras_hub/src/models/qwen3_moe/qwen3_moe_attention.py create mode 100644 keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py create mode 100644 keras_hub/src/models/qwen3_moe/qwen3_moe_decoder.py create mode 100644 keras_hub/src/models/qwen3_moe/qwen3_moe_layernorm.py create mode 100644 keras_hub/src/models/qwen3_moe/qwen3_moe_tokenizer.py create mode 100644 keras_hub/src/utils/transformers/convert_qwen3_moe.py create mode 100644 tools/checkpoint_conversion/convert_qwen3_moe_checkpoints.py diff --git a/keras_hub/src/models/qwen3_moe/qwen3_causal_lm_preprocessor.py b/keras_hub/src/models/qwen3_moe/qwen3_causal_lm_preprocessor.py new file mode 100644 index 0000000000..34d5bf87ab --- /dev/null +++ b/keras_hub/src/models/qwen3_moe/qwen3_causal_lm_preprocessor.py @@ -0,0 +1,17 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor +from keras_hub.src.models.qwen_moe.qwen_moe_backbone import Qwen3MoeBackbone +from keras_hub.src.models.qwen_moe.qwen_moe_tokenizer import Qwen3MoeTokenizer + + +@keras_hub_export( + [ + "keras_hub.models.Qwen3MoeCausalLMPreprocessor", + ] +) +class Qwen3MoeCausalLMPreprocessor(CausalLMPreprocessor): + backbone_cls = Qwen3MoeBackbone + tokenizer_cls = Qwen3MoeTokenizer + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) diff --git a/keras_hub/src/models/qwen3_moe/qwen3_moe_attention.py b/keras_hub/src/models/qwen3_moe/qwen3_moe_attention.py new file mode 100644 index 0000000000..5cf2b1a9f0 --- /dev/null +++ b/keras_hub/src/models/qwen3_moe/qwen3_moe_attention.py @@ -0,0 +1,376 @@ +import math + +import keras +from keras import ops + +from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding +from keras_hub.src.models.qwen3_moe.qwen3_moe_layernorm import Qwen3MoeLayerNorm +from keras_hub.src.utils.keras_utils import clone_initializer +from keras_hub.src.utils.keras_utils import fused_attention_op_available + + +class Qwen3MoeAttention(keras.layers.Layer): + """A multi-head attention layer for Qwen3Moe models + This attention implementation supports grouped-query attention (GQA) where + the number of key-value heads can be less than the number of query heads. + Args: + num_query_heads: Number of query heads. + num_key_value_heads: Number of key/value heads (for GQA). + rope_max_wavelength: Maximum wavelength for RoPE (Rotary Position + Embedding). + rope_scaling_factor: Scaling factor for RoPE, used for extending + context length. + kernel_initializer: Initializer for the kernel weights. + dropout: Dropout rate for attention weights. + use_sliding_window_attention: Whether to use sliding window + attention. + sliding_window_size: Size of the sliding window for attention. + **kwargs: Additional keyword arguments to pass to the Layer. + """ + + def __init__( + self, + num_query_heads, + num_key_value_heads, + layer_index, + head_dim, + rope_max_wavelength=10000, + rope_scaling_factor=1, + kernel_initializer="glorot_uniform", + dropout=0, + use_sliding_window_attention=False, + layer_norm_epsilon=1e-5, + sliding_window_size=4096, + max_window_layers=28, + **kwargs, + ): + super().__init__( + **kwargs, + ) + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.dropout = dropout + + self.layer_norm_epsilon = layer_norm_epsilon + + self.num_key_value_groups = num_query_heads // num_key_value_heads + self.rope_max_wavelength = rope_max_wavelength + + self.kernel_initializer = keras.initializers.get( + clone_initializer(kernel_initializer) + ) + self.layer_index = layer_index + + self.rope_scaling_factor = rope_scaling_factor + self.use_sliding_window_attention = use_sliding_window_attention + + if ( + not self.use_sliding_window_attention + and sliding_window_size + and self.layer_index >= max_window_layers + ): + self.sliding_window_size = None + else: + self.sliding_window_size = sliding_window_size + + def build(self, inputs_shape): + # Einsum variables: + # b = batch size + # q = query length + # k = key/value length + # m = model dim + # u = num query heads + # v = num key/value heads + # h = head dim + hidden_dim = inputs_shape[-1] + if not self.head_dim: + self.head_dim = hidden_dim // self.num_query_heads + + self._inv_norm_factor = 1.0 / math.sqrt(self.head_dim) + self._query_dense = keras.layers.EinsumDense( + equation="bqm,muh->bquh", + output_shape=(None, self.num_query_heads, self.head_dim), + kernel_initializer=self.kernel_initializer, + dtype=self.dtype_policy, + name="query", + ) + self._query_dense.build(inputs_shape) + + self._query_dense_layer_norm = Qwen3MoeLayerNorm( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + head_dim=self.head_dim, + name="query_dense_layernorm", + ) + self._query_dense_layer_norm.build(inputs_shape) + + self._key_dense = keras.layers.EinsumDense( + equation="bkm,mvh->bkvh", + output_shape=( + None, + self.num_key_value_heads, + self.head_dim, + ), + kernel_initializer=self.kernel_initializer, + dtype=self.dtype_policy, + name="key", + ) + self._key_dense.build(inputs_shape) + + self._key_dense_layer_norm = Qwen3MoeLayerNorm( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + head_dim=self.head_dim, + name="key_dense_layernorm", + ) + self._key_dense_layer_norm.build(inputs_shape) + + self._value_dense = keras.layers.EinsumDense( + equation="bkm,mvh->bkvh", + output_shape=( + None, + self.num_key_value_heads, + self.head_dim, + ), + kernel_initializer=self.kernel_initializer, + dtype=self.dtype_policy, + name="value", + ) + self._value_dense.build(inputs_shape) + + self._softmax = keras.layers.Softmax( + axis=-1, + dtype="float32", + name="attention_softmax", + ) + + self._dropout_layer = keras.layers.Dropout( + rate=self.dropout, + dtype=self.dtype_policy, + ) + + self._output_dense = keras.layers.EinsumDense( + equation="bquh,uhm->bqm", + output_shape=(None, hidden_dim), + kernel_initializer=self.kernel_initializer, + dtype=self.dtype_policy, + name="attention_output", + ) + self._output_dense.build( + (None, None, self.num_query_heads, self.head_dim) + ) + + self.rotary_embedding_layer = RotaryEmbedding( + max_wavelength=self.rope_max_wavelength, + scaling_factor=self.rope_scaling_factor, + dtype=self.dtype_policy, + ) + + self._dot_product_equation = "bquh,bkuh->buqk" + self._combine_equation = "buqk,bkuh->bquh" + + self.built = True + + def call( + self, + hidden_states, + attention_mask=None, + cache=None, + cache_update_index=None, + training=None, + ): + """Applies attention mechanism to the input hidden states. + Args: + hidden_states: Input tensor of shape [batch_size, seq_length, + hidden_size]. + attention_mask: Mask tensor of shape [batch_size, seq_length, + seq_length]. + cache: Optional cached key and value tensors. + cache_update_index: Index at which to update the cache. + training: Boolean indicating whether in training mode. + Returns: + attention_output: Output tensor after applying attention. + cache: Updated cache tensors (if cache is provided). + """ + start_index = ( + cache_update_index if cache_update_index is not None else 0 + ) + + query = self._query_dense(hidden_states) + query = self._query_dense_layer_norm(query) + + # Compute RoPE for queries + query = self.rotary_embedding_layer(query, start_index=start_index) + + def _compute_key_value(x): + key = self._key_dense(x) + key = self._key_dense_layer_norm(key) + key = self.rotary_embedding_layer(key, start_index=start_index) + + value = self._value_dense(x) + + return key, value + + if cache is not None: + key_cache = cache[:, 0, ...] + value_cache = cache[:, 1, ...] + if cache_update_index is None: + key = key_cache + value = value_cache + else: + key_update, value_update = _compute_key_value(hidden_states) + start = [0, cache_update_index, 0, 0] + key = ops.slice_update(key_cache, start, key_update) + value = ops.slice_update(value_cache, start, value_update) + cache = ops.stack((key, value), axis=1) + else: + if cache_update_index is not None: + raise ValueError( + "`cache_update_index` should not be set if `cache` is " + f"`None`. Received: cache={cache}, " + f"cache_update_index={cache_update_index}" + ) + key, value = _compute_key_value(hidden_states) + + # [batch_shape, seq_len, num_key_value_heads, head_dim] + # -> [batch_shape, seq_len, num_heads, head_dim] + key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2) + value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2) + + attention_output = self._compute_attention( + query, + key, + value, + attention_mask, + cache_update_index=cache_update_index, + ) + + attention_output = self._dropout_layer( + attention_output, training=training + ) + + attention_output = self._output_dense(attention_output) + + if cache is not None: + return attention_output, cache + return attention_output + + def _masked_softmax(self, attention_scores, attention_mask=None): + """Applies softmax with optional masking. + Args: + attention_scores: Attention score tensor. + attention_mask: Optional mask tensor. + Returns: + Masked softmax attention weights. + """ + if attention_mask is not None: + return self._softmax( + attention_scores, attention_mask[:, None, :, :] + ) + return self._softmax(attention_scores) + + def _compute_attention( + self, query, key, value, attention_mask=None, cache_update_index=None + ): + """Computes attention using query, key, and value tensors. + Uses Flash Attention when available for better performance. + Args: + query: Query tensor. + key: Key tensor. + value: Value tensor. + attention_mask: Optional mask tensor. + cache_update_index: Index for sliding window computation. + Returns: + attention_output: Output tensor after applying attention. + """ + if fused_attention_op_available(): + # Use `dot_product_attention` with Flash Attention support if + # available. + if attention_mask is not None: + attention_mask = ops.expand_dims(attention_mask, axis=1) + attention_mask = ops.cast(attention_mask, dtype="bool") + attention_output = ops.dot_product_attention( + query, + key, + value, + mask=attention_mask, + scale=self._inv_norm_factor, + ) + return attention_output + + attention_scores = ops.einsum(self._dot_product_equation, query, key) + + attention_scores = ops.multiply( + attention_scores, + ops.cast(self._inv_norm_factor, self.compute_dtype), + ) + if self.use_sliding_window_attention: + attention_mask = self._mask_sliding_window( + attention_mask, + cache_update_index=cache_update_index + if cache_update_index + else 0, + ) + attention_scores = self._masked_softmax( + attention_scores, attention_mask + ) + attention_scores = ops.cast(attention_scores, self.compute_dtype) + attention_output = ops.einsum( + self._combine_equation, attention_scores, value + ) + + return attention_output + + def _mask_sliding_window( + self, + attention_mask, + cache_update_index=0, + ): + """Creates and combines a sliding window mask with the attention mask. + Args: + attention_mask: Original attention mask. + cache_update_index: Starting index for the sliding window. + Returns: + Combined attention mask with sliding window constraints. + """ + _, query_len, key_len = ops.shape(attention_mask) + # Compute the sliding window for square attention. + all_ones = ops.ones((key_len, key_len), "bool") + if keras.config.backend() == "tensorflow": + # TODO: trui/tril has issues with dynamic shape on the tensorflow + # backend. We should fix, but use `band_part` for now. + import tensorflow as tf + + band_size = ops.minimum(key_len, self.sliding_window_size - 1) + band_size = ops.cast(band_size, "int32") + sliding_mask = tf.linalg.band_part(all_ones, band_size, band_size) + else: + sliding_mask = ops.triu( + all_ones, -1 * self.sliding_window_size + 1 + ) * ops.tril(all_ones, self.sliding_window_size - 1) + # Slice the window for short queries during generation. + start = (cache_update_index, 0) + sliding_mask = ops.slice(sliding_mask, start, (query_len, key_len)) + sliding_mask = ops.expand_dims(sliding_mask, 0) + return ops.logical_and(attention_mask, ops.cast(sliding_mask, "bool")) + + def get_config(self): + config = super().get_config() + config.update( + { + "num_query_heads": self.num_query_heads, + "num_key_value_heads": self.num_key_value_heads, + "rope_max_wavelength": self.rope_max_wavelength, + "rope_scaling_factor": self.rope_scaling_factor, + "kernel_initializer": keras.initializers.serialize( + self.kernel_initializer + ), + "dropout": self.dropout, + "use_sliding_window_attention": ( + self.use_sliding_window_attention + ), + "sliding_window_size": self.sliding_window_size, + } + ) + return config diff --git a/keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py b/keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py new file mode 100644 index 0000000000..15ad9c9af2 --- /dev/null +++ b/keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py @@ -0,0 +1,364 @@ +import keras +from keras import ops + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.modeling.reversible_embedding import ( + ReversibleEmbedding, +) +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.qwen.qwen_layernorm import QwenLayerNorm +from keras_hub.src.models.qwen_moe.qwen_moe_decoder import ( + Qwen3MoeTransformerDecoder, +) + + +def _qwen_moe_kernel_initializer(stddev=0.02): + return keras.initializers.RandomNormal(stddev=stddev) + + +@keras_hub_export( + "keras_hub.models.Qwen3MoeBackbone", +) +class Qwen3MoeBackbone(Backbone): + """Qwen MoE core network with hyperparameters. + + This backbone implements the base Transformer network for the Qwen MoE + model. It includes embedding lookups and transformer layers with a Mixture + of Experts (MoE) architecture, where each layer uses a sparse set of experts + for efficient computation. This backbone outputs the final hidden states for + each token, not generative predictions over the vocabulary space. For higher + -level object for text generation, see `keras_hub.models.Qwen3MoeCausalLM`. + + The default constructor gives a fully customizable, randomly initialized + Qwen MoE model with any number of layers, heads, and embedding dimensions. + To load preset architectures and weights, use the `from_preset` constructor. + + Args: + vocabulary_size: int. The size of the token vocabulary. + num_layers: int. The number of transformer layers. + num_query_heads: int. The number of heads for the query projections in + the attention layer. + num_key_value_heads: int. The number of heads for the key and value + projections in the attention layer. + hidden_dim: int. The size of the transformer hidden state at the end of + each transformer layer. + intermediate_dim: int. The output dimension of the first Dense layer in + the feedforward network for each transformer. + moe_intermediate_dim: int. The intermediate dimension for each expert + in the MoE feedforward network. + num_experts: int. The number of experts in each MoE layer. + top_k: int. The number of top experts to select for each token in the + MoE layer. + head_dim: int. The size of each attention head. + layer_norm_epsilon: float. The epsilon value used for every layer norm + in the transformer model. + dropout: float. Dropout probability for the transformer encoder. + use_sliding_window_attention: bool. Whether to use sliding local window + attention. Defaults to False. + sliding_window_size: int. Size of the sliding local window. Defaults to + 4096. + max_sequence_length: int. The maximum sequence length supported by the + model. Defaults to 4096. + dtype: str or `keras.mixed_precision.DTypePolicy`. The dtype to use for + the model's computations and weights. Note that some computations, + such as softmax and layer normalization, will always be done at + float32 precision regardless of dtype. + + Example: + ```python + input_data = { + "token_ids": np.ones(shape=(1, 12), dtype="int32"), + "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), + } + + # Pretrained Qwen MoE decoder. + model = keras_hub.models.Qwen3MoeBackbone.from_preset("qwen_moe_a2_7b") + model(input_data) + + # Randomly initialized Qwen MoE decoder with custom config. + model = keras_hub.models.Qwen3MoeBackbone( + vocabulary_size=151936, + num_layers=28, + num_query_heads=16, + num_key_value_heads=8, + hidden_dim=2048, + intermediate_dim=4096, + moe_intermediate_dim=128, + num_experts=60, + top_k=4, + head_dim=128, + max_sequence_length=4096, + ) + model(input_data) + """ + + def __init__( + self, + vocabulary_size, + num_layers, + num_query_heads, + num_key_value_heads, + hidden_dim, + intermediate_dim, + moe_intermediate_dim, + num_experts, + top_k=4, + norm_top_k_prob=False, + decoder_sparse_step=1, + rope_max_wavelength=10000, + rope_scaling_factor=1.0, + layer_norm_epsilon=1e-6, + dropout=0, + dtype=None, + tie_word_embeddings=False, + use_sliding_window_attention=False, + sliding_window_size=32768, + output_router_logits=False, + router_aux_loss_coefficient=0.001, + mlp_only_layers=[], + training=None, + **kwargs, + ): + # === Layers === + self.token_embedding = ReversibleEmbedding( + input_dim=vocabulary_size, + output_dim=hidden_dim, + tie_weights=tie_word_embeddings, + embeddings_initializer=_qwen_moe_kernel_initializer(stddev=0.01), + dtype=dtype, + name="token_embedding", + ) + self.transformer_layers = [] + for i in range(num_layers): + layer = Qwen3MoeTransformerDecoder( + intermediate_dim=intermediate_dim, + num_query_heads=num_query_heads, + num_key_value_heads=num_key_value_heads, + moe_intermediate_dim=moe_intermediate_dim, + num_experts=num_experts, + top_k=top_k, + norm_top_k_prob=norm_top_k_prob, + decoder_sparse_step=decoder_sparse_step, + rope_max_wavelength=rope_max_wavelength, + rope_scaling_factor=rope_scaling_factor, + layer_norm_epsilon=layer_norm_epsilon, + activation=ops.silu, + kernel_initializer=_qwen_moe_kernel_initializer(stddev=0.02), + dropout=dropout, + dtype=dtype, + use_sliding_window_attention=use_sliding_window_attention, + sliding_window_size=sliding_window_size, + output_router_logits=output_router_logits, + router_aux_loss_coefficient=router_aux_loss_coefficient, + mlp_only_layers=mlp_only_layers, + name=f"transformer_layer_{i}", + ) + self.transformer_layers.append(layer) + self.layer_norm = QwenLayerNorm( + epsilon=layer_norm_epsilon, + dtype=dtype, + name="sequence_output_layernorm", + ) + + # === Functional Model === + token_id_input = keras.Input( + shape=(None,), dtype="int32", name="token_ids" + ) + padding_mask_input = keras.Input( + shape=(None,), dtype="int32", name="padding_mask" + ) + x = self.token_embedding(token_id_input) + for transformer_layer in self.transformer_layers: + x = transformer_layer( + x, decoder_padding_mask=padding_mask_input, training=training + ) + sequence_output = self.layer_norm(x) + super().__init__( + inputs={ + "token_ids": token_id_input, + "padding_mask": padding_mask_input, + }, + outputs=sequence_output, + dtype=dtype, + **kwargs, + ) + + # === Config === + self.vocabulary_size = vocabulary_size + self.num_layers = num_layers + self.num_query_heads = num_query_heads + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + self.moe_intermediate_dim = moe_intermediate_dim + self.rope_max_wavelength = rope_max_wavelength + self.num_key_value_heads = num_key_value_heads + self.rope_scaling_factor = rope_scaling_factor + self.layer_norm_epsilon = layer_norm_epsilon + self.dropout = dropout + self.tie_word_embeddings = tie_word_embeddings + self.use_sliding_window_attention = use_sliding_window_attention + self.sliding_window_size = sliding_window_size + self.num_experts = num_experts + self.top_k = top_k + self.norm_top_k_prob = norm_top_k_prob + self.decoder_sparse_step = decoder_sparse_step + self.mlp_only_layers = mlp_only_layers + self.router_aux_loss_coefficient = router_aux_loss_coefficient + self.output_router_logits = output_router_logits + + def get_config(self): + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "num_layers": self.num_layers, + "num_query_heads": self.num_query_heads, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "moe_intermediate_dim": self.moe_intermediate_dim, + "rope_max_wavelength": self.rope_max_wavelength, + "num_key_value_heads": self.num_key_value_heads, + "rope_scaling_factor": self.rope_scaling_factor, + "layer_norm_epsilon": self.layer_norm_epsilon, + "dropout": self.dropout, + "tie_word_embeddings": self.tie_word_embeddings, + "use_sliding_window_attention": ( + self.use_sliding_window_attention + ), + "sliding_window_size": self.sliding_window_size, + "num_experts": self.num_experts, + "top_k": self.top_k, + "norm_top_k_prob": self.norm_top_k_prob, + "decoder_sparse_step": self.decoder_sparse_step, + "mlp_only_layers": self.mlp_only_layers, + "output_router_logits": self.output_router_logits, + } + ) + return config + + @staticmethod + def get_layout_map( + device_mesh, + model_parallel_dim_name="model", + data_parallel_dim_name="batch", + ): + """Get a `keras.distribution.LayoutMap` for model parallel distribution. + + The returned `LayoutMap` contains the sharding spec for the Qwen3Moe + backbone weights, so that you can use it to distribute weights across + the accelerators. + + Example: + ``` + # Feel free to change the mesh shape to balance data and model + # parallelism + mesh = keras.distribution.DeviceMesh( + shape=(1, 8), + axis_names=('batch', 'model'), + devices=keras.distribution.list_devices(), + ) + layout_map = Qwen3MoeBackbone.get_layout_map( + mesh, + model_parallel_dim_name="model", + ) + + distribution = keras.distribution.ModelParallel( + layout_map=layout_map, + batch_dim_name='batch', + ) + + with distribution.scope(): + qwen_moe_model = keras_hub.models.Qwen3MoeBackbone.from_preset() + ``` + + To see how the layout map was applied, load the model then run + (for one decoder block): + ``` + embedding_layer = qwen_moe_model.backbone.get_layer("token_embedding") + decoder_block_1 = qwen_moe_model.backbone.get_layer( + 'transformer_layer_0' + ) + for variable in embedding_layer.weights + decoder_block_1.weights: + print( + f'{variable.path:<58} {str(variable.shape):<16} ' + f'{str(variable.value.sharding.spec)}' + ) + ``` + + Args: + device_mesh: The `keras.distribution.DeviceMesh` instance for + distribution. + model_parallel_dim_name: The axis name of the device mesh, where + the weights should be partition on. + data_parallel_dim_name: The axis name of the device mesh, where + the data should be partition on. + Return: + `keras.distribution.LayoutMap` that contains the sharding spec + for all the model weights. + """ + # The weight path and shape of the Llama backbone is like below + # token_embedding/embeddings (128256, 2048) + # repeat block for decoder + # transformer_layer_0/self_attention/query/kernel (2048, 32, 64) + # transformer_layer_0/self_attention/key/kernel (2048, 8, 64) + # transformer_layer_0/self_attention/value/kernel (2048, 8, 64) + # transformer_layer_0/self_attention/attention_output/kernel + # (32, 64, 2048) + # transformer_layer_0/self_attention_layernorm/scale (2048,) + # transformer_layer_0/feedforward_intermediate_dense/kernel + # (2048, 8192) + # transformer_layer_0/feedforward_gate_dense/kernel (2048, 8192) + # transformer_layer_0/feedforward_output_dense/kerne (8192, 2048) + # transformer_layer_0/feedforward_layernorm/scale (2048,) + + if not isinstance(device_mesh, keras.distribution.DeviceMesh): + raise ValueError( + "Invalid device_mesh type. Expected " + f"`keras.distribution.Device`, got {type(device_mesh)}" + ) + if model_parallel_dim_name not in device_mesh.axis_names: + raise ValueError( + f"{model_parallel_dim_name} is not found in the " + f"device_mesh.axis_names. {device_mesh.axis_name=}" + ) + if data_parallel_dim_name not in device_mesh.axis_names: + raise ValueError( + f"{data_parallel_dim_name} is not found in the " + f"device_mesh.axis_names. {device_mesh.axis_name=}" + ) + # Note that it is possible to further config the mesh to be 3D, eg + # (data, seq, model). We leave it as 2D for now for simplicity. + data_dim = data_parallel_dim_name + model_dim = model_parallel_dim_name + # The sharding config is based on the Gemma team training config. + # See https://arxiv.org/abs/2403.08295 + layout_map = keras.distribution.LayoutMap(device_mesh) + layout_map["token_embedding/embeddings"] = (model_dim, data_dim) + layout_map[ + "transformer_layer.*self_attention.*(query|key|value).kernel" + ] = ( + model_dim, + data_dim, + None, + ) + layout_map["transformer_layer.*attention_output.kernel"] = ( + model_dim, + None, + data_dim, + ) + layout_map[ + "transformer_layer.*feedforward_intermediate_dense.kernel" + ] = ( + data_dim, + model_dim, + ) + layout_map["transformer_layer.*feedforward_gate_dense.kernel"] = ( + data_dim, + model_dim, + ) + layout_map["transformer_layer.*feedforward_output_dense.kernel"] = ( + model_dim, + data_dim, + ) + + return layout_map diff --git a/keras_hub/src/models/qwen3_moe/qwen3_moe_decoder.py b/keras_hub/src/models/qwen3_moe/qwen3_moe_decoder.py new file mode 100644 index 0000000000..2758bba531 --- /dev/null +++ b/keras_hub/src/models/qwen3_moe/qwen3_moe_decoder.py @@ -0,0 +1,598 @@ +import keras +from keras import ops + +from keras_hub.src.layers.modeling.transformer_layer_utils import ( + compute_causal_mask, +) +from keras_hub.src.layers.modeling.transformer_layer_utils import ( + merge_padding_and_attention_mask, +) +from keras_hub.src.models.qwen_moe.qwen_moe_attention import Qwen3MoeAttention +from keras_hub.src.models.qwen_moe.qwen_moe_layernorm import Qwen3MoeLayerNorm +from keras_hub.src.utils.keras_utils import clone_initializer + + +def compute_load_balancing_loss( + router_logits, num_experts, top_k, attention_mask=None +): + """ + Compute the load balancing auxiliary loss for a single MoE layer. + + Args: + router_logits: Tensor of shape (batch_size * seq_len, num_experts). + num_experts: Integer, total number of experts. + top_k: Integer, number of experts to select per token. + attention_mask: Tensor of shape (batch_size, seq_len, seq_len), + optional mask for padding. + + Returns: + Scalar tensor representing the auxiliary loss. + """ + # Compute routing probabilities + routing_weights = ops.softmax( + router_logits, axis=-1 + ) # Shape: (batch_size * seq_len, num_experts) + + # Get top-k experts + _, selected_experts = ops.top_k( + routing_weights, k=top_k + ) # Shape: (batch_size * seq_len, top_k) + + # Create one-hot encoding for selected experts + expert_mask = ops.one_hot( + selected_experts, num_experts + ) # Shape: (batch_size * seq_len, top_k, num_experts) + + if attention_mask is not None: + # Convert attention mask to (batch_size, seq_len) + batch_size, seq_len, _ = ops.shape(attention_mask) + flat_mask = ops.any(attention_mask, axis=-1) + flat_mask = ops.reshape( + flat_mask, (-1,) + ) # Shape: (batch_size * seq_len,) + # Expand mask for broadcasting + expert_attention_mask = ops.expand_dims( + flat_mask, axis=-1 + ) # Shape: (batch_size * seq_len, 1) + expert_attention_mask = ops.cast(expert_attention_mask, dtype="float32") + + # Compute masked means + tokens_per_expert = ops.sum( + expert_mask * expert_attention_mask[:, None, :], axis=0 + ) / ops.maximum( + ops.sum(expert_attention_mask[:, None, :], axis=0), 1e-9 + ) # Shape: (top_k, num_experts) + router_prob_per_expert = ops.sum( + routing_weights * expert_attention_mask, axis=0 + ) / ops.maximum( + ops.sum(expert_attention_mask, axis=0), 1e-9 + ) # Shape: (num_experts,) + else: + # Unmasked means + tokens_per_expert = ops.mean( + expert_mask, axis=0 + ) # Shape: (top_k, num_experts) + router_prob_per_expert = ops.mean( + routing_weights, axis=0 + ) # Shape: (num_experts,) + + # Average over top_k dimension if necessary + tokens_per_expert = ops.mean( + tokens_per_expert, axis=0 + ) # Shape: (num_experts,) + + # Compute the loss + overall_loss = ops.sum(tokens_per_expert * router_prob_per_expert) + return overall_loss * num_experts + + +class Qwen3MoeMLP(keras.layers.Layer): + def __init__( + self, + intermediate_dim, + hidden_dim, + activation_fn="silu", + layer_norm_epsilon=1e-5, + kernel_initializer="glorot_uniform", + **kwargs, + ): + super().__init__(**kwargs) + self.intermediate_dim = intermediate_dim + self.hidden_dim = hidden_dim + self.activation_fn = activation_fn + self.kernel_initializer = kernel_initializer + self.layer_norm_epsilon = layer_norm_epsilon + + def build(self, decoder_sequence_shape): + # Feedforward layers. + self._feedforward_intermediate_dense = keras.layers.Dense( + self.intermediate_dim, + kernel_initializer=clone_initializer(self.kernel_initializer), + use_bias=False, + dtype=self.dtype_policy, + name="feedforward_intermediate_dense", + ) + self._feedforward_intermediate_dense.build(decoder_sequence_shape) + + self._feedforward_gate_dense = keras.layers.Dense( + self.intermediate_dim, + kernel_initializer=clone_initializer(self.kernel_initializer), + use_bias=False, + dtype=self.dtype_policy, + name="feedforward_gate_dense", + ) + self._feedforward_gate_dense.build(decoder_sequence_shape) + + self._feedforward_output_dense = keras.layers.Dense( + self.hidden_dim, + kernel_initializer=clone_initializer(self.kernel_initializer), + use_bias=False, + dtype=self.dtype_policy, + name="feedforward_output_dense", + ) + + self._feedforward_output_dense.build( + self._feedforward_gate_dense.compute_output_shape( + decoder_sequence_shape + ) + ) + + self.activation = keras.activations.get(self.activation_fn) + self.built = True + + def call(self, x): + gate_output = self._feedforward_gate_dense(x) + + # Note that we run the activation function in full 32-bit + # precision since this is what `torch.nn.functional.silu` + # does. Internally, `torch.nn.functional.silu` converts the + # inputs to float32, computes SiLU, and converts the outputs + # back to compute dtype. + # CPU Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cpu/Activation.cpp#L1221-L1235 # noqa: E501 + # CUDA Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cuda/ActivationSiluKernel.cu # noqa: E501 + gate_output = ops.cast(gate_output, "float32") + gate_output = self.activation(gate_output) + gate_output = ops.cast(gate_output, self.compute_dtype) + + x = self._feedforward_intermediate_dense(x) + + x = self._feedforward_output_dense(ops.multiply(x, gate_output)) + + return x + + +class Qwen3MoeExperts(keras.layers.Layer): + """Batched Experts Layer""" + + def __init__( + self, + num_experts, + hidden_dim, + intermediate_dim, + activation_fn="silu", + kernel_initializer="glorot_uniform", + **kwargs, + ): + super().__init__(**kwargs) + self.num_experts = num_experts + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + self.activation = keras.activations.get(activation_fn) + self.kernel_initializer = kernel_initializer + + def build(self, _): + self._expert_feedforward_gate_dense = self.add_weight( + shape=( + self.num_experts, + self.hidden_dim, + 2 * self.intermediate_dim, + ), + initializer=self.kernel_initializer, + trainable=True, + dtype=self.variable_dtype, + name="expert_feedforward_gate_dense", + ) + + self._expert_feedforward_output_dense = self.add_weight( + shape=(self.num_experts, self.intermediate_dim, self.hidden_dim), + initializer=self.kernel_initializer, + trainable=True, + dtype=self.variable_dtype, + name="expert_feedforward_output_dense", + ) + + self.built = True + + def call(self, hidden_states): + gate_up = ops.einsum( + "th,ehm->etm", hidden_states, self._expert_feedforward_gate_dense + ) + gate, up = ops.split(gate_up, 2, axis=-1) + hidden = up * self.activation(gate) + out = ops.einsum( + "eti,eih->eth", hidden, self._expert_feedforward_output_dense + ) + return out + + +class QwenSparseMoeBlock(keras.layers.Layer): + """Qwen-2 Sparse Moe Block""" + + def __init__( + self, + hidden_dim, + moe_intermediate_dim, + num_experts, + top_k, + norm_top_k_prob, + kernel_initializer="glorot_uniform", + layer_norm_epsilon=1e-5, + router_aux_loss_coefficient=0.01, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_dim = hidden_dim + self.intermediate_dim = moe_intermediate_dim + self.num_experts = num_experts + self.top_k = top_k + self.norm_top_k_prob = norm_top_k_prob + self.kernel_initializer = kernel_initializer + self.layer_norm_epsilon = layer_norm_epsilon + self.router_aux_loss_coefficient = router_aux_loss_coefficient + + def build(self, decoder_sequence_shape): + self._sparse_feedforward_gate_dense = keras.layers.Dense( + self.num_experts, + use_bias=False, + kernel_initializer=self.kernel_initializer, + name="sparse_feedforward_gate_dense", + dtype=self.dtype_policy, + ) + self._sparse_feedforward_gate_dense.build(decoder_sequence_shape) + + # NOTE: Experts are implemented as a single layer to enable efficient + # batched computation. Implementing each expert individually is + # currently avoided due to the lack of `ragged_dot` support in the + # Keras ops API, which would make individual implementations unstable + # and prone to bugs. + self.expert_bank = Qwen3MoeExperts( + num_experts=self.num_experts, + hidden_dim=self.hidden_dim, + intermediate_dim=self.intermediate_dim, + kernel_initializer=self.kernel_initializer, + name="experts", + dtype=self.dtype_policy, + ) + self.expert_bank.build(decoder_sequence_shape) + + self.built = True + + def call(self, hidden_states, attention_mask=None, training=None): + batch_size, seq_len, _ = ops.shape(hidden_states) + hidden_states_flattened = ops.reshape( + hidden_states, (-1, self.hidden_dim) + ) + + router_logits = self._sparse_feedforward_gate_dense( + hidden_states_flattened + ) + router_probs = ops.softmax(router_logits, axis=-1) + + top_p, top_i = ops.top_k(router_probs, k=self.top_k) + if self.norm_top_k_prob: + top_p = top_p / ops.sum(top_p, axis=-1, keepdims=True) + + one_hot = ops.one_hot(top_i, self.num_experts) + one_hot = ops.cast(one_hot, top_p.dtype) + routing_full = ops.sum(one_hot * top_p[..., None], axis=1) + routing_full = ops.transpose(routing_full, (1, 0)) + routing_full = ops.cast(routing_full, hidden_states_flattened.dtype) + + expert_out = self.expert_bank(hidden_states_flattened) + + weighted_out = expert_out * routing_full[:, :, None] + expert_contribution = ops.sum(weighted_out, axis=0) + + out = ops.reshape( + expert_contribution, (batch_size, seq_len, self.hidden_dim) + ) + + # Compute and add auxiliary loss during training + if training: + aux_loss = compute_load_balancing_loss( + router_logits=router_logits, + num_experts=self.num_experts, + top_k=self.top_k, + attention_mask=attention_mask, + ) + self.add_loss(self.router_aux_loss_coefficient * aux_loss) + + return out, router_logits + + +class Qwen3MoeTransformerDecoder(keras.layers.Layer): + def __init__( + self, + intermediate_dim, + num_query_heads, + num_key_value_heads, + moe_intermediate_dim, + shared_expert_intermediate_dim, + num_experts, + top_k, + norm_top_k_prob, + decoder_sparse_step, + rope_max_wavelength=10000, + rope_scaling_factor=1.0, + activation="silu", + layer_norm_epsilon=1e-5, + kernel_initializer="glorot_uniform", + dropout=0, + use_sliding_window_attention=False, + sliding_window_size=4096, + layer_index=0, + mlp_only_layers=[], + output_router_logits=False, + router_aux_loss_coefficient=0.001, + **kwargs, + ): + super().__init__(**kwargs) + self.intermediate_dim = intermediate_dim + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + self.rope_max_wavelength = rope_max_wavelength + self.rope_scaling_factor = rope_scaling_factor + self.dropout = dropout + self.use_sliding_window_attention = use_sliding_window_attention + self.sliding_window_size = sliding_window_size + self.activation = keras.activations.get(activation) + self.layer_norm_epsilon = layer_norm_epsilon + self.kernel_initializer = keras.initializers.get(kernel_initializer) + self.layer_index = layer_index + self.mlp_only_layers = mlp_only_layers + self.moe_intermediate_dim = moe_intermediate_dim + self.shared_expert_intermediate_dim = shared_expert_intermediate_dim + self.num_experts = num_experts + self.top_k = top_k + self.norm_top_k_prob = norm_top_k_prob + self.decoder_sparse_step = decoder_sparse_step + self.output_router_logits = output_router_logits + self.router_aux_loss_coefficient = router_aux_loss_coefficient + self.supports_masking = True + + def build(self, decoder_sequence_shape): + self._decoder_sequence_shape = decoder_sequence_shape + self.hidden_dim = decoder_sequence_shape[-1] + + # Self attention layer. + self._self_attention_layer = Qwen3MoeAttention( + num_query_heads=self.num_query_heads, + num_key_value_heads=self.num_key_value_heads, + rope_max_wavelength=self.rope_max_wavelength, + rope_scaling_factor=self.rope_scaling_factor, + kernel_initializer=clone_initializer(self.kernel_initializer), + dropout=self.dropout, + use_sliding_window_attention=self.use_sliding_window_attention, + sliding_window_size=self.sliding_window_size, + name="self_attention", + dtype=self.dtype_policy, + ) + self._self_attention_layer.build(decoder_sequence_shape) + + self._self_attention_layernorm = Qwen3MoeLayerNorm( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="self_attention_layernorm", + ) + + self._self_attention_layernorm.build(decoder_sequence_shape) + self._self_attention_dropout = keras.layers.Dropout( + rate=self.dropout, + dtype=self.dtype_policy, + name="self_attention_dropout", + ) + + # Feedforward layers. + if (self.layer_index not in self.mlp_only_layers) and ( + self.num_experts > 0 + and (self.layer_index + 1) % self.decoder_sparse_step == 0 + ): + self.mlp = QwenSparseMoeBlock( + hidden_dim=self.hidden_dim, + moe_intermediate_dim=self.moe_intermediate_dim, + shared_expert_intermediate_dim=self.shared_expert_intermediate_dim, + num_experts=self.num_experts, + top_k=self.top_k, + norm_top_k_prob=self.norm_top_k_prob, + router_aux_loss_coefficient=self.router_aux_loss_coefficient, + kernel_initializer=self.kernel_initializer, + dtype=self.dtype_policy, + ) + self.mlp.build(decoder_sequence_shape) + else: + self.mlp = Qwen3MoeMLP( + intermediate_dim=self.intermediate_dim, + hidden_dim=self.hidden_dim, + dtype=self.dtype_policy, + ) + self.mlp.build(decoder_sequence_shape) + + self._feedforward_layernorm = Qwen3MoeLayerNorm( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="feedforward_layernorm", + ) + self._feedforward_layernorm.build(decoder_sequence_shape) + + self.built = True + + def call( + self, + decoder_sequence, + decoder_padding_mask=None, + decoder_attention_mask=None, + self_attention_cache=None, + self_attention_cache_update_index=None, + training=None, + ): + """Forward pass for the decoder layer. + + Args: + decoder_sequence: Input tensor of shape [batch_size, seq_length, + hidden_size]. + decoder_padding_mask: Mask tensor for padding tokens. + decoder_attention_mask: Additional attention mask. + self_attention_cache: Optional cached key and value tensors for + self-attention. + self_attention_cache_update_index: Index at which to update the + cache. + training: Boolean indicating whether in training mode. + + Returns: + decoder_output: Output tensor after applying transformer decoder + block. + self_attention_cache: Updated cache tensors (if cache is provided). + """ + self_attention_mask = self._compute_self_attention_mask( + decoder_sequence=decoder_sequence, + decoder_padding_mask=decoder_padding_mask, + decoder_attention_mask=decoder_attention_mask, + self_attention_cache=self_attention_cache, + self_attention_cache_update_index=self_attention_cache_update_index, + ) + residual = decoder_sequence + + x = self._self_attention_layernorm(decoder_sequence) + + # Self attention block. + x = self._self_attention_layer( + hidden_states=x, + attention_mask=self_attention_mask, + cache=self_attention_cache, + cache_update_index=self_attention_cache_update_index, + ) + + if self_attention_cache is not None: + x, self_attention_cache = x + + x = self._self_attention_dropout(x, training=training) + + x = x + residual + residual = x + + x = self._feedforward_layernorm(x) + if isinstance(self.mlp, QwenSparseMoeBlock): + x = self.mlp( + x, training=training, attention_mask=self_attention_mask + ) + else: + x = self.mlp(x) + if isinstance(x, tuple): + x, router_logits = x + else: + router_logits = None + + x = ops.cast(x, ops.dtype(residual)) + decoder_output = x + residual + + output = (decoder_output,) + + if self_attention_cache is not None: + output += (self_attention_cache,) + + if self.output_router_logits: + output += (router_logits,) + + return output[0] if len(output) == 1 else output + + def _compute_self_attention_mask( + self, + decoder_sequence, + decoder_padding_mask, + decoder_attention_mask, + self_attention_cache, + self_attention_cache_update_index, + ): + """Computes the self-attention mask combining causal, padding and + attention masks. + + Args: + decoder_sequence: Input tensor. + decoder_padding_mask: Mask tensor for padding tokens. + decoder_attention_mask: Additional attention mask. + self_attention_cache: Optional cached key and value tensors. + self_attention_cache_update_index: Index at which to update the + cache. + + Returns: + Combined attention mask tensor. + """ + decoder_mask = merge_padding_and_attention_mask( + decoder_sequence, decoder_padding_mask, decoder_attention_mask + ) + batch_size = ops.shape(decoder_sequence)[0] + input_length = output_length = ops.shape(decoder_sequence)[1] + # We need to handle a rectangular causal mask when doing cached + # decoding. For generative inference, `decoder_sequence` will + # generally be length 1, and `cache` will be the full generation length. + if self_attention_cache is not None: + input_length = ops.shape(self_attention_cache)[2] + + cache_update_index = ( + 0 + if self_attention_cache_update_index is None + else self_attention_cache_update_index + ) + + causal_mask = compute_causal_mask( + batch_size, input_length, output_length, cache_update_index + ) + + return ( + ops.minimum(decoder_mask, causal_mask) + if decoder_mask is not None + else causal_mask + ) + + def compute_output_shape(self, decoder_sequence_shape): + """Computes the output shape of the layer. + + Args: + decoder_sequence_shape: Shape of the decoder sequence input. + + Returns: + Output shape, which is the same as the input shape. + """ + return decoder_sequence_shape + + def get_config(self): + """Returns the config of the layer. + + Returns: + Dictionary containing the parameters used to initialize this layer. + """ + config = super().get_config() + config.update( + { + "num_query_heads": self.num_query_heads, + "intermediate_dim": self.intermediate_dim, + "moe_intermediate_dim": self.moe_intermediate_dim, + "rope_max_wavelength": self.rope_max_wavelength, + "num_key_value_heads": self.num_key_value_heads, + "rope_scaling_factor": self.rope_scaling_factor, + "layer_norm_epsilon": self.layer_norm_epsilon, + "dropout": self.dropout, + "use_sliding_window_attention": ( + self.use_sliding_window_attention + ), + "sliding_window_size": self.sliding_window_size, + "num_experts": self.num_experts, + "top_k": self.top_k, + "norm_top_k_prob": self.norm_top_k_prob, + "decoder_sparse_step": self.decoder_sparse_step, + "mlp_only_layers": self.mlp_only_layers, + "output_router_logits": self.output_router_logits, + "router_aux_loss_coefficient": self.router_aux_loss_coefficient, + } + ) + return config diff --git a/keras_hub/src/models/qwen3_moe/qwen3_moe_layernorm.py b/keras_hub/src/models/qwen3_moe/qwen3_moe_layernorm.py new file mode 100644 index 0000000000..25e6002669 --- /dev/null +++ b/keras_hub/src/models/qwen3_moe/qwen3_moe_layernorm.py @@ -0,0 +1,32 @@ +import keras +from keras import ops + + +class Qwen3MoeLayerNorm(keras.layers.Layer): + """A normalization layer for Qwen that implements RMS normalization.""" + + def __init__(self, epsilon=1e-6, **kwargs): + super().__init__(**kwargs) + self.epsilon = epsilon + + def build(self, input_shape): + dim = input_shape[-1] + self.scale = self.add_weight( + name="scale", + trainable=True, + shape=(dim,), + initializer="ones", + dtype=self.variable_dtype, + ) + self.built = True + + def call(self, x): + x = ops.cast(x, "float32") + var = ops.mean(ops.power(x, 2), axis=-1, keepdims=True) + x = x * ops.rsqrt(var + self.epsilon) + return ops.cast(x * self.scale, self.compute_dtype) + + def get_config(self): + config = super().get_config() + config.update({"epsilon": self.epsilon}) + return config diff --git a/keras_hub/src/models/qwen3_moe/qwen3_moe_tokenizer.py b/keras_hub/src/models/qwen3_moe/qwen3_moe_tokenizer.py new file mode 100644 index 0000000000..fbb7404c56 --- /dev/null +++ b/keras_hub/src/models/qwen3_moe/qwen3_moe_tokenizer.py @@ -0,0 +1,46 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.qwen_moe.qwen_moe_backbone import Qwen3MoeBackbone +from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer + + +@keras_hub_export( + "keras_hub.tokenizers.Qwen3MoeTokenizer", +) +class Qwen3MoeTokenizer(BytePairTokenizer): + """Tokenizer for Qwen Moe model. + + This tokenizer implements byte-pair encoding (BPE) for Qwen models, + handling special tokens like BOS (beginning of sequence) and EOS (end of + sequence). + + Args: + vocabulary: Dictionary mapping tokens to token IDs, or path to + vocabulary file. + merges: List of BPE merges, or path to merges file. + bos_token: Beginning of sequence token. Defaults to None. + eos_token: End of sequence token. Defaults to "<|endoftext|>". + misc_special_tokens: Set of additional special tokens. Defaults to + empty set. + """ + + backbone_cls = Qwen3MoeBackbone + + def __init__( + self, + vocabulary=None, + merges=None, + **kwargs, + ): + # Add EOS token + eos_token = "<|endoftext|>" + self._add_special_token(eos_token, "end_token") + + self.start_token_id = None + self.start_token = None + self.pad_token_id = 0 + + super().__init__( + vocabulary=vocabulary, + merges=merges, + **kwargs, + ) diff --git a/keras_hub/src/utils/transformers/convert_qwen3_moe.py b/keras_hub/src/utils/transformers/convert_qwen3_moe.py new file mode 100644 index 0000000000..8e5ad4875b --- /dev/null +++ b/keras_hub/src/utils/transformers/convert_qwen3_moe.py @@ -0,0 +1,220 @@ +import numpy as np + +from keras_hub.src.models.qwen_moe.qwen_moe_backbone import QwenMoeBackbone +from keras_hub.src.utils.preset_utils import load_json + +backbone_cls = QwenMoeBackbone + + +def convert_backbone_config(transformers_config): + return { + "vocabulary_size": transformers_config["vocab_size"], + "hidden_dim": transformers_config["hidden_size"], + "num_layers": transformers_config["num_hidden_layers"], + "num_query_heads": transformers_config["num_attention_heads"], + "num_key_value_heads": transformers_config["num_key_value_heads"], + "intermediate_dim": transformers_config["intermediate_size"], + "moe_intermediate_dim": transformers_config["moe_intermediate_size"], + "num_experts": transformers_config["num_experts"], + "top_k": transformers_config["num_experts_per_tok"], + "norm_top_k_prob": transformers_config["norm_topk_prob"], + "decoder_sparse_step": transformers_config["decoder_sparse_step"], + "layer_norm_epsilon": transformers_config["rms_norm_eps"], + "rope_max_wavelength": transformers_config["rope_theta"], + "use_sliding_window": transformers_config["use_sliding_window"], + "sliding_window_size": transformers_config["sliding_window"], + "output_router_logits": transformers_config["output_router_logits"], + "router_aux_loss_coefficient": transformers_config[ + "router_aux_loss_coef" + ], + } + + +def convert_weights(backbone, loader, transformers_config): + loader.port_weight( + keras_variable=backbone.get_layer("token_embedding").embeddings, + hf_weight_key="model.embed_tokens.weight", + ) + if not backbone.tie_word_embeddings: + loader.port_weight( + keras_variable=backbone.get_layer( + "token_embedding" + ).reverse_embeddings, + hf_weight_key="lm_head.weight", + # rearrange_pattern="b a -> a b", + hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), + ) + + def transpose_and_reshape(x, shape): + return np.reshape(np.transpose(x), shape) + + for i in range(backbone.num_layers): + decoder_layer = backbone.get_layer(f"transformer_layer_{i}") + + # Input layernorm + loader.port_weight( + keras_variable=decoder_layer._self_attention_layernorm.scale, + hf_weight_key=f"model.layers.{i}.input_layernorm.weight", + ) + + # Attention layers + + ## Query + loader.port_weight( + keras_variable=decoder_layer._self_attention_layer._query_dense.kernel, + hf_weight_key=f"model.layers.{i}.self_attn.q_proj.weight", + hook_fn=transpose_and_reshape, + ) + loader.port_weight( + keras_variable=decoder_layer._self_attention_layer._query_dense.bias, + hf_weight_key=f"model.layers.{i}.self_attn.q_proj.bias", + hook_fn=transpose_and_reshape, + ) + ## Key + loader.port_weight( + keras_variable=decoder_layer._self_attention_layer._key_dense.kernel, + hf_weight_key=f"model.layers.{i}.self_attn.k_proj.weight", + hook_fn=transpose_and_reshape, + ) + loader.port_weight( + keras_variable=decoder_layer._self_attention_layer._key_dense.bias, + hf_weight_key=f"model.layers.{i}.self_attn.k_proj.bias", + hook_fn=transpose_and_reshape, + ) + ## Value + loader.port_weight( + keras_variable=decoder_layer._self_attention_layer._value_dense.kernel, + hf_weight_key=f"model.layers.{i}.self_attn.v_proj.weight", + hook_fn=transpose_and_reshape, + ) + loader.port_weight( + keras_variable=decoder_layer._self_attention_layer._value_dense.bias, + hf_weight_key=f"model.layers.{i}.self_attn.v_proj.bias", + hook_fn=transpose_and_reshape, + ) + ## Output + loader.port_weight( + keras_variable=decoder_layer._self_attention_layer._output_dense.kernel, + hf_weight_key=f"model.layers.{i}.self_attn.o_proj.weight", + # rearrange_patterns="c (a b) -> a b c", + # rearrange_dims={"a": backbone.num_query_heads}, + hook_fn=transpose_and_reshape, + ) + + # MLP layers + if ( + (i not in backbone.mlp_only_layers) + and backbone.num_experts > 0 + and ((i + 1) % backbone.decoder_sparse_step == 0) + ): + # MoE layers + loader.port_weight( + keras_variable=decoder_layer.mlp._sparse_feedforward_gate_dense.kernel, + hf_weight_key=f"model.layers.{i}.mlp.gate.weight", + # rearrange_patterns="b a -> a b", + hook_fn=lambda hf_tensor, _: np.transpose( + hf_tensor, axes=(1, 0) + ), + ) + # Batched experts: gate_up_proj and down_proj + gate_up_proj_list = [] + down_proj_list = [] + for expert_idx in range(backbone.num_experts): + # Load gate_proj and up_proj for each expert + gate_proj = loader.get_tensor( + f"model.layers.{i}.mlp.experts.{expert_idx}.gate_proj.weight" + ) + up_proj = loader.get_tensor( + f"model.layers.{i}.mlp.experts.{expert_idx}.up_proj.weight" + ) + # Transpose to (hidden_dim, intermediate_dim) + gate_proj = np.transpose(gate_proj, axes=(1, 0)) + up_proj = np.transpose(up_proj, axes=(1, 0)) + # Concatenate gate_proj and up_proj along the last dimension + gate_up_proj = np.concatenate([gate_proj, up_proj], axis=-1) + gate_up_proj_list.append(gate_up_proj) + + # Load down_proj for each expert + down_proj = loader.get_tensor( + f"model.layers.{i}.mlp.experts.{expert_idx}.down_proj.weight" + ) + down_proj = np.transpose( + down_proj, axes=(1, 0) + ) # (intermediate_dim, hidden_dim) + down_proj_list.append(down_proj) + + # Stack the lists to create batched weights + gate_up_proj_batched = np.stack( + gate_up_proj_list, axis=0 + ) # (num_experts, hidden_dim, 2 * intermediate_dim) + down_proj_batched = np.stack( + down_proj_list, axis=0 + ) # (num_experts, intermediate_dim, hidden_dim) + + # Assign batched weights to expert_bank + decoder_layer.mlp.expert_bank._expert_feedforward_gate_dense.assign( + gate_up_proj_batched + ) + decoder_layer.mlp.expert_bank._expert_feedforward_output_dense.assign( + down_proj_batched + ) + else: + loader.port_weight( + keras_variable=decoder_layer._feedforward_intermediate_dense.kernel, + hf_weight_key=f"model.layers.{i}.mlp.up_proj.weight", + # rearrange_patterns="b a -> a b", + hook_fn=lambda hf_tensor, _: np.transpose( + hf_tensor, axes=(1, 0) + ), + ) + loader.port_weight( + keras_variable=decoder_layer._feedforward_output_dense.kernel, + hf_weight_key=f"model.layers.{i}.mlp.down_proj.weight", + # rearrange_patterns="b a -> a b", + hook_fn=lambda hf_tensor, _: np.transpose( + hf_tensor, axes=(1, 0) + ), + ) + loader.port_weight( + keras_variable=decoder_layer._feedforward_gate_dense.kernel, + hf_weight_key=f"model.layers.{i}.mlp.gate_proj.weight", + # rearrange_patterns="b a -> a b", + hook_fn=lambda hf_tensor, _: np.transpose( + hf_tensor, axes=(1, 0) + ), + ) + + # Feedforward layernorm + loader.port_weight( + keras_variable=decoder_layer._feedforward_layernorm.scale, + hf_weight_key=f"model.layers.{i}.post_attention_layernorm.weight", + ) + + # Final normalization layer + loader.port_weight( + keras_variable=backbone.get_layer("sequence_output_layernorm").scale, + hf_weight_key="model.norm.weight", + ) + + return backbone + + +def convert_tokenizer(cls, preset, **kwargs): + tokenizer_config = load_json(preset, "tokenizer.json") + vocab = tokenizer_config["model"]["vocab"] + merges = tokenizer_config["model"]["merges"] + + # Load all special tokens with the exception of "reserved" ones. + special_tokens = set() + for token in tokenizer_config["added_tokens"]: + if not token["content"].startswith("<|reserved_special_token_"): + vocab[token["content"]] = token["id"] + special_tokens.add(token["content"]) + + kwargs.update( + { + "unsplittable_tokens": list(special_tokens), + } + ) + + return cls(vocabulary=vocab, merges=merges, **kwargs) diff --git a/tools/checkpoint_conversion/convert_qwen3_moe_checkpoints.py b/tools/checkpoint_conversion/convert_qwen3_moe_checkpoints.py new file mode 100644 index 0000000000..3c32e46cdf --- /dev/null +++ b/tools/checkpoint_conversion/convert_qwen3_moe_checkpoints.py @@ -0,0 +1,162 @@ +import os +import traceback + +os.environ["KERAS_BACKEND"] = "torch" +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # Hide any CUDA devices + +import numpy as np +import torch +from absl import app + +# from absl import flags + +device = torch.device("cpu") +# Force PyTorch to use CPU +torch.set_default_device(device) + +from keras import ops # noqa: E402 +from transformers import AutoModelForCausalLM # noqa: E402 +from transformers import AutoTokenizer # noqa: E402 + +import keras_hub # noqa: E402 + +PRESET_MAP = { + "qwen3_moe_3b_en": "Qwen/Qwen3-30B-A3B", +} + +# FLAGS = flags.FLAGS +# flags.DEFINE_string( +# "preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}" +# ) + + +def test_model( + keras_hub_model, keras_hub_tokenizer, hf_model, hf_model_tokenizer +): + # First, test that the number of parameters match + keras_hub_params = keras_hub_model.count_params() + hf_params = hf_model.num_parameters() + assert keras_hub_params == hf_params + + # Test the outputs of both the models + hf_inputs = hf_model_tokenizer(["What is Keras?"], return_tensors="pt").to( + device + ) + hf_outputs = hf_model(**hf_inputs) + hf_output_logits = hf_outputs.logits.detach().cpu().float().numpy() + + keras_hub_preprocessor = keras_hub.models.QwenCausalLMPreprocessor( + keras_hub_tokenizer + ) + keras_hub_inputs = keras_hub_preprocessor( + ["What is Keras?"], sequence_length=5 + )[0] + keras_hub_inputs = {k: v.to(device) for k, v in keras_hub_inputs.items()} + + keras_hub_output = keras_hub_model(keras_hub_inputs) + keras_hub_logits = keras_hub_model.token_embedding( + keras_hub_output, reverse=True + ) + keras_hub_logits = ops.convert_to_numpy(keras_hub_logits) + + # High tolerence since bfloat16 is used as the default dtype for Qwen + + try: + np.testing.assert_allclose( + keras_hub_logits, hf_output_logits, atol=1e-4 + ) + except AssertionError as err: + print("\n") + print(traceback.format_exc()) + print(err.args[0]) + print("\n") + + +def test_tokenizer(keras_hub_tokenizer, hf_tokenizer): + hf_output = hf_tokenizer(["What is Keras?"], return_tensors="pt") + hf_output = hf_output["input_ids"].detach().cpu().numpy() + keras_hub_preprocessor = keras_hub.models.QwenCausalLMPreprocessor( + keras_hub_tokenizer + ) + keras_hub_output = keras_hub_preprocessor( + ["What is Keras?"], sequence_length=5 + ) + keras_hub_output = ops.convert_to_numpy(keras_hub_output[0]["token_ids"]) + + np.testing.assert_equal(keras_hub_output, hf_output) + + +def validate_output( + keras_hub_model, keras_hub_tokenizer, hf_model, hf_tokenizer +): + input_str = "What is Keras?" + length = 32 + + # KerasHub + preprocessor = keras_hub.models.QwenMoeCausalLMPreprocessor( + keras_hub_tokenizer + ) + qwen_moe_lm = keras_hub.models.QwenMoeCausalLM( + backbone=keras_hub_model, preprocessor=preprocessor + ) + + keras_output = qwen_moe_lm.generate([input_str], max_length=length) + keras_output = keras_output[0] + print("🔶 KerasHub output:", keras_output) + + # Transformers + hf_inputs = hf_tokenizer([input_str], return_tensors="pt").to(device) + outputs = hf_model.generate( + **hf_inputs, + max_length=length, # Match KerasHub's max_length + # do_sample=True, # Enable sampling (default in KerasHub for generate) + pad_token_id=hf_tokenizer.pad_token_id, + ) + hf_generated_text = hf_tokenizer.batch_decode( + outputs, skip_special_tokens=True + )[0] + print("🔶 Huggingface output:", hf_generated_text) + + +def main(_): + # === Get the preset name === + # if FLAGS.preset not in PRESET_MAP.keys(): + # raise ValueError( + # f"Invalid preset {FLAGS.preset}. Must be one " + # f"of {','.join(PRESET_MAP.keys())}" + # ) + # preset = FLAGS.preset + # hf_preset = PRESET_MAP[preset] + hf_preset = "Qwen/Qwen1.5-MoE-A2.7B" + + # === Load the Huggingface model === + hf_model = AutoModelForCausalLM.from_pretrained( + hf_preset, + device_map=device, + ) + hf_tokenizer = AutoTokenizer.from_pretrained(hf_preset, return_tensors="pt") + hf_model.eval() + + keras_hub_model = keras_hub.models.QwenMoeBackbone.from_preset( + f"hf://{hf_preset}" + ) + keras_hub_tokenizer = keras_hub.tokenizers.QwenMoeTokenizer.from_preset( + f"hf://{hf_preset}" + ) + + print("\n-> Huggingface model and tokenizer loaded") + + # === Check that the models and tokenizers outputs match === + test_tokenizer(keras_hub_tokenizer, hf_tokenizer) + test_model(keras_hub_model, keras_hub_tokenizer, hf_model, hf_tokenizer) + + # == Validate model.generate output == + validate_output( + keras_hub_model, keras_hub_tokenizer, hf_model, hf_tokenizer + ) + print("\n-> Tests passed!") + + +if __name__ == "__main__": + # flags.mark_flag_as_required("preset") + app.run(main) From 84043a3e913d37491a4f7a6014b2feb19d4970d8 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra Date: Tue, 20 May 2025 14:43:59 +0000 Subject: [PATCH 2/7] bug fixes --- keras_hub/api/models/__init__.py | 6 ++++ keras_hub/api/tokenizers/__init__.py | 3 ++ .../qwen3_moe/qwen3_causal_lm_preprocessor.py | 4 +-- .../models/qwen3_moe/qwen3_moe_attention.py | 4 +-- .../models/qwen3_moe/qwen3_moe_backbone.py | 4 ++- .../src/models/qwen3_moe/qwen3_moe_decoder.py | 17 ++++----- .../models/qwen3_moe/qwen3_moe_layernorm.py | 14 +++++--- .../models/qwen3_moe/qwen3_moe_tokenizer.py | 2 +- .../utils/transformers/convert_qwen3_moe.py | 36 ++++++++++--------- .../src/utils/transformers/preset_loader.py | 4 ++- .../convert_qwen3_moe_checkpoints.py | 16 ++++----- 11 files changed, 66 insertions(+), 44 deletions(-) diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index b831f7be2e..44f74727f1 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -432,6 +432,12 @@ from keras_hub.src.models.qwen.qwen_tokenizer import ( QwenTokenizer as QwenTokenizer, ) +from keras_hub.src.models.qwen3_moe.qwen3_causal_lm_preprocessor import ( + Qwen3MoeCausalLMPreprocessor as Qwen3MoeCausalLMPreprocessor, +) +from keras_hub.src.models.qwen3_moe.qwen3_moe_backbone import ( + Qwen3MoeBackbone as Qwen3MoeBackbone, +) from keras_hub.src.models.qwen_moe.qwen_moe_backbone import ( QwenMoeBackbone as QwenMoeBackbone, ) diff --git a/keras_hub/api/tokenizers/__init__.py b/keras_hub/api/tokenizers/__init__.py index 8d497fff86..769e347bdd 100644 --- a/keras_hub/api/tokenizers/__init__.py +++ b/keras_hub/api/tokenizers/__init__.py @@ -71,6 +71,9 @@ from keras_hub.src.models.qwen.qwen_tokenizer import ( QwenTokenizer as QwenTokenizer, ) +from keras_hub.src.models.qwen3_moe.qwen3_moe_tokenizer import ( + Qwen3MoeTokenizer as Qwen3MoeTokenizer, +) from keras_hub.src.models.qwen_moe.qwen_moe_tokenizer import ( QwenMoeTokenizer as QwenMoeTokenizer, ) diff --git a/keras_hub/src/models/qwen3_moe/qwen3_causal_lm_preprocessor.py b/keras_hub/src/models/qwen3_moe/qwen3_causal_lm_preprocessor.py index 34d5bf87ab..6434b959ce 100644 --- a/keras_hub/src/models/qwen3_moe/qwen3_causal_lm_preprocessor.py +++ b/keras_hub/src/models/qwen3_moe/qwen3_causal_lm_preprocessor.py @@ -1,7 +1,7 @@ from keras_hub.src.api_export import keras_hub_export from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor -from keras_hub.src.models.qwen_moe.qwen_moe_backbone import Qwen3MoeBackbone -from keras_hub.src.models.qwen_moe.qwen_moe_tokenizer import Qwen3MoeTokenizer +from keras_hub.src.models.qwen3_moe.qwen3_moe_backbone import Qwen3MoeBackbone +from keras_hub.src.models.qwen3_moe.qwen3_moe_tokenizer import Qwen3MoeTokenizer @keras_hub_export( diff --git a/keras_hub/src/models/qwen3_moe/qwen3_moe_attention.py b/keras_hub/src/models/qwen3_moe/qwen3_moe_attention.py index 5cf2b1a9f0..da93e5d8e3 100644 --- a/keras_hub/src/models/qwen3_moe/qwen3_moe_attention.py +++ b/keras_hub/src/models/qwen3_moe/qwen3_moe_attention.py @@ -100,7 +100,7 @@ def build(self, inputs_shape): self._query_dense_layer_norm = Qwen3MoeLayerNorm( epsilon=self.layer_norm_epsilon, dtype=self.dtype_policy, - head_dim=self.head_dim, + hidden_dim=self.head_dim, name="query_dense_layernorm", ) self._query_dense_layer_norm.build(inputs_shape) @@ -121,7 +121,7 @@ def build(self, inputs_shape): self._key_dense_layer_norm = Qwen3MoeLayerNorm( epsilon=self.layer_norm_epsilon, dtype=self.dtype_policy, - head_dim=self.head_dim, + hidden_dim=self.head_dim, name="key_dense_layernorm", ) self._key_dense_layer_norm.build(inputs_shape) diff --git a/keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py b/keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py index 15ad9c9af2..8a659edb5c 100644 --- a/keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +++ b/keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py @@ -7,7 +7,7 @@ ) from keras_hub.src.models.backbone import Backbone from keras_hub.src.models.qwen.qwen_layernorm import QwenLayerNorm -from keras_hub.src.models.qwen_moe.qwen_moe_decoder import ( +from keras_hub.src.models.qwen3_moe.qwen3_moe_decoder import ( Qwen3MoeTransformerDecoder, ) @@ -98,6 +98,7 @@ def __init__( num_layers, num_query_heads, num_key_value_heads, + head_dim, hidden_dim, intermediate_dim, moe_intermediate_dim, @@ -135,6 +136,7 @@ def __init__( num_query_heads=num_query_heads, num_key_value_heads=num_key_value_heads, moe_intermediate_dim=moe_intermediate_dim, + head_dim=head_dim, num_experts=num_experts, top_k=top_k, norm_top_k_prob=norm_top_k_prob, diff --git a/keras_hub/src/models/qwen3_moe/qwen3_moe_decoder.py b/keras_hub/src/models/qwen3_moe/qwen3_moe_decoder.py index 2758bba531..351e05db1c 100644 --- a/keras_hub/src/models/qwen3_moe/qwen3_moe_decoder.py +++ b/keras_hub/src/models/qwen3_moe/qwen3_moe_decoder.py @@ -7,8 +7,8 @@ from keras_hub.src.layers.modeling.transformer_layer_utils import ( merge_padding_and_attention_mask, ) -from keras_hub.src.models.qwen_moe.qwen_moe_attention import Qwen3MoeAttention -from keras_hub.src.models.qwen_moe.qwen_moe_layernorm import Qwen3MoeLayerNorm +from keras_hub.src.models.qwen3_moe.qwen3_moe_attention import Qwen3MoeAttention +from keras_hub.src.models.qwen3_moe.qwen3_moe_layernorm import Qwen3MoeLayerNorm from keras_hub.src.utils.keras_utils import clone_initializer @@ -215,7 +215,7 @@ def call(self, hidden_states): return out -class QwenSparseMoeBlock(keras.layers.Layer): +class Qwen3SparseMoeBlock(keras.layers.Layer): """Qwen-2 Sparse Moe Block""" def __init__( @@ -315,9 +315,9 @@ def __init__( self, intermediate_dim, num_query_heads, + head_dim, num_key_value_heads, moe_intermediate_dim, - shared_expert_intermediate_dim, num_experts, top_k, norm_top_k_prob, @@ -351,7 +351,7 @@ def __init__( self.layer_index = layer_index self.mlp_only_layers = mlp_only_layers self.moe_intermediate_dim = moe_intermediate_dim - self.shared_expert_intermediate_dim = shared_expert_intermediate_dim + self.head_dim = head_dim self.num_experts = num_experts self.top_k = top_k self.norm_top_k_prob = norm_top_k_prob @@ -367,6 +367,7 @@ def build(self, decoder_sequence_shape): # Self attention layer. self._self_attention_layer = Qwen3MoeAttention( num_query_heads=self.num_query_heads, + head_dim=self.head_dim, num_key_value_heads=self.num_key_value_heads, rope_max_wavelength=self.rope_max_wavelength, rope_scaling_factor=self.rope_scaling_factor, @@ -374,6 +375,7 @@ def build(self, decoder_sequence_shape): dropout=self.dropout, use_sliding_window_attention=self.use_sliding_window_attention, sliding_window_size=self.sliding_window_size, + layer_index=self.layer_index, name="self_attention", dtype=self.dtype_policy, ) @@ -397,10 +399,9 @@ def build(self, decoder_sequence_shape): self.num_experts > 0 and (self.layer_index + 1) % self.decoder_sparse_step == 0 ): - self.mlp = QwenSparseMoeBlock( + self.mlp = Qwen3SparseMoeBlock( hidden_dim=self.hidden_dim, moe_intermediate_dim=self.moe_intermediate_dim, - shared_expert_intermediate_dim=self.shared_expert_intermediate_dim, num_experts=self.num_experts, top_k=self.top_k, norm_top_k_prob=self.norm_top_k_prob, @@ -481,7 +482,7 @@ def call( residual = x x = self._feedforward_layernorm(x) - if isinstance(self.mlp, QwenSparseMoeBlock): + if isinstance(self.mlp, Qwen3SparseMoeBlock): x = self.mlp( x, training=training, attention_mask=self_attention_mask ) diff --git a/keras_hub/src/models/qwen3_moe/qwen3_moe_layernorm.py b/keras_hub/src/models/qwen3_moe/qwen3_moe_layernorm.py index 25e6002669..1c232cdf9e 100644 --- a/keras_hub/src/models/qwen3_moe/qwen3_moe_layernorm.py +++ b/keras_hub/src/models/qwen3_moe/qwen3_moe_layernorm.py @@ -5,12 +5,17 @@ class Qwen3MoeLayerNorm(keras.layers.Layer): """A normalization layer for Qwen that implements RMS normalization.""" - def __init__(self, epsilon=1e-6, **kwargs): + def __init__(self, hidden_dim=None, epsilon=1e-6, **kwargs): super().__init__(**kwargs) + self.hidden_dim = hidden_dim self.epsilon = epsilon def build(self, input_shape): - dim = input_shape[-1] + if self.hidden_dim: + dim = self.hidden_dim + else: + dim = input_shape[-1] + self.scale = self.add_weight( name="scale", trainable=True, @@ -21,12 +26,13 @@ def build(self, input_shape): self.built = True def call(self, x): + input_dtype = x.dtype x = ops.cast(x, "float32") var = ops.mean(ops.power(x, 2), axis=-1, keepdims=True) x = x * ops.rsqrt(var + self.epsilon) - return ops.cast(x * self.scale, self.compute_dtype) + return ops.cast(x * self.scale, input_dtype) def get_config(self): config = super().get_config() config.update({"epsilon": self.epsilon}) - return config + return config \ No newline at end of file diff --git a/keras_hub/src/models/qwen3_moe/qwen3_moe_tokenizer.py b/keras_hub/src/models/qwen3_moe/qwen3_moe_tokenizer.py index fbb7404c56..a4e4cc946f 100644 --- a/keras_hub/src/models/qwen3_moe/qwen3_moe_tokenizer.py +++ b/keras_hub/src/models/qwen3_moe/qwen3_moe_tokenizer.py @@ -1,5 +1,5 @@ from keras_hub.src.api_export import keras_hub_export -from keras_hub.src.models.qwen_moe.qwen_moe_backbone import Qwen3MoeBackbone +from keras_hub.src.models.qwen3_moe.qwen3_moe_backbone import Qwen3MoeBackbone from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer diff --git a/keras_hub/src/utils/transformers/convert_qwen3_moe.py b/keras_hub/src/utils/transformers/convert_qwen3_moe.py index 8e5ad4875b..083ac81aa3 100644 --- a/keras_hub/src/utils/transformers/convert_qwen3_moe.py +++ b/keras_hub/src/utils/transformers/convert_qwen3_moe.py @@ -1,15 +1,16 @@ import numpy as np -from keras_hub.src.models.qwen_moe.qwen_moe_backbone import QwenMoeBackbone +from keras_hub.src.models.qwen3_moe.qwen3_moe_backbone import Qwen3MoeBackbone from keras_hub.src.utils.preset_utils import load_json -backbone_cls = QwenMoeBackbone +backbone_cls = Qwen3MoeBackbone def convert_backbone_config(transformers_config): return { "vocabulary_size": transformers_config["vocab_size"], "hidden_dim": transformers_config["hidden_size"], + "head_dim": transformers_config['head_dim'], "num_layers": transformers_config["num_hidden_layers"], "num_query_heads": transformers_config["num_attention_heads"], "num_key_value_heads": transformers_config["num_key_value_heads"], @@ -65,33 +66,33 @@ def transpose_and_reshape(x, shape): hf_weight_key=f"model.layers.{i}.self_attn.q_proj.weight", hook_fn=transpose_and_reshape, ) - loader.port_weight( - keras_variable=decoder_layer._self_attention_layer._query_dense.bias, - hf_weight_key=f"model.layers.{i}.self_attn.q_proj.bias", - hook_fn=transpose_and_reshape, - ) + # loader.port_weight( + # keras_variable=decoder_layer._self_attention_layer._query_dense.bias, + # hf_weight_key=f"model.layers.{i}.self_attn.q_proj.bias", + # hook_fn=transpose_and_reshape, + # ) ## Key loader.port_weight( keras_variable=decoder_layer._self_attention_layer._key_dense.kernel, hf_weight_key=f"model.layers.{i}.self_attn.k_proj.weight", hook_fn=transpose_and_reshape, ) - loader.port_weight( - keras_variable=decoder_layer._self_attention_layer._key_dense.bias, - hf_weight_key=f"model.layers.{i}.self_attn.k_proj.bias", - hook_fn=transpose_and_reshape, - ) + # loader.port_weight( + # keras_variable=decoder_layer._self_attention_layer._key_dense.bias, + # hf_weight_key=f"model.layers.{i}.self_attn.k_proj.bias", + # hook_fn=transpose_and_reshape, + # ) ## Value loader.port_weight( keras_variable=decoder_layer._self_attention_layer._value_dense.kernel, hf_weight_key=f"model.layers.{i}.self_attn.v_proj.weight", hook_fn=transpose_and_reshape, ) - loader.port_weight( - keras_variable=decoder_layer._self_attention_layer._value_dense.bias, - hf_weight_key=f"model.layers.{i}.self_attn.v_proj.bias", - hook_fn=transpose_and_reshape, - ) + # loader.port_weight( + # keras_variable=decoder_layer._self_attention_layer._value_dense.bias, + # hf_weight_key=f"model.layers.{i}.self_attn.v_proj.bias", + # hook_fn=transpose_and_reshape, + # ) ## Output loader.port_weight( keras_variable=decoder_layer._self_attention_layer._output_dense.kernel, @@ -203,6 +204,7 @@ def convert_tokenizer(cls, preset, **kwargs): tokenizer_config = load_json(preset, "tokenizer.json") vocab = tokenizer_config["model"]["vocab"] merges = tokenizer_config["model"]["merges"] + merges = [" ".join(item) for item in merges] # Load all special tokens with the exception of "reserved" ones. special_tokens = set() diff --git a/keras_hub/src/utils/transformers/preset_loader.py b/keras_hub/src/utils/transformers/preset_loader.py index 1c126bcbb1..cfae088e3d 100644 --- a/keras_hub/src/utils/transformers/preset_loader.py +++ b/keras_hub/src/utils/transformers/preset_loader.py @@ -3,7 +3,7 @@ from keras_hub.src.models.image_classifier import ImageClassifier from keras_hub.src.utils.preset_utils import PresetLoader from keras_hub.src.utils.preset_utils import jax_memory_cleanup -from keras_hub.src.utils.transformers import convert_albert +from keras_hub.src.utils.transformers import convert_albert, convert_qwen3_moe from keras_hub.src.utils.transformers import convert_bart from keras_hub.src.utils.transformers import convert_bert from keras_hub.src.utils.transformers import convert_distilbert @@ -50,6 +50,8 @@ def __init__(self, preset, config): self.converter = convert_mixtral elif model_type == "qwen2_moe": self.converter = convert_qwen_moe + elif model_type == "qwen3_moe": + self.converter = convert_qwen3_moe else: raise ValueError( "KerasHub has no converter for huggingface/transformers models " diff --git a/tools/checkpoint_conversion/convert_qwen3_moe_checkpoints.py b/tools/checkpoint_conversion/convert_qwen3_moe_checkpoints.py index 3c32e46cdf..fba956a516 100644 --- a/tools/checkpoint_conversion/convert_qwen3_moe_checkpoints.py +++ b/tools/checkpoint_conversion/convert_qwen3_moe_checkpoints.py @@ -93,10 +93,10 @@ def validate_output( length = 32 # KerasHub - preprocessor = keras_hub.models.QwenMoeCausalLMPreprocessor( + preprocessor = keras_hub.models.Qwen3MoeCausalLMPreprocessor( keras_hub_tokenizer ) - qwen_moe_lm = keras_hub.models.QwenMoeCausalLM( + qwen_moe_lm = keras_hub.models.Qwen3MoeCausalLM( backbone=keras_hub_model, preprocessor=preprocessor ) @@ -127,7 +127,7 @@ def main(_): # ) # preset = FLAGS.preset # hf_preset = PRESET_MAP[preset] - hf_preset = "Qwen/Qwen1.5-MoE-A2.7B" + hf_preset = "Qwen/Qwen3-30B-A3B" # === Load the Huggingface model === hf_model = AutoModelForCausalLM.from_pretrained( @@ -137,10 +137,10 @@ def main(_): hf_tokenizer = AutoTokenizer.from_pretrained(hf_preset, return_tensors="pt") hf_model.eval() - keras_hub_model = keras_hub.models.QwenMoeBackbone.from_preset( + keras_hub_model = keras_hub.models.Qwen3MoeBackbone.from_preset( f"hf://{hf_preset}" ) - keras_hub_tokenizer = keras_hub.tokenizers.QwenMoeTokenizer.from_preset( + keras_hub_tokenizer = keras_hub.tokenizers.Qwen3MoeTokenizer.from_preset( f"hf://{hf_preset}" ) @@ -151,9 +151,9 @@ def main(_): test_model(keras_hub_model, keras_hub_tokenizer, hf_model, hf_tokenizer) # == Validate model.generate output == - validate_output( - keras_hub_model, keras_hub_tokenizer, hf_model, hf_tokenizer - ) + # validate_output( + # keras_hub_model, keras_hub_tokenizer, hf_model, hf_tokenizer + # ) print("\n-> Tests passed!") From 750412ce01fdaad9158c17a4282d5e63bd7e481a Mon Sep 17 00:00:00 2001 From: Anshuman Mishra Date: Tue, 20 May 2025 14:45:39 +0000 Subject: [PATCH 3/7] update --- .../src/models/qwen3_moe/qwen3_moe_layernorm.py | 2 +- .../src/utils/transformers/convert_qwen3_moe.py | 17 +---------------- .../src/utils/transformers/preset_loader.py | 3 ++- 3 files changed, 4 insertions(+), 18 deletions(-) diff --git a/keras_hub/src/models/qwen3_moe/qwen3_moe_layernorm.py b/keras_hub/src/models/qwen3_moe/qwen3_moe_layernorm.py index 1c232cdf9e..c21da3cca6 100644 --- a/keras_hub/src/models/qwen3_moe/qwen3_moe_layernorm.py +++ b/keras_hub/src/models/qwen3_moe/qwen3_moe_layernorm.py @@ -35,4 +35,4 @@ def call(self, x): def get_config(self): config = super().get_config() config.update({"epsilon": self.epsilon}) - return config \ No newline at end of file + return config diff --git a/keras_hub/src/utils/transformers/convert_qwen3_moe.py b/keras_hub/src/utils/transformers/convert_qwen3_moe.py index 083ac81aa3..1d60977707 100644 --- a/keras_hub/src/utils/transformers/convert_qwen3_moe.py +++ b/keras_hub/src/utils/transformers/convert_qwen3_moe.py @@ -10,7 +10,7 @@ def convert_backbone_config(transformers_config): return { "vocabulary_size": transformers_config["vocab_size"], "hidden_dim": transformers_config["hidden_size"], - "head_dim": transformers_config['head_dim'], + "head_dim": transformers_config["head_dim"], "num_layers": transformers_config["num_hidden_layers"], "num_query_heads": transformers_config["num_attention_heads"], "num_key_value_heads": transformers_config["num_key_value_heads"], @@ -66,33 +66,18 @@ def transpose_and_reshape(x, shape): hf_weight_key=f"model.layers.{i}.self_attn.q_proj.weight", hook_fn=transpose_and_reshape, ) - # loader.port_weight( - # keras_variable=decoder_layer._self_attention_layer._query_dense.bias, - # hf_weight_key=f"model.layers.{i}.self_attn.q_proj.bias", - # hook_fn=transpose_and_reshape, - # ) ## Key loader.port_weight( keras_variable=decoder_layer._self_attention_layer._key_dense.kernel, hf_weight_key=f"model.layers.{i}.self_attn.k_proj.weight", hook_fn=transpose_and_reshape, ) - # loader.port_weight( - # keras_variable=decoder_layer._self_attention_layer._key_dense.bias, - # hf_weight_key=f"model.layers.{i}.self_attn.k_proj.bias", - # hook_fn=transpose_and_reshape, - # ) ## Value loader.port_weight( keras_variable=decoder_layer._self_attention_layer._value_dense.kernel, hf_weight_key=f"model.layers.{i}.self_attn.v_proj.weight", hook_fn=transpose_and_reshape, ) - # loader.port_weight( - # keras_variable=decoder_layer._self_attention_layer._value_dense.bias, - # hf_weight_key=f"model.layers.{i}.self_attn.v_proj.bias", - # hook_fn=transpose_and_reshape, - # ) ## Output loader.port_weight( keras_variable=decoder_layer._self_attention_layer._output_dense.kernel, diff --git a/keras_hub/src/utils/transformers/preset_loader.py b/keras_hub/src/utils/transformers/preset_loader.py index cfae088e3d..1574d21a3b 100644 --- a/keras_hub/src/utils/transformers/preset_loader.py +++ b/keras_hub/src/utils/transformers/preset_loader.py @@ -3,7 +3,7 @@ from keras_hub.src.models.image_classifier import ImageClassifier from keras_hub.src.utils.preset_utils import PresetLoader from keras_hub.src.utils.preset_utils import jax_memory_cleanup -from keras_hub.src.utils.transformers import convert_albert, convert_qwen3_moe +from keras_hub.src.utils.transformers import convert_albert from keras_hub.src.utils.transformers import convert_bart from keras_hub.src.utils.transformers import convert_bert from keras_hub.src.utils.transformers import convert_distilbert @@ -14,6 +14,7 @@ from keras_hub.src.utils.transformers import convert_mixtral from keras_hub.src.utils.transformers import convert_pali_gemma from keras_hub.src.utils.transformers import convert_qwen +from keras_hub.src.utils.transformers import convert_qwen3_moe from keras_hub.src.utils.transformers import convert_qwen_moe from keras_hub.src.utils.transformers import convert_vit from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader From 6b741719f4f2f843190a57a6c626ca71a7f0e964 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra Date: Sat, 24 May 2025 17:14:38 +0000 Subject: [PATCH 4/7] address comments --- .../models/qwen3_moe/qwen3_moe_attention.py | 18 +--------------- .../models/qwen3_moe/qwen3_moe_backbone.py | 12 ----------- .../src/models/qwen3_moe/qwen3_moe_decoder.py | 21 +++---------------- 3 files changed, 4 insertions(+), 47 deletions(-) diff --git a/keras_hub/src/models/qwen3_moe/qwen3_moe_attention.py b/keras_hub/src/models/qwen3_moe/qwen3_moe_attention.py index da93e5d8e3..78467dd9ca 100644 --- a/keras_hub/src/models/qwen3_moe/qwen3_moe_attention.py +++ b/keras_hub/src/models/qwen3_moe/qwen3_moe_attention.py @@ -22,8 +22,6 @@ class Qwen3MoeAttention(keras.layers.Layer): context length. kernel_initializer: Initializer for the kernel weights. dropout: Dropout rate for attention weights. - use_sliding_window_attention: Whether to use sliding window - attention. sliding_window_size: Size of the sliding window for attention. **kwargs: Additional keyword arguments to pass to the Layer. """ @@ -38,7 +36,6 @@ def __init__( rope_scaling_factor=1, kernel_initializer="glorot_uniform", dropout=0, - use_sliding_window_attention=False, layer_norm_epsilon=1e-5, sliding_window_size=4096, max_window_layers=28, @@ -63,16 +60,6 @@ def __init__( self.layer_index = layer_index self.rope_scaling_factor = rope_scaling_factor - self.use_sliding_window_attention = use_sliding_window_attention - - if ( - not self.use_sliding_window_attention - and sliding_window_size - and self.layer_index >= max_window_layers - ): - self.sliding_window_size = None - else: - self.sliding_window_size = sliding_window_size def build(self, inputs_shape): # Einsum variables: @@ -305,7 +292,7 @@ def _compute_attention( attention_scores, ops.cast(self._inv_norm_factor, self.compute_dtype), ) - if self.use_sliding_window_attention: + if self.sliding_window_size: attention_mask = self._mask_sliding_window( attention_mask, cache_update_index=cache_update_index @@ -367,9 +354,6 @@ def get_config(self): self.kernel_initializer ), "dropout": self.dropout, - "use_sliding_window_attention": ( - self.use_sliding_window_attention - ), "sliding_window_size": self.sliding_window_size, } ) diff --git a/keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py b/keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py index 8a659edb5c..34900e26dc 100644 --- a/keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +++ b/keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py @@ -53,8 +53,6 @@ class Qwen3MoeBackbone(Backbone): layer_norm_epsilon: float. The epsilon value used for every layer norm in the transformer model. dropout: float. Dropout probability for the transformer encoder. - use_sliding_window_attention: bool. Whether to use sliding local window - attention. Defaults to False. sliding_window_size: int. Size of the sliding local window. Defaults to 4096. max_sequence_length: int. The maximum sequence length supported by the @@ -112,9 +110,7 @@ def __init__( dropout=0, dtype=None, tie_word_embeddings=False, - use_sliding_window_attention=False, sliding_window_size=32768, - output_router_logits=False, router_aux_loss_coefficient=0.001, mlp_only_layers=[], training=None, @@ -148,9 +144,7 @@ def __init__( kernel_initializer=_qwen_moe_kernel_initializer(stddev=0.02), dropout=dropout, dtype=dtype, - use_sliding_window_attention=use_sliding_window_attention, sliding_window_size=sliding_window_size, - output_router_logits=output_router_logits, router_aux_loss_coefficient=router_aux_loss_coefficient, mlp_only_layers=mlp_only_layers, name=f"transformer_layer_{i}", @@ -198,7 +192,6 @@ def __init__( self.layer_norm_epsilon = layer_norm_epsilon self.dropout = dropout self.tie_word_embeddings = tie_word_embeddings - self.use_sliding_window_attention = use_sliding_window_attention self.sliding_window_size = sliding_window_size self.num_experts = num_experts self.top_k = top_k @@ -206,7 +199,6 @@ def __init__( self.decoder_sparse_step = decoder_sparse_step self.mlp_only_layers = mlp_only_layers self.router_aux_loss_coefficient = router_aux_loss_coefficient - self.output_router_logits = output_router_logits def get_config(self): config = super().get_config() @@ -224,16 +216,12 @@ def get_config(self): "layer_norm_epsilon": self.layer_norm_epsilon, "dropout": self.dropout, "tie_word_embeddings": self.tie_word_embeddings, - "use_sliding_window_attention": ( - self.use_sliding_window_attention - ), "sliding_window_size": self.sliding_window_size, "num_experts": self.num_experts, "top_k": self.top_k, "norm_top_k_prob": self.norm_top_k_prob, "decoder_sparse_step": self.decoder_sparse_step, "mlp_only_layers": self.mlp_only_layers, - "output_router_logits": self.output_router_logits, } ) return config diff --git a/keras_hub/src/models/qwen3_moe/qwen3_moe_decoder.py b/keras_hub/src/models/qwen3_moe/qwen3_moe_decoder.py index 351e05db1c..d6146401b5 100644 --- a/keras_hub/src/models/qwen3_moe/qwen3_moe_decoder.py +++ b/keras_hub/src/models/qwen3_moe/qwen3_moe_decoder.py @@ -156,9 +156,7 @@ def call(self, x): x = self._feedforward_intermediate_dense(x) - x = self._feedforward_output_dense(ops.multiply(x, gate_output)) - - return x + return self._feedforward_output_dense(ops.multiply(x, gate_output)) class Qwen3MoeExperts(keras.layers.Layer): @@ -328,11 +326,9 @@ def __init__( layer_norm_epsilon=1e-5, kernel_initializer="glorot_uniform", dropout=0, - use_sliding_window_attention=False, sliding_window_size=4096, layer_index=0, mlp_only_layers=[], - output_router_logits=False, router_aux_loss_coefficient=0.001, **kwargs, ): @@ -343,7 +339,6 @@ def __init__( self.rope_max_wavelength = rope_max_wavelength self.rope_scaling_factor = rope_scaling_factor self.dropout = dropout - self.use_sliding_window_attention = use_sliding_window_attention self.sliding_window_size = sliding_window_size self.activation = keras.activations.get(activation) self.layer_norm_epsilon = layer_norm_epsilon @@ -356,7 +351,6 @@ def __init__( self.top_k = top_k self.norm_top_k_prob = norm_top_k_prob self.decoder_sparse_step = decoder_sparse_step - self.output_router_logits = output_router_logits self.router_aux_loss_coefficient = router_aux_loss_coefficient self.supports_masking = True @@ -373,7 +367,6 @@ def build(self, decoder_sequence_shape): rope_scaling_factor=self.rope_scaling_factor, kernel_initializer=clone_initializer(self.kernel_initializer), dropout=self.dropout, - use_sliding_window_attention=self.use_sliding_window_attention, sliding_window_size=self.sliding_window_size, layer_index=self.layer_index, name="self_attention", @@ -488,10 +481,9 @@ def call( ) else: x = self.mlp(x) + if isinstance(x, tuple): - x, router_logits = x - else: - router_logits = None + x, _ = x x = ops.cast(x, ops.dtype(residual)) decoder_output = x + residual @@ -501,9 +493,6 @@ def call( if self_attention_cache is not None: output += (self_attention_cache,) - if self.output_router_logits: - output += (router_logits,) - return output[0] if len(output) == 1 else output def _compute_self_attention_mask( @@ -583,16 +572,12 @@ def get_config(self): "rope_scaling_factor": self.rope_scaling_factor, "layer_norm_epsilon": self.layer_norm_epsilon, "dropout": self.dropout, - "use_sliding_window_attention": ( - self.use_sliding_window_attention - ), "sliding_window_size": self.sliding_window_size, "num_experts": self.num_experts, "top_k": self.top_k, "norm_top_k_prob": self.norm_top_k_prob, "decoder_sparse_step": self.decoder_sparse_step, "mlp_only_layers": self.mlp_only_layers, - "output_router_logits": self.output_router_logits, "router_aux_loss_coefficient": self.router_aux_loss_coefficient, } ) From 730a9c41e95a74906a041d8933b8d7738391b438 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra Date: Tue, 27 May 2025 15:08:49 +0000 Subject: [PATCH 5/7] address comments --- .../models/qwen3_moe/qwen3_moe_attention.py | 1 + .../models/qwen3_moe/qwen3_moe_backbone.py | 12 +++- .../qwen3_moe/qwen3_moe_backbone_test.py | 70 +++++++++++++++++++ .../src/models/qwen3_moe/qwen3_moe_decoder.py | 15 ++-- 4 files changed, 84 insertions(+), 14 deletions(-) create mode 100644 keras_hub/src/models/qwen3_moe/qwen3_moe_backbone_test.py diff --git a/keras_hub/src/models/qwen3_moe/qwen3_moe_attention.py b/keras_hub/src/models/qwen3_moe/qwen3_moe_attention.py index 78467dd9ca..152da540e8 100644 --- a/keras_hub/src/models/qwen3_moe/qwen3_moe_attention.py +++ b/keras_hub/src/models/qwen3_moe/qwen3_moe_attention.py @@ -60,6 +60,7 @@ def __init__( self.layer_index = layer_index self.rope_scaling_factor = rope_scaling_factor + self.sliding_window_size = sliding_window_size def build(self, inputs_shape): # Einsum variables: diff --git a/keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py b/keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py index 34900e26dc..dc78ff345c 100644 --- a/keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +++ b/keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py @@ -20,7 +20,7 @@ def _qwen_moe_kernel_initializer(stddev=0.02): "keras_hub.models.Qwen3MoeBackbone", ) class Qwen3MoeBackbone(Backbone): - """Qwen MoE core network with hyperparameters. + """Qwen3 MoE core network with hyperparameters. This backbone implements the base Transformer network for the Qwen MoE model. It includes embedding lookups and transformer layers with a Mixture @@ -127,6 +127,11 @@ def __init__( ) self.transformer_layers = [] for i in range(num_layers): + is_sparse_mlp = ( + (i not in mlp_only_layers) + and num_experts > 0 + and (i + 1) % decoder_sparse_step == 0 + ) layer = Qwen3MoeTransformerDecoder( intermediate_dim=intermediate_dim, num_query_heads=num_query_heads, @@ -136,7 +141,6 @@ def __init__( num_experts=num_experts, top_k=top_k, norm_top_k_prob=norm_top_k_prob, - decoder_sparse_step=decoder_sparse_step, rope_max_wavelength=rope_max_wavelength, rope_scaling_factor=rope_scaling_factor, layer_norm_epsilon=layer_norm_epsilon, @@ -146,7 +150,7 @@ def __init__( dtype=dtype, sliding_window_size=sliding_window_size, router_aux_loss_coefficient=router_aux_loss_coefficient, - mlp_only_layers=mlp_only_layers, + is_sparse_mlp=is_sparse_mlp, name=f"transformer_layer_{i}", ) self.transformer_layers.append(layer) @@ -186,6 +190,7 @@ def __init__( self.hidden_dim = hidden_dim self.intermediate_dim = intermediate_dim self.moe_intermediate_dim = moe_intermediate_dim + self.head_dim = head_dim self.rope_max_wavelength = rope_max_wavelength self.num_key_value_heads = num_key_value_heads self.rope_scaling_factor = rope_scaling_factor @@ -207,6 +212,7 @@ def get_config(self): "vocabulary_size": self.vocabulary_size, "num_layers": self.num_layers, "num_query_heads": self.num_query_heads, + "head_dim": self.head_dim, "hidden_dim": self.hidden_dim, "intermediate_dim": self.intermediate_dim, "moe_intermediate_dim": self.moe_intermediate_dim, diff --git a/keras_hub/src/models/qwen3_moe/qwen3_moe_backbone_test.py b/keras_hub/src/models/qwen3_moe/qwen3_moe_backbone_test.py new file mode 100644 index 0000000000..5373ed3a0b --- /dev/null +++ b/keras_hub/src/models/qwen3_moe/qwen3_moe_backbone_test.py @@ -0,0 +1,70 @@ +import pytest +from keras import ops + +from keras_hub.src.models.qwen3_moe.qwen3_moe_backbone import Qwen3MoeBackbone +from keras_hub.src.tests.test_case import TestCase + + +class Qwen3MoeBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "vocabulary_size": 20, + "num_layers": 2, + "num_query_heads": 4, + "num_key_value_heads": 2, + "hidden_dim": 16, + "intermediate_dim": 32, + "head_dim": 2, + "moe_intermediate_dim": 16, + "shared_expert_intermediate_dim": 32, + "num_experts": 4, + "top_k": 2, + "norm_top_k_prob": True, + "decoder_sparse_step": 1, + "layer_norm_epsilon": 1e-6, + "rope_max_wavelength": 10000, + "rope_scaling_factor": 1.0, + "dropout": 0.0, + "sliding_window_size": 4096, + "router_aux_loss_coefficient": 0.01, + "tie_word_embeddings": False, + "mlp_only_layers": [], + "dtype": "float32", # Explicitly set dtype to avoid mixed precision + } + self.input_data = { + "token_ids": ops.ones((2, 7), dtype="int32"), + "padding_mask": ops.ones((2, 7), dtype="int32"), + } + + def test_backbone_basics(self): + self.run_backbone_test( + cls=Qwen3MoeBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 7, 16), + run_quantization_check=False, + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=Qwen3MoeBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + def test_architecture_characteristics(self): + model = Qwen3MoeBackbone(**self.init_kwargs) + expected_params = 7768 + self.assertEqual(model.count_params(), expected_params) + expected_layers = 6 + self.assertEqual(len(model.layers), expected_layers) + + def test_auxiliary_loss(self): + model = Qwen3MoeBackbone(**self.init_kwargs) + _ = model(self.input_data, training=True) + self.assertTrue( + len(model.losses) > 0, "Auxiliary losses should be present" + ) + for loss in model.losses: + self.assertGreater(loss, 0.0, "Auxiliary loss should be positive") diff --git a/keras_hub/src/models/qwen3_moe/qwen3_moe_decoder.py b/keras_hub/src/models/qwen3_moe/qwen3_moe_decoder.py index d6146401b5..bb523c8fee 100644 --- a/keras_hub/src/models/qwen3_moe/qwen3_moe_decoder.py +++ b/keras_hub/src/models/qwen3_moe/qwen3_moe_decoder.py @@ -214,7 +214,7 @@ def call(self, hidden_states): class Qwen3SparseMoeBlock(keras.layers.Layer): - """Qwen-2 Sparse Moe Block""" + """Qwen-3 Sparse Moe Block""" def __init__( self, @@ -319,7 +319,7 @@ def __init__( num_experts, top_k, norm_top_k_prob, - decoder_sparse_step, + is_sparse_mlp=False, rope_max_wavelength=10000, rope_scaling_factor=1.0, activation="silu", @@ -328,7 +328,6 @@ def __init__( dropout=0, sliding_window_size=4096, layer_index=0, - mlp_only_layers=[], router_aux_loss_coefficient=0.001, **kwargs, ): @@ -344,13 +343,12 @@ def __init__( self.layer_norm_epsilon = layer_norm_epsilon self.kernel_initializer = keras.initializers.get(kernel_initializer) self.layer_index = layer_index - self.mlp_only_layers = mlp_only_layers self.moe_intermediate_dim = moe_intermediate_dim self.head_dim = head_dim self.num_experts = num_experts self.top_k = top_k self.norm_top_k_prob = norm_top_k_prob - self.decoder_sparse_step = decoder_sparse_step + self.is_sparse_mlp = is_sparse_mlp self.router_aux_loss_coefficient = router_aux_loss_coefficient self.supports_masking = True @@ -388,10 +386,7 @@ def build(self, decoder_sequence_shape): ) # Feedforward layers. - if (self.layer_index not in self.mlp_only_layers) and ( - self.num_experts > 0 - and (self.layer_index + 1) % self.decoder_sparse_step == 0 - ): + if self.is_sparse_mlp: self.mlp = Qwen3SparseMoeBlock( hidden_dim=self.hidden_dim, moe_intermediate_dim=self.moe_intermediate_dim, @@ -576,8 +571,6 @@ def get_config(self): "num_experts": self.num_experts, "top_k": self.top_k, "norm_top_k_prob": self.norm_top_k_prob, - "decoder_sparse_step": self.decoder_sparse_step, - "mlp_only_layers": self.mlp_only_layers, "router_aux_loss_coefficient": self.router_aux_loss_coefficient, } ) From 5f90d107bc9df8d042475e4d9d6dc2e7d1493dd1 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra Date: Tue, 27 May 2025 15:24:27 +0000 Subject: [PATCH 6/7] update output matching script --- .../convert_qwen3_moe_checkpoints.py | 28 +++++++++---------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/tools/checkpoint_conversion/convert_qwen3_moe_checkpoints.py b/tools/checkpoint_conversion/convert_qwen3_moe_checkpoints.py index fba956a516..95b4c91228 100644 --- a/tools/checkpoint_conversion/convert_qwen3_moe_checkpoints.py +++ b/tools/checkpoint_conversion/convert_qwen3_moe_checkpoints.py @@ -7,8 +7,7 @@ import numpy as np import torch from absl import app - -# from absl import flags +from absl import flags device = torch.device("cpu") # Force PyTorch to use CPU @@ -24,10 +23,10 @@ "qwen3_moe_3b_en": "Qwen/Qwen3-30B-A3B", } -# FLAGS = flags.FLAGS -# flags.DEFINE_string( -# "preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}" -# ) +FLAGS = flags.FLAGS +flags.DEFINE_string( + "preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}" +) def test_model( @@ -120,14 +119,13 @@ def validate_output( def main(_): # === Get the preset name === - # if FLAGS.preset not in PRESET_MAP.keys(): - # raise ValueError( - # f"Invalid preset {FLAGS.preset}. Must be one " - # f"of {','.join(PRESET_MAP.keys())}" - # ) - # preset = FLAGS.preset - # hf_preset = PRESET_MAP[preset] - hf_preset = "Qwen/Qwen3-30B-A3B" + if FLAGS.preset not in PRESET_MAP.keys(): + raise ValueError( + f"Invalid preset {FLAGS.preset}. Must be one " + f"of {','.join(PRESET_MAP.keys())}" + ) + preset = FLAGS.preset + hf_preset = PRESET_MAP[preset] # === Load the Huggingface model === hf_model = AutoModelForCausalLM.from_pretrained( @@ -158,5 +156,5 @@ def main(_): if __name__ == "__main__": - # flags.mark_flag_as_required("preset") + flags.mark_flag_as_required("preset") app.run(main) From cda9cfc563906a3804a6f3a002ecf68d3c2b36f5 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra Date: Tue, 27 May 2025 21:40:34 +0530 Subject: [PATCH 7/7] fix test --- keras_hub/src/models/qwen3_moe/qwen3_moe_backbone_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/keras_hub/src/models/qwen3_moe/qwen3_moe_backbone_test.py b/keras_hub/src/models/qwen3_moe/qwen3_moe_backbone_test.py index 5373ed3a0b..cdfd5440e9 100644 --- a/keras_hub/src/models/qwen3_moe/qwen3_moe_backbone_test.py +++ b/keras_hub/src/models/qwen3_moe/qwen3_moe_backbone_test.py @@ -16,7 +16,6 @@ def setUp(self): "intermediate_dim": 32, "head_dim": 2, "moe_intermediate_dim": 16, - "shared_expert_intermediate_dim": 32, "num_experts": 4, "top_k": 2, "norm_top_k_prob": True,