Skip to content

Unit test failure #372

@torshie

Description

@torshie

Test case test_hstu_attention_tma.HSTUAttentionTmaTest.test_attn_triton_tma seems always fail on my environment:

fbgemm_gpu 1.3.0+cu128
torch 2.8.0+cu128
torchrec 1.3.0+cu128
triton 3.4.0
Python 3.12.3
OS: Ubuntu 24.04

python3: /root/.triton/llvm/llvm-8957e64a-almalinux-x64/include/llvm/ADT/SmallVector.h:296: const_reference llvm::SmallVectorTemplateCommon<long>::operator[](size_type) const [T = long]: Assertion `idx < size()' failed.
module {
  tt.func public @_hstu_attn_fwd(%arg0: !tt.ptr<bf16>, %arg1: i64, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i32, %arg6: i32, %arg7: i64, %arg8: i64, %arg9: !tt.ptr<bf16>, %arg10: i64, %arg11: i64, %arg12: i64, %arg13: i64, %arg14: i32, %arg15: i32 {tt.divisibility = 16 : i32}, %arg16: i64 {tt.divisibility = 16 : i32}, %arg17: i64 {tt.divisibility = 16 : i32}, %arg18: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg19: i64 {tt.divisibility = 16 : i32}, %arg20: i64 {tt.divisibility = 16 : i32}, %arg21: i64 {tt.divisibility = 16 : i32}, %arg22: i64 {tt.divisibility = 16 : i32}, %arg23: i32 {tt.divisibility = 16 : i32}, %arg24: i32 {tt.divisibility = 16 : i32}, %arg25: i64 {tt.divisibility = 16 : i32}, %arg26: i64, %arg27: !tt.ptr<i64>, %arg28: !tt.ptr<i64>, %arg29: !tt.ptr<bf16>, %arg30: i32 {tt.divisibility = 16 : i32}, %arg31: i32 {tt.divisibility = 16 : i32}, %arg32: i32 {tt.divisibility = 16 : i32}, %arg33: i32, %arg34: i32, %arg35: i32, %arg36: i32, %arg37: i32, %arg38: f32, %arg39: i32, %arg40: i32, %arg41: i32, %arg42: i32, %arg43: i32, %arg44: i32, %arg45: i32, %arg46: i32) attributes {noinline = false} {
    %cst = arith.constant dense<0> : tensor<32x1xi64>
    %cst_0 = arith.constant dense<0> : tensor<1x16xi64>
    %cst_1 = arith.constant dense<0> : tensor<16x1xi64>
    %cst_2 = arith.constant dense<0.000000e+00> : tensor<32x16xbf16>
    %cst_3 = arith.constant dense<0.000000e+00> : tensor<16x16xbf16>
    %c1_i32 = arith.constant 1 : i32
    %c16_i32 = arith.constant 16 : i32
    %c32_i32 = arith.constant 32 : i32
    %c31_i32 = arith.constant 31 : i32
    %cst_4 = arith.constant dense<0.000000e+00> : tensor<16x16xf32>
    %cst_5 = arith.constant dense<1.000000e+00> : tensor<16x32xf32>
    %cst_6 = arith.constant 1.000000e+00 : f32
    %cst_7 = arith.constant dense<0> : tensor<16x1xi32>
    %cst_8 = arith.constant dense<0> : tensor<16x32xi32>
    %cst_9 = arith.constant dense<0> : tensor<32xi32>
    %cst_10 = arith.constant dense<1> : tensor<32xi32>
    %cst_11 = arith.constant dense<0> : tensor<16xi32>
    %cst_12 = arith.constant dense<1> : tensor<16xi32>
    %cst_13 = arith.constant dense<0.000000e+00> : tensor<16x32xf32>
    %c0_i32 = arith.constant 0 : i32
    %0 = tt.get_program_id y : i32
    %1 = tt.get_program_id x : i32
    %2 = tt.addptr %arg27, %0 : !tt.ptr<i64>, i32
    %3 = tt.load %2 : !tt.ptr<i64>
    %4 = arith.extsi %0 : i32 to i64
    %5 = tt.addptr %arg27, %4 : !tt.ptr<i64>, i64
    %6 = tt.addptr %5, %c1_i32 : !tt.ptr<i64>, i32
    %7 = tt.load %6 : !tt.ptr<i64>
    %8 = arith.subi %7, %3 : i64
    %9 = arith.trunci %8 : i64 to i32
    %10 = arith.muli %1, %c16_i32 : i32
    %11 = arith.cmpi slt, %10, %9 : i32
    scf.if %11 {
      %12 = tt.addptr %arg28, %4 : !tt.ptr<i64>, i64
      %13 = tt.load %12 : !tt.ptr<i64>
      %14 = arith.trunci %13 : i64 to i32
      %15 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32>
      %16 = tt.splat %10 : i32 -> tensor<16xi32>
      %17 = arith.addi %16, %15 : tensor<16xi32>
      %18 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
      %19 = arith.extsi %10 : i32 to i64
      %20 = arith.addi %3, %19 : i64
      %21 = arith.trunci %20 : i64 to i32
      %22 = arith.extsi %21 : i32 to i64
      %23 = tt.splat %22 : i64 -> tensor<16xi64>
      %24 = arith.extsi %15 : tensor<16xi32> to tensor<16xi64>
      %25 = arith.addi %23, %24 : tensor<16xi64>
      %26 = tt.expand_dims %25 {axis = 1 : i32} : tensor<16xi64> -> tensor<16x1xi64>
      %27 = tt.expand_dims %24 {axis = 0 : i32} : tensor<16xi64> -> tensor<1x16xi64>
      %28 = tt.splat %arg0 : !tt.ptr<bf16> -> tensor<16x16x!tt.ptr<bf16>>
      %29 = tt.splat %arg3 : i64 -> tensor<16x1xi64>
      %30 = arith.muli %26, %29 : tensor<16x1xi64>
      %31 = tt.broadcast %30 : tensor<16x1xi64> -> tensor<16x16xi64>
      %32 = tt.splat %arg4 : i64 -> tensor<1x16xi64>
      %33 = arith.muli %27, %32 : tensor<1x16xi64>
      %34 = tt.broadcast %33 : tensor<1x16xi64> -> tensor<16x16xi64>
      %35 = arith.addi %31, %34 : tensor<16x16xi64>
      %36 = tt.addptr %28, %35 : tensor<16x16x!tt.ptr<bf16>>, tensor<16x16xi64>
      %37 = arith.cmpi sge, %26, %cst_1 : tensor<16x1xi64>
      %38 = tt.splat %arg1 : i64 -> tensor<16x1xi64>
      %39 = arith.cmpi slt, %26, %38 : tensor<16x1xi64>
      %40 = arith.andi %37, %39 : tensor<16x1xi1>
      %41 = tt.broadcast %40 : tensor<16x1xi1> -> tensor<16x16xi1>
      %42 = arith.cmpi sge, %27, %cst_0 : tensor<1x16xi64>
      %43 = tt.splat %arg2 : i64 -> tensor<1x16xi64>
      %44 = arith.cmpi slt, %27, %43 : tensor<1x16xi64>
      %45 = arith.andi %42, %44 : tensor<1x16xi1>
      %46 = tt.broadcast %45 : tensor<1x16xi1> -> tensor<16x16xi1>
      %47 = arith.andi %41, %46 : tensor<16x16xi1>
      %48 = tt.load %36, %47, %cst_3 : tensor<16x16x!tt.ptr<bf16>>
      %49 = arith.subi %9, %14 : i32
      %50 = arith.cmpi slt, %10, %arg45 : i32
      %51:3 = scf.if %50 -> (i32, i32, i32) {
        scf.yield %49, %9, %c0_i32 : i32, i32, i32
      } else {
        %72 = arith.addi %10, %c16_i32 : i32
        %73 = arith.cmpi sgt, %10, %49 : i32
        %74 = scf.if %73 -> (i32) {
          %82 = arith.subi %49, %arg46 : i32
          scf.yield %82 : i32
        } else {
          %82 = arith.subi %10, %arg46 : i32
          scf.yield %82 : i32
        }
        %75 = arith.cmpi sgt, %74, %arg45 : i32
        %76 = arith.select %75, %74, %c0_i32 : i32
        %77 = arith.addi %49, %c31_i32 : i32
        %78 = arith.divsi %77, %c32_i32 : i32
        %79 = arith.muli %78, %c32_i32 : i32
        %80 = arith.cmpi slt, %79, %10 : i32
        %81 = scf.if %80 -> (i32) {
          scf.yield %49 : i32
        } else {
          scf.yield %72 : i32
        }
        scf.yield %79, %81, %76 : i32, i32, i32
      }
      %52:2 = scf.for %arg47 = %51#2 to %51#1 step %c32_i32 iter_args(%arg48 = %cst_4, %arg49 = %51#2) -> (tensor<16x16xf32>, i32)  : i32 {
        %72 = tt.splat %arg47 : i32 -> tensor<32xi32>
        %73 = arith.addi %18, %72 : tensor<32xi32>
        %74 = arith.extsi %arg47 : i32 to i64
        %75 = arith.addi %3, %74 : i64
        %76 = arith.trunci %75 : i64 to i32
        %77 = arith.extsi %76 : i32 to i64
        %78 = tt.splat %77 : i64 -> tensor<32xi64>
        %79 = arith.extsi %18 : tensor<32xi32> to tensor<32xi64>
        %80 = arith.addi %78, %79 : tensor<32xi64>
        %81 = tt.expand_dims %80 {axis = 1 : i32} : tensor<32xi64> -> tensor<32x1xi64>
        %82 = tt.splat %arg9 : !tt.ptr<bf16> -> tensor<32x16x!tt.ptr<bf16>>
        %83 = tt.splat %arg12 : i64 -> tensor<32x1xi64>
        %84 = arith.muli %81, %83 : tensor<32x1xi64>
        %85 = tt.broadcast %84 : tensor<32x1xi64> -> tensor<32x16xi64>
        %86 = tt.splat %arg13 : i64 -> tensor<1x16xi64>
        %87 = arith.muli %27, %86 : tensor<1x16xi64>
        %88 = tt.broadcast %87 : tensor<1x16xi64> -> tensor<32x16xi64>
        %89 = arith.addi %85, %88 : tensor<32x16xi64>
        %90 = tt.addptr %82, %89 : tensor<32x16x!tt.ptr<bf16>>, tensor<32x16xi64>
        %91 = arith.cmpi sge, %81, %cst : tensor<32x1xi64>
        %92 = tt.splat %arg10 : i64 -> tensor<32x1xi64>
        %93 = arith.cmpi slt, %81, %92 : tensor<32x1xi64>
        %94 = arith.andi %91, %93 : tensor<32x1xi1>
        %95 = tt.broadcast %94 : tensor<32x1xi1> -> tensor<32x16xi1>
        %96 = tt.splat %arg11 : i64 -> tensor<1x16xi64>
        %97 = arith.cmpi slt, %27, %96 : tensor<1x16xi64>
        %98 = arith.andi %42, %97 : tensor<1x16xi1>
        %99 = tt.broadcast %98 : tensor<1x16xi1> -> tensor<32x16xi1>
        %100 = arith.andi %95, %99 : tensor<32x16xi1>
        %101 = tt.load %90, %100, %cst_2 : tensor<32x16x!tt.ptr<bf16>>
        %102 = tt.trans %101 {order = array<i32: 1, 0>} : tensor<32x16xbf16> -> tensor<16x32xbf16>
        %103 = tt.dot %48, %102, %cst_13, inputPrecision = tf32 : tensor<16x16xbf16> * tensor<16x32xbf16> -> tensor<16x32xf32>
        %104 = tt.splat %arg38 : f32 -> tensor<16x32xf32>
        %105 = arith.mulf %103, %104 : tensor<16x32xf32>
        %106 = tt.expand_dims %17 {axis = 1 : i32} : tensor<16xi32> -> tensor<16x1xi32>
        %107 = tt.expand_dims %73 {axis = 0 : i32} : tensor<32xi32> -> tensor<1x32xi32>
        %108 = tt.broadcast %106 : tensor<16x1xi32> -> tensor<16x32xi32>
        %109 = tt.broadcast %107 : tensor<1x32xi32> -> tensor<16x32xi32>
        %110 = arith.cmpi eq, %108, %109 : tensor<16x32xi32>
        %111 = tt.splat %arg45 : i32 -> tensor<16xi32>
        %112 = arith.subi %17, %111 : tensor<16xi32>
        %113 = arith.addi %112, %cst_12 : tensor<16xi32>
        %114 = arith.cmpi sgt, %113, %cst_11 : tensor<16xi32>
        %115 = arith.select %114, %113, %cst_11 : tensor<16xi1>, tensor<16xi32>
        %116 = tt.splat %arg45 : i32 -> tensor<32xi32>
        %117 = arith.subi %73, %116 : tensor<32xi32>
        %118 = arith.addi %117, %cst_10 : tensor<32xi32>
        %119 = arith.cmpi sgt, %118, %cst_9 : tensor<32xi32>
        %120 = arith.select %119, %118, %cst_9 : tensor<32xi1>, tensor<32xi32>
        %121 = arith.subi %9, %arg45 : i32
        %122 = arith.addi %121, %c1_i32 : i32
        %123 = arith.subi %122, %14 : i32
        %124 = tt.splat %123 : i32 -> tensor<16xi32>
        %125 = arith.cmpi slt, %115, %124 : tensor<16xi32>
        %126 = arith.select %125, %115, %124 : tensor<16xi1>, tensor<16xi32>
        %127 = tt.splat %123 : i32 -> tensor<32xi32>
        %128 = arith.cmpi slt, %120, %127 : tensor<32xi32>
        %129 = arith.select %128, %120, %127 : tensor<32xi1>, tensor<32xi32>
        %130 = tt.expand_dims %126 {axis = 1 : i32} : tensor<16xi32> -> tensor<16x1xi32>
        %131 = tt.expand_dims %129 {axis = 0 : i32} : tensor<32xi32> -> tensor<1x32xi32>
        %132 = tt.broadcast %130 : tensor<16x1xi32> -> tensor<16x32xi32>
        %133 = tt.broadcast %131 : tensor<1x32xi32> -> tensor<16x32xi32>
        %134 = arith.subi %132, %133 : tensor<16x32xi32>
        %135 = arith.cmpi sgt, %134, %cst_8 : tensor<16x32xi32>
        %136 = arith.ori %110, %135 : tensor<16x32xi1>
        %137 = tt.splat %arg46 : i32 -> tensor<16x32xi32>
        %138 = arith.cmpi sle, %134, %137 : tensor<16x32xi32>
        %139 = arith.andi %136, %138 : tensor<16x32xi1>
        %140 = arith.cmpi eq, %130, %cst_7 : tensor<16x1xi32>
        %141 = tt.splat %123 : i32 -> tensor<1x32xi32>
        %142 = arith.cmpi slt, %131, %141 : tensor<1x32xi32>
        %143 = tt.broadcast %140 : tensor<16x1xi1> -> tensor<16x32xi1>
        %144 = tt.broadcast %142 : tensor<1x32xi1> -> tensor<16x32xi1>
        %145 = arith.andi %143, %144 : tensor<16x32xi1>
        %146 = arith.ori %139, %145 : tensor<16x32xi1>
        %147 = arith.sitofp %arg41 : i32 to f32
        %148 = arith.divf %cst_6, %147 : f32
        %149 = tt.splat %148 : f32 -> tensor<16x32xf32>
        %150 = arith.select %146, %149, %cst_13 : tensor<16x32xi1>, tensor<16x32xf32>
        %151 = arith.subf %cst_13, %105 : tensor<16x32xf32>
        %152 = tt.extern_elementwise %151 {libname = "", libpath = "", pure = true, symbol = "__nv_fast_expf"} : (tensor<16x32xf32>) -> tensor<16x32xf32>
        %153 = arith.addf %152, %cst_5 : tensor<16x32xf32>
        %154 = tt.extern_elementwise %105, %153 {libname = "", libpath = "", pure = true, symbol = "__nv_fast_fdividef"} : (tensor<16x32xf32>, tensor<16x32xf32>) -> tensor<16x32xf32>
        %155 = arith.mulf %154, %150 : tensor<16x32xf32>
        %156 = tt.splat %arg18 : !tt.ptr<bf16> -> tensor<32x16x!tt.ptr<bf16>>
        %157 = tt.splat %arg21 : i64 -> tensor<32x1xi64>
        %158 = arith.muli %81, %157 : tensor<32x1xi64>
        %159 = tt.broadcast %158 : tensor<32x1xi64> -> tensor<32x16xi64>
        %160 = tt.splat %arg22 : i64 -> tensor<1x16xi64>
        %161 = arith.muli %27, %160 : tensor<1x16xi64>
        %162 = tt.broadcast %161 : tensor<1x16xi64> -> tensor<32x16xi64>
        %163 = arith.addi %159, %162 : tensor<32x16xi64>
        %164 = tt.addptr %156, %163 : tensor<32x16x!tt.ptr<bf16>>, tensor<32x16xi64>
        %165 = tt.splat %arg19 : i64 -> tensor<32x1xi64>
        %166 = arith.cmpi slt, %81, %165 : tensor<32x1xi64>
        %167 = arith.andi %91, %166 : tensor<32x1xi1>
        %168 = tt.broadcast %167 : tensor<32x1xi1> -> tensor<32x16xi1>
        %169 = tt.splat %arg20 : i64 -> tensor<1x16xi64>
        %170 = arith.cmpi slt, %27, %169 : tensor<1x16xi64>
        %171 = arith.andi %42, %170 : tensor<1x16xi1>
        %172 = tt.broadcast %171 : tensor<1x16xi1> -> tensor<32x16xi1>
        %173 = arith.andi %168, %172 : tensor<32x16xi1>
        %174 = tt.load %164, %173, %cst_2 : tensor<32x16x!tt.ptr<bf16>>
        %175 = arith.truncf %155 : tensor<16x32xf32> to tensor<16x32xbf16>
        %176 = tt.dot %175, %174, %arg48, inputPrecision = tf32 : tensor<16x32xbf16> * tensor<32x16xbf16> -> tensor<16x16xf32>
        %177 = arith.addi %arg49, %c32_i32 : i32
        scf.yield %176, %177 : tensor<16x16xf32>, i32
      }
      %53 = arith.cmpi slt, %51#0, %10 : i32
      %54 = scf.if %53 -> (tensor<16x16xf32>) {
        %72 = tt.splat %10 : i32 -> tensor<32xi32>
        %73 = arith.addi %18, %72 : tensor<32xi32>
        %74 = tt.splat %22 : i64 -> tensor<32xi64>
        %75 = arith.extsi %18 : tensor<32xi32> to tensor<32xi64>
        %76 = arith.addi %74, %75 : tensor<32xi64>
        %77 = tt.expand_dims %76 {axis = 1 : i32} : tensor<32xi64> -> tensor<32x1xi64>
        %78 = tt.splat %arg9 : !tt.ptr<bf16> -> tensor<32x16x!tt.ptr<bf16>>
        %79 = tt.splat %arg12 : i64 -> tensor<32x1xi64>
        %80 = arith.muli %77, %79 : tensor<32x1xi64>
        %81 = tt.broadcast %80 : tensor<32x1xi64> -> tensor<32x16xi64>
        %82 = tt.splat %arg13 : i64 -> tensor<1x16xi64>
        %83 = arith.muli %27, %82 : tensor<1x16xi64>
        %84 = tt.broadcast %83 : tensor<1x16xi64> -> tensor<32x16xi64>
        %85 = arith.addi %81, %84 : tensor<32x16xi64>
        %86 = tt.addptr %78, %85 : tensor<32x16x!tt.ptr<bf16>>, tensor<32x16xi64>
        %87 = arith.cmpi sge, %77, %cst : tensor<32x1xi64>
        %88 = tt.splat %arg10 : i64 -> tensor<32x1xi64>
        %89 = arith.cmpi slt, %77, %88 : tensor<32x1xi64>
        %90 = arith.andi %87, %89 : tensor<32x1xi1>
        %91 = tt.broadcast %90 : tensor<32x1xi1> -> tensor<32x16xi1>
        %92 = tt.splat %arg11 : i64 -> tensor<1x16xi64>
        %93 = arith.cmpi slt, %27, %92 : tensor<1x16xi64>
        %94 = arith.andi %42, %93 : tensor<1x16xi1>
        %95 = tt.broadcast %94 : tensor<1x16xi1> -> tensor<32x16xi1>
        %96 = arith.andi %91, %95 : tensor<32x16xi1>
        %97 = tt.load %86, %96, %cst_2 : tensor<32x16x!tt.ptr<bf16>>
        %98 = tt.trans %97 {order = array<i32: 1, 0>} : tensor<32x16xbf16> -> tensor<16x32xbf16>
        %99 = tt.dot %48, %98, %cst_13, inputPrecision = tf32 : tensor<16x16xbf16> * tensor<16x32xbf16> -> tensor<16x32xf32>
        %100 = tt.splat %arg38 : f32 -> tensor<16x32xf32>
        %101 = arith.mulf %99, %100 : tensor<16x32xf32>
        %102 = tt.expand_dims %17 {axis = 1 : i32} : tensor<16xi32> -> tensor<16x1xi32>
        %103 = tt.expand_dims %73 {axis = 0 : i32} : tensor<32xi32> -> tensor<1x32xi32>
        %104 = tt.broadcast %102 : tensor<16x1xi32> -> tensor<16x32xi32>
        %105 = tt.broadcast %103 : tensor<1x32xi32> -> tensor<16x32xi32>
        %106 = arith.cmpi eq, %104, %105 : tensor<16x32xi32>
        %107 = tt.splat %arg45 : i32 -> tensor<16xi32>
        %108 = arith.subi %17, %107 : tensor<16xi32>
        %109 = arith.addi %108, %cst_12 : tensor<16xi32>
        %110 = arith.cmpi sgt, %109, %cst_11 : tensor<16xi32>
        %111 = arith.select %110, %109, %cst_11 : tensor<16xi1>, tensor<16xi32>
        %112 = tt.splat %arg45 : i32 -> tensor<32xi32>
        %113 = arith.subi %73, %112 : tensor<32xi32>
        %114 = arith.addi %113, %cst_10 : tensor<32xi32>
        %115 = arith.cmpi sgt, %114, %cst_9 : tensor<32xi32>
        %116 = arith.select %115, %114, %cst_9 : tensor<32xi1>, tensor<32xi32>
        %117 = arith.subi %9, %arg45 : i32
        %118 = arith.addi %117, %c1_i32 : i32
        %119 = arith.subi %118, %14 : i32
        %120 = tt.splat %119 : i32 -> tensor<16xi32>
        %121 = arith.cmpi slt, %111, %120 : tensor<16xi32>
        %122 = arith.select %121, %111, %120 : tensor<16xi1>, tensor<16xi32>
        %123 = tt.splat %119 : i32 -> tensor<32xi32>
        %124 = arith.cmpi slt, %116, %123 : tensor<32xi32>
        %125 = arith.select %124, %116, %123 : tensor<32xi1>, tensor<32xi32>
        %126 = tt.expand_dims %122 {axis = 1 : i32} : tensor<16xi32> -> tensor<16x1xi32>
        %127 = tt.expand_dims %125 {axis = 0 : i32} : tensor<32xi32> -> tensor<1x32xi32>
        %128 = tt.broadcast %126 : tensor<16x1xi32> -> tensor<16x32xi32>
        %129 = tt.broadcast %127 : tensor<1x32xi32> -> tensor<16x32xi32>
        %130 = arith.subi %128, %129 : tensor<16x32xi32>
        %131 = arith.cmpi sgt, %130, %cst_8 : tensor<16x32xi32>
        %132 = arith.ori %106, %131 : tensor<16x32xi1>
        %133 = tt.splat %arg46 : i32 -> tensor<16x32xi32>
        %134 = arith.cmpi sle, %130, %133 : tensor<16x32xi32>
        %135 = arith.andi %132, %134 : tensor<16x32xi1>
        %136 = arith.cmpi eq, %126, %cst_7 : tensor<16x1xi32>
        %137 = tt.splat %119 : i32 -> tensor<1x32xi32>
        %138 = arith.cmpi slt, %127, %137 : tensor<1x32xi32>
        %139 = tt.broadcast %136 : tensor<16x1xi1> -> tensor<16x32xi1>
        %140 = tt.broadcast %138 : tensor<1x32xi1> -> tensor<16x32xi1>
        %141 = arith.andi %139, %140 : tensor<16x32xi1>
        %142 = arith.ori %135, %141 : tensor<16x32xi1>
        %143 = arith.sitofp %arg41 : i32 to f32
        %144 = arith.divf %cst_6, %143 : f32
        %145 = tt.splat %144 : f32 -> tensor<16x32xf32>
        %146 = arith.select %142, %145, %cst_13 : tensor<16x32xi1>, tensor<16x32xf32>
        %147 = arith.subf %cst_13, %101 : tensor<16x32xf32>
        %148 = tt.extern_elementwise %147 {libname = "", libpath = "", pure = true, symbol = "__nv_fast_expf"} : (tensor<16x32xf32>) -> tensor<16x32xf32>
        %149 = arith.addf %148, %cst_5 : tensor<16x32xf32>
        %150 = tt.extern_elementwise %101, %149 {libname = "", libpath = "", pure = true, symbol = "__nv_fast_fdividef"} : (tensor<16x32xf32>, tensor<16x32xf32>) -> tensor<16x32xf32>
        %151 = arith.mulf %150, %146 : tensor<16x32xf32>
        %152 = tt.splat %arg18 : !tt.ptr<bf16> -> tensor<32x16x!tt.ptr<bf16>>
        %153 = tt.splat %arg21 : i64 -> tensor<32x1xi64>
        %154 = arith.muli %77, %153 : tensor<32x1xi64>
        %155 = tt.broadcast %154 : tensor<32x1xi64> -> tensor<32x16xi64>
        %156 = tt.splat %arg22 : i64 -> tensor<1x16xi64>
        %157 = arith.muli %27, %156 : tensor<1x16xi64>
        %158 = tt.broadcast %157 : tensor<1x16xi64> -> tensor<32x16xi64>
        %159 = arith.addi %155, %158 : tensor<32x16xi64>
        %160 = tt.addptr %152, %159 : tensor<32x16x!tt.ptr<bf16>>, tensor<32x16xi64>
        %161 = tt.splat %arg19 : i64 -> tensor<32x1xi64>
        %162 = arith.cmpi slt, %77, %161 : tensor<32x1xi64>
        %163 = arith.andi %87, %162 : tensor<32x1xi1>
        %164 = tt.broadcast %163 : tensor<32x1xi1> -> tensor<32x16xi1>
        %165 = tt.splat %arg20 : i64 -> tensor<1x16xi64>
        %166 = arith.cmpi slt, %27, %165 : tensor<1x16xi64>
        %167 = arith.andi %42, %166 : tensor<1x16xi1>
        %168 = tt.broadcast %167 : tensor<1x16xi1> -> tensor<32x16xi1>
        %169 = arith.andi %164, %168 : tensor<32x16xi1>
        %170 = tt.load %160, %169, %cst_2 : tensor<32x16x!tt.ptr<bf16>>
        %171 = arith.truncf %151 : tensor<16x32xf32> to tensor<16x32xbf16>
        %172 = tt.dot %171, %170, %52#0, inputPrecision = tf32 : tensor<16x32xbf16> * tensor<32x16xbf16> -> tensor<16x16xf32>
        scf.yield %172 : tensor<16x16xf32>
      } else {
        scf.yield %52#0 : tensor<16x16xf32>
      }
      %55 = arith.extsi %arg36 : i32 to i64
      %56 = arith.muli %3, %55 : i64
      %57 = tt.addptr %arg29, %56 : !tt.ptr<bf16>, i64
      %58 = tt.expand_dims %17 {axis = 1 : i32} : tensor<16xi32> -> tensor<16x1xi32>
      %59 = tt.splat %arg36 : i32 -> tensor<16x1xi32>
      %60 = arith.muli %58, %59 : tensor<16x1xi32>
      %61 = tt.splat %57 : !tt.ptr<bf16> -> tensor<16x1x!tt.ptr<bf16>>
      %62 = tt.addptr %61, %60 : tensor<16x1x!tt.ptr<bf16>>, tensor<16x1xi32>
      %63 = tt.expand_dims %15 {axis = 0 : i32} : tensor<16xi32> -> tensor<1x16xi32>
      %64 = tt.broadcast %62 : tensor<16x1x!tt.ptr<bf16>> -> tensor<16x16x!tt.ptr<bf16>>
      %65 = tt.broadcast %63 : tensor<1x16xi32> -> tensor<16x16xi32>
      %66 = tt.addptr %64, %65 : tensor<16x16x!tt.ptr<bf16>>, tensor<16x16xi32>
      %67 = tt.splat %9 : i32 -> tensor<16xi32>
      %68 = arith.cmpi slt, %17, %67 : tensor<16xi32>
      %69 = tt.expand_dims %68 {axis = 1 : i32} : tensor<16xi1> -> tensor<16x1xi1>
      %70 = tt.broadcast %69 : tensor<16x1xi1> -> tensor<16x16xi1>
      %71 = arith.truncf %54 : tensor<16x16xf32> to tensor<16x16xbf16>
      tt.store %66, %71, %70 : tensor<16x16x!tt.ptr<bf16>>
    }
    tt.return
  }
}

{-#
  external_resources: {
    mlir_reproducer: {
      pipeline: "builtin.module(convert-triton-to-tritongpu{enable-source-remat=false num-ctas=1 num-warps=2 target=cuda:89 threads-per-warp=32}, tritongpu-coalesce, tritongpu-F32DotTC, triton-nvidia-gpu-plan-cta, tritongpu-remove-layout-conversions, tritongpu-optimize-thread-locality, tritongpu-accelerate-matmul, tritongpu-remove-layout-conversions, tritongpu-optimize-dot-operands{hoist-layout-conversion=true}, triton-nvidia-optimize-descriptor-encoding, triton-loop-aware-cse, tritongpu-fuse-nested-loops, canonicalize{  max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, triton-licm, canonicalize{  max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, tritongpu-combine-tensor-select-and-if, nvgpu-warp-specialization{dump-intermediate-steps=false num-stages=2}, tritongpu-assign-latencies{num-stages=2}, tritongpu-schedule-loops, tritongpu-pipeline{dump-intermediate-steps=false num-stages=2}, canonicalize{  max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, triton-loop-aware-cse, tritongpu-prefetch, tritongpu-optimize-dot-operands{hoist-layout-conversion=true}, tritongpu-coalesce-async-copy, triton-nvidia-optimize-tmem-layouts, tritongpu-remove-layout-conversions, triton-nvidia-interleave-tmem, tritongpu-reduce-data-duplication, tritongpu-reorder-instructions, triton-loop-aware-cse, symbol-dce, sccp, canonicalize{  max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true})",
      disable_threading: false,
      verify_each: true
    }
  }
#-}
/models/dengyao/Code/northstar-recsys/src/northstar/module/hstu/ops/triton/triton_hstu_attention.py:1508:0: error: Failures have been detected while processing an MLIR pass pipeline
/models/dengyao/Code/northstar-recsys/src/northstar/module/hstu/ops/triton/triton_hstu_attention.py:1508:0: note: Pipeline failed while executing [`TritonGPUCoalesce` on 'builtin.module' operation]: reproducer generated at `std::errs, please share the reproducer above with Triton project.`
E
======================================================================
ERROR: test_attn_triton_tma (ut.hstu.ops.test_hstu_attention_tma.HSTUAttentionTmaTest.test_attn_triton_tma)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/models/dengyao/Code/northstar-recsys/src/ut/hstu/ops/test_hstu_attention_tma.py", line 36, in test_attn_triton_tma
    # pyre-ignore
               ^^^
  File "/usr/local/lib/python3.12/dist-packages/hypothesis/core.py", line 2124, in wrapped_test
    raise the_error_hypothesis_found
  File "/models/dengyao/Code/northstar-recsys/src/ut/hstu/ops/test_hstu_attention_tma.py", line 61, in test_attn_triton_tma
    test_attn(
  File "/models/dengyao/Code/northstar-recsys/src/ut/hstu/ops/test_hstu_attention.py", line 132, in test_attn
    real_out = hstu_mha(
               ^^^^^^^^^
  File "/models/dengyao/Code/northstar-recsys/src/northstar/module/hstu/ops/hstu_attention.py", line 90, in hstu_mha
    return triton_hstu_mha(
           ^^^^^^^^^^^^^^^^
  File "/models/dengyao/Code/northstar-recsys/src/northstar/module/hstu/ops/triton/triton_hstu_attention.py", line 2893, in triton_hstu_mha
    return _AttentionFunction.apply(
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py", line 576, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/models/dengyao/Code/northstar-recsys/src/northstar/module/hstu/ops/triton/triton_hstu_attention.py", line 2800, in forward
    return triton_hstu_attention_fwd(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/models/dengyao/Code/northstar-recsys/src/northstar/module/hstu/ops/triton/triton_hstu_attention.py", line 2619, in triton_hstu_attention_fwd
    _hstu_attn_fwd[grid](
  File "/usr/local/lib/python3.12/dist-packages/triton/runtime/jit.py", line 390, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/triton/runtime/autotuner.py", line 251, in run
    ret = self.fn.run(
          ^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/triton/runtime/jit.py", line 594, in run
    kernel = self.compile(src, target=target, options=options.__dict__)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/triton/compiler/compiler.py", line 359, in compile
    next_module = compile_ir(module, metadata)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/triton/backends/nvidia/compiler.py", line 456, in <lambda>
    stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, capability)
                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/triton/backends/nvidia/compiler.py", line 298, in make_ttgir
    pm.run(mod)
RuntimeError: PassManager::run failed
Falsifying example: test_attn_triton_tma(
    self=<ut.hstu.ops.test_hstu_attention_tma.HSTUAttentionTmaTest testMethod=test_attn_triton_tma>,
    batch_size=4,  # or any other generated value
    heads=1,  # or any other generated value
    max_uih_len=20,  # or any other generated value
    max_targets=20,  # or any other generated value
    attn_dim=16,  # or any other generated value
    hidden_dim=16,  # or any other generated value
    causal=True,
    has_multiple_targets=True,  # or any other generated value
    dtype=torch.bfloat16,  # or any other generated value
    has_max_attn_len=True,
    contextual_seq_len=10,
)
Explanation:
    These lines were always and only run by failing examples:
        /models/dengyao/Code/northstar-recsys/src/northstar/module/hstu/ops/pytorch/pt_hstu_attention.py:68
        /models/dengyao/Code/northstar-recsys/src/ut/hstu/ops/test_hstu_attention.py:75
        /usr/local/lib/python3.12/dist-packages/triton/compiler/code_generator.py:827
        /usr/local/lib/python3.12/dist-packages/triton/compiler/code_generator.py:829
        /usr/local/lib/python3.12/dist-packages/triton/compiler/code_generator.py:830
        /usr/local/lib/python3.12/dist-packages/triton/compiler/code_generator.py:832
        /usr/local/lib/python3.12/dist-packages/triton/compiler/code_generator.py:835
        /usr/local/lib/python3.12/dist-packages/triton/compiler/code_generator.py:842
        /usr/local/lib/python3.12/dist-packages/triton/compiler/code_generator.py:846
        /usr/local/lib/python3.12/dist-packages/triton/compiler/code_generator.py:848
        /usr/local/lib/python3.12/dist-packages/triton/compiler/code_generator.py:851
        /usr/local/lib/python3.12/dist-packages/triton/language/core.py:1059
        /usr/local/lib/python3.12/dist-packages/triton/language/core.py:1060

----------------------------------------------------------------------
Ran 1 test in 215.469s

FAILED (errors=1)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions