diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index 0a71dbcace..48a188f5d2 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -444,6 +444,12 @@ from keras_hub.src.models.qwen.qwen_tokenizer import ( QwenTokenizer as QwenTokenizer, ) +from keras_hub.src.models.qwen3_moe.qwen3_causal_lm_preprocessor import ( + Qwen3MoeCausalLMPreprocessor as Qwen3MoeCausalLMPreprocessor, +) +from keras_hub.src.models.qwen3_moe.qwen3_moe_backbone import ( + Qwen3MoeBackbone as Qwen3MoeBackbone, +) from keras_hub.src.models.qwen_moe.qwen_moe_backbone import ( QwenMoeBackbone as QwenMoeBackbone, ) diff --git a/keras_hub/api/tokenizers/__init__.py b/keras_hub/api/tokenizers/__init__.py index 082078184f..532cf8b455 100644 --- a/keras_hub/api/tokenizers/__init__.py +++ b/keras_hub/api/tokenizers/__init__.py @@ -74,6 +74,9 @@ from keras_hub.src.models.qwen.qwen_tokenizer import ( QwenTokenizer as QwenTokenizer, ) +from keras_hub.src.models.qwen3_moe.qwen3_moe_tokenizer import ( + Qwen3MoeTokenizer as Qwen3MoeTokenizer, +) from keras_hub.src.models.qwen_moe.qwen_moe_tokenizer import ( QwenMoeTokenizer as QwenMoeTokenizer, ) diff --git a/keras_hub/src/models/qwen3_moe/qwen3_causal_lm_preprocessor.py b/keras_hub/src/models/qwen3_moe/qwen3_causal_lm_preprocessor.py new file mode 100644 index 0000000000..6434b959ce --- /dev/null +++ b/keras_hub/src/models/qwen3_moe/qwen3_causal_lm_preprocessor.py @@ -0,0 +1,17 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor +from keras_hub.src.models.qwen3_moe.qwen3_moe_backbone import Qwen3MoeBackbone +from keras_hub.src.models.qwen3_moe.qwen3_moe_tokenizer import Qwen3MoeTokenizer + + +@keras_hub_export( + [ + "keras_hub.models.Qwen3MoeCausalLMPreprocessor", + ] +) +class Qwen3MoeCausalLMPreprocessor(CausalLMPreprocessor): + backbone_cls = Qwen3MoeBackbone + tokenizer_cls = Qwen3MoeTokenizer + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) diff --git a/keras_hub/src/models/qwen3_moe/qwen3_moe_attention.py b/keras_hub/src/models/qwen3_moe/qwen3_moe_attention.py new file mode 100644 index 0000000000..152da540e8 --- /dev/null +++ b/keras_hub/src/models/qwen3_moe/qwen3_moe_attention.py @@ -0,0 +1,361 @@ +import math + +import keras +from keras import ops + +from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding +from keras_hub.src.models.qwen3_moe.qwen3_moe_layernorm import Qwen3MoeLayerNorm +from keras_hub.src.utils.keras_utils import clone_initializer +from keras_hub.src.utils.keras_utils import fused_attention_op_available + + +class Qwen3MoeAttention(keras.layers.Layer): + """A multi-head attention layer for Qwen3Moe models + This attention implementation supports grouped-query attention (GQA) where + the number of key-value heads can be less than the number of query heads. + Args: + num_query_heads: Number of query heads. + num_key_value_heads: Number of key/value heads (for GQA). + rope_max_wavelength: Maximum wavelength for RoPE (Rotary Position + Embedding). + rope_scaling_factor: Scaling factor for RoPE, used for extending + context length. + kernel_initializer: Initializer for the kernel weights. + dropout: Dropout rate for attention weights. + sliding_window_size: Size of the sliding window for attention. + **kwargs: Additional keyword arguments to pass to the Layer. + """ + + def __init__( + self, + num_query_heads, + num_key_value_heads, + layer_index, + head_dim, + rope_max_wavelength=10000, + rope_scaling_factor=1, + kernel_initializer="glorot_uniform", + dropout=0, + layer_norm_epsilon=1e-5, + sliding_window_size=4096, + max_window_layers=28, + **kwargs, + ): + super().__init__( + **kwargs, + ) + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.dropout = dropout + + self.layer_norm_epsilon = layer_norm_epsilon + + self.num_key_value_groups = num_query_heads // num_key_value_heads + self.rope_max_wavelength = rope_max_wavelength + + self.kernel_initializer = keras.initializers.get( + clone_initializer(kernel_initializer) + ) + self.layer_index = layer_index + + self.rope_scaling_factor = rope_scaling_factor + self.sliding_window_size = sliding_window_size + + def build(self, inputs_shape): + # Einsum variables: + # b = batch size + # q = query length + # k = key/value length + # m = model dim + # u = num query heads + # v = num key/value heads + # h = head dim + hidden_dim = inputs_shape[-1] + if not self.head_dim: + self.head_dim = hidden_dim // self.num_query_heads + + self._inv_norm_factor = 1.0 / math.sqrt(self.head_dim) + self._query_dense = keras.layers.EinsumDense( + equation="bqm,muh->bquh", + output_shape=(None, self.num_query_heads, self.head_dim), + kernel_initializer=self.kernel_initializer, + dtype=self.dtype_policy, + name="query", + ) + self._query_dense.build(inputs_shape) + + self._query_dense_layer_norm = Qwen3MoeLayerNorm( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + hidden_dim=self.head_dim, + name="query_dense_layernorm", + ) + self._query_dense_layer_norm.build(inputs_shape) + + self._key_dense = keras.layers.EinsumDense( + equation="bkm,mvh->bkvh", + output_shape=( + None, + self.num_key_value_heads, + self.head_dim, + ), + kernel_initializer=self.kernel_initializer, + dtype=self.dtype_policy, + name="key", + ) + self._key_dense.build(inputs_shape) + + self._key_dense_layer_norm = Qwen3MoeLayerNorm( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + hidden_dim=self.head_dim, + name="key_dense_layernorm", + ) + self._key_dense_layer_norm.build(inputs_shape) + + self._value_dense = keras.layers.EinsumDense( + equation="bkm,mvh->bkvh", + output_shape=( + None, + self.num_key_value_heads, + self.head_dim, + ), + kernel_initializer=self.kernel_initializer, + dtype=self.dtype_policy, + name="value", + ) + self._value_dense.build(inputs_shape) + + self._softmax = keras.layers.Softmax( + axis=-1, + dtype="float32", + name="attention_softmax", + ) + + self._dropout_layer = keras.layers.Dropout( + rate=self.dropout, + dtype=self.dtype_policy, + ) + + self._output_dense = keras.layers.EinsumDense( + equation="bquh,uhm->bqm", + output_shape=(None, hidden_dim), + kernel_initializer=self.kernel_initializer, + dtype=self.dtype_policy, + name="attention_output", + ) + self._output_dense.build( + (None, None, self.num_query_heads, self.head_dim) + ) + + self.rotary_embedding_layer = RotaryEmbedding( + max_wavelength=self.rope_max_wavelength, + scaling_factor=self.rope_scaling_factor, + dtype=self.dtype_policy, + ) + + self._dot_product_equation = "bquh,bkuh->buqk" + self._combine_equation = "buqk,bkuh->bquh" + + self.built = True + + def call( + self, + hidden_states, + attention_mask=None, + cache=None, + cache_update_index=None, + training=None, + ): + """Applies attention mechanism to the input hidden states. + Args: + hidden_states: Input tensor of shape [batch_size, seq_length, + hidden_size]. + attention_mask: Mask tensor of shape [batch_size, seq_length, + seq_length]. + cache: Optional cached key and value tensors. + cache_update_index: Index at which to update the cache. + training: Boolean indicating whether in training mode. + Returns: + attention_output: Output tensor after applying attention. + cache: Updated cache tensors (if cache is provided). + """ + start_index = ( + cache_update_index if cache_update_index is not None else 0 + ) + + query = self._query_dense(hidden_states) + query = self._query_dense_layer_norm(query) + + # Compute RoPE for queries + query = self.rotary_embedding_layer(query, start_index=start_index) + + def _compute_key_value(x): + key = self._key_dense(x) + key = self._key_dense_layer_norm(key) + key = self.rotary_embedding_layer(key, start_index=start_index) + + value = self._value_dense(x) + + return key, value + + if cache is not None: + key_cache = cache[:, 0, ...] + value_cache = cache[:, 1, ...] + if cache_update_index is None: + key = key_cache + value = value_cache + else: + key_update, value_update = _compute_key_value(hidden_states) + start = [0, cache_update_index, 0, 0] + key = ops.slice_update(key_cache, start, key_update) + value = ops.slice_update(value_cache, start, value_update) + cache = ops.stack((key, value), axis=1) + else: + if cache_update_index is not None: + raise ValueError( + "`cache_update_index` should not be set if `cache` is " + f"`None`. Received: cache={cache}, " + f"cache_update_index={cache_update_index}" + ) + key, value = _compute_key_value(hidden_states) + + # [batch_shape, seq_len, num_key_value_heads, head_dim] + # -> [batch_shape, seq_len, num_heads, head_dim] + key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2) + value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2) + + attention_output = self._compute_attention( + query, + key, + value, + attention_mask, + cache_update_index=cache_update_index, + ) + + attention_output = self._dropout_layer( + attention_output, training=training + ) + + attention_output = self._output_dense(attention_output) + + if cache is not None: + return attention_output, cache + return attention_output + + def _masked_softmax(self, attention_scores, attention_mask=None): + """Applies softmax with optional masking. + Args: + attention_scores: Attention score tensor. + attention_mask: Optional mask tensor. + Returns: + Masked softmax attention weights. + """ + if attention_mask is not None: + return self._softmax( + attention_scores, attention_mask[:, None, :, :] + ) + return self._softmax(attention_scores) + + def _compute_attention( + self, query, key, value, attention_mask=None, cache_update_index=None + ): + """Computes attention using query, key, and value tensors. + Uses Flash Attention when available for better performance. + Args: + query: Query tensor. + key: Key tensor. + value: Value tensor. + attention_mask: Optional mask tensor. + cache_update_index: Index for sliding window computation. + Returns: + attention_output: Output tensor after applying attention. + """ + if fused_attention_op_available(): + # Use `dot_product_attention` with Flash Attention support if + # available. + if attention_mask is not None: + attention_mask = ops.expand_dims(attention_mask, axis=1) + attention_mask = ops.cast(attention_mask, dtype="bool") + attention_output = ops.dot_product_attention( + query, + key, + value, + mask=attention_mask, + scale=self._inv_norm_factor, + ) + return attention_output + + attention_scores = ops.einsum(self._dot_product_equation, query, key) + + attention_scores = ops.multiply( + attention_scores, + ops.cast(self._inv_norm_factor, self.compute_dtype), + ) + if self.sliding_window_size: + attention_mask = self._mask_sliding_window( + attention_mask, + cache_update_index=cache_update_index + if cache_update_index + else 0, + ) + attention_scores = self._masked_softmax( + attention_scores, attention_mask + ) + attention_scores = ops.cast(attention_scores, self.compute_dtype) + attention_output = ops.einsum( + self._combine_equation, attention_scores, value + ) + + return attention_output + + def _mask_sliding_window( + self, + attention_mask, + cache_update_index=0, + ): + """Creates and combines a sliding window mask with the attention mask. + Args: + attention_mask: Original attention mask. + cache_update_index: Starting index for the sliding window. + Returns: + Combined attention mask with sliding window constraints. + """ + _, query_len, key_len = ops.shape(attention_mask) + # Compute the sliding window for square attention. + all_ones = ops.ones((key_len, key_len), "bool") + if keras.config.backend() == "tensorflow": + # TODO: trui/tril has issues with dynamic shape on the tensorflow + # backend. We should fix, but use `band_part` for now. + import tensorflow as tf + + band_size = ops.minimum(key_len, self.sliding_window_size - 1) + band_size = ops.cast(band_size, "int32") + sliding_mask = tf.linalg.band_part(all_ones, band_size, band_size) + else: + sliding_mask = ops.triu( + all_ones, -1 * self.sliding_window_size + 1 + ) * ops.tril(all_ones, self.sliding_window_size - 1) + # Slice the window for short queries during generation. + start = (cache_update_index, 0) + sliding_mask = ops.slice(sliding_mask, start, (query_len, key_len)) + sliding_mask = ops.expand_dims(sliding_mask, 0) + return ops.logical_and(attention_mask, ops.cast(sliding_mask, "bool")) + + def get_config(self): + config = super().get_config() + config.update( + { + "num_query_heads": self.num_query_heads, + "num_key_value_heads": self.num_key_value_heads, + "rope_max_wavelength": self.rope_max_wavelength, + "rope_scaling_factor": self.rope_scaling_factor, + "kernel_initializer": keras.initializers.serialize( + self.kernel_initializer + ), + "dropout": self.dropout, + "sliding_window_size": self.sliding_window_size, + } + ) + return config diff --git a/keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py b/keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py new file mode 100644 index 0000000000..dc78ff345c --- /dev/null +++ b/keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py @@ -0,0 +1,360 @@ +import keras +from keras import ops + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.modeling.reversible_embedding import ( + ReversibleEmbedding, +) +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.qwen.qwen_layernorm import QwenLayerNorm +from keras_hub.src.models.qwen3_moe.qwen3_moe_decoder import ( + Qwen3MoeTransformerDecoder, +) + + +def _qwen_moe_kernel_initializer(stddev=0.02): + return keras.initializers.RandomNormal(stddev=stddev) + + +@keras_hub_export( + "keras_hub.models.Qwen3MoeBackbone", +) +class Qwen3MoeBackbone(Backbone): + """Qwen3 MoE core network with hyperparameters. + + This backbone implements the base Transformer network for the Qwen MoE + model. It includes embedding lookups and transformer layers with a Mixture + of Experts (MoE) architecture, where each layer uses a sparse set of experts + for efficient computation. This backbone outputs the final hidden states for + each token, not generative predictions over the vocabulary space. For higher + -level object for text generation, see `keras_hub.models.Qwen3MoeCausalLM`. + + The default constructor gives a fully customizable, randomly initialized + Qwen MoE model with any number of layers, heads, and embedding dimensions. + To load preset architectures and weights, use the `from_preset` constructor. + + Args: + vocabulary_size: int. The size of the token vocabulary. + num_layers: int. The number of transformer layers. + num_query_heads: int. The number of heads for the query projections in + the attention layer. + num_key_value_heads: int. The number of heads for the key and value + projections in the attention layer. + hidden_dim: int. The size of the transformer hidden state at the end of + each transformer layer. + intermediate_dim: int. The output dimension of the first Dense layer in + the feedforward network for each transformer. + moe_intermediate_dim: int. The intermediate dimension for each expert + in the MoE feedforward network. + num_experts: int. The number of experts in each MoE layer. + top_k: int. The number of top experts to select for each token in the + MoE layer. + head_dim: int. The size of each attention head. + layer_norm_epsilon: float. The epsilon value used for every layer norm + in the transformer model. + dropout: float. Dropout probability for the transformer encoder. + sliding_window_size: int. Size of the sliding local window. Defaults to + 4096. + max_sequence_length: int. The maximum sequence length supported by the + model. Defaults to 4096. + dtype: str or `keras.mixed_precision.DTypePolicy`. The dtype to use for + the model's computations and weights. Note that some computations, + such as softmax and layer normalization, will always be done at + float32 precision regardless of dtype. + + Example: + ```python + input_data = { + "token_ids": np.ones(shape=(1, 12), dtype="int32"), + "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), + } + + # Pretrained Qwen MoE decoder. + model = keras_hub.models.Qwen3MoeBackbone.from_preset("qwen_moe_a2_7b") + model(input_data) + + # Randomly initialized Qwen MoE decoder with custom config. + model = keras_hub.models.Qwen3MoeBackbone( + vocabulary_size=151936, + num_layers=28, + num_query_heads=16, + num_key_value_heads=8, + hidden_dim=2048, + intermediate_dim=4096, + moe_intermediate_dim=128, + num_experts=60, + top_k=4, + head_dim=128, + max_sequence_length=4096, + ) + model(input_data) + """ + + def __init__( + self, + vocabulary_size, + num_layers, + num_query_heads, + num_key_value_heads, + head_dim, + hidden_dim, + intermediate_dim, + moe_intermediate_dim, + num_experts, + top_k=4, + norm_top_k_prob=False, + decoder_sparse_step=1, + rope_max_wavelength=10000, + rope_scaling_factor=1.0, + layer_norm_epsilon=1e-6, + dropout=0, + dtype=None, + tie_word_embeddings=False, + sliding_window_size=32768, + router_aux_loss_coefficient=0.001, + mlp_only_layers=[], + training=None, + **kwargs, + ): + # === Layers === + self.token_embedding = ReversibleEmbedding( + input_dim=vocabulary_size, + output_dim=hidden_dim, + tie_weights=tie_word_embeddings, + embeddings_initializer=_qwen_moe_kernel_initializer(stddev=0.01), + dtype=dtype, + name="token_embedding", + ) + self.transformer_layers = [] + for i in range(num_layers): + is_sparse_mlp = ( + (i not in mlp_only_layers) + and num_experts > 0 + and (i + 1) % decoder_sparse_step == 0 + ) + layer = Qwen3MoeTransformerDecoder( + intermediate_dim=intermediate_dim, + num_query_heads=num_query_heads, + num_key_value_heads=num_key_value_heads, + moe_intermediate_dim=moe_intermediate_dim, + head_dim=head_dim, + num_experts=num_experts, + top_k=top_k, + norm_top_k_prob=norm_top_k_prob, + rope_max_wavelength=rope_max_wavelength, + rope_scaling_factor=rope_scaling_factor, + layer_norm_epsilon=layer_norm_epsilon, + activation=ops.silu, + kernel_initializer=_qwen_moe_kernel_initializer(stddev=0.02), + dropout=dropout, + dtype=dtype, + sliding_window_size=sliding_window_size, + router_aux_loss_coefficient=router_aux_loss_coefficient, + is_sparse_mlp=is_sparse_mlp, + name=f"transformer_layer_{i}", + ) + self.transformer_layers.append(layer) + self.layer_norm = QwenLayerNorm( + epsilon=layer_norm_epsilon, + dtype=dtype, + name="sequence_output_layernorm", + ) + + # === Functional Model === + token_id_input = keras.Input( + shape=(None,), dtype="int32", name="token_ids" + ) + padding_mask_input = keras.Input( + shape=(None,), dtype="int32", name="padding_mask" + ) + x = self.token_embedding(token_id_input) + for transformer_layer in self.transformer_layers: + x = transformer_layer( + x, decoder_padding_mask=padding_mask_input, training=training + ) + sequence_output = self.layer_norm(x) + super().__init__( + inputs={ + "token_ids": token_id_input, + "padding_mask": padding_mask_input, + }, + outputs=sequence_output, + dtype=dtype, + **kwargs, + ) + + # === Config === + self.vocabulary_size = vocabulary_size + self.num_layers = num_layers + self.num_query_heads = num_query_heads + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + self.moe_intermediate_dim = moe_intermediate_dim + self.head_dim = head_dim + self.rope_max_wavelength = rope_max_wavelength + self.num_key_value_heads = num_key_value_heads + self.rope_scaling_factor = rope_scaling_factor + self.layer_norm_epsilon = layer_norm_epsilon + self.dropout = dropout + self.tie_word_embeddings = tie_word_embeddings + self.sliding_window_size = sliding_window_size + self.num_experts = num_experts + self.top_k = top_k + self.norm_top_k_prob = norm_top_k_prob + self.decoder_sparse_step = decoder_sparse_step + self.mlp_only_layers = mlp_only_layers + self.router_aux_loss_coefficient = router_aux_loss_coefficient + + def get_config(self): + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "num_layers": self.num_layers, + "num_query_heads": self.num_query_heads, + "head_dim": self.head_dim, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "moe_intermediate_dim": self.moe_intermediate_dim, + "rope_max_wavelength": self.rope_max_wavelength, + "num_key_value_heads": self.num_key_value_heads, + "rope_scaling_factor": self.rope_scaling_factor, + "layer_norm_epsilon": self.layer_norm_epsilon, + "dropout": self.dropout, + "tie_word_embeddings": self.tie_word_embeddings, + "sliding_window_size": self.sliding_window_size, + "num_experts": self.num_experts, + "top_k": self.top_k, + "norm_top_k_prob": self.norm_top_k_prob, + "decoder_sparse_step": self.decoder_sparse_step, + "mlp_only_layers": self.mlp_only_layers, + } + ) + return config + + @staticmethod + def get_layout_map( + device_mesh, + model_parallel_dim_name="model", + data_parallel_dim_name="batch", + ): + """Get a `keras.distribution.LayoutMap` for model parallel distribution. + + The returned `LayoutMap` contains the sharding spec for the Qwen3Moe + backbone weights, so that you can use it to distribute weights across + the accelerators. + + Example: + ``` + # Feel free to change the mesh shape to balance data and model + # parallelism + mesh = keras.distribution.DeviceMesh( + shape=(1, 8), + axis_names=('batch', 'model'), + devices=keras.distribution.list_devices(), + ) + layout_map = Qwen3MoeBackbone.get_layout_map( + mesh, + model_parallel_dim_name="model", + ) + + distribution = keras.distribution.ModelParallel( + layout_map=layout_map, + batch_dim_name='batch', + ) + + with distribution.scope(): + qwen_moe_model = keras_hub.models.Qwen3MoeBackbone.from_preset() + ``` + + To see how the layout map was applied, load the model then run + (for one decoder block): + ``` + embedding_layer = qwen_moe_model.backbone.get_layer("token_embedding") + decoder_block_1 = qwen_moe_model.backbone.get_layer( + 'transformer_layer_0' + ) + for variable in embedding_layer.weights + decoder_block_1.weights: + print( + f'{variable.path:<58} {str(variable.shape):<16} ' + f'{str(variable.value.sharding.spec)}' + ) + ``` + + Args: + device_mesh: The `keras.distribution.DeviceMesh` instance for + distribution. + model_parallel_dim_name: The axis name of the device mesh, where + the weights should be partition on. + data_parallel_dim_name: The axis name of the device mesh, where + the data should be partition on. + Return: + `keras.distribution.LayoutMap` that contains the sharding spec + for all the model weights. + """ + # The weight path and shape of the Llama backbone is like below + # token_embedding/embeddings (128256, 2048) + # repeat block for decoder + # transformer_layer_0/self_attention/query/kernel (2048, 32, 64) + # transformer_layer_0/self_attention/key/kernel (2048, 8, 64) + # transformer_layer_0/self_attention/value/kernel (2048, 8, 64) + # transformer_layer_0/self_attention/attention_output/kernel + # (32, 64, 2048) + # transformer_layer_0/self_attention_layernorm/scale (2048,) + # transformer_layer_0/feedforward_intermediate_dense/kernel + # (2048, 8192) + # transformer_layer_0/feedforward_gate_dense/kernel (2048, 8192) + # transformer_layer_0/feedforward_output_dense/kerne (8192, 2048) + # transformer_layer_0/feedforward_layernorm/scale (2048,) + + if not isinstance(device_mesh, keras.distribution.DeviceMesh): + raise ValueError( + "Invalid device_mesh type. Expected " + f"`keras.distribution.Device`, got {type(device_mesh)}" + ) + if model_parallel_dim_name not in device_mesh.axis_names: + raise ValueError( + f"{model_parallel_dim_name} is not found in the " + f"device_mesh.axis_names. {device_mesh.axis_name=}" + ) + if data_parallel_dim_name not in device_mesh.axis_names: + raise ValueError( + f"{data_parallel_dim_name} is not found in the " + f"device_mesh.axis_names. {device_mesh.axis_name=}" + ) + # Note that it is possible to further config the mesh to be 3D, eg + # (data, seq, model). We leave it as 2D for now for simplicity. + data_dim = data_parallel_dim_name + model_dim = model_parallel_dim_name + # The sharding config is based on the Gemma team training config. + # See https://arxiv.org/abs/2403.08295 + layout_map = keras.distribution.LayoutMap(device_mesh) + layout_map["token_embedding/embeddings"] = (model_dim, data_dim) + layout_map[ + "transformer_layer.*self_attention.*(query|key|value).kernel" + ] = ( + model_dim, + data_dim, + None, + ) + layout_map["transformer_layer.*attention_output.kernel"] = ( + model_dim, + None, + data_dim, + ) + layout_map[ + "transformer_layer.*feedforward_intermediate_dense.kernel" + ] = ( + data_dim, + model_dim, + ) + layout_map["transformer_layer.*feedforward_gate_dense.kernel"] = ( + data_dim, + model_dim, + ) + layout_map["transformer_layer.*feedforward_output_dense.kernel"] = ( + model_dim, + data_dim, + ) + + return layout_map diff --git a/keras_hub/src/models/qwen3_moe/qwen3_moe_backbone_test.py b/keras_hub/src/models/qwen3_moe/qwen3_moe_backbone_test.py new file mode 100644 index 0000000000..cdfd5440e9 --- /dev/null +++ b/keras_hub/src/models/qwen3_moe/qwen3_moe_backbone_test.py @@ -0,0 +1,69 @@ +import pytest +from keras import ops + +from keras_hub.src.models.qwen3_moe.qwen3_moe_backbone import Qwen3MoeBackbone +from keras_hub.src.tests.test_case import TestCase + + +class Qwen3MoeBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "vocabulary_size": 20, + "num_layers": 2, + "num_query_heads": 4, + "num_key_value_heads": 2, + "hidden_dim": 16, + "intermediate_dim": 32, + "head_dim": 2, + "moe_intermediate_dim": 16, + "num_experts": 4, + "top_k": 2, + "norm_top_k_prob": True, + "decoder_sparse_step": 1, + "layer_norm_epsilon": 1e-6, + "rope_max_wavelength": 10000, + "rope_scaling_factor": 1.0, + "dropout": 0.0, + "sliding_window_size": 4096, + "router_aux_loss_coefficient": 0.01, + "tie_word_embeddings": False, + "mlp_only_layers": [], + "dtype": "float32", # Explicitly set dtype to avoid mixed precision + } + self.input_data = { + "token_ids": ops.ones((2, 7), dtype="int32"), + "padding_mask": ops.ones((2, 7), dtype="int32"), + } + + def test_backbone_basics(self): + self.run_backbone_test( + cls=Qwen3MoeBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 7, 16), + run_quantization_check=False, + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=Qwen3MoeBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + def test_architecture_characteristics(self): + model = Qwen3MoeBackbone(**self.init_kwargs) + expected_params = 7768 + self.assertEqual(model.count_params(), expected_params) + expected_layers = 6 + self.assertEqual(len(model.layers), expected_layers) + + def test_auxiliary_loss(self): + model = Qwen3MoeBackbone(**self.init_kwargs) + _ = model(self.input_data, training=True) + self.assertTrue( + len(model.losses) > 0, "Auxiliary losses should be present" + ) + for loss in model.losses: + self.assertGreater(loss, 0.0, "Auxiliary loss should be positive") diff --git a/keras_hub/src/models/qwen3_moe/qwen3_moe_decoder.py b/keras_hub/src/models/qwen3_moe/qwen3_moe_decoder.py new file mode 100644 index 0000000000..bb523c8fee --- /dev/null +++ b/keras_hub/src/models/qwen3_moe/qwen3_moe_decoder.py @@ -0,0 +1,577 @@ +import keras +from keras import ops + +from keras_hub.src.layers.modeling.transformer_layer_utils import ( + compute_causal_mask, +) +from keras_hub.src.layers.modeling.transformer_layer_utils import ( + merge_padding_and_attention_mask, +) +from keras_hub.src.models.qwen3_moe.qwen3_moe_attention import Qwen3MoeAttention +from keras_hub.src.models.qwen3_moe.qwen3_moe_layernorm import Qwen3MoeLayerNorm +from keras_hub.src.utils.keras_utils import clone_initializer + + +def compute_load_balancing_loss( + router_logits, num_experts, top_k, attention_mask=None +): + """ + Compute the load balancing auxiliary loss for a single MoE layer. + + Args: + router_logits: Tensor of shape (batch_size * seq_len, num_experts). + num_experts: Integer, total number of experts. + top_k: Integer, number of experts to select per token. + attention_mask: Tensor of shape (batch_size, seq_len, seq_len), + optional mask for padding. + + Returns: + Scalar tensor representing the auxiliary loss. + """ + # Compute routing probabilities + routing_weights = ops.softmax( + router_logits, axis=-1 + ) # Shape: (batch_size * seq_len, num_experts) + + # Get top-k experts + _, selected_experts = ops.top_k( + routing_weights, k=top_k + ) # Shape: (batch_size * seq_len, top_k) + + # Create one-hot encoding for selected experts + expert_mask = ops.one_hot( + selected_experts, num_experts + ) # Shape: (batch_size * seq_len, top_k, num_experts) + + if attention_mask is not None: + # Convert attention mask to (batch_size, seq_len) + batch_size, seq_len, _ = ops.shape(attention_mask) + flat_mask = ops.any(attention_mask, axis=-1) + flat_mask = ops.reshape( + flat_mask, (-1,) + ) # Shape: (batch_size * seq_len,) + # Expand mask for broadcasting + expert_attention_mask = ops.expand_dims( + flat_mask, axis=-1 + ) # Shape: (batch_size * seq_len, 1) + expert_attention_mask = ops.cast(expert_attention_mask, dtype="float32") + + # Compute masked means + tokens_per_expert = ops.sum( + expert_mask * expert_attention_mask[:, None, :], axis=0 + ) / ops.maximum( + ops.sum(expert_attention_mask[:, None, :], axis=0), 1e-9 + ) # Shape: (top_k, num_experts) + router_prob_per_expert = ops.sum( + routing_weights * expert_attention_mask, axis=0 + ) / ops.maximum( + ops.sum(expert_attention_mask, axis=0), 1e-9 + ) # Shape: (num_experts,) + else: + # Unmasked means + tokens_per_expert = ops.mean( + expert_mask, axis=0 + ) # Shape: (top_k, num_experts) + router_prob_per_expert = ops.mean( + routing_weights, axis=0 + ) # Shape: (num_experts,) + + # Average over top_k dimension if necessary + tokens_per_expert = ops.mean( + tokens_per_expert, axis=0 + ) # Shape: (num_experts,) + + # Compute the loss + overall_loss = ops.sum(tokens_per_expert * router_prob_per_expert) + return overall_loss * num_experts + + +class Qwen3MoeMLP(keras.layers.Layer): + def __init__( + self, + intermediate_dim, + hidden_dim, + activation_fn="silu", + layer_norm_epsilon=1e-5, + kernel_initializer="glorot_uniform", + **kwargs, + ): + super().__init__(**kwargs) + self.intermediate_dim = intermediate_dim + self.hidden_dim = hidden_dim + self.activation_fn = activation_fn + self.kernel_initializer = kernel_initializer + self.layer_norm_epsilon = layer_norm_epsilon + + def build(self, decoder_sequence_shape): + # Feedforward layers. + self._feedforward_intermediate_dense = keras.layers.Dense( + self.intermediate_dim, + kernel_initializer=clone_initializer(self.kernel_initializer), + use_bias=False, + dtype=self.dtype_policy, + name="feedforward_intermediate_dense", + ) + self._feedforward_intermediate_dense.build(decoder_sequence_shape) + + self._feedforward_gate_dense = keras.layers.Dense( + self.intermediate_dim, + kernel_initializer=clone_initializer(self.kernel_initializer), + use_bias=False, + dtype=self.dtype_policy, + name="feedforward_gate_dense", + ) + self._feedforward_gate_dense.build(decoder_sequence_shape) + + self._feedforward_output_dense = keras.layers.Dense( + self.hidden_dim, + kernel_initializer=clone_initializer(self.kernel_initializer), + use_bias=False, + dtype=self.dtype_policy, + name="feedforward_output_dense", + ) + + self._feedforward_output_dense.build( + self._feedforward_gate_dense.compute_output_shape( + decoder_sequence_shape + ) + ) + + self.activation = keras.activations.get(self.activation_fn) + self.built = True + + def call(self, x): + gate_output = self._feedforward_gate_dense(x) + + # Note that we run the activation function in full 32-bit + # precision since this is what `torch.nn.functional.silu` + # does. Internally, `torch.nn.functional.silu` converts the + # inputs to float32, computes SiLU, and converts the outputs + # back to compute dtype. + # CPU Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cpu/Activation.cpp#L1221-L1235 # noqa: E501 + # CUDA Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cuda/ActivationSiluKernel.cu # noqa: E501 + gate_output = ops.cast(gate_output, "float32") + gate_output = self.activation(gate_output) + gate_output = ops.cast(gate_output, self.compute_dtype) + + x = self._feedforward_intermediate_dense(x) + + return self._feedforward_output_dense(ops.multiply(x, gate_output)) + + +class Qwen3MoeExperts(keras.layers.Layer): + """Batched Experts Layer""" + + def __init__( + self, + num_experts, + hidden_dim, + intermediate_dim, + activation_fn="silu", + kernel_initializer="glorot_uniform", + **kwargs, + ): + super().__init__(**kwargs) + self.num_experts = num_experts + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + self.activation = keras.activations.get(activation_fn) + self.kernel_initializer = kernel_initializer + + def build(self, _): + self._expert_feedforward_gate_dense = self.add_weight( + shape=( + self.num_experts, + self.hidden_dim, + 2 * self.intermediate_dim, + ), + initializer=self.kernel_initializer, + trainable=True, + dtype=self.variable_dtype, + name="expert_feedforward_gate_dense", + ) + + self._expert_feedforward_output_dense = self.add_weight( + shape=(self.num_experts, self.intermediate_dim, self.hidden_dim), + initializer=self.kernel_initializer, + trainable=True, + dtype=self.variable_dtype, + name="expert_feedforward_output_dense", + ) + + self.built = True + + def call(self, hidden_states): + gate_up = ops.einsum( + "th,ehm->etm", hidden_states, self._expert_feedforward_gate_dense + ) + gate, up = ops.split(gate_up, 2, axis=-1) + hidden = up * self.activation(gate) + out = ops.einsum( + "eti,eih->eth", hidden, self._expert_feedforward_output_dense + ) + return out + + +class Qwen3SparseMoeBlock(keras.layers.Layer): + """Qwen-3 Sparse Moe Block""" + + def __init__( + self, + hidden_dim, + moe_intermediate_dim, + num_experts, + top_k, + norm_top_k_prob, + kernel_initializer="glorot_uniform", + layer_norm_epsilon=1e-5, + router_aux_loss_coefficient=0.01, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_dim = hidden_dim + self.intermediate_dim = moe_intermediate_dim + self.num_experts = num_experts + self.top_k = top_k + self.norm_top_k_prob = norm_top_k_prob + self.kernel_initializer = kernel_initializer + self.layer_norm_epsilon = layer_norm_epsilon + self.router_aux_loss_coefficient = router_aux_loss_coefficient + + def build(self, decoder_sequence_shape): + self._sparse_feedforward_gate_dense = keras.layers.Dense( + self.num_experts, + use_bias=False, + kernel_initializer=self.kernel_initializer, + name="sparse_feedforward_gate_dense", + dtype=self.dtype_policy, + ) + self._sparse_feedforward_gate_dense.build(decoder_sequence_shape) + + # NOTE: Experts are implemented as a single layer to enable efficient + # batched computation. Implementing each expert individually is + # currently avoided due to the lack of `ragged_dot` support in the + # Keras ops API, which would make individual implementations unstable + # and prone to bugs. + self.expert_bank = Qwen3MoeExperts( + num_experts=self.num_experts, + hidden_dim=self.hidden_dim, + intermediate_dim=self.intermediate_dim, + kernel_initializer=self.kernel_initializer, + name="experts", + dtype=self.dtype_policy, + ) + self.expert_bank.build(decoder_sequence_shape) + + self.built = True + + def call(self, hidden_states, attention_mask=None, training=None): + batch_size, seq_len, _ = ops.shape(hidden_states) + hidden_states_flattened = ops.reshape( + hidden_states, (-1, self.hidden_dim) + ) + + router_logits = self._sparse_feedforward_gate_dense( + hidden_states_flattened + ) + router_probs = ops.softmax(router_logits, axis=-1) + + top_p, top_i = ops.top_k(router_probs, k=self.top_k) + if self.norm_top_k_prob: + top_p = top_p / ops.sum(top_p, axis=-1, keepdims=True) + + one_hot = ops.one_hot(top_i, self.num_experts) + one_hot = ops.cast(one_hot, top_p.dtype) + routing_full = ops.sum(one_hot * top_p[..., None], axis=1) + routing_full = ops.transpose(routing_full, (1, 0)) + routing_full = ops.cast(routing_full, hidden_states_flattened.dtype) + + expert_out = self.expert_bank(hidden_states_flattened) + + weighted_out = expert_out * routing_full[:, :, None] + expert_contribution = ops.sum(weighted_out, axis=0) + + out = ops.reshape( + expert_contribution, (batch_size, seq_len, self.hidden_dim) + ) + + # Compute and add auxiliary loss during training + if training: + aux_loss = compute_load_balancing_loss( + router_logits=router_logits, + num_experts=self.num_experts, + top_k=self.top_k, + attention_mask=attention_mask, + ) + self.add_loss(self.router_aux_loss_coefficient * aux_loss) + + return out, router_logits + + +class Qwen3MoeTransformerDecoder(keras.layers.Layer): + def __init__( + self, + intermediate_dim, + num_query_heads, + head_dim, + num_key_value_heads, + moe_intermediate_dim, + num_experts, + top_k, + norm_top_k_prob, + is_sparse_mlp=False, + rope_max_wavelength=10000, + rope_scaling_factor=1.0, + activation="silu", + layer_norm_epsilon=1e-5, + kernel_initializer="glorot_uniform", + dropout=0, + sliding_window_size=4096, + layer_index=0, + router_aux_loss_coefficient=0.001, + **kwargs, + ): + super().__init__(**kwargs) + self.intermediate_dim = intermediate_dim + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + self.rope_max_wavelength = rope_max_wavelength + self.rope_scaling_factor = rope_scaling_factor + self.dropout = dropout + self.sliding_window_size = sliding_window_size + self.activation = keras.activations.get(activation) + self.layer_norm_epsilon = layer_norm_epsilon + self.kernel_initializer = keras.initializers.get(kernel_initializer) + self.layer_index = layer_index + self.moe_intermediate_dim = moe_intermediate_dim + self.head_dim = head_dim + self.num_experts = num_experts + self.top_k = top_k + self.norm_top_k_prob = norm_top_k_prob + self.is_sparse_mlp = is_sparse_mlp + self.router_aux_loss_coefficient = router_aux_loss_coefficient + self.supports_masking = True + + def build(self, decoder_sequence_shape): + self._decoder_sequence_shape = decoder_sequence_shape + self.hidden_dim = decoder_sequence_shape[-1] + + # Self attention layer. + self._self_attention_layer = Qwen3MoeAttention( + num_query_heads=self.num_query_heads, + head_dim=self.head_dim, + num_key_value_heads=self.num_key_value_heads, + rope_max_wavelength=self.rope_max_wavelength, + rope_scaling_factor=self.rope_scaling_factor, + kernel_initializer=clone_initializer(self.kernel_initializer), + dropout=self.dropout, + sliding_window_size=self.sliding_window_size, + layer_index=self.layer_index, + name="self_attention", + dtype=self.dtype_policy, + ) + self._self_attention_layer.build(decoder_sequence_shape) + + self._self_attention_layernorm = Qwen3MoeLayerNorm( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="self_attention_layernorm", + ) + + self._self_attention_layernorm.build(decoder_sequence_shape) + self._self_attention_dropout = keras.layers.Dropout( + rate=self.dropout, + dtype=self.dtype_policy, + name="self_attention_dropout", + ) + + # Feedforward layers. + if self.is_sparse_mlp: + self.mlp = Qwen3SparseMoeBlock( + hidden_dim=self.hidden_dim, + moe_intermediate_dim=self.moe_intermediate_dim, + num_experts=self.num_experts, + top_k=self.top_k, + norm_top_k_prob=self.norm_top_k_prob, + router_aux_loss_coefficient=self.router_aux_loss_coefficient, + kernel_initializer=self.kernel_initializer, + dtype=self.dtype_policy, + ) + self.mlp.build(decoder_sequence_shape) + else: + self.mlp = Qwen3MoeMLP( + intermediate_dim=self.intermediate_dim, + hidden_dim=self.hidden_dim, + dtype=self.dtype_policy, + ) + self.mlp.build(decoder_sequence_shape) + + self._feedforward_layernorm = Qwen3MoeLayerNorm( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="feedforward_layernorm", + ) + self._feedforward_layernorm.build(decoder_sequence_shape) + + self.built = True + + def call( + self, + decoder_sequence, + decoder_padding_mask=None, + decoder_attention_mask=None, + self_attention_cache=None, + self_attention_cache_update_index=None, + training=None, + ): + """Forward pass for the decoder layer. + + Args: + decoder_sequence: Input tensor of shape [batch_size, seq_length, + hidden_size]. + decoder_padding_mask: Mask tensor for padding tokens. + decoder_attention_mask: Additional attention mask. + self_attention_cache: Optional cached key and value tensors for + self-attention. + self_attention_cache_update_index: Index at which to update the + cache. + training: Boolean indicating whether in training mode. + + Returns: + decoder_output: Output tensor after applying transformer decoder + block. + self_attention_cache: Updated cache tensors (if cache is provided). + """ + self_attention_mask = self._compute_self_attention_mask( + decoder_sequence=decoder_sequence, + decoder_padding_mask=decoder_padding_mask, + decoder_attention_mask=decoder_attention_mask, + self_attention_cache=self_attention_cache, + self_attention_cache_update_index=self_attention_cache_update_index, + ) + residual = decoder_sequence + + x = self._self_attention_layernorm(decoder_sequence) + + # Self attention block. + x = self._self_attention_layer( + hidden_states=x, + attention_mask=self_attention_mask, + cache=self_attention_cache, + cache_update_index=self_attention_cache_update_index, + ) + + if self_attention_cache is not None: + x, self_attention_cache = x + + x = self._self_attention_dropout(x, training=training) + + x = x + residual + residual = x + + x = self._feedforward_layernorm(x) + if isinstance(self.mlp, Qwen3SparseMoeBlock): + x = self.mlp( + x, training=training, attention_mask=self_attention_mask + ) + else: + x = self.mlp(x) + + if isinstance(x, tuple): + x, _ = x + + x = ops.cast(x, ops.dtype(residual)) + decoder_output = x + residual + + output = (decoder_output,) + + if self_attention_cache is not None: + output += (self_attention_cache,) + + return output[0] if len(output) == 1 else output + + def _compute_self_attention_mask( + self, + decoder_sequence, + decoder_padding_mask, + decoder_attention_mask, + self_attention_cache, + self_attention_cache_update_index, + ): + """Computes the self-attention mask combining causal, padding and + attention masks. + + Args: + decoder_sequence: Input tensor. + decoder_padding_mask: Mask tensor for padding tokens. + decoder_attention_mask: Additional attention mask. + self_attention_cache: Optional cached key and value tensors. + self_attention_cache_update_index: Index at which to update the + cache. + + Returns: + Combined attention mask tensor. + """ + decoder_mask = merge_padding_and_attention_mask( + decoder_sequence, decoder_padding_mask, decoder_attention_mask + ) + batch_size = ops.shape(decoder_sequence)[0] + input_length = output_length = ops.shape(decoder_sequence)[1] + # We need to handle a rectangular causal mask when doing cached + # decoding. For generative inference, `decoder_sequence` will + # generally be length 1, and `cache` will be the full generation length. + if self_attention_cache is not None: + input_length = ops.shape(self_attention_cache)[2] + + cache_update_index = ( + 0 + if self_attention_cache_update_index is None + else self_attention_cache_update_index + ) + + causal_mask = compute_causal_mask( + batch_size, input_length, output_length, cache_update_index + ) + + return ( + ops.minimum(decoder_mask, causal_mask) + if decoder_mask is not None + else causal_mask + ) + + def compute_output_shape(self, decoder_sequence_shape): + """Computes the output shape of the layer. + + Args: + decoder_sequence_shape: Shape of the decoder sequence input. + + Returns: + Output shape, which is the same as the input shape. + """ + return decoder_sequence_shape + + def get_config(self): + """Returns the config of the layer. + + Returns: + Dictionary containing the parameters used to initialize this layer. + """ + config = super().get_config() + config.update( + { + "num_query_heads": self.num_query_heads, + "intermediate_dim": self.intermediate_dim, + "moe_intermediate_dim": self.moe_intermediate_dim, + "rope_max_wavelength": self.rope_max_wavelength, + "num_key_value_heads": self.num_key_value_heads, + "rope_scaling_factor": self.rope_scaling_factor, + "layer_norm_epsilon": self.layer_norm_epsilon, + "dropout": self.dropout, + "sliding_window_size": self.sliding_window_size, + "num_experts": self.num_experts, + "top_k": self.top_k, + "norm_top_k_prob": self.norm_top_k_prob, + "router_aux_loss_coefficient": self.router_aux_loss_coefficient, + } + ) + return config diff --git a/keras_hub/src/models/qwen3_moe/qwen3_moe_layernorm.py b/keras_hub/src/models/qwen3_moe/qwen3_moe_layernorm.py new file mode 100644 index 0000000000..c21da3cca6 --- /dev/null +++ b/keras_hub/src/models/qwen3_moe/qwen3_moe_layernorm.py @@ -0,0 +1,38 @@ +import keras +from keras import ops + + +class Qwen3MoeLayerNorm(keras.layers.Layer): + """A normalization layer for Qwen that implements RMS normalization.""" + + def __init__(self, hidden_dim=None, epsilon=1e-6, **kwargs): + super().__init__(**kwargs) + self.hidden_dim = hidden_dim + self.epsilon = epsilon + + def build(self, input_shape): + if self.hidden_dim: + dim = self.hidden_dim + else: + dim = input_shape[-1] + + self.scale = self.add_weight( + name="scale", + trainable=True, + shape=(dim,), + initializer="ones", + dtype=self.variable_dtype, + ) + self.built = True + + def call(self, x): + input_dtype = x.dtype + x = ops.cast(x, "float32") + var = ops.mean(ops.power(x, 2), axis=-1, keepdims=True) + x = x * ops.rsqrt(var + self.epsilon) + return ops.cast(x * self.scale, input_dtype) + + def get_config(self): + config = super().get_config() + config.update({"epsilon": self.epsilon}) + return config diff --git a/keras_hub/src/models/qwen3_moe/qwen3_moe_tokenizer.py b/keras_hub/src/models/qwen3_moe/qwen3_moe_tokenizer.py new file mode 100644 index 0000000000..a4e4cc946f --- /dev/null +++ b/keras_hub/src/models/qwen3_moe/qwen3_moe_tokenizer.py @@ -0,0 +1,46 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.qwen3_moe.qwen3_moe_backbone import Qwen3MoeBackbone +from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer + + +@keras_hub_export( + "keras_hub.tokenizers.Qwen3MoeTokenizer", +) +class Qwen3MoeTokenizer(BytePairTokenizer): + """Tokenizer for Qwen Moe model. + + This tokenizer implements byte-pair encoding (BPE) for Qwen models, + handling special tokens like BOS (beginning of sequence) and EOS (end of + sequence). + + Args: + vocabulary: Dictionary mapping tokens to token IDs, or path to + vocabulary file. + merges: List of BPE merges, or path to merges file. + bos_token: Beginning of sequence token. Defaults to None. + eos_token: End of sequence token. Defaults to "<|endoftext|>". + misc_special_tokens: Set of additional special tokens. Defaults to + empty set. + """ + + backbone_cls = Qwen3MoeBackbone + + def __init__( + self, + vocabulary=None, + merges=None, + **kwargs, + ): + # Add EOS token + eos_token = "<|endoftext|>" + self._add_special_token(eos_token, "end_token") + + self.start_token_id = None + self.start_token = None + self.pad_token_id = 0 + + super().__init__( + vocabulary=vocabulary, + merges=merges, + **kwargs, + ) diff --git a/keras_hub/src/utils/transformers/convert_qwen3_moe.py b/keras_hub/src/utils/transformers/convert_qwen3_moe.py new file mode 100644 index 0000000000..1d60977707 --- /dev/null +++ b/keras_hub/src/utils/transformers/convert_qwen3_moe.py @@ -0,0 +1,207 @@ +import numpy as np + +from keras_hub.src.models.qwen3_moe.qwen3_moe_backbone import Qwen3MoeBackbone +from keras_hub.src.utils.preset_utils import load_json + +backbone_cls = Qwen3MoeBackbone + + +def convert_backbone_config(transformers_config): + return { + "vocabulary_size": transformers_config["vocab_size"], + "hidden_dim": transformers_config["hidden_size"], + "head_dim": transformers_config["head_dim"], + "num_layers": transformers_config["num_hidden_layers"], + "num_query_heads": transformers_config["num_attention_heads"], + "num_key_value_heads": transformers_config["num_key_value_heads"], + "intermediate_dim": transformers_config["intermediate_size"], + "moe_intermediate_dim": transformers_config["moe_intermediate_size"], + "num_experts": transformers_config["num_experts"], + "top_k": transformers_config["num_experts_per_tok"], + "norm_top_k_prob": transformers_config["norm_topk_prob"], + "decoder_sparse_step": transformers_config["decoder_sparse_step"], + "layer_norm_epsilon": transformers_config["rms_norm_eps"], + "rope_max_wavelength": transformers_config["rope_theta"], + "use_sliding_window": transformers_config["use_sliding_window"], + "sliding_window_size": transformers_config["sliding_window"], + "output_router_logits": transformers_config["output_router_logits"], + "router_aux_loss_coefficient": transformers_config[ + "router_aux_loss_coef" + ], + } + + +def convert_weights(backbone, loader, transformers_config): + loader.port_weight( + keras_variable=backbone.get_layer("token_embedding").embeddings, + hf_weight_key="model.embed_tokens.weight", + ) + if not backbone.tie_word_embeddings: + loader.port_weight( + keras_variable=backbone.get_layer( + "token_embedding" + ).reverse_embeddings, + hf_weight_key="lm_head.weight", + # rearrange_pattern="b a -> a b", + hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), + ) + + def transpose_and_reshape(x, shape): + return np.reshape(np.transpose(x), shape) + + for i in range(backbone.num_layers): + decoder_layer = backbone.get_layer(f"transformer_layer_{i}") + + # Input layernorm + loader.port_weight( + keras_variable=decoder_layer._self_attention_layernorm.scale, + hf_weight_key=f"model.layers.{i}.input_layernorm.weight", + ) + + # Attention layers + + ## Query + loader.port_weight( + keras_variable=decoder_layer._self_attention_layer._query_dense.kernel, + hf_weight_key=f"model.layers.{i}.self_attn.q_proj.weight", + hook_fn=transpose_and_reshape, + ) + ## Key + loader.port_weight( + keras_variable=decoder_layer._self_attention_layer._key_dense.kernel, + hf_weight_key=f"model.layers.{i}.self_attn.k_proj.weight", + hook_fn=transpose_and_reshape, + ) + ## Value + loader.port_weight( + keras_variable=decoder_layer._self_attention_layer._value_dense.kernel, + hf_weight_key=f"model.layers.{i}.self_attn.v_proj.weight", + hook_fn=transpose_and_reshape, + ) + ## Output + loader.port_weight( + keras_variable=decoder_layer._self_attention_layer._output_dense.kernel, + hf_weight_key=f"model.layers.{i}.self_attn.o_proj.weight", + # rearrange_patterns="c (a b) -> a b c", + # rearrange_dims={"a": backbone.num_query_heads}, + hook_fn=transpose_and_reshape, + ) + + # MLP layers + if ( + (i not in backbone.mlp_only_layers) + and backbone.num_experts > 0 + and ((i + 1) % backbone.decoder_sparse_step == 0) + ): + # MoE layers + loader.port_weight( + keras_variable=decoder_layer.mlp._sparse_feedforward_gate_dense.kernel, + hf_weight_key=f"model.layers.{i}.mlp.gate.weight", + # rearrange_patterns="b a -> a b", + hook_fn=lambda hf_tensor, _: np.transpose( + hf_tensor, axes=(1, 0) + ), + ) + # Batched experts: gate_up_proj and down_proj + gate_up_proj_list = [] + down_proj_list = [] + for expert_idx in range(backbone.num_experts): + # Load gate_proj and up_proj for each expert + gate_proj = loader.get_tensor( + f"model.layers.{i}.mlp.experts.{expert_idx}.gate_proj.weight" + ) + up_proj = loader.get_tensor( + f"model.layers.{i}.mlp.experts.{expert_idx}.up_proj.weight" + ) + # Transpose to (hidden_dim, intermediate_dim) + gate_proj = np.transpose(gate_proj, axes=(1, 0)) + up_proj = np.transpose(up_proj, axes=(1, 0)) + # Concatenate gate_proj and up_proj along the last dimension + gate_up_proj = np.concatenate([gate_proj, up_proj], axis=-1) + gate_up_proj_list.append(gate_up_proj) + + # Load down_proj for each expert + down_proj = loader.get_tensor( + f"model.layers.{i}.mlp.experts.{expert_idx}.down_proj.weight" + ) + down_proj = np.transpose( + down_proj, axes=(1, 0) + ) # (intermediate_dim, hidden_dim) + down_proj_list.append(down_proj) + + # Stack the lists to create batched weights + gate_up_proj_batched = np.stack( + gate_up_proj_list, axis=0 + ) # (num_experts, hidden_dim, 2 * intermediate_dim) + down_proj_batched = np.stack( + down_proj_list, axis=0 + ) # (num_experts, intermediate_dim, hidden_dim) + + # Assign batched weights to expert_bank + decoder_layer.mlp.expert_bank._expert_feedforward_gate_dense.assign( + gate_up_proj_batched + ) + decoder_layer.mlp.expert_bank._expert_feedforward_output_dense.assign( + down_proj_batched + ) + else: + loader.port_weight( + keras_variable=decoder_layer._feedforward_intermediate_dense.kernel, + hf_weight_key=f"model.layers.{i}.mlp.up_proj.weight", + # rearrange_patterns="b a -> a b", + hook_fn=lambda hf_tensor, _: np.transpose( + hf_tensor, axes=(1, 0) + ), + ) + loader.port_weight( + keras_variable=decoder_layer._feedforward_output_dense.kernel, + hf_weight_key=f"model.layers.{i}.mlp.down_proj.weight", + # rearrange_patterns="b a -> a b", + hook_fn=lambda hf_tensor, _: np.transpose( + hf_tensor, axes=(1, 0) + ), + ) + loader.port_weight( + keras_variable=decoder_layer._feedforward_gate_dense.kernel, + hf_weight_key=f"model.layers.{i}.mlp.gate_proj.weight", + # rearrange_patterns="b a -> a b", + hook_fn=lambda hf_tensor, _: np.transpose( + hf_tensor, axes=(1, 0) + ), + ) + + # Feedforward layernorm + loader.port_weight( + keras_variable=decoder_layer._feedforward_layernorm.scale, + hf_weight_key=f"model.layers.{i}.post_attention_layernorm.weight", + ) + + # Final normalization layer + loader.port_weight( + keras_variable=backbone.get_layer("sequence_output_layernorm").scale, + hf_weight_key="model.norm.weight", + ) + + return backbone + + +def convert_tokenizer(cls, preset, **kwargs): + tokenizer_config = load_json(preset, "tokenizer.json") + vocab = tokenizer_config["model"]["vocab"] + merges = tokenizer_config["model"]["merges"] + merges = [" ".join(item) for item in merges] + + # Load all special tokens with the exception of "reserved" ones. + special_tokens = set() + for token in tokenizer_config["added_tokens"]: + if not token["content"].startswith("<|reserved_special_token_"): + vocab[token["content"]] = token["id"] + special_tokens.add(token["content"]) + + kwargs.update( + { + "unsplittable_tokens": list(special_tokens), + } + ) + + return cls(vocabulary=vocab, merges=merges, **kwargs) diff --git a/keras_hub/src/utils/transformers/preset_loader.py b/keras_hub/src/utils/transformers/preset_loader.py index 1c126bcbb1..1574d21a3b 100644 --- a/keras_hub/src/utils/transformers/preset_loader.py +++ b/keras_hub/src/utils/transformers/preset_loader.py @@ -14,6 +14,7 @@ from keras_hub.src.utils.transformers import convert_mixtral from keras_hub.src.utils.transformers import convert_pali_gemma from keras_hub.src.utils.transformers import convert_qwen +from keras_hub.src.utils.transformers import convert_qwen3_moe from keras_hub.src.utils.transformers import convert_qwen_moe from keras_hub.src.utils.transformers import convert_vit from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader @@ -50,6 +51,8 @@ def __init__(self, preset, config): self.converter = convert_mixtral elif model_type == "qwen2_moe": self.converter = convert_qwen_moe + elif model_type == "qwen3_moe": + self.converter = convert_qwen3_moe else: raise ValueError( "KerasHub has no converter for huggingface/transformers models " diff --git a/tools/checkpoint_conversion/convert_qwen3_moe_checkpoints.py b/tools/checkpoint_conversion/convert_qwen3_moe_checkpoints.py new file mode 100644 index 0000000000..95b4c91228 --- /dev/null +++ b/tools/checkpoint_conversion/convert_qwen3_moe_checkpoints.py @@ -0,0 +1,160 @@ +import os +import traceback + +os.environ["KERAS_BACKEND"] = "torch" +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # Hide any CUDA devices + +import numpy as np +import torch +from absl import app +from absl import flags + +device = torch.device("cpu") +# Force PyTorch to use CPU +torch.set_default_device(device) + +from keras import ops # noqa: E402 +from transformers import AutoModelForCausalLM # noqa: E402 +from transformers import AutoTokenizer # noqa: E402 + +import keras_hub # noqa: E402 + +PRESET_MAP = { + "qwen3_moe_3b_en": "Qwen/Qwen3-30B-A3B", +} + +FLAGS = flags.FLAGS +flags.DEFINE_string( + "preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}" +) + + +def test_model( + keras_hub_model, keras_hub_tokenizer, hf_model, hf_model_tokenizer +): + # First, test that the number of parameters match + keras_hub_params = keras_hub_model.count_params() + hf_params = hf_model.num_parameters() + assert keras_hub_params == hf_params + + # Test the outputs of both the models + hf_inputs = hf_model_tokenizer(["What is Keras?"], return_tensors="pt").to( + device + ) + hf_outputs = hf_model(**hf_inputs) + hf_output_logits = hf_outputs.logits.detach().cpu().float().numpy() + + keras_hub_preprocessor = keras_hub.models.QwenCausalLMPreprocessor( + keras_hub_tokenizer + ) + keras_hub_inputs = keras_hub_preprocessor( + ["What is Keras?"], sequence_length=5 + )[0] + keras_hub_inputs = {k: v.to(device) for k, v in keras_hub_inputs.items()} + + keras_hub_output = keras_hub_model(keras_hub_inputs) + keras_hub_logits = keras_hub_model.token_embedding( + keras_hub_output, reverse=True + ) + keras_hub_logits = ops.convert_to_numpy(keras_hub_logits) + + # High tolerence since bfloat16 is used as the default dtype for Qwen + + try: + np.testing.assert_allclose( + keras_hub_logits, hf_output_logits, atol=1e-4 + ) + except AssertionError as err: + print("\n") + print(traceback.format_exc()) + print(err.args[0]) + print("\n") + + +def test_tokenizer(keras_hub_tokenizer, hf_tokenizer): + hf_output = hf_tokenizer(["What is Keras?"], return_tensors="pt") + hf_output = hf_output["input_ids"].detach().cpu().numpy() + keras_hub_preprocessor = keras_hub.models.QwenCausalLMPreprocessor( + keras_hub_tokenizer + ) + keras_hub_output = keras_hub_preprocessor( + ["What is Keras?"], sequence_length=5 + ) + keras_hub_output = ops.convert_to_numpy(keras_hub_output[0]["token_ids"]) + + np.testing.assert_equal(keras_hub_output, hf_output) + + +def validate_output( + keras_hub_model, keras_hub_tokenizer, hf_model, hf_tokenizer +): + input_str = "What is Keras?" + length = 32 + + # KerasHub + preprocessor = keras_hub.models.Qwen3MoeCausalLMPreprocessor( + keras_hub_tokenizer + ) + qwen_moe_lm = keras_hub.models.Qwen3MoeCausalLM( + backbone=keras_hub_model, preprocessor=preprocessor + ) + + keras_output = qwen_moe_lm.generate([input_str], max_length=length) + keras_output = keras_output[0] + print("🔶 KerasHub output:", keras_output) + + # Transformers + hf_inputs = hf_tokenizer([input_str], return_tensors="pt").to(device) + outputs = hf_model.generate( + **hf_inputs, + max_length=length, # Match KerasHub's max_length + # do_sample=True, # Enable sampling (default in KerasHub for generate) + pad_token_id=hf_tokenizer.pad_token_id, + ) + hf_generated_text = hf_tokenizer.batch_decode( + outputs, skip_special_tokens=True + )[0] + print("🔶 Huggingface output:", hf_generated_text) + + +def main(_): + # === Get the preset name === + if FLAGS.preset not in PRESET_MAP.keys(): + raise ValueError( + f"Invalid preset {FLAGS.preset}. Must be one " + f"of {','.join(PRESET_MAP.keys())}" + ) + preset = FLAGS.preset + hf_preset = PRESET_MAP[preset] + + # === Load the Huggingface model === + hf_model = AutoModelForCausalLM.from_pretrained( + hf_preset, + device_map=device, + ) + hf_tokenizer = AutoTokenizer.from_pretrained(hf_preset, return_tensors="pt") + hf_model.eval() + + keras_hub_model = keras_hub.models.Qwen3MoeBackbone.from_preset( + f"hf://{hf_preset}" + ) + keras_hub_tokenizer = keras_hub.tokenizers.Qwen3MoeTokenizer.from_preset( + f"hf://{hf_preset}" + ) + + print("\n-> Huggingface model and tokenizer loaded") + + # === Check that the models and tokenizers outputs match === + test_tokenizer(keras_hub_tokenizer, hf_tokenizer) + test_model(keras_hub_model, keras_hub_tokenizer, hf_model, hf_tokenizer) + + # == Validate model.generate output == + # validate_output( + # keras_hub_model, keras_hub_tokenizer, hf_model, hf_tokenizer + # ) + print("\n-> Tests passed!") + + +if __name__ == "__main__": + flags.mark_flag_as_required("preset") + app.run(main)