Skip to content

[Bug]: The AddPtr right hand may need to support the TT_IntTensor input #298

@Realtyxxx

Description

@Realtyxxx

Triton python code

@triton.jit
def _fwd_grouped_kernel_stage1(
    Q,
    K_Buffer,
    V_Buffer,
    sm_scale,
    Req_to_tokens,
    B_Seqlen,
    Att_Out,
    stride_req_to_tokens_b,
    stride_qbs,
    stride_qh,
    stride_buf_kbs,
    stride_buf_kh,
    stride_buf_vbs,
    stride_buf_vh,
    stride_mid_ob,
    stride_mid_oh,
    stride_mid_os,
    kv_group_num: tl.constexpr,
    q_head_num: tl.constexpr,
    BLOCK_DMODEL: tl.constexpr,
    BLOCK_DPE: tl.constexpr,
    BLOCK_DV: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_H: tl.constexpr,
    NUM_KV_SPLITS: tl.constexpr,
    PAGE_SIZE: tl.constexpr,
    logit_cap: tl.constexpr,
    Lk: tl.constexpr,
    Lv: tl.constexpr,
):
    cur_batch = tl.program_id(0)
    cur_head_id = tl.program_id(1)
    cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H)
    split_kv_id = tl.program_id(2)

    if kv_group_num > BLOCK_H:
        VALID_BLOCK_H: tl.constexpr = BLOCK_H
    else:
        VALID_BLOCK_H: tl.constexpr = kv_group_num
    cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)
    mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H
    mask_h = mask_h & (cur_head < q_head_num)

    offs_d = tl.arange(0, BLOCK_DMODEL)
    offs_dv = tl.arange(0, BLOCK_DV)
    mask_d = offs_d < Lk
    mask_dv = offs_dv < Lv
    cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
    cur_batch_req_idx = cur_batch

    offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[
        None, :]
    q = tl.load(Q + offs_q,
                mask=(mask_h[:, None]) & (mask_d[None, :]),
                other=0.0)

    if BLOCK_DPE > 0:
        offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
        mask_dpe = offs_dpe < Lk
        off_qpe = (cur_batch * stride_qbs + cur_head[:, None] * stride_qh +
                   offs_dpe[None, :])
        qpe = tl.load(Q + off_qpe,
                      mask=(mask_h[:, None]) & (mask_dpe[None, :]),
                      other=0.0)

    kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
    split_kv_start = kv_len_per_split * split_kv_id
    split_kv_end = tl.minimum(split_kv_start + kv_len_per_split,
                              cur_batch_seq_len)

    e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf")
    e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
    acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)

    if split_kv_end > split_kv_start:
        for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
            offs_n = start_n + tl.arange(0, BLOCK_N)
            kv_page_number = tl.load(
                Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx +
                offs_n // PAGE_SIZE,
                mask=offs_n < split_kv_end,
                other=0,
            )
            kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
            offs_buf_k = (kv_loc[None, :] * stride_buf_kbs +
                          cur_kv_head * stride_buf_kh + offs_d[:, None])
            k = tl.load(
                K_Buffer + offs_buf_k,
                mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]),
                other=0.0,
            )
            qk = tl.dot(q, k.to(q.dtype))
            if BLOCK_DPE > 0:
                offs_buf_kpe = (kv_loc[None, :] * stride_buf_kbs +
                                cur_kv_head * stride_buf_kh +
                                offs_dpe[:, None])
                kpe = tl.load(
                    K_Buffer + offs_buf_kpe,
                    mask=(offs_n[None, :] < split_kv_end) &
                    (mask_dpe[:, None]),
                    other=0.0,
                )
                qk += tl.dot(qpe, kpe.to(qpe.dtype))
            qk *= sm_scale

            if logit_cap > 0:
                qk = logit_cap * tanh(qk / logit_cap)

            qk = tl.where(mask_h[:, None] & (offs_n[None, :] < split_kv_end),
                          qk, float("-inf"))

            offs_buf_v = (kv_loc[:, None] * stride_buf_vbs +
                          cur_kv_head * stride_buf_vh + offs_dv[None, :])
            v = tl.load(
                V_Buffer + offs_buf_v,
                mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
                other=0.0,
            )

            n_e_max = tl.maximum(tl.max(qk, 1), e_max)
            re_scale = tl.exp(e_max - n_e_max)
            p = tl.exp(qk - n_e_max[:, None])
            acc *= re_scale[:, None]
            acc += tl.dot(p.to(v.dtype), v)

            e_sum = e_sum * re_scale + tl.sum(p, 1)
            e_max = n_e_max

        offs_mid_o = (cur_batch * stride_mid_ob +
                      cur_head[:, None] * stride_mid_oh +
                      split_kv_id * stride_mid_os + offs_dv[None, :])

        tl.store(
            Att_Out + offs_mid_o,
            acc / e_sum[:, None],
            mask=(mask_h[:, None]) & (mask_dv[None, :]),
        )

        offs_mid_o_1 = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh +
                        split_kv_id * stride_mid_os + Lv)

        tl.store(
            Att_Out + offs_mid_o_1,
            e_max + tl.log(e_sum),
            mask=mask_h,
        )

Triton IR

Crash log

// -----// IR Dump Before TritonToLinalg (triton-to-linalg) //----- //
module {
  tt.func public @_fwd_grouped_kernel_stage1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32 {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: i32, %arg16: i32) attributes {noinline = false} {
    %cst = arith.constant dense<0xFF800000> : tensor<16xf32>
    %cst_0 = arith.constant dense<0.000000e+00> : tensor<16xf32>
    %c3_i32 = arith.constant 3 : i32
    %c4_i32 = arith.constant 4 : i32
    %c32_i32 = arith.constant 32 : i32
    %cst_1 = arith.constant dense<32> : tensor<16xi32>
    %cst_2 = arith.constant dense<0.000000e+00> : tensor<32x32xf32>
    %cst_3 = arith.constant dense<0xFF800000> : tensor<16x32xf32>
    %cst_4 = arith.constant dense<0.000000e+00> : tensor<16x32xf32>
    %cst_5 = arith.constant dense<0.000000e+00> : tensor<64x32xf32>
    %cst_6 = arith.constant dense<0> : tensor<32xi32>
    %cst_7 = arith.constant dense<0.000000e+00> : tensor<16x64xf32>
    %cst_8 = arith.constant dense<32> : tensor<32xi32>
    %cst_9 = arith.constant dense<64> : tensor<64xi32>
    %cst_10 = arith.constant dense<8> : tensor<16xi32>
    %c1_i32 = arith.constant 1 : i32
    %c8_i32 = arith.constant 8 : i32
    %0 = tt.get_program_id x : i32
    %1 = tt.get_program_id y : i32
    %2 = tt.get_program_id z : i32
    %3 = arith.muli %1, %c8_i32 : i32
    %4 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32>
    %5 = tt.splat %3 : i32 -> tensor<16xi32>
    %6 = arith.addi %5, %4 : tensor<16xi32>
    %7 = arith.addi %1, %c1_i32 : i32
    %8 = arith.muli %7, %c8_i32 : i32
    %9 = tt.splat %8 : i32 -> tensor<16xi32>
    %10 = arith.cmpi slt, %6, %9 : tensor<16xi32>
    %11 = arith.cmpi slt, %6, %cst_10 : tensor<16xi32>
    %12 = arith.andi %10, %11 : tensor<16xi1>
    %13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
    %14 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
    %15 = arith.cmpi slt, %13, %cst_9 : tensor<64xi32>
    %16 = arith.cmpi slt, %14, %cst_8 : tensor<32xi32>
    %17 = tt.addptr %arg5, %0 : !tt.ptr<i32>, i32
    %18 = tt.load %17 : !tt.ptr<i32>
    %19 = arith.muli %0, %arg8 : i32
    %20 = tt.expand_dims %6 {axis = 1 : i32} : tensor<16xi32> -> tensor<16x1xi32>
    %21 = tt.splat %arg9 : i32 -> tensor<16x1xi32>
    %22 = arith.muli %20, %21 : tensor<16x1xi32>
    %23 = tt.splat %19 : i32 -> tensor<16x1xi32>
    %24 = arith.addi %23, %22 : tensor<16x1xi32>
    %25 = tt.expand_dims %13 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32>
    %26 = tt.broadcast %24 : tensor<16x1xi32> -> tensor<16x64xi32>
    %27 = tt.broadcast %25 : tensor<1x64xi32> -> tensor<16x64xi32>
    %28 = arith.addi %26, %27 : tensor<16x64xi32>
    %29 = tt.expand_dims %12 {axis = 1 : i32} : tensor<16xi1> -> tensor<16x1xi1>
    %30 = tt.expand_dims %15 {axis = 0 : i32} : tensor<64xi1> -> tensor<1x64xi1>
    %31 = tt.broadcast %29 : tensor<16x1xi1> -> tensor<16x64xi1>
    %32 = tt.broadcast %30 : tensor<1x64xi1> -> tensor<16x64xi1>
    %33 = arith.andi %31, %32 : tensor<16x64xi1>
    %34 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<16x64x!tt.ptr<f32>>
    %35 = tt.addptr %34, %28 : tensor<16x64x!tt.ptr<f32>>, tensor<16x64xi32>
    %36 = tt.load %35, %33, %cst_7 : tensor<16x64x!tt.ptr<f32>>
    %37 = arith.addi %18, %c3_i32 : i32
    %38 = arith.divsi %37, %c4_i32 : i32
    %39 = arith.muli %38, %2 : i32
    %40 = arith.addi %39, %38 : i32
    %41 = arith.minsi %40, %18 : i32
    %42 = arith.cmpi sgt, %41, %39 : i32
    scf.if %42 {
      %43 = tt.splat %41 : i32 -> tensor<32xi32>
      %44 = arith.muli %arg7, %0 : i32
      %45 = tt.addptr %arg4, %44 : !tt.ptr<i32>, i32
      %46 = tt.splat %45 : !tt.ptr<i32> -> tensor<32x!tt.ptr<i32>>
      %47 = tt.splat %arg10 : i32 -> tensor<1x32xi32>
      %48 = arith.muli %1, %arg11 : i32
      %49 = tt.splat %48 : i32 -> tensor<1x32xi32>
      %50 = tt.expand_dims %13 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32>
      %51 = tt.broadcast %50 : tensor<64x1xi32> -> tensor<64x32xi32>
      %52 = tt.splat %41 : i32 -> tensor<1x32xi32>
      %53 = tt.expand_dims %15 {axis = 1 : i32} : tensor<64xi1> -> tensor<64x1xi1>
      %54 = tt.broadcast %53 : tensor<64x1xi1> -> tensor<64x32xi1>
      %55 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<64x32x!tt.ptr<f32>>
      %56 = tt.splat %arg3 : f32 -> tensor<16x32xf32>
      %57 = tt.broadcast %29 : tensor<16x1xi1> -> tensor<16x32xi1>
      %58 = tt.splat %arg12 : i32 -> tensor<32x1xi32>
      %59 = arith.muli %1, %arg13 : i32
      %60 = tt.splat %59 : i32 -> tensor<32x1xi32>
      %61 = tt.expand_dims %14 {axis = 0 : i32} : tensor<32xi32> -> tensor<1x32xi32>
      %62 = tt.broadcast %61 : tensor<1x32xi32> -> tensor<32x32xi32>
      %63 = tt.splat %41 : i32 -> tensor<32x1xi32>
      %64 = tt.expand_dims %16 {axis = 0 : i32} : tensor<32xi1> -> tensor<1x32xi1>
      %65 = tt.broadcast %64 : tensor<1x32xi1> -> tensor<32x32xi1>
      %66 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<32x32x!tt.ptr<f32>>
      %67:3 = scf.for %arg17 = %39 to %41 step %c32_i32 iter_args(%arg18 = %cst_4, %arg19 = %cst_0, %arg20 = %cst) -> (tensor<16x32xf32>, tensor<16xf32>, tensor<16xf32>)  : i32 {
        %100 = tt.splat %arg17 : i32 -> tensor<32xi32>
        %101 = arith.addi %100, %14 : tensor<32xi32>
        %102 = arith.cmpi slt, %101, %43 : tensor<32xi32>
        %103 = tt.addptr %46, %101 : tensor<32x!tt.ptr<i32>>, tensor<32xi32>
        %104 = tt.load %103, %102, %cst_6 : tensor<32x!tt.ptr<i32>>
        %105 = tt.expand_dims %104 {axis = 0 : i32} : tensor<32xi32> -> tensor<1x32xi32>
        %106 = arith.muli %105, %47 : tensor<1x32xi32>
        %107 = arith.addi %106, %49 : tensor<1x32xi32>
        %108 = tt.broadcast %107 : tensor<1x32xi32> -> tensor<64x32xi32>
        %109 = arith.addi %108, %51 : tensor<64x32xi32>
        %110 = tt.expand_dims %101 {axis = 0 : i32} : tensor<32xi32> -> tensor<1x32xi32>
        %111 = arith.cmpi slt, %110, %52 : tensor<1x32xi32>
        %112 = tt.broadcast %111 : tensor<1x32xi1> -> tensor<64x32xi1>
        %113 = arith.andi %112, %54 : tensor<64x32xi1>
        %114 = tt.addptr %55, %109 : tensor<64x32x!tt.ptr<f32>>, tensor<64x32xi32>
        %115 = tt.load %114, %113, %cst_5 : tensor<64x32x!tt.ptr<f32>>
        %116 = tt.dot %36, %115, %cst_4 : tensor<16x64xf32> * tensor<64x32xf32> -> tensor<16x32xf32>
        %117 = arith.mulf %116, %56 : tensor<16x32xf32>
        %118 = tt.broadcast %111 : tensor<1x32xi1> -> tensor<16x32xi1>
        %119 = arith.andi %57, %118 : tensor<16x32xi1>
        %120 = arith.select %119, %117, %cst_3 : tensor<16x32xi1>, tensor<16x32xf32>
        %121 = tt.expand_dims %104 {axis = 1 : i32} : tensor<32xi32> -> tensor<32x1xi32>
        %122 = arith.muli %121, %58 : tensor<32x1xi32>
        %123 = arith.addi %122, %60 : tensor<32x1xi32>
        %124 = tt.broadcast %123 : tensor<32x1xi32> -> tensor<32x32xi32>
        %125 = arith.addi %124, %62 : tensor<32x32xi32>
        %126 = tt.expand_dims %101 {axis = 1 : i32} : tensor<32xi32> -> tensor<32x1xi32>
        %127 = arith.cmpi slt, %126, %63 : tensor<32x1xi32>
        %128 = tt.broadcast %127 : tensor<32x1xi1> -> tensor<32x32xi1>
        %129 = arith.andi %128, %65 : tensor<32x32xi1>
        %130 = tt.addptr %66, %125 : tensor<32x32x!tt.ptr<f32>>, tensor<32x32xi32>
        %131 = tt.load %130, %129, %cst_2 : tensor<32x32x!tt.ptr<f32>>
        %132 = "tt.reduce"(%120) <{axis = 1 : i32}> ({
        ^bb0(%arg21: f32, %arg22: f32):
          %147 = arith.maxnumf %arg21, %arg22 : f32
          tt.reduce.return %147 : f32
        }) : (tensor<16x32xf32>) -> tensor<16xf32>
        %133 = arith.maxnumf %132, %arg20 : tensor<16xf32>
        %134 = arith.subf %arg20, %133 : tensor<16xf32>
        %135 = math.exp %134 : tensor<16xf32>
        %136 = tt.expand_dims %133 {axis = 1 : i32} : tensor<16xf32> -> tensor<16x1xf32>
        %137 = tt.broadcast %136 : tensor<16x1xf32> -> tensor<16x32xf32>
        %138 = arith.subf %120, %137 : tensor<16x32xf32>
        %139 = math.exp %138 : tensor<16x32xf32>
        %140 = tt.expand_dims %135 {axis = 1 : i32} : tensor<16xf32> -> tensor<16x1xf32>
        %141 = tt.broadcast %140 : tensor<16x1xf32> -> tensor<16x32xf32>
        %142 = arith.mulf %arg18, %141 : tensor<16x32xf32>
        %143 = tt.dot %139, %131, %142 : tensor<16x32xf32> * tensor<32x32xf32> -> tensor<16x32xf32>
        %144 = arith.mulf %arg19, %135 : tensor<16xf32>
        %145 = "tt.reduce"(%139) <{axis = 1 : i32}> ({
        ^bb0(%arg21: f32, %arg22: f32):
          %147 = arith.addf %arg21, %arg22 : f32
          tt.reduce.return %147 : f32
        }) : (tensor<16x32xf32>) -> tensor<16xf32>
        %146 = arith.addf %144, %145 : tensor<16xf32>
        scf.yield %143, %146, %133 : tensor<16x32xf32>, tensor<16xf32>, tensor<16xf32>
      }
      %68 = arith.muli %0, %arg14 : i32
      %69 = tt.splat %arg15 : i32 -> tensor<16x1xi32>
      %70 = arith.muli %20, %69 : tensor<16x1xi32>
      %71 = tt.splat %68 : i32 -> tensor<16x1xi32>
      %72 = arith.addi %71, %70 : tensor<16x1xi32>
      %73 = arith.muli %2, %arg16 : i32
      %74 = tt.splat %73 : i32 -> tensor<16x1xi32>
      %75 = arith.addi %72, %74 : tensor<16x1xi32>
      %76 = tt.expand_dims %14 {axis = 0 : i32} : tensor<32xi32> -> tensor<1x32xi32>
      %77 = tt.broadcast %75 : tensor<16x1xi32> -> tensor<16x32xi32>
      %78 = tt.broadcast %76 : tensor<1x32xi32> -> tensor<16x32xi32>
      %79 = arith.addi %77, %78 : tensor<16x32xi32>
      %80 = tt.expand_dims %16 {axis = 0 : i32} : tensor<32xi1> -> tensor<1x32xi1>
      %81 = tt.broadcast %29 : tensor<16x1xi1> -> tensor<16x32xi1>
      %82 = tt.broadcast %80 : tensor<1x32xi1> -> tensor<16x32xi1>
      %83 = arith.andi %81, %82 : tensor<16x32xi1>
      %84 = tt.splat %arg6 : !tt.ptr<f32> -> tensor<16x32x!tt.ptr<f32>>
      %85 = tt.addptr %84, %79 : tensor<16x32x!tt.ptr<f32>>, tensor<16x32xi32>
      %86 = tt.expand_dims %67#1 {axis = 1 : i32} : tensor<16xf32> -> tensor<16x1xf32>
      %87 = tt.broadcast %86 : tensor<16x1xf32> -> tensor<16x32xf32>
      %88 = arith.divf %67#0, %87 : tensor<16x32xf32>
      tt.store %85, %88, %83 : tensor<16x32x!tt.ptr<f32>>
      %89 = tt.splat %arg15 : i32 -> tensor<16xi32>
      %90 = arith.muli %6, %89 : tensor<16xi32>
      %91 = tt.splat %68 : i32 -> tensor<16xi32>
      %92 = arith.addi %91, %90 : tensor<16xi32>
      %93 = tt.splat %73 : i32 -> tensor<16xi32>
      %94 = arith.addi %92, %93 : tensor<16xi32>
      %95 = arith.addi %94, %cst_1 : tensor<16xi32>
      %96 = tt.splat %arg6 : !tt.ptr<f32> -> tensor<16x!tt.ptr<f32>>
      %97 = tt.addptr %96, %95 : tensor<16x!tt.ptr<f32>>, tensor<16xi32>
      %98 = math.log %67#1 : tensor<16xf32>
      %99 = arith.addf %67#2, %98 : tensor<16xf32>
      tt.store %97, %99, %12 : tensor<16x!tt.ptr<f32>>
    }
    tt.return
  }
}


%260 = "tt.load"(%259, %253, %25) <{boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1>}> {MetaUse} : (tensor<32x!tt.ptr<i32>>, tensor<32xi1>, tensor<32xi32>) -> tensor<32xi32>
encountered addptr operand produced by an unsupported operation
UNREACHABLE executed at /data03/yanxitan/triton_/develop/triton-shared/lib/Analysis/PtrAnalysis.cpp:723!
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Stack dump:
0.      Program arguments: /data03/yanxitan/triton_/develop/triton/build/cmake.linux-x86_64-cpython-3.10/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt ./stage1_tt.mlir --triton-to-linalg --mlir-print-ir-before-all --mlir-print-ir-after-all

Additional information

%283 the addptr's righthand finally find a tt.load; And didn't support it.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions