Skip to content

Commit ff9467e

Browse files
authored
[Bfloat16] Added support for bfloat16 in fuse add norm (#3211)
In this PR I have added support for bfloat16 in fuse add norm. I have confirmed that this is applied by compiling LLama3 in debug mode.
1 parent 1434760 commit ff9467e

File tree

1 file changed

+32
-22
lines changed

1 file changed

+32
-22
lines changed

python/mlc_llm/compiler_pass/fuse_add_norm.py

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
from ..support.max_thread_check import get_max_num_threads_per_block
1414

1515

16-
def _get_add_rms_norm_decode(hidden_size: int, eps: float, TX: int):
16+
def _get_add_rms_norm_decode(hidden_size: int, eps: float, TX: int, in_dtype: str):
17+
if in_dtype not in ("float16", "bfloat16"):
18+
raise ValueError(f"Unsupported data type: {in_dtype}")
1719
inv_hidden_size = T.float32(1.0 / float(hidden_size))
1820
eps = T.float32(eps)
1921
add_local_size = hidden_size // TX
@@ -24,12 +26,12 @@ def decode_add_rms( # pylint: disable=too-many-locals
2426
):
2527
T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1})
2628
batch_size = T.int32()
27-
A = T.match_buffer(pA, (batch_size, 1, hidden_size), "float16")
28-
B = T.match_buffer(pB, (batch_size, 1, hidden_size), "float16")
29-
C = T.match_buffer(pC, (hidden_size,), "float16")
30-
O = T.match_buffer(pO, (batch_size, 1, hidden_size), "float16")
31-
add = T.match_buffer(pAdd, (batch_size, 1, hidden_size), "float16")
32-
add_local = T.alloc_buffer((hidden_size // TX,), "float16", scope="local")
29+
A = T.match_buffer(pA, (batch_size, 1, hidden_size), in_dtype)
30+
B = T.match_buffer(pB, (batch_size, 1, hidden_size), in_dtype)
31+
C = T.match_buffer(pC, (hidden_size,), in_dtype)
32+
O = T.match_buffer(pO, (batch_size, 1, hidden_size), in_dtype)
33+
add = T.match_buffer(pAdd, (batch_size, 1, hidden_size), in_dtype)
34+
add_local = T.alloc_buffer((hidden_size // TX,), in_dtype, scope="local")
3335
sum_shared = T.alloc_buffer((batch_size, 1), scope="shared")
3436
sum_local = T.alloc_buffer((TX, batch_size, 1), scope="local")
3537
for v_bx in T.thread_binding(batch_size, thread="blockIdx.x"):
@@ -69,16 +71,19 @@ def decode_add_rms( # pylint: disable=too-many-locals
6971
with T.block("T_cast_2"):
7072
bx = T.axis.spatial(batch_size, v_bx)
7173
h = T.axis.spatial(hidden_size, i * TX + v_tx_2)
72-
O[bx, 0, h] = T.float16(
74+
O[bx, 0, h] = T.cast(
7375
T.rsqrt(sum_shared[bx, 0] * inv_hidden_size + eps)
7476
* T.float32(add_local[h // TX])
75-
* T.float32(C[h])
77+
* T.float32(C[h]),
78+
dtype=in_dtype,
7679
)
7780

7881
return decode_add_rms
7982

8083

81-
def _get_add_rms_norm_prefill(hidden_size: int, eps: float, TX: int):
84+
def _get_add_rms_norm_prefill(hidden_size: int, eps: float, TX: int, in_dtype: str):
85+
if in_dtype not in ("float16", "bfloat16"):
86+
raise ValueError(f"Unsupported data type: {in_dtype}")
8287
inv_hidden_size = T.float32(1.0 / float(hidden_size))
8388
eps = T.float32(eps)
8489
add_local_size = hidden_size // TX
@@ -89,12 +94,12 @@ def prefill_add_rms( # pylint: disable=too-many-locals
8994
):
9095
T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1})
9196
seq_len = T.int32()
92-
A = T.match_buffer(pA, (1, seq_len, hidden_size), "float16")
93-
B = T.match_buffer(pB, (1, seq_len, hidden_size), "float16")
94-
C = T.match_buffer(pC, (hidden_size,), "float16")
95-
O = T.match_buffer(pO, (1, seq_len, hidden_size), "float16")
96-
add = T.match_buffer(pAdd, (1, seq_len, hidden_size), "float16")
97-
add_local = T.alloc_buffer((hidden_size // TX,), "float16", scope="local")
97+
A = T.match_buffer(pA, (1, seq_len, hidden_size), in_dtype)
98+
B = T.match_buffer(pB, (1, seq_len, hidden_size), in_dtype)
99+
C = T.match_buffer(pC, (hidden_size,), in_dtype)
100+
O = T.match_buffer(pO, (1, seq_len, hidden_size), in_dtype)
101+
add = T.match_buffer(pAdd, (1, seq_len, hidden_size), in_dtype)
102+
add_local = T.alloc_buffer((hidden_size // TX,), in_dtype, scope="local")
98103
sum_shared = T.alloc_buffer((1, seq_len), scope="shared")
99104
sum_local = T.alloc_buffer((TX, 1, seq_len), scope="local")
100105
for v_bx in T.thread_binding(seq_len, thread="blockIdx.x"):
@@ -131,10 +136,11 @@ def prefill_add_rms( # pylint: disable=too-many-locals
131136
with T.block("T_cast_2"):
132137
bx = T.axis.spatial(seq_len, v_bx)
133138
v1 = T.axis.spatial(hidden_size, v_i * TX + v_tx_2)
134-
O[0, bx, v1] = T.float16(
139+
O[0, bx, v1] = T.cast(
135140
T.rsqrt(sum_shared[0, bx] * inv_hidden_size + eps)
136141
* T.float32(add_local[v1 // TX])
137-
* T.float32(C[v1])
142+
* T.float32(C[v1]),
143+
dtype=in_dtype,
138144
)
139145

140146
return prefill_add_rms
@@ -182,8 +188,10 @@ def visit_call_(self, call: relax.Call) -> relax.Expr: # pylint: disable=argume
182188
call = super().visit_call_(call)
183189

184190
# Match the "rms_norm(add(x1, x2), w)" pattern
185-
# Todo: support bf16 # pylint: disable=fixme
186-
if call.op != tvm.ir.Op.get("relax.nn.rms_norm") or call.struct_info.dtype != "float16":
191+
if call.op != tvm.ir.Op.get("relax.nn.rms_norm") or call.struct_info.dtype not in [
192+
"bfloat16",
193+
"float16",
194+
]:
187195
return call
188196
assert len(call.args) == 2
189197
weight = call.args[1]
@@ -206,12 +214,14 @@ def visit_call_(self, call: relax.Call) -> relax.Expr: # pylint: disable=argume
206214
if func_gv is None:
207215
if is_prefill:
208216
func_gv = self.builder_.add_func(
209-
_get_add_rms_norm_prefill(h, eps, self.TX), "fuse_add_norm_prefill"
217+
_get_add_rms_norm_prefill(h, eps, self.TX, call.struct_info.dtype),
218+
"fuse_add_norm_prefill",
210219
)
211220
self.prefill_norm_gv = func_gv
212221
else:
213222
func_gv = self.builder_.add_func(
214-
_get_add_rms_norm_decode(h, eps, self.TX), "fuse_add_norm_decode"
223+
_get_add_rms_norm_decode(h, eps, self.TX, call.struct_info.dtype),
224+
"fuse_add_norm_decode",
215225
)
216226
self.decode_norm_gv = func_gv
217227

0 commit comments

Comments
 (0)