-
Notifications
You must be signed in to change notification settings - Fork 299
Open
Description
Test case test_hstu_attention_tma.HSTUAttentionTmaTest.test_attn_triton_tma seems always fail on my environment:
fbgemm_gpu 1.3.0+cu128
torch 2.8.0+cu128
torchrec 1.3.0+cu128
triton 3.4.0
Python 3.12.3
OS: Ubuntu 24.04
python3: /root/.triton/llvm/llvm-8957e64a-almalinux-x64/include/llvm/ADT/SmallVector.h:296: const_reference llvm::SmallVectorTemplateCommon<long>::operator[](size_type) const [T = long]: Assertion `idx < size()' failed.
module {
tt.func public @_hstu_attn_fwd(%arg0: !tt.ptr<bf16>, %arg1: i64, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i32, %arg6: i32, %arg7: i64, %arg8: i64, %arg9: !tt.ptr<bf16>, %arg10: i64, %arg11: i64, %arg12: i64, %arg13: i64, %arg14: i32, %arg15: i32 {tt.divisibility = 16 : i32}, %arg16: i64 {tt.divisibility = 16 : i32}, %arg17: i64 {tt.divisibility = 16 : i32}, %arg18: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg19: i64 {tt.divisibility = 16 : i32}, %arg20: i64 {tt.divisibility = 16 : i32}, %arg21: i64 {tt.divisibility = 16 : i32}, %arg22: i64 {tt.divisibility = 16 : i32}, %arg23: i32 {tt.divisibility = 16 : i32}, %arg24: i32 {tt.divisibility = 16 : i32}, %arg25: i64 {tt.divisibility = 16 : i32}, %arg26: i64, %arg27: !tt.ptr<i64>, %arg28: !tt.ptr<i64>, %arg29: !tt.ptr<bf16>, %arg30: i32 {tt.divisibility = 16 : i32}, %arg31: i32 {tt.divisibility = 16 : i32}, %arg32: i32 {tt.divisibility = 16 : i32}, %arg33: i32, %arg34: i32, %arg35: i32, %arg36: i32, %arg37: i32, %arg38: f32, %arg39: i32, %arg40: i32, %arg41: i32, %arg42: i32, %arg43: i32, %arg44: i32, %arg45: i32, %arg46: i32) attributes {noinline = false} {
%cst = arith.constant dense<0> : tensor<32x1xi64>
%cst_0 = arith.constant dense<0> : tensor<1x16xi64>
%cst_1 = arith.constant dense<0> : tensor<16x1xi64>
%cst_2 = arith.constant dense<0.000000e+00> : tensor<32x16xbf16>
%cst_3 = arith.constant dense<0.000000e+00> : tensor<16x16xbf16>
%c1_i32 = arith.constant 1 : i32
%c16_i32 = arith.constant 16 : i32
%c32_i32 = arith.constant 32 : i32
%c31_i32 = arith.constant 31 : i32
%cst_4 = arith.constant dense<0.000000e+00> : tensor<16x16xf32>
%cst_5 = arith.constant dense<1.000000e+00> : tensor<16x32xf32>
%cst_6 = arith.constant 1.000000e+00 : f32
%cst_7 = arith.constant dense<0> : tensor<16x1xi32>
%cst_8 = arith.constant dense<0> : tensor<16x32xi32>
%cst_9 = arith.constant dense<0> : tensor<32xi32>
%cst_10 = arith.constant dense<1> : tensor<32xi32>
%cst_11 = arith.constant dense<0> : tensor<16xi32>
%cst_12 = arith.constant dense<1> : tensor<16xi32>
%cst_13 = arith.constant dense<0.000000e+00> : tensor<16x32xf32>
%c0_i32 = arith.constant 0 : i32
%0 = tt.get_program_id y : i32
%1 = tt.get_program_id x : i32
%2 = tt.addptr %arg27, %0 : !tt.ptr<i64>, i32
%3 = tt.load %2 : !tt.ptr<i64>
%4 = arith.extsi %0 : i32 to i64
%5 = tt.addptr %arg27, %4 : !tt.ptr<i64>, i64
%6 = tt.addptr %5, %c1_i32 : !tt.ptr<i64>, i32
%7 = tt.load %6 : !tt.ptr<i64>
%8 = arith.subi %7, %3 : i64
%9 = arith.trunci %8 : i64 to i32
%10 = arith.muli %1, %c16_i32 : i32
%11 = arith.cmpi slt, %10, %9 : i32
scf.if %11 {
%12 = tt.addptr %arg28, %4 : !tt.ptr<i64>, i64
%13 = tt.load %12 : !tt.ptr<i64>
%14 = arith.trunci %13 : i64 to i32
%15 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32>
%16 = tt.splat %10 : i32 -> tensor<16xi32>
%17 = arith.addi %16, %15 : tensor<16xi32>
%18 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
%19 = arith.extsi %10 : i32 to i64
%20 = arith.addi %3, %19 : i64
%21 = arith.trunci %20 : i64 to i32
%22 = arith.extsi %21 : i32 to i64
%23 = tt.splat %22 : i64 -> tensor<16xi64>
%24 = arith.extsi %15 : tensor<16xi32> to tensor<16xi64>
%25 = arith.addi %23, %24 : tensor<16xi64>
%26 = tt.expand_dims %25 {axis = 1 : i32} : tensor<16xi64> -> tensor<16x1xi64>
%27 = tt.expand_dims %24 {axis = 0 : i32} : tensor<16xi64> -> tensor<1x16xi64>
%28 = tt.splat %arg0 : !tt.ptr<bf16> -> tensor<16x16x!tt.ptr<bf16>>
%29 = tt.splat %arg3 : i64 -> tensor<16x1xi64>
%30 = arith.muli %26, %29 : tensor<16x1xi64>
%31 = tt.broadcast %30 : tensor<16x1xi64> -> tensor<16x16xi64>
%32 = tt.splat %arg4 : i64 -> tensor<1x16xi64>
%33 = arith.muli %27, %32 : tensor<1x16xi64>
%34 = tt.broadcast %33 : tensor<1x16xi64> -> tensor<16x16xi64>
%35 = arith.addi %31, %34 : tensor<16x16xi64>
%36 = tt.addptr %28, %35 : tensor<16x16x!tt.ptr<bf16>>, tensor<16x16xi64>
%37 = arith.cmpi sge, %26, %cst_1 : tensor<16x1xi64>
%38 = tt.splat %arg1 : i64 -> tensor<16x1xi64>
%39 = arith.cmpi slt, %26, %38 : tensor<16x1xi64>
%40 = arith.andi %37, %39 : tensor<16x1xi1>
%41 = tt.broadcast %40 : tensor<16x1xi1> -> tensor<16x16xi1>
%42 = arith.cmpi sge, %27, %cst_0 : tensor<1x16xi64>
%43 = tt.splat %arg2 : i64 -> tensor<1x16xi64>
%44 = arith.cmpi slt, %27, %43 : tensor<1x16xi64>
%45 = arith.andi %42, %44 : tensor<1x16xi1>
%46 = tt.broadcast %45 : tensor<1x16xi1> -> tensor<16x16xi1>
%47 = arith.andi %41, %46 : tensor<16x16xi1>
%48 = tt.load %36, %47, %cst_3 : tensor<16x16x!tt.ptr<bf16>>
%49 = arith.subi %9, %14 : i32
%50 = arith.cmpi slt, %10, %arg45 : i32
%51:3 = scf.if %50 -> (i32, i32, i32) {
scf.yield %49, %9, %c0_i32 : i32, i32, i32
} else {
%72 = arith.addi %10, %c16_i32 : i32
%73 = arith.cmpi sgt, %10, %49 : i32
%74 = scf.if %73 -> (i32) {
%82 = arith.subi %49, %arg46 : i32
scf.yield %82 : i32
} else {
%82 = arith.subi %10, %arg46 : i32
scf.yield %82 : i32
}
%75 = arith.cmpi sgt, %74, %arg45 : i32
%76 = arith.select %75, %74, %c0_i32 : i32
%77 = arith.addi %49, %c31_i32 : i32
%78 = arith.divsi %77, %c32_i32 : i32
%79 = arith.muli %78, %c32_i32 : i32
%80 = arith.cmpi slt, %79, %10 : i32
%81 = scf.if %80 -> (i32) {
scf.yield %49 : i32
} else {
scf.yield %72 : i32
}
scf.yield %79, %81, %76 : i32, i32, i32
}
%52:2 = scf.for %arg47 = %51#2 to %51#1 step %c32_i32 iter_args(%arg48 = %cst_4, %arg49 = %51#2) -> (tensor<16x16xf32>, i32) : i32 {
%72 = tt.splat %arg47 : i32 -> tensor<32xi32>
%73 = arith.addi %18, %72 : tensor<32xi32>
%74 = arith.extsi %arg47 : i32 to i64
%75 = arith.addi %3, %74 : i64
%76 = arith.trunci %75 : i64 to i32
%77 = arith.extsi %76 : i32 to i64
%78 = tt.splat %77 : i64 -> tensor<32xi64>
%79 = arith.extsi %18 : tensor<32xi32> to tensor<32xi64>
%80 = arith.addi %78, %79 : tensor<32xi64>
%81 = tt.expand_dims %80 {axis = 1 : i32} : tensor<32xi64> -> tensor<32x1xi64>
%82 = tt.splat %arg9 : !tt.ptr<bf16> -> tensor<32x16x!tt.ptr<bf16>>
%83 = tt.splat %arg12 : i64 -> tensor<32x1xi64>
%84 = arith.muli %81, %83 : tensor<32x1xi64>
%85 = tt.broadcast %84 : tensor<32x1xi64> -> tensor<32x16xi64>
%86 = tt.splat %arg13 : i64 -> tensor<1x16xi64>
%87 = arith.muli %27, %86 : tensor<1x16xi64>
%88 = tt.broadcast %87 : tensor<1x16xi64> -> tensor<32x16xi64>
%89 = arith.addi %85, %88 : tensor<32x16xi64>
%90 = tt.addptr %82, %89 : tensor<32x16x!tt.ptr<bf16>>, tensor<32x16xi64>
%91 = arith.cmpi sge, %81, %cst : tensor<32x1xi64>
%92 = tt.splat %arg10 : i64 -> tensor<32x1xi64>
%93 = arith.cmpi slt, %81, %92 : tensor<32x1xi64>
%94 = arith.andi %91, %93 : tensor<32x1xi1>
%95 = tt.broadcast %94 : tensor<32x1xi1> -> tensor<32x16xi1>
%96 = tt.splat %arg11 : i64 -> tensor<1x16xi64>
%97 = arith.cmpi slt, %27, %96 : tensor<1x16xi64>
%98 = arith.andi %42, %97 : tensor<1x16xi1>
%99 = tt.broadcast %98 : tensor<1x16xi1> -> tensor<32x16xi1>
%100 = arith.andi %95, %99 : tensor<32x16xi1>
%101 = tt.load %90, %100, %cst_2 : tensor<32x16x!tt.ptr<bf16>>
%102 = tt.trans %101 {order = array<i32: 1, 0>} : tensor<32x16xbf16> -> tensor<16x32xbf16>
%103 = tt.dot %48, %102, %cst_13, inputPrecision = tf32 : tensor<16x16xbf16> * tensor<16x32xbf16> -> tensor<16x32xf32>
%104 = tt.splat %arg38 : f32 -> tensor<16x32xf32>
%105 = arith.mulf %103, %104 : tensor<16x32xf32>
%106 = tt.expand_dims %17 {axis = 1 : i32} : tensor<16xi32> -> tensor<16x1xi32>
%107 = tt.expand_dims %73 {axis = 0 : i32} : tensor<32xi32> -> tensor<1x32xi32>
%108 = tt.broadcast %106 : tensor<16x1xi32> -> tensor<16x32xi32>
%109 = tt.broadcast %107 : tensor<1x32xi32> -> tensor<16x32xi32>
%110 = arith.cmpi eq, %108, %109 : tensor<16x32xi32>
%111 = tt.splat %arg45 : i32 -> tensor<16xi32>
%112 = arith.subi %17, %111 : tensor<16xi32>
%113 = arith.addi %112, %cst_12 : tensor<16xi32>
%114 = arith.cmpi sgt, %113, %cst_11 : tensor<16xi32>
%115 = arith.select %114, %113, %cst_11 : tensor<16xi1>, tensor<16xi32>
%116 = tt.splat %arg45 : i32 -> tensor<32xi32>
%117 = arith.subi %73, %116 : tensor<32xi32>
%118 = arith.addi %117, %cst_10 : tensor<32xi32>
%119 = arith.cmpi sgt, %118, %cst_9 : tensor<32xi32>
%120 = arith.select %119, %118, %cst_9 : tensor<32xi1>, tensor<32xi32>
%121 = arith.subi %9, %arg45 : i32
%122 = arith.addi %121, %c1_i32 : i32
%123 = arith.subi %122, %14 : i32
%124 = tt.splat %123 : i32 -> tensor<16xi32>
%125 = arith.cmpi slt, %115, %124 : tensor<16xi32>
%126 = arith.select %125, %115, %124 : tensor<16xi1>, tensor<16xi32>
%127 = tt.splat %123 : i32 -> tensor<32xi32>
%128 = arith.cmpi slt, %120, %127 : tensor<32xi32>
%129 = arith.select %128, %120, %127 : tensor<32xi1>, tensor<32xi32>
%130 = tt.expand_dims %126 {axis = 1 : i32} : tensor<16xi32> -> tensor<16x1xi32>
%131 = tt.expand_dims %129 {axis = 0 : i32} : tensor<32xi32> -> tensor<1x32xi32>
%132 = tt.broadcast %130 : tensor<16x1xi32> -> tensor<16x32xi32>
%133 = tt.broadcast %131 : tensor<1x32xi32> -> tensor<16x32xi32>
%134 = arith.subi %132, %133 : tensor<16x32xi32>
%135 = arith.cmpi sgt, %134, %cst_8 : tensor<16x32xi32>
%136 = arith.ori %110, %135 : tensor<16x32xi1>
%137 = tt.splat %arg46 : i32 -> tensor<16x32xi32>
%138 = arith.cmpi sle, %134, %137 : tensor<16x32xi32>
%139 = arith.andi %136, %138 : tensor<16x32xi1>
%140 = arith.cmpi eq, %130, %cst_7 : tensor<16x1xi32>
%141 = tt.splat %123 : i32 -> tensor<1x32xi32>
%142 = arith.cmpi slt, %131, %141 : tensor<1x32xi32>
%143 = tt.broadcast %140 : tensor<16x1xi1> -> tensor<16x32xi1>
%144 = tt.broadcast %142 : tensor<1x32xi1> -> tensor<16x32xi1>
%145 = arith.andi %143, %144 : tensor<16x32xi1>
%146 = arith.ori %139, %145 : tensor<16x32xi1>
%147 = arith.sitofp %arg41 : i32 to f32
%148 = arith.divf %cst_6, %147 : f32
%149 = tt.splat %148 : f32 -> tensor<16x32xf32>
%150 = arith.select %146, %149, %cst_13 : tensor<16x32xi1>, tensor<16x32xf32>
%151 = arith.subf %cst_13, %105 : tensor<16x32xf32>
%152 = tt.extern_elementwise %151 {libname = "", libpath = "", pure = true, symbol = "__nv_fast_expf"} : (tensor<16x32xf32>) -> tensor<16x32xf32>
%153 = arith.addf %152, %cst_5 : tensor<16x32xf32>
%154 = tt.extern_elementwise %105, %153 {libname = "", libpath = "", pure = true, symbol = "__nv_fast_fdividef"} : (tensor<16x32xf32>, tensor<16x32xf32>) -> tensor<16x32xf32>
%155 = arith.mulf %154, %150 : tensor<16x32xf32>
%156 = tt.splat %arg18 : !tt.ptr<bf16> -> tensor<32x16x!tt.ptr<bf16>>
%157 = tt.splat %arg21 : i64 -> tensor<32x1xi64>
%158 = arith.muli %81, %157 : tensor<32x1xi64>
%159 = tt.broadcast %158 : tensor<32x1xi64> -> tensor<32x16xi64>
%160 = tt.splat %arg22 : i64 -> tensor<1x16xi64>
%161 = arith.muli %27, %160 : tensor<1x16xi64>
%162 = tt.broadcast %161 : tensor<1x16xi64> -> tensor<32x16xi64>
%163 = arith.addi %159, %162 : tensor<32x16xi64>
%164 = tt.addptr %156, %163 : tensor<32x16x!tt.ptr<bf16>>, tensor<32x16xi64>
%165 = tt.splat %arg19 : i64 -> tensor<32x1xi64>
%166 = arith.cmpi slt, %81, %165 : tensor<32x1xi64>
%167 = arith.andi %91, %166 : tensor<32x1xi1>
%168 = tt.broadcast %167 : tensor<32x1xi1> -> tensor<32x16xi1>
%169 = tt.splat %arg20 : i64 -> tensor<1x16xi64>
%170 = arith.cmpi slt, %27, %169 : tensor<1x16xi64>
%171 = arith.andi %42, %170 : tensor<1x16xi1>
%172 = tt.broadcast %171 : tensor<1x16xi1> -> tensor<32x16xi1>
%173 = arith.andi %168, %172 : tensor<32x16xi1>
%174 = tt.load %164, %173, %cst_2 : tensor<32x16x!tt.ptr<bf16>>
%175 = arith.truncf %155 : tensor<16x32xf32> to tensor<16x32xbf16>
%176 = tt.dot %175, %174, %arg48, inputPrecision = tf32 : tensor<16x32xbf16> * tensor<32x16xbf16> -> tensor<16x16xf32>
%177 = arith.addi %arg49, %c32_i32 : i32
scf.yield %176, %177 : tensor<16x16xf32>, i32
}
%53 = arith.cmpi slt, %51#0, %10 : i32
%54 = scf.if %53 -> (tensor<16x16xf32>) {
%72 = tt.splat %10 : i32 -> tensor<32xi32>
%73 = arith.addi %18, %72 : tensor<32xi32>
%74 = tt.splat %22 : i64 -> tensor<32xi64>
%75 = arith.extsi %18 : tensor<32xi32> to tensor<32xi64>
%76 = arith.addi %74, %75 : tensor<32xi64>
%77 = tt.expand_dims %76 {axis = 1 : i32} : tensor<32xi64> -> tensor<32x1xi64>
%78 = tt.splat %arg9 : !tt.ptr<bf16> -> tensor<32x16x!tt.ptr<bf16>>
%79 = tt.splat %arg12 : i64 -> tensor<32x1xi64>
%80 = arith.muli %77, %79 : tensor<32x1xi64>
%81 = tt.broadcast %80 : tensor<32x1xi64> -> tensor<32x16xi64>
%82 = tt.splat %arg13 : i64 -> tensor<1x16xi64>
%83 = arith.muli %27, %82 : tensor<1x16xi64>
%84 = tt.broadcast %83 : tensor<1x16xi64> -> tensor<32x16xi64>
%85 = arith.addi %81, %84 : tensor<32x16xi64>
%86 = tt.addptr %78, %85 : tensor<32x16x!tt.ptr<bf16>>, tensor<32x16xi64>
%87 = arith.cmpi sge, %77, %cst : tensor<32x1xi64>
%88 = tt.splat %arg10 : i64 -> tensor<32x1xi64>
%89 = arith.cmpi slt, %77, %88 : tensor<32x1xi64>
%90 = arith.andi %87, %89 : tensor<32x1xi1>
%91 = tt.broadcast %90 : tensor<32x1xi1> -> tensor<32x16xi1>
%92 = tt.splat %arg11 : i64 -> tensor<1x16xi64>
%93 = arith.cmpi slt, %27, %92 : tensor<1x16xi64>
%94 = arith.andi %42, %93 : tensor<1x16xi1>
%95 = tt.broadcast %94 : tensor<1x16xi1> -> tensor<32x16xi1>
%96 = arith.andi %91, %95 : tensor<32x16xi1>
%97 = tt.load %86, %96, %cst_2 : tensor<32x16x!tt.ptr<bf16>>
%98 = tt.trans %97 {order = array<i32: 1, 0>} : tensor<32x16xbf16> -> tensor<16x32xbf16>
%99 = tt.dot %48, %98, %cst_13, inputPrecision = tf32 : tensor<16x16xbf16> * tensor<16x32xbf16> -> tensor<16x32xf32>
%100 = tt.splat %arg38 : f32 -> tensor<16x32xf32>
%101 = arith.mulf %99, %100 : tensor<16x32xf32>
%102 = tt.expand_dims %17 {axis = 1 : i32} : tensor<16xi32> -> tensor<16x1xi32>
%103 = tt.expand_dims %73 {axis = 0 : i32} : tensor<32xi32> -> tensor<1x32xi32>
%104 = tt.broadcast %102 : tensor<16x1xi32> -> tensor<16x32xi32>
%105 = tt.broadcast %103 : tensor<1x32xi32> -> tensor<16x32xi32>
%106 = arith.cmpi eq, %104, %105 : tensor<16x32xi32>
%107 = tt.splat %arg45 : i32 -> tensor<16xi32>
%108 = arith.subi %17, %107 : tensor<16xi32>
%109 = arith.addi %108, %cst_12 : tensor<16xi32>
%110 = arith.cmpi sgt, %109, %cst_11 : tensor<16xi32>
%111 = arith.select %110, %109, %cst_11 : tensor<16xi1>, tensor<16xi32>
%112 = tt.splat %arg45 : i32 -> tensor<32xi32>
%113 = arith.subi %73, %112 : tensor<32xi32>
%114 = arith.addi %113, %cst_10 : tensor<32xi32>
%115 = arith.cmpi sgt, %114, %cst_9 : tensor<32xi32>
%116 = arith.select %115, %114, %cst_9 : tensor<32xi1>, tensor<32xi32>
%117 = arith.subi %9, %arg45 : i32
%118 = arith.addi %117, %c1_i32 : i32
%119 = arith.subi %118, %14 : i32
%120 = tt.splat %119 : i32 -> tensor<16xi32>
%121 = arith.cmpi slt, %111, %120 : tensor<16xi32>
%122 = arith.select %121, %111, %120 : tensor<16xi1>, tensor<16xi32>
%123 = tt.splat %119 : i32 -> tensor<32xi32>
%124 = arith.cmpi slt, %116, %123 : tensor<32xi32>
%125 = arith.select %124, %116, %123 : tensor<32xi1>, tensor<32xi32>
%126 = tt.expand_dims %122 {axis = 1 : i32} : tensor<16xi32> -> tensor<16x1xi32>
%127 = tt.expand_dims %125 {axis = 0 : i32} : tensor<32xi32> -> tensor<1x32xi32>
%128 = tt.broadcast %126 : tensor<16x1xi32> -> tensor<16x32xi32>
%129 = tt.broadcast %127 : tensor<1x32xi32> -> tensor<16x32xi32>
%130 = arith.subi %128, %129 : tensor<16x32xi32>
%131 = arith.cmpi sgt, %130, %cst_8 : tensor<16x32xi32>
%132 = arith.ori %106, %131 : tensor<16x32xi1>
%133 = tt.splat %arg46 : i32 -> tensor<16x32xi32>
%134 = arith.cmpi sle, %130, %133 : tensor<16x32xi32>
%135 = arith.andi %132, %134 : tensor<16x32xi1>
%136 = arith.cmpi eq, %126, %cst_7 : tensor<16x1xi32>
%137 = tt.splat %119 : i32 -> tensor<1x32xi32>
%138 = arith.cmpi slt, %127, %137 : tensor<1x32xi32>
%139 = tt.broadcast %136 : tensor<16x1xi1> -> tensor<16x32xi1>
%140 = tt.broadcast %138 : tensor<1x32xi1> -> tensor<16x32xi1>
%141 = arith.andi %139, %140 : tensor<16x32xi1>
%142 = arith.ori %135, %141 : tensor<16x32xi1>
%143 = arith.sitofp %arg41 : i32 to f32
%144 = arith.divf %cst_6, %143 : f32
%145 = tt.splat %144 : f32 -> tensor<16x32xf32>
%146 = arith.select %142, %145, %cst_13 : tensor<16x32xi1>, tensor<16x32xf32>
%147 = arith.subf %cst_13, %101 : tensor<16x32xf32>
%148 = tt.extern_elementwise %147 {libname = "", libpath = "", pure = true, symbol = "__nv_fast_expf"} : (tensor<16x32xf32>) -> tensor<16x32xf32>
%149 = arith.addf %148, %cst_5 : tensor<16x32xf32>
%150 = tt.extern_elementwise %101, %149 {libname = "", libpath = "", pure = true, symbol = "__nv_fast_fdividef"} : (tensor<16x32xf32>, tensor<16x32xf32>) -> tensor<16x32xf32>
%151 = arith.mulf %150, %146 : tensor<16x32xf32>
%152 = tt.splat %arg18 : !tt.ptr<bf16> -> tensor<32x16x!tt.ptr<bf16>>
%153 = tt.splat %arg21 : i64 -> tensor<32x1xi64>
%154 = arith.muli %77, %153 : tensor<32x1xi64>
%155 = tt.broadcast %154 : tensor<32x1xi64> -> tensor<32x16xi64>
%156 = tt.splat %arg22 : i64 -> tensor<1x16xi64>
%157 = arith.muli %27, %156 : tensor<1x16xi64>
%158 = tt.broadcast %157 : tensor<1x16xi64> -> tensor<32x16xi64>
%159 = arith.addi %155, %158 : tensor<32x16xi64>
%160 = tt.addptr %152, %159 : tensor<32x16x!tt.ptr<bf16>>, tensor<32x16xi64>
%161 = tt.splat %arg19 : i64 -> tensor<32x1xi64>
%162 = arith.cmpi slt, %77, %161 : tensor<32x1xi64>
%163 = arith.andi %87, %162 : tensor<32x1xi1>
%164 = tt.broadcast %163 : tensor<32x1xi1> -> tensor<32x16xi1>
%165 = tt.splat %arg20 : i64 -> tensor<1x16xi64>
%166 = arith.cmpi slt, %27, %165 : tensor<1x16xi64>
%167 = arith.andi %42, %166 : tensor<1x16xi1>
%168 = tt.broadcast %167 : tensor<1x16xi1> -> tensor<32x16xi1>
%169 = arith.andi %164, %168 : tensor<32x16xi1>
%170 = tt.load %160, %169, %cst_2 : tensor<32x16x!tt.ptr<bf16>>
%171 = arith.truncf %151 : tensor<16x32xf32> to tensor<16x32xbf16>
%172 = tt.dot %171, %170, %52#0, inputPrecision = tf32 : tensor<16x32xbf16> * tensor<32x16xbf16> -> tensor<16x16xf32>
scf.yield %172 : tensor<16x16xf32>
} else {
scf.yield %52#0 : tensor<16x16xf32>
}
%55 = arith.extsi %arg36 : i32 to i64
%56 = arith.muli %3, %55 : i64
%57 = tt.addptr %arg29, %56 : !tt.ptr<bf16>, i64
%58 = tt.expand_dims %17 {axis = 1 : i32} : tensor<16xi32> -> tensor<16x1xi32>
%59 = tt.splat %arg36 : i32 -> tensor<16x1xi32>
%60 = arith.muli %58, %59 : tensor<16x1xi32>
%61 = tt.splat %57 : !tt.ptr<bf16> -> tensor<16x1x!tt.ptr<bf16>>
%62 = tt.addptr %61, %60 : tensor<16x1x!tt.ptr<bf16>>, tensor<16x1xi32>
%63 = tt.expand_dims %15 {axis = 0 : i32} : tensor<16xi32> -> tensor<1x16xi32>
%64 = tt.broadcast %62 : tensor<16x1x!tt.ptr<bf16>> -> tensor<16x16x!tt.ptr<bf16>>
%65 = tt.broadcast %63 : tensor<1x16xi32> -> tensor<16x16xi32>
%66 = tt.addptr %64, %65 : tensor<16x16x!tt.ptr<bf16>>, tensor<16x16xi32>
%67 = tt.splat %9 : i32 -> tensor<16xi32>
%68 = arith.cmpi slt, %17, %67 : tensor<16xi32>
%69 = tt.expand_dims %68 {axis = 1 : i32} : tensor<16xi1> -> tensor<16x1xi1>
%70 = tt.broadcast %69 : tensor<16x1xi1> -> tensor<16x16xi1>
%71 = arith.truncf %54 : tensor<16x16xf32> to tensor<16x16xbf16>
tt.store %66, %71, %70 : tensor<16x16x!tt.ptr<bf16>>
}
tt.return
}
}
{-#
external_resources: {
mlir_reproducer: {
pipeline: "builtin.module(convert-triton-to-tritongpu{enable-source-remat=false num-ctas=1 num-warps=2 target=cuda:89 threads-per-warp=32}, tritongpu-coalesce, tritongpu-F32DotTC, triton-nvidia-gpu-plan-cta, tritongpu-remove-layout-conversions, tritongpu-optimize-thread-locality, tritongpu-accelerate-matmul, tritongpu-remove-layout-conversions, tritongpu-optimize-dot-operands{hoist-layout-conversion=true}, triton-nvidia-optimize-descriptor-encoding, triton-loop-aware-cse, tritongpu-fuse-nested-loops, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, triton-licm, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, tritongpu-combine-tensor-select-and-if, nvgpu-warp-specialization{dump-intermediate-steps=false num-stages=2}, tritongpu-assign-latencies{num-stages=2}, tritongpu-schedule-loops, tritongpu-pipeline{dump-intermediate-steps=false num-stages=2}, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, triton-loop-aware-cse, tritongpu-prefetch, tritongpu-optimize-dot-operands{hoist-layout-conversion=true}, tritongpu-coalesce-async-copy, triton-nvidia-optimize-tmem-layouts, tritongpu-remove-layout-conversions, triton-nvidia-interleave-tmem, tritongpu-reduce-data-duplication, tritongpu-reorder-instructions, triton-loop-aware-cse, symbol-dce, sccp, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true})",
disable_threading: false,
verify_each: true
}
}
#-}
/models/dengyao/Code/northstar-recsys/src/northstar/module/hstu/ops/triton/triton_hstu_attention.py:1508:0: error: Failures have been detected while processing an MLIR pass pipeline
/models/dengyao/Code/northstar-recsys/src/northstar/module/hstu/ops/triton/triton_hstu_attention.py:1508:0: note: Pipeline failed while executing [`TritonGPUCoalesce` on 'builtin.module' operation]: reproducer generated at `std::errs, please share the reproducer above with Triton project.`
E
======================================================================
ERROR: test_attn_triton_tma (ut.hstu.ops.test_hstu_attention_tma.HSTUAttentionTmaTest.test_attn_triton_tma)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/models/dengyao/Code/northstar-recsys/src/ut/hstu/ops/test_hstu_attention_tma.py", line 36, in test_attn_triton_tma
# pyre-ignore
^^^
File "/usr/local/lib/python3.12/dist-packages/hypothesis/core.py", line 2124, in wrapped_test
raise the_error_hypothesis_found
File "/models/dengyao/Code/northstar-recsys/src/ut/hstu/ops/test_hstu_attention_tma.py", line 61, in test_attn_triton_tma
test_attn(
File "/models/dengyao/Code/northstar-recsys/src/ut/hstu/ops/test_hstu_attention.py", line 132, in test_attn
real_out = hstu_mha(
^^^^^^^^^
File "/models/dengyao/Code/northstar-recsys/src/northstar/module/hstu/ops/hstu_attention.py", line 90, in hstu_mha
return triton_hstu_mha(
^^^^^^^^^^^^^^^^
File "/models/dengyao/Code/northstar-recsys/src/northstar/module/hstu/ops/triton/triton_hstu_attention.py", line 2893, in triton_hstu_mha
return _AttentionFunction.apply(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py", line 576, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/models/dengyao/Code/northstar-recsys/src/northstar/module/hstu/ops/triton/triton_hstu_attention.py", line 2800, in forward
return triton_hstu_attention_fwd(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/models/dengyao/Code/northstar-recsys/src/northstar/module/hstu/ops/triton/triton_hstu_attention.py", line 2619, in triton_hstu_attention_fwd
_hstu_attn_fwd[grid](
File "/usr/local/lib/python3.12/dist-packages/triton/runtime/jit.py", line 390, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/triton/runtime/autotuner.py", line 251, in run
ret = self.fn.run(
^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/triton/runtime/jit.py", line 594, in run
kernel = self.compile(src, target=target, options=options.__dict__)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/triton/compiler/compiler.py", line 359, in compile
next_module = compile_ir(module, metadata)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/triton/backends/nvidia/compiler.py", line 456, in <lambda>
stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, capability)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/triton/backends/nvidia/compiler.py", line 298, in make_ttgir
pm.run(mod)
RuntimeError: PassManager::run failed
Falsifying example: test_attn_triton_tma(
self=<ut.hstu.ops.test_hstu_attention_tma.HSTUAttentionTmaTest testMethod=test_attn_triton_tma>,
batch_size=4, # or any other generated value
heads=1, # or any other generated value
max_uih_len=20, # or any other generated value
max_targets=20, # or any other generated value
attn_dim=16, # or any other generated value
hidden_dim=16, # or any other generated value
causal=True,
has_multiple_targets=True, # or any other generated value
dtype=torch.bfloat16, # or any other generated value
has_max_attn_len=True,
contextual_seq_len=10,
)
Explanation:
These lines were always and only run by failing examples:
/models/dengyao/Code/northstar-recsys/src/northstar/module/hstu/ops/pytorch/pt_hstu_attention.py:68
/models/dengyao/Code/northstar-recsys/src/ut/hstu/ops/test_hstu_attention.py:75
/usr/local/lib/python3.12/dist-packages/triton/compiler/code_generator.py:827
/usr/local/lib/python3.12/dist-packages/triton/compiler/code_generator.py:829
/usr/local/lib/python3.12/dist-packages/triton/compiler/code_generator.py:830
/usr/local/lib/python3.12/dist-packages/triton/compiler/code_generator.py:832
/usr/local/lib/python3.12/dist-packages/triton/compiler/code_generator.py:835
/usr/local/lib/python3.12/dist-packages/triton/compiler/code_generator.py:842
/usr/local/lib/python3.12/dist-packages/triton/compiler/code_generator.py:846
/usr/local/lib/python3.12/dist-packages/triton/compiler/code_generator.py:848
/usr/local/lib/python3.12/dist-packages/triton/compiler/code_generator.py:851
/usr/local/lib/python3.12/dist-packages/triton/language/core.py:1059
/usr/local/lib/python3.12/dist-packages/triton/language/core.py:1060
----------------------------------------------------------------------
Ran 1 test in 215.469s
FAILED (errors=1)
Metadata
Metadata
Assignees
Labels
No labels