-
Notifications
You must be signed in to change notification settings - Fork 73
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Triton python code
@triton.jit
def where_kernel(cond_ptr, a_ptr, b_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr,
TEST_POINTERS: tl.constexpr, TEST_SCALAR_POINTERS: tl.constexpr):
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
decide = tl.load(cond_ptr + offsets, mask=mask)
if TEST_SCALAR_POINTERS:
ptr = tl.where(tl.load(cond_ptr), a_ptr, b_ptr)
output = tl.load(ptr + offsets, mask=mask)
else:
if TEST_POINTERS:
a = tl.load(a_ptr + offsets, mask=mask).to(tl.pi32_t)
b = tl.load(b_ptr + offsets, mask=mask).to(tl.pi32_t)
else:
a = tl.load(a_ptr + offsets, mask=mask)
b = tl.load(b_ptr + offsets, mask=mask)
output = tl.where(decide, a, b)
tl.store(output_ptr + offsets, output, mask=mask)
Triton IR
#loc = loc(unknown)
module {
tt.func public @where_kernel(%arg0: !tt.ptr<i1> {tt.divisibility = 16 : i32} loc(unknown), %arg1: !tt.ptr<i8> {tt.divisibility = 16 : i32} loc(unknown), %arg2: !tt.ptr<i8> {tt.divisib
ility = 16 : i32} loc(unknown), %arg3: !tt.ptr<i8> {tt.divisibility = 16 : i32} loc(unknown), %arg4: i32 loc(unknown)) attributes {noinline = false} {
%cst = arith.constant dense<0> : tensor<1024xi8> loc(#loc)
%c1024_i32 = arith.constant 1024 : i32 loc(#loc)
%0 = tt.get_program_id x : i32 loc(#loc)
%1 = arith.muli %0, %c1024_i32 : i32 loc(#loc)
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> loc(#loc)
%3 = tt.splat %1 : i32 -> tensor<1024xi32> loc(#loc)
%4 = arith.addi %3, %2 : tensor<1024xi32> loc(#loc)
%5 = tt.splat %arg4 : i32 -> tensor<1024xi32> loc(#loc)
%6 = arith.cmpi slt, %4, %5 : tensor<1024xi32> loc(#loc)
%7 = tt.splat %arg0 : !tt.ptr<i1> -> tensor<1024x!tt.ptr<i1>> loc(#loc)
%8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr<i1>>, tensor<1024xi32> loc(#loc)
%9 = tt.bitcast %8 : tensor<1024x!tt.ptr<i1>> -> tensor<1024x!tt.ptr<i8>> loc(#loc)
%10 = tt.load %9, %6 : tensor<1024x!tt.ptr<i8>> loc(#loc)
%11 = tt.splat %arg1 : !tt.ptr<i8> -> tensor<1024x!tt.ptr<i8>> loc(#loc)
%12 = tt.addptr %11, %4 : tensor<1024x!tt.ptr<i8>>, tensor<1024xi32> loc(#loc)
%13 = tt.load %12, %6 : tensor<1024x!tt.ptr<i8>> loc(#loc)
%14 = tt.splat %arg2 : !tt.ptr<i8> -> tensor<1024x!tt.ptr<i8>> loc(#loc)
%15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr<i8>>, tensor<1024xi32> loc(#loc)
%16 = tt.load %15, %6 : tensor<1024x!tt.ptr<i8>> loc(#loc)
%17 = arith.cmpi ne, %10, %cst : tensor<1024xi8> loc(#loc)
%18 = arith.select %17, %13, %16 : tensor<1024xi1>, tensor<1024xi8> loc(#loc)
%19 = tt.splat %arg3 : !tt.ptr<i8> -> tensor<1024x!tt.ptr<i8>> loc(#loc)
%20 = tt.addptr %19, %4 : tensor<1024x!tt.ptr<i8>>, tensor<1024xi32> loc(#loc)
tt.store %20, %18, %6 : tensor<1024x!tt.ptr<i8>> loc(#loc)
tt.return loc(#loc)
} loc(#loc)
} loc(#loc)
Crash log
<unknown>:0: remark: PtrAnalysis: pointer is not replace with tts.make_tptr so loadOp cannot be rewritten
<unknown>:0: note: see current operation: %12 = tt.load %11, %7 : tensor<1024x!tt.ptr<i8>>
<unknown>:0: remark: PtrAnalysis: Failed to rewrite LoadOp
<unknown>:0: note: see current operation: %12 = tt.load %11, %7 : tensor<1024x!tt.ptr<i8>>
<unknown>:0: error: unexpected op in ptr sequence
<unknown>:0: note: see current operation: %19 = "tt.bitcast"(%18) : (tensor<1024x!tt.ptr<i1>>) -> tensor<1024x!tt.ptr<i8>>
Additional information
This seems to happen on a lot of different kernels every time there is a tensor with tt.ptrs, any plans or ideas on how to fix this @nhat-nguyen ?
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working