Skip to content

[Bug]: tt.storeOp cannot be rewritten #311

@Dasor

Description

@Dasor

Triton python code

@triton.jit
    def where_kernel(cond_ptr, a_ptr, b_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr,
                     TEST_POINTERS: tl.constexpr, TEST_SCALAR_POINTERS: tl.constexpr):
        offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
        mask = offsets < n_elements
        decide = tl.load(cond_ptr + offsets, mask=mask)
        if TEST_SCALAR_POINTERS:
            ptr = tl.where(tl.load(cond_ptr), a_ptr, b_ptr)
            output = tl.load(ptr + offsets, mask=mask)
        else:
            if TEST_POINTERS:
                a = tl.load(a_ptr + offsets, mask=mask).to(tl.pi32_t)
                b = tl.load(b_ptr + offsets, mask=mask).to(tl.pi32_t)
            else:
                a = tl.load(a_ptr + offsets, mask=mask)
                b = tl.load(b_ptr + offsets, mask=mask)
            output = tl.where(decide, a, b)
        tl.store(output_ptr + offsets, output, mask=mask)

Triton IR

#loc = loc(unknown)
module {
  tt.func public @where_kernel(%arg0: !tt.ptr<i1> {tt.divisibility = 16 : i32} loc(unknown), %arg1: !tt.ptr<i8> {tt.divisibility = 16 : i32} loc(unknown), %arg2: !tt.ptr<i8> {tt.divisib
ility = 16 : i32} loc(unknown), %arg3: !tt.ptr<i8> {tt.divisibility = 16 : i32} loc(unknown), %arg4: i32 loc(unknown)) attributes {noinline = false} {
    %cst = arith.constant dense<0> : tensor<1024xi8> loc(#loc)
    %c1024_i32 = arith.constant 1024 : i32 loc(#loc)
    %0 = tt.get_program_id x : i32 loc(#loc)
    %1 = arith.muli %0, %c1024_i32 : i32 loc(#loc)
    %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> loc(#loc)
    %3 = tt.splat %1 : i32 -> tensor<1024xi32> loc(#loc)
    %4 = arith.addi %3, %2 : tensor<1024xi32> loc(#loc)
    %5 = tt.splat %arg4 : i32 -> tensor<1024xi32> loc(#loc)
    %6 = arith.cmpi slt, %4, %5 : tensor<1024xi32> loc(#loc)
    %7 = tt.splat %arg0 : !tt.ptr<i1> -> tensor<1024x!tt.ptr<i1>> loc(#loc)
    %8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr<i1>>, tensor<1024xi32> loc(#loc)
    %9 = tt.bitcast %8 : tensor<1024x!tt.ptr<i1>> -> tensor<1024x!tt.ptr<i8>> loc(#loc)
    %10 = tt.load %9, %6 : tensor<1024x!tt.ptr<i8>> loc(#loc)
    %11 = tt.splat %arg1 : !tt.ptr<i8> -> tensor<1024x!tt.ptr<i8>> loc(#loc)
    %12 = tt.addptr %11, %4 : tensor<1024x!tt.ptr<i8>>, tensor<1024xi32> loc(#loc)
    %13 = tt.load %12, %6 : tensor<1024x!tt.ptr<i8>> loc(#loc)
    %14 = tt.splat %arg2 : !tt.ptr<i8> -> tensor<1024x!tt.ptr<i8>> loc(#loc)
    %15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr<i8>>, tensor<1024xi32> loc(#loc)
    %16 = tt.load %15, %6 : tensor<1024x!tt.ptr<i8>> loc(#loc)
    %17 = arith.cmpi ne, %10, %cst : tensor<1024xi8> loc(#loc)
    %18 = arith.select %17, %13, %16 : tensor<1024xi1>, tensor<1024xi8> loc(#loc)
    %19 = tt.splat %arg3 : !tt.ptr<i8> -> tensor<1024x!tt.ptr<i8>> loc(#loc)
    %20 = tt.addptr %19, %4 : tensor<1024x!tt.ptr<i8>>, tensor<1024xi32> loc(#loc)
    tt.store %20, %18, %6 : tensor<1024x!tt.ptr<i8>> loc(#loc)
    tt.return loc(#loc)
  } loc(#loc)
} loc(#loc)

Crash log

<unknown>:0: remark: PtrAnalysis: pointer is not replace with tts.make_tptr so loadOp cannot be rewritten
<unknown>:0: note: see current operation: %12 = tt.load %11, %7 : tensor<1024x!tt.ptr<i8>>
<unknown>:0: remark: PtrAnalysis: Failed to rewrite LoadOp
<unknown>:0: note: see current operation: %12 = tt.load %11, %7 : tensor<1024x!tt.ptr<i8>>
<unknown>:0: error: unexpected op in ptr sequence
<unknown>:0: note: see current operation: %19 = "tt.bitcast"(%18) : (tensor<1024x!tt.ptr<i1>>) -> tensor<1024x!tt.ptr<i8>>

Additional information

This seems to happen on a lot of different kernels every time there is a tensor with tt.ptrs, any plans or ideas on how to fix this @nhat-nguyen ?

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions