-
Notifications
You must be signed in to change notification settings - Fork 73
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Triton python code
@triton.jit
def _fwd_grouped_kernel_stage1(
Q,
K_Buffer,
V_Buffer,
sm_scale,
Req_to_tokens,
B_Seqlen,
Att_Out,
stride_req_to_tokens_b,
stride_qbs,
stride_qh,
stride_buf_kbs,
stride_buf_kh,
stride_buf_vbs,
stride_buf_vh,
stride_mid_ob,
stride_mid_oh,
stride_mid_os,
kv_group_num: tl.constexpr,
q_head_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_DPE: tl.constexpr,
BLOCK_DV: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_H: tl.constexpr,
NUM_KV_SPLITS: tl.constexpr,
PAGE_SIZE: tl.constexpr,
logit_cap: tl.constexpr,
Lk: tl.constexpr,
Lv: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head_id = tl.program_id(1)
cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H)
split_kv_id = tl.program_id(2)
if kv_group_num > BLOCK_H:
VALID_BLOCK_H: tl.constexpr = BLOCK_H
else:
VALID_BLOCK_H: tl.constexpr = kv_group_num
cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)
mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H
mask_h = mask_h & (cur_head < q_head_num)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_dv = tl.arange(0, BLOCK_DV)
mask_d = offs_d < Lk
mask_dv = offs_dv < Lv
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_req_idx = cur_batch
offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[
None, :]
q = tl.load(Q + offs_q,
mask=(mask_h[:, None]) & (mask_d[None, :]),
other=0.0)
if BLOCK_DPE > 0:
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
mask_dpe = offs_dpe < Lk
off_qpe = (cur_batch * stride_qbs + cur_head[:, None] * stride_qh +
offs_dpe[None, :])
qpe = tl.load(Q + off_qpe,
mask=(mask_h[:, None]) & (mask_dpe[None, :]),
other=0.0)
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
split_kv_start = kv_len_per_split * split_kv_id
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split,
cur_batch_seq_len)
e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf")
e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)
if split_kv_end > split_kv_start:
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
offs_n = start_n + tl.arange(0, BLOCK_N)
kv_page_number = tl.load(
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx +
offs_n // PAGE_SIZE,
mask=offs_n < split_kv_end,
other=0,
)
kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
offs_buf_k = (kv_loc[None, :] * stride_buf_kbs +
cur_kv_head * stride_buf_kh + offs_d[:, None])
k = tl.load(
K_Buffer + offs_buf_k,
mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]),
other=0.0,
)
qk = tl.dot(q, k.to(q.dtype))
if BLOCK_DPE > 0:
offs_buf_kpe = (kv_loc[None, :] * stride_buf_kbs +
cur_kv_head * stride_buf_kh +
offs_dpe[:, None])
kpe = tl.load(
K_Buffer + offs_buf_kpe,
mask=(offs_n[None, :] < split_kv_end) &
(mask_dpe[:, None]),
other=0.0,
)
qk += tl.dot(qpe, kpe.to(qpe.dtype))
qk *= sm_scale
if logit_cap > 0:
qk = logit_cap * tanh(qk / logit_cap)
qk = tl.where(mask_h[:, None] & (offs_n[None, :] < split_kv_end),
qk, float("-inf"))
offs_buf_v = (kv_loc[:, None] * stride_buf_vbs +
cur_kv_head * stride_buf_vh + offs_dv[None, :])
v = tl.load(
V_Buffer + offs_buf_v,
mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
other=0.0,
)
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
re_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max[:, None])
acc *= re_scale[:, None]
acc += tl.dot(p.to(v.dtype), v)
e_sum = e_sum * re_scale + tl.sum(p, 1)
e_max = n_e_max
offs_mid_o = (cur_batch * stride_mid_ob +
cur_head[:, None] * stride_mid_oh +
split_kv_id * stride_mid_os + offs_dv[None, :])
tl.store(
Att_Out + offs_mid_o,
acc / e_sum[:, None],
mask=(mask_h[:, None]) & (mask_dv[None, :]),
)
offs_mid_o_1 = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh +
split_kv_id * stride_mid_os + Lv)
tl.store(
Att_Out + offs_mid_o_1,
e_max + tl.log(e_sum),
mask=mask_h,
)
Triton IR
Crash log
// -----// IR Dump Before TritonToLinalg (triton-to-linalg) //----- //
module {
tt.func public @_fwd_grouped_kernel_stage1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32 {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: i32, %arg16: i32) attributes {noinline = false} {
%cst = arith.constant dense<0xFF800000> : tensor<16xf32>
%cst_0 = arith.constant dense<0.000000e+00> : tensor<16xf32>
%c3_i32 = arith.constant 3 : i32
%c4_i32 = arith.constant 4 : i32
%c32_i32 = arith.constant 32 : i32
%cst_1 = arith.constant dense<32> : tensor<16xi32>
%cst_2 = arith.constant dense<0.000000e+00> : tensor<32x32xf32>
%cst_3 = arith.constant dense<0xFF800000> : tensor<16x32xf32>
%cst_4 = arith.constant dense<0.000000e+00> : tensor<16x32xf32>
%cst_5 = arith.constant dense<0.000000e+00> : tensor<64x32xf32>
%cst_6 = arith.constant dense<0> : tensor<32xi32>
%cst_7 = arith.constant dense<0.000000e+00> : tensor<16x64xf32>
%cst_8 = arith.constant dense<32> : tensor<32xi32>
%cst_9 = arith.constant dense<64> : tensor<64xi32>
%cst_10 = arith.constant dense<8> : tensor<16xi32>
%c1_i32 = arith.constant 1 : i32
%c8_i32 = arith.constant 8 : i32
%0 = tt.get_program_id x : i32
%1 = tt.get_program_id y : i32
%2 = tt.get_program_id z : i32
%3 = arith.muli %1, %c8_i32 : i32
%4 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32>
%5 = tt.splat %3 : i32 -> tensor<16xi32>
%6 = arith.addi %5, %4 : tensor<16xi32>
%7 = arith.addi %1, %c1_i32 : i32
%8 = arith.muli %7, %c8_i32 : i32
%9 = tt.splat %8 : i32 -> tensor<16xi32>
%10 = arith.cmpi slt, %6, %9 : tensor<16xi32>
%11 = arith.cmpi slt, %6, %cst_10 : tensor<16xi32>
%12 = arith.andi %10, %11 : tensor<16xi1>
%13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%14 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
%15 = arith.cmpi slt, %13, %cst_9 : tensor<64xi32>
%16 = arith.cmpi slt, %14, %cst_8 : tensor<32xi32>
%17 = tt.addptr %arg5, %0 : !tt.ptr<i32>, i32
%18 = tt.load %17 : !tt.ptr<i32>
%19 = arith.muli %0, %arg8 : i32
%20 = tt.expand_dims %6 {axis = 1 : i32} : tensor<16xi32> -> tensor<16x1xi32>
%21 = tt.splat %arg9 : i32 -> tensor<16x1xi32>
%22 = arith.muli %20, %21 : tensor<16x1xi32>
%23 = tt.splat %19 : i32 -> tensor<16x1xi32>
%24 = arith.addi %23, %22 : tensor<16x1xi32>
%25 = tt.expand_dims %13 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32>
%26 = tt.broadcast %24 : tensor<16x1xi32> -> tensor<16x64xi32>
%27 = tt.broadcast %25 : tensor<1x64xi32> -> tensor<16x64xi32>
%28 = arith.addi %26, %27 : tensor<16x64xi32>
%29 = tt.expand_dims %12 {axis = 1 : i32} : tensor<16xi1> -> tensor<16x1xi1>
%30 = tt.expand_dims %15 {axis = 0 : i32} : tensor<64xi1> -> tensor<1x64xi1>
%31 = tt.broadcast %29 : tensor<16x1xi1> -> tensor<16x64xi1>
%32 = tt.broadcast %30 : tensor<1x64xi1> -> tensor<16x64xi1>
%33 = arith.andi %31, %32 : tensor<16x64xi1>
%34 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<16x64x!tt.ptr<f32>>
%35 = tt.addptr %34, %28 : tensor<16x64x!tt.ptr<f32>>, tensor<16x64xi32>
%36 = tt.load %35, %33, %cst_7 : tensor<16x64x!tt.ptr<f32>>
%37 = arith.addi %18, %c3_i32 : i32
%38 = arith.divsi %37, %c4_i32 : i32
%39 = arith.muli %38, %2 : i32
%40 = arith.addi %39, %38 : i32
%41 = arith.minsi %40, %18 : i32
%42 = arith.cmpi sgt, %41, %39 : i32
scf.if %42 {
%43 = tt.splat %41 : i32 -> tensor<32xi32>
%44 = arith.muli %arg7, %0 : i32
%45 = tt.addptr %arg4, %44 : !tt.ptr<i32>, i32
%46 = tt.splat %45 : !tt.ptr<i32> -> tensor<32x!tt.ptr<i32>>
%47 = tt.splat %arg10 : i32 -> tensor<1x32xi32>
%48 = arith.muli %1, %arg11 : i32
%49 = tt.splat %48 : i32 -> tensor<1x32xi32>
%50 = tt.expand_dims %13 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32>
%51 = tt.broadcast %50 : tensor<64x1xi32> -> tensor<64x32xi32>
%52 = tt.splat %41 : i32 -> tensor<1x32xi32>
%53 = tt.expand_dims %15 {axis = 1 : i32} : tensor<64xi1> -> tensor<64x1xi1>
%54 = tt.broadcast %53 : tensor<64x1xi1> -> tensor<64x32xi1>
%55 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<64x32x!tt.ptr<f32>>
%56 = tt.splat %arg3 : f32 -> tensor<16x32xf32>
%57 = tt.broadcast %29 : tensor<16x1xi1> -> tensor<16x32xi1>
%58 = tt.splat %arg12 : i32 -> tensor<32x1xi32>
%59 = arith.muli %1, %arg13 : i32
%60 = tt.splat %59 : i32 -> tensor<32x1xi32>
%61 = tt.expand_dims %14 {axis = 0 : i32} : tensor<32xi32> -> tensor<1x32xi32>
%62 = tt.broadcast %61 : tensor<1x32xi32> -> tensor<32x32xi32>
%63 = tt.splat %41 : i32 -> tensor<32x1xi32>
%64 = tt.expand_dims %16 {axis = 0 : i32} : tensor<32xi1> -> tensor<1x32xi1>
%65 = tt.broadcast %64 : tensor<1x32xi1> -> tensor<32x32xi1>
%66 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<32x32x!tt.ptr<f32>>
%67:3 = scf.for %arg17 = %39 to %41 step %c32_i32 iter_args(%arg18 = %cst_4, %arg19 = %cst_0, %arg20 = %cst) -> (tensor<16x32xf32>, tensor<16xf32>, tensor<16xf32>) : i32 {
%100 = tt.splat %arg17 : i32 -> tensor<32xi32>
%101 = arith.addi %100, %14 : tensor<32xi32>
%102 = arith.cmpi slt, %101, %43 : tensor<32xi32>
%103 = tt.addptr %46, %101 : tensor<32x!tt.ptr<i32>>, tensor<32xi32>
%104 = tt.load %103, %102, %cst_6 : tensor<32x!tt.ptr<i32>>
%105 = tt.expand_dims %104 {axis = 0 : i32} : tensor<32xi32> -> tensor<1x32xi32>
%106 = arith.muli %105, %47 : tensor<1x32xi32>
%107 = arith.addi %106, %49 : tensor<1x32xi32>
%108 = tt.broadcast %107 : tensor<1x32xi32> -> tensor<64x32xi32>
%109 = arith.addi %108, %51 : tensor<64x32xi32>
%110 = tt.expand_dims %101 {axis = 0 : i32} : tensor<32xi32> -> tensor<1x32xi32>
%111 = arith.cmpi slt, %110, %52 : tensor<1x32xi32>
%112 = tt.broadcast %111 : tensor<1x32xi1> -> tensor<64x32xi1>
%113 = arith.andi %112, %54 : tensor<64x32xi1>
%114 = tt.addptr %55, %109 : tensor<64x32x!tt.ptr<f32>>, tensor<64x32xi32>
%115 = tt.load %114, %113, %cst_5 : tensor<64x32x!tt.ptr<f32>>
%116 = tt.dot %36, %115, %cst_4 : tensor<16x64xf32> * tensor<64x32xf32> -> tensor<16x32xf32>
%117 = arith.mulf %116, %56 : tensor<16x32xf32>
%118 = tt.broadcast %111 : tensor<1x32xi1> -> tensor<16x32xi1>
%119 = arith.andi %57, %118 : tensor<16x32xi1>
%120 = arith.select %119, %117, %cst_3 : tensor<16x32xi1>, tensor<16x32xf32>
%121 = tt.expand_dims %104 {axis = 1 : i32} : tensor<32xi32> -> tensor<32x1xi32>
%122 = arith.muli %121, %58 : tensor<32x1xi32>
%123 = arith.addi %122, %60 : tensor<32x1xi32>
%124 = tt.broadcast %123 : tensor<32x1xi32> -> tensor<32x32xi32>
%125 = arith.addi %124, %62 : tensor<32x32xi32>
%126 = tt.expand_dims %101 {axis = 1 : i32} : tensor<32xi32> -> tensor<32x1xi32>
%127 = arith.cmpi slt, %126, %63 : tensor<32x1xi32>
%128 = tt.broadcast %127 : tensor<32x1xi1> -> tensor<32x32xi1>
%129 = arith.andi %128, %65 : tensor<32x32xi1>
%130 = tt.addptr %66, %125 : tensor<32x32x!tt.ptr<f32>>, tensor<32x32xi32>
%131 = tt.load %130, %129, %cst_2 : tensor<32x32x!tt.ptr<f32>>
%132 = "tt.reduce"(%120) <{axis = 1 : i32}> ({
^bb0(%arg21: f32, %arg22: f32):
%147 = arith.maxnumf %arg21, %arg22 : f32
tt.reduce.return %147 : f32
}) : (tensor<16x32xf32>) -> tensor<16xf32>
%133 = arith.maxnumf %132, %arg20 : tensor<16xf32>
%134 = arith.subf %arg20, %133 : tensor<16xf32>
%135 = math.exp %134 : tensor<16xf32>
%136 = tt.expand_dims %133 {axis = 1 : i32} : tensor<16xf32> -> tensor<16x1xf32>
%137 = tt.broadcast %136 : tensor<16x1xf32> -> tensor<16x32xf32>
%138 = arith.subf %120, %137 : tensor<16x32xf32>
%139 = math.exp %138 : tensor<16x32xf32>
%140 = tt.expand_dims %135 {axis = 1 : i32} : tensor<16xf32> -> tensor<16x1xf32>
%141 = tt.broadcast %140 : tensor<16x1xf32> -> tensor<16x32xf32>
%142 = arith.mulf %arg18, %141 : tensor<16x32xf32>
%143 = tt.dot %139, %131, %142 : tensor<16x32xf32> * tensor<32x32xf32> -> tensor<16x32xf32>
%144 = arith.mulf %arg19, %135 : tensor<16xf32>
%145 = "tt.reduce"(%139) <{axis = 1 : i32}> ({
^bb0(%arg21: f32, %arg22: f32):
%147 = arith.addf %arg21, %arg22 : f32
tt.reduce.return %147 : f32
}) : (tensor<16x32xf32>) -> tensor<16xf32>
%146 = arith.addf %144, %145 : tensor<16xf32>
scf.yield %143, %146, %133 : tensor<16x32xf32>, tensor<16xf32>, tensor<16xf32>
}
%68 = arith.muli %0, %arg14 : i32
%69 = tt.splat %arg15 : i32 -> tensor<16x1xi32>
%70 = arith.muli %20, %69 : tensor<16x1xi32>
%71 = tt.splat %68 : i32 -> tensor<16x1xi32>
%72 = arith.addi %71, %70 : tensor<16x1xi32>
%73 = arith.muli %2, %arg16 : i32
%74 = tt.splat %73 : i32 -> tensor<16x1xi32>
%75 = arith.addi %72, %74 : tensor<16x1xi32>
%76 = tt.expand_dims %14 {axis = 0 : i32} : tensor<32xi32> -> tensor<1x32xi32>
%77 = tt.broadcast %75 : tensor<16x1xi32> -> tensor<16x32xi32>
%78 = tt.broadcast %76 : tensor<1x32xi32> -> tensor<16x32xi32>
%79 = arith.addi %77, %78 : tensor<16x32xi32>
%80 = tt.expand_dims %16 {axis = 0 : i32} : tensor<32xi1> -> tensor<1x32xi1>
%81 = tt.broadcast %29 : tensor<16x1xi1> -> tensor<16x32xi1>
%82 = tt.broadcast %80 : tensor<1x32xi1> -> tensor<16x32xi1>
%83 = arith.andi %81, %82 : tensor<16x32xi1>
%84 = tt.splat %arg6 : !tt.ptr<f32> -> tensor<16x32x!tt.ptr<f32>>
%85 = tt.addptr %84, %79 : tensor<16x32x!tt.ptr<f32>>, tensor<16x32xi32>
%86 = tt.expand_dims %67#1 {axis = 1 : i32} : tensor<16xf32> -> tensor<16x1xf32>
%87 = tt.broadcast %86 : tensor<16x1xf32> -> tensor<16x32xf32>
%88 = arith.divf %67#0, %87 : tensor<16x32xf32>
tt.store %85, %88, %83 : tensor<16x32x!tt.ptr<f32>>
%89 = tt.splat %arg15 : i32 -> tensor<16xi32>
%90 = arith.muli %6, %89 : tensor<16xi32>
%91 = tt.splat %68 : i32 -> tensor<16xi32>
%92 = arith.addi %91, %90 : tensor<16xi32>
%93 = tt.splat %73 : i32 -> tensor<16xi32>
%94 = arith.addi %92, %93 : tensor<16xi32>
%95 = arith.addi %94, %cst_1 : tensor<16xi32>
%96 = tt.splat %arg6 : !tt.ptr<f32> -> tensor<16x!tt.ptr<f32>>
%97 = tt.addptr %96, %95 : tensor<16x!tt.ptr<f32>>, tensor<16xi32>
%98 = math.log %67#1 : tensor<16xf32>
%99 = arith.addf %67#2, %98 : tensor<16xf32>
tt.store %97, %99, %12 : tensor<16x!tt.ptr<f32>>
}
tt.return
}
}
%260 = "tt.load"(%259, %253, %25) <{boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1>}> {MetaUse} : (tensor<32x!tt.ptr<i32>>, tensor<32xi1>, tensor<32xi32>) -> tensor<32xi32>
encountered addptr operand produced by an unsupported operation
UNREACHABLE executed at /data03/yanxitan/triton_/develop/triton-shared/lib/Analysis/PtrAnalysis.cpp:723!
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Stack dump:
0. Program arguments: /data03/yanxitan/triton_/develop/triton/build/cmake.linux-x86_64-cpython-3.10/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt ./stage1_tt.mlir --triton-to-linalg --mlir-print-ir-before-all --mlir-print-ir-after-all
Additional information
%283 the addptr's righthand finally find a tt.load; And didn't support it.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working