13
13
from ..support .max_thread_check import get_max_num_threads_per_block
14
14
15
15
16
- def _get_add_rms_norm_decode (hidden_size : int , eps : float , TX : int ):
16
+ def _get_add_rms_norm_decode (hidden_size : int , eps : float , TX : int , in_dtype : str ):
17
+ if in_dtype not in ("float16" , "bfloat16" ):
18
+ raise ValueError (f"Unsupported data type: { in_dtype } " )
17
19
inv_hidden_size = T .float32 (1.0 / float (hidden_size ))
18
20
eps = T .float32 (eps )
19
21
add_local_size = hidden_size // TX
@@ -24,12 +26,12 @@ def decode_add_rms( # pylint: disable=too-many-locals
24
26
):
25
27
T .func_attr ({"tir.noalias" : T .bool (True ), "tir.is_scheduled" : 1 })
26
28
batch_size = T .int32 ()
27
- A = T .match_buffer (pA , (batch_size , 1 , hidden_size ), "float16" )
28
- B = T .match_buffer (pB , (batch_size , 1 , hidden_size ), "float16" )
29
- C = T .match_buffer (pC , (hidden_size ,), "float16" )
30
- O = T .match_buffer (pO , (batch_size , 1 , hidden_size ), "float16" )
31
- add = T .match_buffer (pAdd , (batch_size , 1 , hidden_size ), "float16" )
32
- add_local = T .alloc_buffer ((hidden_size // TX ,), "float16" , scope = "local" )
29
+ A = T .match_buffer (pA , (batch_size , 1 , hidden_size ), in_dtype )
30
+ B = T .match_buffer (pB , (batch_size , 1 , hidden_size ), in_dtype )
31
+ C = T .match_buffer (pC , (hidden_size ,), in_dtype )
32
+ O = T .match_buffer (pO , (batch_size , 1 , hidden_size ), in_dtype )
33
+ add = T .match_buffer (pAdd , (batch_size , 1 , hidden_size ), in_dtype )
34
+ add_local = T .alloc_buffer ((hidden_size // TX ,), in_dtype , scope = "local" )
33
35
sum_shared = T .alloc_buffer ((batch_size , 1 ), scope = "shared" )
34
36
sum_local = T .alloc_buffer ((TX , batch_size , 1 ), scope = "local" )
35
37
for v_bx in T .thread_binding (batch_size , thread = "blockIdx.x" ):
@@ -69,16 +71,19 @@ def decode_add_rms( # pylint: disable=too-many-locals
69
71
with T .block ("T_cast_2" ):
70
72
bx = T .axis .spatial (batch_size , v_bx )
71
73
h = T .axis .spatial (hidden_size , i * TX + v_tx_2 )
72
- O [bx , 0 , h ] = T .float16 (
74
+ O [bx , 0 , h ] = T .cast (
73
75
T .rsqrt (sum_shared [bx , 0 ] * inv_hidden_size + eps )
74
76
* T .float32 (add_local [h // TX ])
75
- * T .float32 (C [h ])
77
+ * T .float32 (C [h ]),
78
+ dtype = in_dtype ,
76
79
)
77
80
78
81
return decode_add_rms
79
82
80
83
81
- def _get_add_rms_norm_prefill (hidden_size : int , eps : float , TX : int ):
84
+ def _get_add_rms_norm_prefill (hidden_size : int , eps : float , TX : int , in_dtype : str ):
85
+ if in_dtype not in ("float16" , "bfloat16" ):
86
+ raise ValueError (f"Unsupported data type: { in_dtype } " )
82
87
inv_hidden_size = T .float32 (1.0 / float (hidden_size ))
83
88
eps = T .float32 (eps )
84
89
add_local_size = hidden_size // TX
@@ -89,12 +94,12 @@ def prefill_add_rms( # pylint: disable=too-many-locals
89
94
):
90
95
T .func_attr ({"tir.noalias" : T .bool (True ), "tir.is_scheduled" : 1 })
91
96
seq_len = T .int32 ()
92
- A = T .match_buffer (pA , (1 , seq_len , hidden_size ), "float16" )
93
- B = T .match_buffer (pB , (1 , seq_len , hidden_size ), "float16" )
94
- C = T .match_buffer (pC , (hidden_size ,), "float16" )
95
- O = T .match_buffer (pO , (1 , seq_len , hidden_size ), "float16" )
96
- add = T .match_buffer (pAdd , (1 , seq_len , hidden_size ), "float16" )
97
- add_local = T .alloc_buffer ((hidden_size // TX ,), "float16" , scope = "local" )
97
+ A = T .match_buffer (pA , (1 , seq_len , hidden_size ), in_dtype )
98
+ B = T .match_buffer (pB , (1 , seq_len , hidden_size ), in_dtype )
99
+ C = T .match_buffer (pC , (hidden_size ,), in_dtype )
100
+ O = T .match_buffer (pO , (1 , seq_len , hidden_size ), in_dtype )
101
+ add = T .match_buffer (pAdd , (1 , seq_len , hidden_size ), in_dtype )
102
+ add_local = T .alloc_buffer ((hidden_size // TX ,), in_dtype , scope = "local" )
98
103
sum_shared = T .alloc_buffer ((1 , seq_len ), scope = "shared" )
99
104
sum_local = T .alloc_buffer ((TX , 1 , seq_len ), scope = "local" )
100
105
for v_bx in T .thread_binding (seq_len , thread = "blockIdx.x" ):
@@ -131,10 +136,11 @@ def prefill_add_rms( # pylint: disable=too-many-locals
131
136
with T .block ("T_cast_2" ):
132
137
bx = T .axis .spatial (seq_len , v_bx )
133
138
v1 = T .axis .spatial (hidden_size , v_i * TX + v_tx_2 )
134
- O [0 , bx , v1 ] = T .float16 (
139
+ O [0 , bx , v1 ] = T .cast (
135
140
T .rsqrt (sum_shared [0 , bx ] * inv_hidden_size + eps )
136
141
* T .float32 (add_local [v1 // TX ])
137
- * T .float32 (C [v1 ])
142
+ * T .float32 (C [v1 ]),
143
+ dtype = in_dtype ,
138
144
)
139
145
140
146
return prefill_add_rms
@@ -182,8 +188,10 @@ def visit_call_(self, call: relax.Call) -> relax.Expr: # pylint: disable=argume
182
188
call = super ().visit_call_ (call )
183
189
184
190
# Match the "rms_norm(add(x1, x2), w)" pattern
185
- # Todo: support bf16 # pylint: disable=fixme
186
- if call .op != tvm .ir .Op .get ("relax.nn.rms_norm" ) or call .struct_info .dtype != "float16" :
191
+ if call .op != tvm .ir .Op .get ("relax.nn.rms_norm" ) or call .struct_info .dtype not in [
192
+ "bfloat16" ,
193
+ "float16" ,
194
+ ]:
187
195
return call
188
196
assert len (call .args ) == 2
189
197
weight = call .args [1 ]
@@ -206,12 +214,14 @@ def visit_call_(self, call: relax.Call) -> relax.Expr: # pylint: disable=argume
206
214
if func_gv is None :
207
215
if is_prefill :
208
216
func_gv = self .builder_ .add_func (
209
- _get_add_rms_norm_prefill (h , eps , self .TX ), "fuse_add_norm_prefill"
217
+ _get_add_rms_norm_prefill (h , eps , self .TX , call .struct_info .dtype ),
218
+ "fuse_add_norm_prefill" ,
210
219
)
211
220
self .prefill_norm_gv = func_gv
212
221
else :
213
222
func_gv = self .builder_ .add_func (
214
- _get_add_rms_norm_decode (h , eps , self .TX ), "fuse_add_norm_decode"
223
+ _get_add_rms_norm_decode (h , eps , self .TX , call .struct_info .dtype ),
224
+ "fuse_add_norm_decode" ,
215
225
)
216
226
self .decode_norm_gv = func_gv
217
227
0 commit comments