1
1
"""A pass that rewrites KV cache creation functions in IRModule."""
2
2
3
3
import json
4
- from typing import Any , Dict
4
+ from typing import Any , Dict , Literal , Tuple
5
5
6
6
import tvm
7
7
from tvm import IRModule , relax
8
8
from tvm .relax .frontend .nn .llm import kv_cache
9
9
from tvm .relax .frontend .nn .llm .kv_cache import RopeMode
10
10
11
11
12
- def extract_creation_args (func : relax .Function ) -> Dict [str , Any ]:
12
+ def extract_creation_args (func : relax .Function ) -> Tuple [ Literal [ "mha" , "mla" ], Dict [str , Any ] ]:
13
13
"""Extract the KV cache creation args from the given generic creation func."""
14
14
assert isinstance (func .body , relax .SeqExpr )
15
15
assert len (func .body .blocks ) == 1
16
16
assert isinstance (func .body .blocks [0 ], relax .DataflowBlock )
17
17
assert isinstance (func .body .blocks [0 ].bindings [0 ], relax .VarBinding )
18
18
assert isinstance (func .body .blocks [0 ].bindings [0 ].value , relax .Call )
19
19
assert func .body .blocks [0 ].bindings [0 ].value .op == tvm .ir .Op .get ("relax.call_pure_packed" )
20
- args = func .body .blocks [0 ].bindings [0 ].value .args
21
- assert isinstance (args [0 ], relax .ExternFunc )
22
- assert args [0 ].global_symbol == "mlc.create_paged_kv_cache_generic"
23
-
24
- assert len (args ) == 15
25
- assert isinstance (args [1 ], relax .ShapeExpr )
26
- assert len (args [1 ].values ) == 5
27
- assert isinstance (args [2 ], relax .ShapeExpr )
28
- for i in range (3 , 14 ):
29
- if i in [10 , 11 ]:
30
- continue
31
- assert isinstance (args [i ], relax .PrimValue )
32
- assert isinstance (args [i ].value , (tvm .tir .IntImm , tvm .tir .FloatImm ))
33
- assert isinstance (args [10 ], relax .StringImm )
34
- assert isinstance (args [11 ], (relax .Constant , relax .PrimValue ))
35
- assert isinstance (args [14 ], relax .DataTypeImm )
36
-
37
- return {
38
- "max_batch_size" : args [1 ].values [0 ],
39
- "max_total_seq_len" : args [1 ].values [1 ],
40
- "prefill_chunk_size" : args [1 ].values [2 ],
41
- "page_size" : args [1 ].values [3 ],
42
- "support_sliding_window" : args [1 ].values [4 ],
43
- "layer_partition" : args [2 ],
44
- "num_hidden_layers" : args [3 ].value .value ,
45
- "num_attention_heads" : args [4 ].value .value ,
46
- "num_key_value_heads" : args [5 ].value .value ,
47
- "head_dim" : args [6 ].value .value ,
48
- "rope_mode" : args [7 ].value .value ,
49
- "rope_scale" : args [8 ].value .value ,
50
- "rope_theta" : args [9 ].value .value ,
51
- "rope_scaling" : json .loads (args [10 ].value ),
52
- "rope_ext_factors" : args [11 ],
53
- "rotary_dim" : args [12 ].value .value ,
54
- "enable_disaggregation" : bool (args [13 ].value .value ),
55
- "dtype" : args [14 ].value ,
56
- }
20
+ call_args = func .body .blocks [0 ].bindings [0 ].value .args
21
+ assert isinstance (call_args [0 ], relax .ExternFunc )
22
+ assert call_args [0 ].global_symbol == "mlc.create_paged_kv_cache_generic"
23
+ assert isinstance (call_args [1 ], relax .StringImm )
24
+
25
+ args = call_args [1 :]
26
+ if args [0 ].value == "mha" :
27
+ assert len (args ) == 15
28
+ assert isinstance (args [1 ], relax .ShapeExpr )
29
+ assert len (args [1 ].values ) == 5
30
+ assert isinstance (args [2 ], relax .ShapeExpr )
31
+ for i in range (3 , 14 ):
32
+ if i in [10 , 11 ]:
33
+ continue
34
+ assert isinstance (args [i ], relax .PrimValue )
35
+ assert isinstance (args [i ].value , (tvm .tir .IntImm , tvm .tir .FloatImm ))
36
+ assert isinstance (args [10 ], relax .StringImm )
37
+ assert isinstance (args [11 ], (relax .Constant , relax .PrimValue ))
38
+ assert isinstance (args [14 ], relax .DataTypeImm )
39
+
40
+ return "mha" , {
41
+ "max_batch_size" : args [1 ].values [0 ],
42
+ "max_total_seq_len" : args [1 ].values [1 ],
43
+ "prefill_chunk_size" : args [1 ].values [2 ],
44
+ "page_size" : args [1 ].values [3 ],
45
+ "support_sliding_window" : args [1 ].values [4 ],
46
+ "layer_partition" : args [2 ],
47
+ "num_hidden_layers" : args [3 ].value .value ,
48
+ "num_attention_heads" : args [4 ].value .value ,
49
+ "num_key_value_heads" : args [5 ].value .value ,
50
+ "head_dim" : args [6 ].value .value ,
51
+ "rope_mode" : args [7 ].value .value ,
52
+ "rope_scale" : args [8 ].value .value ,
53
+ "rope_theta" : args [9 ].value .value ,
54
+ "rope_scaling" : json .loads (args [10 ].value ),
55
+ "rope_ext_factors" : args [11 ],
56
+ "rotary_dim" : args [12 ].value .value ,
57
+ "enable_disaggregation" : bool (args [13 ].value .value ),
58
+ "dtype" : args [14 ].value ,
59
+ }
60
+ if call_args [1 ].value == "mla" :
61
+ assert len (args ) == 12
62
+ assert isinstance (args [1 ], relax .ShapeExpr )
63
+ assert len (args [1 ].values ) == 5
64
+ assert isinstance (args [2 ], relax .ShapeExpr )
65
+ for i in range (3 , 11 ):
66
+ assert isinstance (args [i ], relax .PrimValue )
67
+ assert isinstance (args [i ].value , tvm .tir .IntImm )
68
+ assert isinstance (args [11 ], relax .DataTypeImm )
69
+
70
+ return "mla" , {
71
+ "max_batch_size" : args [1 ].values [0 ],
72
+ "max_total_seq_len" : args [1 ].values [1 ],
73
+ "prefill_chunk_size" : args [1 ].values [2 ],
74
+ "page_size" : args [1 ].values [3 ],
75
+ "support_sliding_window" : args [1 ].values [4 ],
76
+ "layer_partition" : args [2 ],
77
+ "num_hidden_layers" : args [3 ].value .value ,
78
+ "num_attention_heads" : args [4 ].value .value ,
79
+ "num_key_value_heads" : args [5 ].value .value ,
80
+ "qk_nope_head_dim" : args [6 ].value .value ,
81
+ "qk_rope_head_dim" : args [7 ].value .value ,
82
+ "v_head_dim" : args [8 ].value .value ,
83
+ "kv_lora_rank" : args [9 ].value .value ,
84
+ "enable_disaggregation" : bool (args [10 ].value .value ),
85
+ "dtype" : args [11 ].value ,
86
+ }
87
+
88
+ raise ValueError ("Cannot reach here" )
57
89
58
90
59
91
@tvm .transform .module_pass (opt_level = 0 , name = "DispatchKVCacheCreation" )
@@ -100,24 +132,38 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR
100
132
if mod .attrs is not None :
101
133
new_mod = new_mod .with_attrs (mod .attrs )
102
134
103
- kwargs = extract_creation_args (creation_func )
104
- self .attach_kv_cache_metadata (kwargs )
135
+ kv_cache_kind , kwargs = extract_creation_args (creation_func )
136
+ self .attach_kv_cache_metadata (kv_cache_kind , kwargs )
105
137
106
138
bb = relax .BlockBuilder (new_mod )
107
- self .create_tir_paged_kv_cache (bb , kwargs )
108
- self .create_flashinfer_paged_kv_cache (bb , kwargs )
139
+ self .create_tir_paged_kv_cache (bb , kv_cache_kind , kwargs )
140
+ self .create_flashinfer_paged_kv_cache (bb , kv_cache_kind , kwargs )
109
141
return bb .finalize ()
110
142
111
- def attach_kv_cache_metadata (self , kwargs : Dict [str , Any ]):
143
+ def attach_kv_cache_metadata (
144
+ self , kv_cache_kind : Literal ["mha" , "mla" ], kwargs : Dict [str , Any ]
145
+ ):
112
146
"""Attach the KV cache metadata to model metadata."""
113
- self .metadata ["kv_cache" ] = {
114
- "num_hidden_layers" : kwargs ["num_hidden_layers" ],
115
- "num_attention_heads" : kwargs ["num_attention_heads" ],
116
- "num_key_value_heads" : kwargs ["num_key_value_heads" ],
117
- "head_dim" : kwargs ["head_dim" ],
118
- }
119
-
120
- def create_tir_paged_kv_cache (self , bb : relax .BlockBuilder , kwargs : Dict [str , Any ]) -> None :
147
+ if kv_cache_kind == "mha" :
148
+ self .metadata ["kv_cache" ] = {
149
+ "num_hidden_layers" : kwargs ["num_hidden_layers" ],
150
+ "num_attention_heads" : kwargs ["num_attention_heads" ],
151
+ "num_key_value_heads" : kwargs ["num_key_value_heads" ],
152
+ "head_dim" : kwargs ["head_dim" ],
153
+ }
154
+ elif kv_cache_kind == "mla" :
155
+ self .metadata ["kv_cache" ] = {
156
+ "num_hidden_layers" : kwargs ["num_hidden_layers" ],
157
+ "num_attention_heads" : kwargs ["num_attention_heads" ],
158
+ "num_key_value_heads" : 1 ,
159
+ "head_dim" : kwargs ["kv_lora_rank" ] + kwargs ["qk_rope_head_dim" ],
160
+ }
161
+ else :
162
+ raise ValueError ("Cannot reach here." )
163
+
164
+ def create_tir_paged_kv_cache (
165
+ self , bb : relax .BlockBuilder , kv_cache_kind : Literal ["mha" , "mla" ], kwargs : Dict [str , Any ]
166
+ ) -> None :
121
167
"""Create the TIR-based PagedKVCache"""
122
168
max_batch_size = relax .Var (
123
169
"max_batch_size_" , relax .ShapeStructInfo ([kwargs ["max_batch_size" ]])
@@ -143,16 +189,22 @@ def create_tir_paged_kv_cache(self, bb: relax.BlockBuilder, kwargs: Dict[str, An
143
189
support_sliding_window ,
144
190
],
145
191
):
146
- cache = kv_cache .TIRPagedKVCache (target = self .target , ** kwargs )
192
+ if kv_cache_kind == "mha" :
193
+ cache = kv_cache .TIRPagedKVCache (target = self .target , ** kwargs )
194
+ elif kv_cache_kind == "mla" :
195
+ cache = kv_cache .TIRPagedKVCache .create_mla_kv_cache (target = self .target , ** kwargs )
196
+ else :
197
+ raise ValueError ("Cannot reach here" )
147
198
bb .emit_func_output (cache ._expr ) # pylint: disable=protected-access
148
199
149
200
def create_flashinfer_paged_kv_cache (
150
- self , bb : relax .BlockBuilder , kwargs : Dict [str , Any ]
201
+ self , bb : relax .BlockBuilder , kv_cache_kind : Literal [ "mha" , "mla" ], kwargs : Dict [str , Any ]
151
202
) -> None :
152
203
"""Create the FlashInfer-based PagedKVCache"""
153
204
# Filter the cases which FlashInfer does not support.
154
205
if ( # pylint: disable=too-many-boolean-expressions
155
206
not self .flashinfer
207
+ or kv_cache_kind != "mha"
156
208
or str (kwargs ["dtype" ]) != "float16"
157
209
or kwargs ["head_dim" ] != 128
158
210
or (
0 commit comments