From fd88dddde0b3c7932677ae1bee9bd6105d7fe3e2 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Tue, 3 Sep 2024 01:51:27 +0800 Subject: [PATCH 1/4] Update cohere_model.py --- python/mlc_llm/model/cohere/cohere_model.py | 35 +++++++++++++++++---- 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/python/mlc_llm/model/cohere/cohere_model.py b/python/mlc_llm/model/cohere/cohere_model.py index 180c60ba13..af11c6e663 100644 --- a/python/mlc_llm/model/cohere/cohere_model.py +++ b/python/mlc_llm/model/cohere/cohere_model.py @@ -4,7 +4,7 @@ """ import dataclasses -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union from tvm import te, tir from tvm.relax.frontend import nn @@ -32,6 +32,7 @@ class CohereConfig(ConfigBase): # pylint: disable=too-many-instance-attributes num_key_value_heads: int intermediate_size: int layer_norm_eps: float + use_qk_norm: bool position_embedding_base: int = 0 context_window_size: int = 0 prefill_chunk_size: int = 0 @@ -124,7 +125,17 @@ def __init__(self, config: CohereConfig): f"num_attention_heads({config.num_key_value_heads}) " "must be divisible by tensor_parallel_shards" ) + self.head_dim = config.head_dim + self.use_qk_norm = config.use_qk_norm + + if self.use_qk_norm: + self.q_norm = CohereNorm( + hidden_size=[self.num_q_heads, self.head_dim], eps=config.layer_norm_eps + ) + self.k_norm = CohereNorm( + hidden_size=[self.num_key_value_heads, self.head_dim], eps=config.layer_norm_eps + ) self.qkv_proj = nn.Linear( in_features=config.hidden_size, @@ -139,6 +150,13 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: # QKV Projection qkv = self.qkv_proj(hidden_states) qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d)) + + if self.use_qk_norm: + q, k, v = op.split(qkv, indices_or_sections=[h_q, h_q + h_kv], axis=2) + q = self.q_norm(q) + k = self.k_norm(k) + qkv = op.concat([q, k, v], dim=2) + # Attention output = op.reshape( paged_kv_cache.attention_with_fused_qkv(layer_id, qkv, self.num_q_heads), @@ -188,17 +206,22 @@ def _apply_parallel_residual(self, mlp_out, residual): class CohereNorm(nn.Module): def __init__( - self, normalized_shape: int, eps: float = 1e-5, dtype: Optional[str] = None + self, hidden_size: Optional[Union[int, list]] = None, eps: float = 1e-5, dtype: Optional[str] = None ) -> None: - super().__init__() - self.normalized_shape = normalized_shape + self.hidden_size = hidden_size self.eps = eps - self.weight = nn.Parameter((normalized_shape,), dtype=dtype) + if isinstance(hidden_size, int): + normalized_shape = [hidden_size] + elif isinstance(hidden_size, list): + normalized_shape = hidden_size + else: + raise ValueError("hidden_size must be an int or a list of ints") + self.weight = nn.Parameter(normalized_shape, dtype=dtype) def forward(self, x: Tensor) -> Tensor: return op.layer_norm( x, - normalized_shape=self.normalized_shape, + normalized_shape=self.hidden_size, weight=self.weight, bias=None, eps=self.eps, From 5247ca44794dd3557458fdbebe88119691f3af21 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Tue, 3 Sep 2024 01:59:16 +0800 Subject: [PATCH 2/4] fix lint --- python/mlc_llm/model/cohere/cohere_model.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/mlc_llm/model/cohere/cohere_model.py b/python/mlc_llm/model/cohere/cohere_model.py index af11c6e663..1c9df62021 100644 --- a/python/mlc_llm/model/cohere/cohere_model.py +++ b/python/mlc_llm/model/cohere/cohere_model.py @@ -128,7 +128,7 @@ def __init__(self, config: CohereConfig): self.head_dim = config.head_dim self.use_qk_norm = config.use_qk_norm - + if self.use_qk_norm: self.q_norm = CohereNorm( hidden_size=[self.num_q_heads, self.head_dim], eps=config.layer_norm_eps @@ -206,7 +206,10 @@ def _apply_parallel_residual(self, mlp_out, residual): class CohereNorm(nn.Module): def __init__( - self, hidden_size: Optional[Union[int, list]] = None, eps: float = 1e-5, dtype: Optional[str] = None + self, + hidden_size: Optional[Union[int, list]] = None, + eps: float = 1e-5, + dtype: Optional[str] = None ) -> None: self.hidden_size = hidden_size self.eps = eps From aeb0be3b07f1741f0524275cbaa19cc2193a2d6c Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Tue, 3 Sep 2024 02:00:48 +0800 Subject: [PATCH 3/4] fix lint --- python/mlc_llm/model/cohere/cohere_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mlc_llm/model/cohere/cohere_model.py b/python/mlc_llm/model/cohere/cohere_model.py index 1c9df62021..90478e3e04 100644 --- a/python/mlc_llm/model/cohere/cohere_model.py +++ b/python/mlc_llm/model/cohere/cohere_model.py @@ -209,7 +209,7 @@ def __init__( self, hidden_size: Optional[Union[int, list]] = None, eps: float = 1e-5, - dtype: Optional[str] = None + dtype: Optional[str] = None, ) -> None: self.hidden_size = hidden_size self.eps = eps From 892ae6b5243f74606fde6381db10c524d4970c9c Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Tue, 3 Sep 2024 02:02:36 +0800 Subject: [PATCH 4/4] fix lint --- python/mlc_llm/model/cohere/cohere_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mlc_llm/model/cohere/cohere_model.py b/python/mlc_llm/model/cohere/cohere_model.py index 90478e3e04..8fd7c349a3 100644 --- a/python/mlc_llm/model/cohere/cohere_model.py +++ b/python/mlc_llm/model/cohere/cohere_model.py @@ -113,7 +113,7 @@ def forward(self, x): # pylint: disable=invalid-name,missing-docstring -class CohereAttention(nn.Module): +class CohereAttention(nn.Module): # pylint: disable=too-many-instance-attributes def __init__(self, config: CohereConfig): self.num_q_heads = config.num_attention_heads // config.tensor_parallel_shards assert config.num_attention_heads % config.tensor_parallel_shards == 0, (