diff --git a/python/mlc_llm/model/baichuan/baichuan_model.py b/python/mlc_llm/model/baichuan/baichuan_model.py index cc5b3c28f6..fa76255262 100644 --- a/python/mlc_llm/model/baichuan/baichuan_model.py +++ b/python/mlc_llm/model/baichuan/baichuan_model.py @@ -10,7 +10,7 @@ from tvm.relax.frontend.nn import Tensor, op from mlc_llm import op as op_ext -from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.nn import PagedKVCache, RopeMode, create_generic_paged_kv_cache from mlc_llm.support import logging from mlc_llm.support import tensor_parallel as tp from mlc_llm.support.config import ConfigBase @@ -280,7 +280,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return create_generic_paged_kv_cache( attn_kind="mha", max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, @@ -309,7 +309,7 @@ def get_default_spec(self): }, "prefill": { "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -317,7 +317,7 @@ def get_default_spec(self): }, "decode": { "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -326,7 +326,7 @@ def get_default_spec(self): "batch_prefill": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -334,7 +334,7 @@ def get_default_spec(self): }, "batch_decode": { "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -342,7 +342,7 @@ def get_default_spec(self): }, "batch_verify": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", diff --git a/python/mlc_llm/model/chatglm3/chatglm3_model.py b/python/mlc_llm/model/chatglm3/chatglm3_model.py index c895ad6cd4..f1ef1a5bc6 100644 --- a/python/mlc_llm/model/chatglm3/chatglm3_model.py +++ b/python/mlc_llm/model/chatglm3/chatglm3_model.py @@ -10,7 +10,7 @@ from tvm.relax.frontend.nn import Tensor, op from mlc_llm import op as op_ext -from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.nn import PagedKVCache, RopeMode, create_generic_paged_kv_cache from mlc_llm.support import logging from mlc_llm.support import tensor_parallel as tp from mlc_llm.support.config import ConfigBase @@ -355,7 +355,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return create_generic_paged_kv_cache( attn_kind="mha", max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, @@ -384,7 +384,7 @@ def get_default_spec(self): }, "prefill": { "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -392,7 +392,7 @@ def get_default_spec(self): }, "decode": { "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -401,7 +401,7 @@ def get_default_spec(self): "batch_prefill": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -409,7 +409,7 @@ def get_default_spec(self): }, "batch_decode": { "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -417,7 +417,7 @@ def get_default_spec(self): }, "batch_verify": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", diff --git a/python/mlc_llm/model/cohere/cohere_model.py b/python/mlc_llm/model/cohere/cohere_model.py index 0f6d43823a..c79eb6111d 100644 --- a/python/mlc_llm/model/cohere/cohere_model.py +++ b/python/mlc_llm/model/cohere/cohere_model.py @@ -10,7 +10,7 @@ from tvm.relax.frontend.nn import Tensor, op from mlc_llm import op as op_ext -from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.nn import PagedKVCache, RopeMode, create_generic_paged_kv_cache from mlc_llm.support import logging from mlc_llm.support import tensor_parallel as tp from mlc_llm.support.config import ConfigBase @@ -324,7 +324,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return create_generic_paged_kv_cache( attn_kind="mha", max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, @@ -353,7 +353,7 @@ def get_default_spec(self): }, "prefill": { "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -361,7 +361,7 @@ def get_default_spec(self): }, "decode": { "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -370,7 +370,7 @@ def get_default_spec(self): "batch_prefill": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -378,7 +378,7 @@ def get_default_spec(self): }, "batch_decode": { "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -386,7 +386,7 @@ def get_default_spec(self): }, "batch_verify": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", diff --git a/python/mlc_llm/model/deepseek/deepseek_model.py b/python/mlc_llm/model/deepseek/deepseek_model.py index d3c931f399..9dc0cb5ed4 100644 --- a/python/mlc_llm/model/deepseek/deepseek_model.py +++ b/python/mlc_llm/model/deepseek/deepseek_model.py @@ -11,7 +11,7 @@ from tvm.relax.frontend.nn import Tensor, op from mlc_llm import op as op_ext -from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.nn import PagedKVCache, RopeMode, create_generic_paged_kv_cache from mlc_llm.nn.expert import MixtralExperts from mlc_llm.support import logging from mlc_llm.support import tensor_parallel as tp @@ -430,7 +430,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return create_generic_paged_kv_cache( attn_kind="mha", max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, @@ -459,7 +459,7 @@ def get_default_spec(self): }, "prefill": { "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -467,7 +467,7 @@ def get_default_spec(self): }, "decode": { "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -476,7 +476,7 @@ def get_default_spec(self): "batch_prefill": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -484,7 +484,7 @@ def get_default_spec(self): }, "batch_decode": { "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -492,7 +492,7 @@ def get_default_spec(self): }, "batch_verify": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", diff --git a/python/mlc_llm/model/deepseek_v2/deepseek_v2_model.py b/python/mlc_llm/model/deepseek_v2/deepseek_v2_model.py index 45ae4c3787..7a2c76b7c5 100644 --- a/python/mlc_llm/model/deepseek_v2/deepseek_v2_model.py +++ b/python/mlc_llm/model/deepseek_v2/deepseek_v2_model.py @@ -12,7 +12,7 @@ from tvm.relax.frontend.nn.llm import position_embedding from mlc_llm import op as op_ext -from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.nn import PagedKVCache, RopeMode, create_generic_paged_kv_cache from mlc_llm.nn.expert import MixtralExperts from mlc_llm.op import batch_matmul from mlc_llm.support import logging @@ -771,7 +771,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return create_generic_paged_kv_cache( attn_kind="mla", max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, @@ -802,7 +802,7 @@ def get_default_spec(self): }, "prefill": { "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -810,7 +810,7 @@ def get_default_spec(self): }, "extend": { "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -818,7 +818,7 @@ def get_default_spec(self): }, "decode": { "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -827,7 +827,7 @@ def get_default_spec(self): "batch_prefill": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -836,7 +836,7 @@ def get_default_spec(self): "batch_extend": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -844,7 +844,7 @@ def get_default_spec(self): }, "batch_decode": { "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -852,7 +852,7 @@ def get_default_spec(self): }, "batch_verify": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", diff --git a/python/mlc_llm/model/eagle/eagle_model.py b/python/mlc_llm/model/eagle/eagle_model.py index 2eefda7493..0877ce7ef2 100644 --- a/python/mlc_llm/model/eagle/eagle_model.py +++ b/python/mlc_llm/model/eagle/eagle_model.py @@ -11,7 +11,7 @@ from mlc_llm import op as op_ext from mlc_llm.model.llama.llama_model import LlamaAttention, LlamaConfig, LlamaFFN -from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.nn import PagedKVCache, RopeMode, create_generic_paged_kv_cache from mlc_llm.support import logging from mlc_llm.support import tensor_parallel as tp @@ -164,7 +164,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return create_generic_paged_kv_cache( attn_kind="mha", max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, @@ -201,7 +201,7 @@ def get_default_spec(self): }, "prefill_to_last_hidden_states": { "hidden_states": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -209,7 +209,7 @@ def get_default_spec(self): }, "decode_to_last_hidden_states": { "hidden_states": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -217,7 +217,7 @@ def get_default_spec(self): }, "batch_prefill_to_last_hidden_states": { "hidden_states": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -225,7 +225,7 @@ def get_default_spec(self): }, "batch_decode_to_last_hidden_states": { "hidden_states": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", diff --git a/python/mlc_llm/model/gemma/gemma_model.py b/python/mlc_llm/model/gemma/gemma_model.py index 26668030ea..d0e602db69 100644 --- a/python/mlc_llm/model/gemma/gemma_model.py +++ b/python/mlc_llm/model/gemma/gemma_model.py @@ -8,7 +8,7 @@ from tvm.relax.frontend.nn import Tensor, op from mlc_llm import op as op_ext -from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.nn import PagedKVCache, RopeMode, create_generic_paged_kv_cache from mlc_llm.support import logging from mlc_llm.support import tensor_parallel as tp from mlc_llm.support.config import ConfigBase @@ -305,7 +305,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return create_generic_paged_kv_cache( attn_kind="mha", max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, @@ -334,7 +334,7 @@ def get_default_spec(self): }, "prefill": { "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -342,7 +342,7 @@ def get_default_spec(self): }, "decode": { "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -351,7 +351,7 @@ def get_default_spec(self): "batch_prefill": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -359,7 +359,7 @@ def get_default_spec(self): }, "batch_decode": { "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -367,7 +367,7 @@ def get_default_spec(self): }, "batch_verify": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", diff --git a/python/mlc_llm/model/gemma3/gemma3_model.py b/python/mlc_llm/model/gemma3/gemma3_model.py index 3de24bc5ff..d3a302f788 100644 --- a/python/mlc_llm/model/gemma3/gemma3_model.py +++ b/python/mlc_llm/model/gemma3/gemma3_model.py @@ -9,7 +9,7 @@ from mlc_llm import op as op_ext from mlc_llm.model.gemma.gemma_model import GemmaEmbedding -from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.nn import PagedKVCache, RopeMode, create_generic_paged_kv_cache from mlc_llm.support import logging from mlc_llm.support import tensor_parallel as tp from mlc_llm.support.config import ConfigBase @@ -381,7 +381,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return create_generic_paged_kv_cache( attn_kind="mha", max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, @@ -410,7 +410,7 @@ def get_default_spec(self): }, "prefill": { "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -418,7 +418,7 @@ def get_default_spec(self): }, "decode": { "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -427,7 +427,7 @@ def get_default_spec(self): "batch_prefill": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -435,7 +435,7 @@ def get_default_spec(self): }, "batch_decode": { "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -443,7 +443,7 @@ def get_default_spec(self): }, "batch_verify": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -547,7 +547,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return create_generic_paged_kv_cache( attn_kind="mha", max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, @@ -580,7 +580,7 @@ def get_default_spec(self): "input_embed": nn.spec.Tensor( [1, "seq_len", self.language_model.hidden_size], self.dtype ), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -588,7 +588,7 @@ def get_default_spec(self): }, "decode": { "input_embed": nn.spec.Tensor([1, 1, self.language_model.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -599,7 +599,7 @@ def get_default_spec(self): [1, "seq_len", self.language_model.hidden_size], self.dtype ), "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -609,7 +609,7 @@ def get_default_spec(self): "input_embeds": nn.spec.Tensor( ["batch_size", 1, self.language_model.hidden_size], self.dtype ), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -619,7 +619,7 @@ def get_default_spec(self): "input_embeds": nn.spec.Tensor( [1, "seq_len", self.language_model.hidden_size], self.dtype ), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", diff --git a/python/mlc_llm/model/gpt2/gpt2_model.py b/python/mlc_llm/model/gpt2/gpt2_model.py index 3b6ba009ba..d376c46ff4 100644 --- a/python/mlc_llm/model/gpt2/gpt2_model.py +++ b/python/mlc_llm/model/gpt2/gpt2_model.py @@ -10,7 +10,7 @@ from tvm.relax.frontend.nn import Tensor, op from mlc_llm import op as op_ext -from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.nn import PagedKVCache, RopeMode, create_generic_paged_kv_cache from mlc_llm.support import logging from mlc_llm.support import tensor_parallel as tp from mlc_llm.support.config import ConfigBase @@ -300,7 +300,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return create_generic_paged_kv_cache( attn_kind="mha", max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, @@ -329,7 +329,7 @@ def get_default_spec(self): }, "prefill": { "input_embed": nn.spec.Tensor([1, "seq_len", self.n_embed], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -337,7 +337,7 @@ def get_default_spec(self): }, "decode": { "input_embed": nn.spec.Tensor([1, 1, self.n_embed], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -346,7 +346,7 @@ def get_default_spec(self): "batch_prefill": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.n_embed], self.dtype), "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -354,7 +354,7 @@ def get_default_spec(self): }, "batch_decode": { "input_embeds": nn.spec.Tensor(["batch_size", 1, self.n_embed], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -362,7 +362,7 @@ def get_default_spec(self): }, "batch_verify": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.n_embed], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", diff --git a/python/mlc_llm/model/gpt_bigcode/gpt_bigcode_model.py b/python/mlc_llm/model/gpt_bigcode/gpt_bigcode_model.py index 0fa41074fe..51de4aca98 100644 --- a/python/mlc_llm/model/gpt_bigcode/gpt_bigcode_model.py +++ b/python/mlc_llm/model/gpt_bigcode/gpt_bigcode_model.py @@ -10,7 +10,7 @@ from tvm.relax.frontend.nn import Tensor, op from mlc_llm import op as op_ext -from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.nn import PagedKVCache, RopeMode, create_generic_paged_kv_cache from mlc_llm.support import logging from mlc_llm.support import tensor_parallel as tp from mlc_llm.support.config import ConfigBase @@ -267,7 +267,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return create_generic_paged_kv_cache( attn_kind="mha", max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, @@ -296,7 +296,7 @@ def get_default_spec(self): }, "prefill": { "input_embed": nn.spec.Tensor([1, "seq_len", self.n_embd], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -304,7 +304,7 @@ def get_default_spec(self): }, "decode": { "input_embed": nn.spec.Tensor([1, 1, self.n_embd], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -313,7 +313,7 @@ def get_default_spec(self): "batch_prefill": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.n_embd], self.dtype), "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -321,7 +321,7 @@ def get_default_spec(self): }, "batch_decode": { "input_embeds": nn.spec.Tensor(["batch_size", 1, self.n_embd], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -329,7 +329,7 @@ def get_default_spec(self): }, "batch_verify": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.n_embd], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", diff --git a/python/mlc_llm/model/gpt_j/gpt_j_model.py b/python/mlc_llm/model/gpt_j/gpt_j_model.py index d00a79c48f..7d613c3416 100644 --- a/python/mlc_llm/model/gpt_j/gpt_j_model.py +++ b/python/mlc_llm/model/gpt_j/gpt_j_model.py @@ -12,7 +12,7 @@ from tvm.relax.frontend.nn import Tensor, op from mlc_llm import op as op_ext -from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.nn import PagedKVCache, RopeMode, create_generic_paged_kv_cache from mlc_llm.support import logging from mlc_llm.support import tensor_parallel as tp from mlc_llm.support.config import ConfigBase @@ -288,7 +288,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return create_generic_paged_kv_cache( attn_kind="mha", max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, @@ -319,7 +319,7 @@ def get_default_spec(self): }, "prefill": { "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -327,7 +327,7 @@ def get_default_spec(self): }, "decode": { "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -336,7 +336,7 @@ def get_default_spec(self): "batch_prefill": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -344,7 +344,7 @@ def get_default_spec(self): }, "batch_decode": { "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -352,7 +352,7 @@ def get_default_spec(self): }, "batch_verify": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", diff --git a/python/mlc_llm/model/gpt_neox/gpt_neox_model.py b/python/mlc_llm/model/gpt_neox/gpt_neox_model.py index aae63d7271..74cacad172 100644 --- a/python/mlc_llm/model/gpt_neox/gpt_neox_model.py +++ b/python/mlc_llm/model/gpt_neox/gpt_neox_model.py @@ -11,7 +11,7 @@ from tvm.relax.frontend.nn import Tensor, op from mlc_llm import op as op_ext -from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.nn import PagedKVCache, RopeMode, create_generic_paged_kv_cache from mlc_llm.support import tensor_parallel as tp from mlc_llm.support.config import ConfigBase from mlc_llm.support.style import bold @@ -330,7 +330,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return create_generic_paged_kv_cache( attn_kind="mha", max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, @@ -360,7 +360,7 @@ def get_default_spec(self): }, "prefill": { "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -368,7 +368,7 @@ def get_default_spec(self): }, "decode": { "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -377,7 +377,7 @@ def get_default_spec(self): "batch_prefill": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -385,7 +385,7 @@ def get_default_spec(self): }, "batch_decode": { "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -393,7 +393,7 @@ def get_default_spec(self): }, "batch_verify": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", diff --git a/python/mlc_llm/model/internlm/internlm_model.py b/python/mlc_llm/model/internlm/internlm_model.py index 03a6c1730b..02f1742265 100644 --- a/python/mlc_llm/model/internlm/internlm_model.py +++ b/python/mlc_llm/model/internlm/internlm_model.py @@ -10,7 +10,7 @@ from tvm.relax.frontend.nn import Tensor, op from mlc_llm import op as op_ext -from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.nn import PagedKVCache, RopeMode, create_generic_paged_kv_cache from mlc_llm.support import logging from mlc_llm.support import tensor_parallel as tp from mlc_llm.support.config import ConfigBase @@ -291,7 +291,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return create_generic_paged_kv_cache( attn_kind="mha", max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, @@ -320,7 +320,7 @@ def get_default_spec(self): }, "prefill": { "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -328,7 +328,7 @@ def get_default_spec(self): }, "decode": { "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -337,7 +337,7 @@ def get_default_spec(self): "batch_prefill": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -345,7 +345,7 @@ def get_default_spec(self): }, "batch_decode": { "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -353,7 +353,7 @@ def get_default_spec(self): }, "batch_verify": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", diff --git a/python/mlc_llm/model/internlm2/internlm2_model.py b/python/mlc_llm/model/internlm2/internlm2_model.py index 78a080e43e..8bff50376e 100644 --- a/python/mlc_llm/model/internlm2/internlm2_model.py +++ b/python/mlc_llm/model/internlm2/internlm2_model.py @@ -10,7 +10,7 @@ from tvm.relax.frontend.nn import Tensor, op from mlc_llm import op as op_ext -from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.nn import PagedKVCache, RopeMode, create_generic_paged_kv_cache from mlc_llm.support import logging from mlc_llm.support import tensor_parallel as tp from mlc_llm.support.config import ConfigBase @@ -297,7 +297,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return create_generic_paged_kv_cache( attn_kind="mha", max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, @@ -326,7 +326,7 @@ def get_default_spec(self): }, "prefill": { "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -334,7 +334,7 @@ def get_default_spec(self): }, "decode": { "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -343,7 +343,7 @@ def get_default_spec(self): "batch_prefill": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -351,7 +351,7 @@ def get_default_spec(self): }, "batch_decode": { "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -359,7 +359,7 @@ def get_default_spec(self): }, "batch_verify": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", diff --git a/python/mlc_llm/model/llama/llama_model.py b/python/mlc_llm/model/llama/llama_model.py index 24db8aa06d..214dddcdcc 100644 --- a/python/mlc_llm/model/llama/llama_model.py +++ b/python/mlc_llm/model/llama/llama_model.py @@ -10,7 +10,7 @@ from tvm.relax.frontend.nn import Tensor, op from mlc_llm import op as op_ext -from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.nn import PagedKVCache, RopeMode, create_generic_paged_kv_cache from mlc_llm.support import logging from mlc_llm.support import tensor_parallel as tp from mlc_llm.support.config import ConfigBase @@ -398,7 +398,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return create_generic_paged_kv_cache( attn_kind="mha", max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, @@ -445,7 +445,7 @@ def get_default_spec(self): }, "prefill": { "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -453,7 +453,7 @@ def get_default_spec(self): }, "decode": { "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -461,7 +461,7 @@ def get_default_spec(self): }, "prefill_to_last_hidden_states": { "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -469,7 +469,7 @@ def get_default_spec(self): }, "decode_to_last_hidden_states": { "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -478,7 +478,7 @@ def get_default_spec(self): "batch_prefill": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -486,7 +486,7 @@ def get_default_spec(self): }, "batch_decode": { "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -494,7 +494,7 @@ def get_default_spec(self): }, "batch_verify": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -502,7 +502,7 @@ def get_default_spec(self): }, "batch_prefill_to_last_hidden_states": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -510,7 +510,7 @@ def get_default_spec(self): }, "batch_decode_to_last_hidden_states": { "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -518,7 +518,7 @@ def get_default_spec(self): }, "batch_verify_to_last_hidden_states": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", diff --git a/python/mlc_llm/model/llava/llava_model.py b/python/mlc_llm/model/llava/llava_model.py index 4bf6612d4e..e1855f1da5 100644 --- a/python/mlc_llm/model/llava/llava_model.py +++ b/python/mlc_llm/model/llava/llava_model.py @@ -16,7 +16,7 @@ from mlc_llm import op as op_ext from mlc_llm.model.model_preset import MODEL_PRESETS from mlc_llm.model.vision import CLIPVisionConfig, CLIPVisionModel, ImageProcessor -from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.nn import PagedKVCache, RopeMode, create_generic_paged_kv_cache from ...support.config import ConfigBase from ..llama.llama_model import LlamaConfig, LlamaForCausalLM @@ -229,7 +229,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return create_generic_paged_kv_cache( attn_kind="mha", max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, @@ -272,7 +272,7 @@ def get_default_spec(self): "input_embed": nn.spec.Tensor( [1, "seq_len", self.config.text_config.hidden_size], self.dtype ), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -282,7 +282,7 @@ def get_default_spec(self): "input_embed": nn.spec.Tensor( [1, 1, self.config.text_config.hidden_size], self.dtype ), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -293,7 +293,7 @@ def get_default_spec(self): [1, "seq_len", self.config.text_config.hidden_size], self.dtype ), "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -303,7 +303,7 @@ def get_default_spec(self): "input_embeds": nn.spec.Tensor( ["batch_size", 1, self.config.text_config.hidden_size], self.dtype ), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -313,7 +313,7 @@ def get_default_spec(self): "input_embeds": nn.spec.Tensor( [1, "seq_len", self.config.text_config.hidden_size], self.dtype ), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", diff --git a/python/mlc_llm/model/minicpm/minicpm_model.py b/python/mlc_llm/model/minicpm/minicpm_model.py index 4a2c2eaef9..a3f4d1a9cd 100644 --- a/python/mlc_llm/model/minicpm/minicpm_model.py +++ b/python/mlc_llm/model/minicpm/minicpm_model.py @@ -12,7 +12,7 @@ from tvm.relax.frontend.nn import Tensor, op from mlc_llm import op as op_ext -from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.nn import PagedKVCache, RopeMode, create_generic_paged_kv_cache from mlc_llm.nn.expert import MixtralExperts from mlc_llm.support import logging from mlc_llm.support import tensor_parallel as tp @@ -430,7 +430,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return create_generic_paged_kv_cache( attn_kind="mha", max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, @@ -459,7 +459,7 @@ def get_default_spec(self): }, "prefill": { "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -467,7 +467,7 @@ def get_default_spec(self): }, "decode": { "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -476,7 +476,7 @@ def get_default_spec(self): "batch_prefill": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -484,7 +484,7 @@ def get_default_spec(self): }, "batch_decode": { "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -492,7 +492,7 @@ def get_default_spec(self): }, "batch_verify": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", diff --git a/python/mlc_llm/model/mistral/mistral_model.py b/python/mlc_llm/model/mistral/mistral_model.py index 5bab089744..a69d4da9dd 100644 --- a/python/mlc_llm/model/mistral/mistral_model.py +++ b/python/mlc_llm/model/mistral/mistral_model.py @@ -10,7 +10,7 @@ from tvm.relax.frontend.nn import Tensor, op from mlc_llm import op as op_ext -from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.nn import PagedKVCache, RopeMode, create_generic_paged_kv_cache from mlc_llm.support import logging from mlc_llm.support import tensor_parallel as tp from mlc_llm.support.config import ConfigBase @@ -304,7 +304,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return create_generic_paged_kv_cache( attn_kind="mha", max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, @@ -333,7 +333,7 @@ def get_default_spec(self): }, "prefill": { "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -341,7 +341,7 @@ def get_default_spec(self): }, "decode": { "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -350,7 +350,7 @@ def get_default_spec(self): "batch_prefill": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -358,7 +358,7 @@ def get_default_spec(self): }, "batch_decode": { "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -366,7 +366,7 @@ def get_default_spec(self): }, "batch_verify": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", diff --git a/python/mlc_llm/model/nemotron/nemotron_model.py b/python/mlc_llm/model/nemotron/nemotron_model.py index afd4fd07c0..2710a69c98 100644 --- a/python/mlc_llm/model/nemotron/nemotron_model.py +++ b/python/mlc_llm/model/nemotron/nemotron_model.py @@ -10,7 +10,7 @@ from tvm.relax.frontend.nn import Tensor, op from mlc_llm import op as op_ext -from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.nn import PagedKVCache, RopeMode, create_generic_paged_kv_cache from mlc_llm.support import logging from mlc_llm.support import tensor_parallel as tp from mlc_llm.support.config import ConfigBase @@ -380,7 +380,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return create_generic_paged_kv_cache( attn_kind="mha", max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, @@ -428,7 +428,7 @@ def get_default_spec(self): }, "prefill": { "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -436,7 +436,7 @@ def get_default_spec(self): }, "decode": { "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -444,7 +444,7 @@ def get_default_spec(self): }, "prefill_to_last_hidden_states": { "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -452,7 +452,7 @@ def get_default_spec(self): }, "decode_to_last_hidden_states": { "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -461,7 +461,7 @@ def get_default_spec(self): "batch_prefill": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -469,7 +469,7 @@ def get_default_spec(self): }, "batch_decode": { "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -477,7 +477,7 @@ def get_default_spec(self): }, "batch_verify": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -485,7 +485,7 @@ def get_default_spec(self): }, "batch_prefill_to_last_hidden_states": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -493,7 +493,7 @@ def get_default_spec(self): }, "batch_decode_to_last_hidden_states": { "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -501,7 +501,7 @@ def get_default_spec(self): }, "batch_verify_to_last_hidden_states": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", diff --git a/python/mlc_llm/model/olmo/olmo_model.py b/python/mlc_llm/model/olmo/olmo_model.py index 4427d89b42..2a4d27d4b4 100644 --- a/python/mlc_llm/model/olmo/olmo_model.py +++ b/python/mlc_llm/model/olmo/olmo_model.py @@ -12,7 +12,7 @@ from tvm.relax.frontend.nn import Tensor, op from mlc_llm import op as op_ext -from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.nn import PagedKVCache, RopeMode, create_generic_paged_kv_cache from mlc_llm.support import logging from mlc_llm.support import tensor_parallel as tp from mlc_llm.support.config import ConfigBase @@ -452,7 +452,7 @@ def create_paged_kv_cache( # pylint: disable=missing-function-docstring,too-man page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return create_generic_paged_kv_cache( attn_kind="mha", max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, @@ -498,7 +498,7 @@ def get_default_spec(self): # pylint: disable=missing-function-docstring }, "prefill": { "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -506,7 +506,7 @@ def get_default_spec(self): # pylint: disable=missing-function-docstring }, "decode": { "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -514,7 +514,7 @@ def get_default_spec(self): # pylint: disable=missing-function-docstring }, "prefill_to_last_hidden_states": { "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -522,7 +522,7 @@ def get_default_spec(self): # pylint: disable=missing-function-docstring }, "decode_to_last_hidden_states": { "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -531,7 +531,7 @@ def get_default_spec(self): # pylint: disable=missing-function-docstring "batch_prefill": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -539,7 +539,7 @@ def get_default_spec(self): # pylint: disable=missing-function-docstring }, "batch_decode": { "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -547,7 +547,7 @@ def get_default_spec(self): # pylint: disable=missing-function-docstring }, "batch_verify": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -555,7 +555,7 @@ def get_default_spec(self): # pylint: disable=missing-function-docstring }, "batch_prefill_to_last_hidden_states": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -563,7 +563,7 @@ def get_default_spec(self): # pylint: disable=missing-function-docstring }, "batch_decode_to_last_hidden_states": { "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -571,7 +571,7 @@ def get_default_spec(self): # pylint: disable=missing-function-docstring }, "batch_verify_to_last_hidden_states": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", diff --git a/python/mlc_llm/model/orion/orion_model.py b/python/mlc_llm/model/orion/orion_model.py index 224b2d9b3e..e2faaf656c 100644 --- a/python/mlc_llm/model/orion/orion_model.py +++ b/python/mlc_llm/model/orion/orion_model.py @@ -10,7 +10,7 @@ from tvm.relax.frontend.nn import Tensor, op from mlc_llm import op as op_ext -from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.nn import PagedKVCache, RopeMode, create_generic_paged_kv_cache from mlc_llm.support import logging from mlc_llm.support import tensor_parallel as tp from mlc_llm.support.config import ConfigBase @@ -286,7 +286,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return create_generic_paged_kv_cache( attn_kind="mha", max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, @@ -315,7 +315,7 @@ def get_default_spec(self): }, "prefill": { "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -323,7 +323,7 @@ def get_default_spec(self): }, "decode": { "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -332,7 +332,7 @@ def get_default_spec(self): "batch_prefill": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -340,7 +340,7 @@ def get_default_spec(self): }, "batch_decode": { "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -348,7 +348,7 @@ def get_default_spec(self): }, "batch_verify": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", diff --git a/python/mlc_llm/model/phi/phi_model.py b/python/mlc_llm/model/phi/phi_model.py index 2809d2e0f4..3a5521002a 100644 --- a/python/mlc_llm/model/phi/phi_model.py +++ b/python/mlc_llm/model/phi/phi_model.py @@ -10,7 +10,7 @@ from tvm.relax.frontend.nn import Tensor, op from mlc_llm import op as op_ext -from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.nn import PagedKVCache, RopeMode, create_generic_paged_kv_cache from mlc_llm.support import logging from mlc_llm.support import tensor_parallel as tp from mlc_llm.support.config import ConfigBase @@ -407,7 +407,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return create_generic_paged_kv_cache( attn_kind="mha", max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, @@ -437,7 +437,7 @@ def get_default_spec(self): }, "prefill": { "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -445,7 +445,7 @@ def get_default_spec(self): }, "decode": { "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -454,7 +454,7 @@ def get_default_spec(self): "batch_prefill": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -462,7 +462,7 @@ def get_default_spec(self): }, "batch_decode": { "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -470,7 +470,7 @@ def get_default_spec(self): }, "batch_verify": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", diff --git a/python/mlc_llm/model/phi3/phi3_model.py b/python/mlc_llm/model/phi3/phi3_model.py index 0ef6f4b7dc..b5e2102ebc 100644 --- a/python/mlc_llm/model/phi3/phi3_model.py +++ b/python/mlc_llm/model/phi3/phi3_model.py @@ -10,7 +10,7 @@ from tvm.relax.frontend.nn import Tensor, op from mlc_llm import op as op_ext -from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.nn import PagedKVCache, RopeMode, create_generic_paged_kv_cache from mlc_llm.support import logging from mlc_llm.support import tensor_parallel as tp from mlc_llm.support.config import ConfigBase @@ -304,7 +304,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return create_generic_paged_kv_cache( attn_kind="mha", max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, @@ -335,7 +335,7 @@ def get_default_spec(self): }, "prefill": { "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -343,7 +343,7 @@ def get_default_spec(self): }, "decode": { "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -352,7 +352,7 @@ def get_default_spec(self): "batch_prefill": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -360,7 +360,7 @@ def get_default_spec(self): }, "batch_decode": { "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -368,7 +368,7 @@ def get_default_spec(self): }, "batch_verify": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", diff --git a/python/mlc_llm/model/phi3v/phi3v_model.py b/python/mlc_llm/model/phi3v/phi3v_model.py index 02b88bee8a..83d10daa2a 100644 --- a/python/mlc_llm/model/phi3v/phi3v_model.py +++ b/python/mlc_llm/model/phi3v/phi3v_model.py @@ -12,7 +12,7 @@ from mlc_llm import op as op_ext from mlc_llm.model.phi3 import Phi3Model from mlc_llm.model.vision import CLIPVisionConfig, ImageProcessor -from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.nn import PagedKVCache, RopeMode, create_generic_paged_kv_cache from mlc_llm.support import logging from mlc_llm.support.config import ConfigBase from mlc_llm.support.style import bold @@ -282,7 +282,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return create_generic_paged_kv_cache( attn_kind="mha", max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, @@ -324,7 +324,7 @@ def get_default_spec(self): }, "prefill": { "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -332,7 +332,7 @@ def get_default_spec(self): }, "decode": { "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -341,7 +341,7 @@ def get_default_spec(self): "batch_prefill": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -349,7 +349,7 @@ def get_default_spec(self): }, "batch_decode": { "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -357,7 +357,7 @@ def get_default_spec(self): }, "batch_verify": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", diff --git a/python/mlc_llm/model/qwen/qwen_model.py b/python/mlc_llm/model/qwen/qwen_model.py index 6daedcb1fe..3d0d1f7481 100644 --- a/python/mlc_llm/model/qwen/qwen_model.py +++ b/python/mlc_llm/model/qwen/qwen_model.py @@ -10,7 +10,7 @@ from tvm.relax.frontend.nn import Tensor, op from mlc_llm import op as op_ext -from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.nn import PagedKVCache, RopeMode, create_generic_paged_kv_cache from mlc_llm.support import logging from mlc_llm.support import tensor_parallel as tp from mlc_llm.support.config import ConfigBase @@ -286,7 +286,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return create_generic_paged_kv_cache( attn_kind="mha", max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, @@ -315,7 +315,7 @@ def get_default_spec(self): }, "prefill": { "inputs": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -323,7 +323,7 @@ def get_default_spec(self): }, "decode": { "inputs": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -332,7 +332,7 @@ def get_default_spec(self): "batch_prefill": { "inputs": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -340,7 +340,7 @@ def get_default_spec(self): }, "batch_decode": { "inputs": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -348,7 +348,7 @@ def get_default_spec(self): }, "batch_verify": { "inputs": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", diff --git a/python/mlc_llm/model/qwen2/qwen2_model.py b/python/mlc_llm/model/qwen2/qwen2_model.py index 2ee4909a9a..e48f28849f 100644 --- a/python/mlc_llm/model/qwen2/qwen2_model.py +++ b/python/mlc_llm/model/qwen2/qwen2_model.py @@ -11,7 +11,7 @@ from tvm.relax.frontend.nn import Tensor, op from mlc_llm import op as op_ext -from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.nn import PagedKVCache, RopeMode, create_generic_paged_kv_cache from mlc_llm.support import logging from mlc_llm.support import tensor_parallel as tp from mlc_llm.support.config import ConfigBase @@ -326,7 +326,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return create_generic_paged_kv_cache( attn_kind="mha", max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, @@ -355,7 +355,7 @@ def get_default_spec(self): }, "prefill": { "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -363,7 +363,7 @@ def get_default_spec(self): }, "decode": { "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -372,7 +372,7 @@ def get_default_spec(self): "batch_prefill": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -380,7 +380,7 @@ def get_default_spec(self): }, "batch_decode": { "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -388,7 +388,7 @@ def get_default_spec(self): }, "batch_verify": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", diff --git a/python/mlc_llm/model/qwen2_moe/qwen2_moe_model.py b/python/mlc_llm/model/qwen2_moe/qwen2_moe_model.py index 1fbc63c90b..91c6a0a53c 100644 --- a/python/mlc_llm/model/qwen2_moe/qwen2_moe_model.py +++ b/python/mlc_llm/model/qwen2_moe/qwen2_moe_model.py @@ -11,7 +11,7 @@ from mlc_llm import op as op_ext from mlc_llm.model.qwen2.qwen2_model import ACT2FN, QWen2Attention, QWen2Config -from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.nn import PagedKVCache, RopeMode, create_generic_paged_kv_cache from mlc_llm.nn.expert import MixtralExperts from mlc_llm.support import logging from mlc_llm.support import tensor_parallel as tp @@ -307,7 +307,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return create_generic_paged_kv_cache( attn_kind="mha", max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, @@ -336,7 +336,7 @@ def get_default_spec(self): }, "prefill": { "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -344,7 +344,7 @@ def get_default_spec(self): }, "decode": { "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -353,7 +353,7 @@ def get_default_spec(self): "batch_prefill": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -361,7 +361,7 @@ def get_default_spec(self): }, "batch_decode": { "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -369,7 +369,7 @@ def get_default_spec(self): }, "batch_verify": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", diff --git a/python/mlc_llm/model/stable_lm/stablelm_model.py b/python/mlc_llm/model/stable_lm/stablelm_model.py index 575fda54b4..75ba0d47ae 100644 --- a/python/mlc_llm/model/stable_lm/stablelm_model.py +++ b/python/mlc_llm/model/stable_lm/stablelm_model.py @@ -10,7 +10,7 @@ from tvm.relax.frontend.nn import Tensor, op from mlc_llm import op as op_ext -from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.nn import PagedKVCache, RopeMode, create_generic_paged_kv_cache from mlc_llm.support import logging from mlc_llm.support import tensor_parallel as tp from mlc_llm.support.config import ConfigBase @@ -294,7 +294,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return create_generic_paged_kv_cache( attn_kind="mha", max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, @@ -324,7 +324,7 @@ def get_default_spec(self): }, "prefill": { "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -332,7 +332,7 @@ def get_default_spec(self): }, "decode": { "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -341,7 +341,7 @@ def get_default_spec(self): "batch_prefill": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -349,7 +349,7 @@ def get_default_spec(self): }, "batch_decode": { "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -357,7 +357,7 @@ def get_default_spec(self): }, "batch_verify": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", diff --git a/python/mlc_llm/model/starcoder2/starcoder2_model.py b/python/mlc_llm/model/starcoder2/starcoder2_model.py index 670cd6b6a4..985ad3b7f5 100644 --- a/python/mlc_llm/model/starcoder2/starcoder2_model.py +++ b/python/mlc_llm/model/starcoder2/starcoder2_model.py @@ -10,7 +10,7 @@ from tvm.relax.frontend.nn import Tensor, op from mlc_llm import op as op_ext -from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.nn import PagedKVCache, RopeMode, create_generic_paged_kv_cache from mlc_llm.support import logging from mlc_llm.support import tensor_parallel as tp from mlc_llm.support.config import ConfigBase @@ -312,7 +312,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments page_size: tir.Var, support_sliding_window: tir.Var, ) -> PagedKVCache: - return PagedKVCache.create_generic( + return create_generic_paged_kv_cache( attn_kind="mha", max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, @@ -341,7 +341,7 @@ def get_default_spec(self): }, "prefill": { "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -349,7 +349,7 @@ def get_default_spec(self): }, "decode": { "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -358,7 +358,7 @@ def get_default_spec(self): "batch_prefill": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -366,7 +366,7 @@ def get_default_spec(self): }, "batch_decode": { "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", @@ -374,7 +374,7 @@ def get_default_spec(self): }, "batch_verify": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), - "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "paged_kv_cache": nn.spec.PagedKVCache(), "$": { "param_mode": "packed", "effect_mode": "none", diff --git a/python/mlc_llm/nn/__init__.py b/python/mlc_llm/nn/__init__.py index bd3ee5cf53..be903a8058 100644 --- a/python/mlc_llm/nn/__init__.py +++ b/python/mlc_llm/nn/__init__.py @@ -1,4 +1,4 @@ """Common `nn.Modules` used to define LLMs in this project.""" from .expert import MixtralExperts -from .kv_cache import PagedKVCache, RopeMode +from .kv_cache import PagedKVCache, RopeMode, create_generic_paged_kv_cache diff --git a/python/mlc_llm/nn/kv_cache.py b/python/mlc_llm/nn/kv_cache.py index 2dd878c8a4..4e53dc9154 100644 --- a/python/mlc_llm/nn/kv_cache.py +++ b/python/mlc_llm/nn/kv_cache.py @@ -7,84 +7,79 @@ import numpy as np from tvm import relax as rx from tvm import tir -from tvm.relax.frontend.nn.llm.kv_cache import PagedKVCache as TVMPagedKVCache -from tvm.relax.frontend.nn.llm.kv_cache import RopeMode +from tvm.relax.frontend.nn.llm.kv_cache import PagedKVCache, RopeMode -class PagedKVCache(TVMPagedKVCache): # pylint: disable=too-few-public-methods - """The Paged KV Cache used in LLM batching for efficient attention computation.""" - - @staticmethod - def create_generic( # pylint: disable=too-many-locals - attn_kind: Literal["mha", "mla"], - max_batch_size: tir.Var, - max_total_seq_len: tir.Var, - prefill_chunk_size: tir.Var, - page_size: tir.Var, - support_sliding_window: tir.Var, - num_hidden_layers: int, - num_attention_heads: int, - num_key_value_heads: int, - qk_head_dim: int, - v_head_dim: int, - rope_mode: RopeMode, - rope_scale: int, - rope_theta: int, - dtype: str, - mla_original_qk_head_dim: int = 0, - mla_original_v_head_dim: int = 0, - rotary_dim: Optional[int] = None, - rope_scaling: Optional[Dict[str, Any]] = None, - rope_ext_factors: Optional[List[int]] = None, - layer_partition: Optional[List[int]] = None, - enable_disaggregation: bool = False, - name: str = "paged_kv_cache", - ) -> "PagedKVCache": - """The generic function of creating a multi-head attention PagedKVCache, - which will be rewritten by functions in compilation pipeline. - """ - if rotary_dim is None: - rotary_dim = qk_head_dim - if rope_scaling is None: - rope_scaling = {} - if layer_partition is None: - layer_partition = [0, num_hidden_layers] - return PagedKVCache( - _expr=rx.call_pure_packed( - "mlc.create_paged_kv_cache_generic", - rx.StringImm(attn_kind), - rx.ShapeExpr( - [ - max_batch_size, - max_total_seq_len, - prefill_chunk_size, - page_size, - support_sliding_window, - ] - ), - rx.ShapeExpr(layer_partition), - rx.PrimValue(num_hidden_layers), - rx.PrimValue(num_attention_heads), - rx.PrimValue(num_key_value_heads), - rx.PrimValue(qk_head_dim), - rx.PrimValue(v_head_dim), - rx.PrimValue(mla_original_qk_head_dim), - rx.PrimValue(mla_original_v_head_dim), - rx.PrimValue(rope_mode), - rx.PrimValue(rope_scale), - rx.PrimValue(rope_theta), - rx.StringImm(json.dumps(rope_scaling)), - ( - rx.const(np.array(rope_ext_factors, "float32")) - if rope_ext_factors is not None - else rx.PrimValue(0) - # NOTE: since relax does not have "Optional" type, we use PrimValue(0) - # to represent "undefined". - ), - rx.PrimValue(rotary_dim), - rx.PrimValue(int(enable_disaggregation)), - rx.DataTypeImm(dtype), - sinfo_args=rx.ObjectStructInfo(), +def create_generic_paged_kv_cache( # pylint: disable=too-many-locals + attn_kind: Literal["mha", "mla"], + max_batch_size: tir.Var, + max_total_seq_len: tir.Var, + prefill_chunk_size: tir.Var, + page_size: tir.Var, + support_sliding_window: tir.Var, + num_hidden_layers: int, + num_attention_heads: int, + num_key_value_heads: int, + qk_head_dim: int, + v_head_dim: int, + rope_mode: RopeMode, + rope_scale: int, + rope_theta: int, + dtype: str, + mla_original_qk_head_dim: int = 0, + mla_original_v_head_dim: int = 0, + rotary_dim: Optional[int] = None, + rope_scaling: Optional[Dict[str, Any]] = None, + rope_ext_factors: Optional[List[int]] = None, + layer_partition: Optional[List[int]] = None, + enable_disaggregation: bool = False, + name: str = "paged_kv_cache", +) -> "PagedKVCache": + """The generic function of creating a multi-head attention PagedKVCache, + which will be rewritten by functions in compilation pipeline. + """ + if rotary_dim is None: + rotary_dim = qk_head_dim + if rope_scaling is None: + rope_scaling = {} + if layer_partition is None: + layer_partition = [0, num_hidden_layers] + return PagedKVCache( + _expr=rx.call_pure_packed( + "mlc.create_paged_kv_cache_generic", + rx.StringImm(attn_kind), + rx.ShapeExpr( + [ + max_batch_size, + max_total_seq_len, + prefill_chunk_size, + page_size, + support_sliding_window, + ] + ), + rx.ShapeExpr(layer_partition), + rx.PrimValue(num_hidden_layers), + rx.PrimValue(num_attention_heads), + rx.PrimValue(num_key_value_heads), + rx.PrimValue(qk_head_dim), + rx.PrimValue(v_head_dim), + rx.PrimValue(mla_original_qk_head_dim), + rx.PrimValue(mla_original_v_head_dim), + rx.PrimValue(rope_mode), + rx.PrimValue(rope_scale), + rx.PrimValue(rope_theta), + rx.StringImm(json.dumps(rope_scaling)), + ( + rx.const(np.array(rope_ext_factors, "float32")) + if rope_ext_factors is not None + else rx.PrimValue(0) + # NOTE: since relax does not have "Optional" type, we use PrimValue(0) + # to represent "undefined". ), - _name=name, - ) + rx.PrimValue(rotary_dim), + rx.PrimValue(int(enable_disaggregation)), + rx.DataTypeImm(dtype), + sinfo_args=rx.ObjectStructInfo(), + ), + _name=name, + )