From bf558772085f53fc3ef3a2722aefaf5fd4c4bcb1 Mon Sep 17 00:00:00 2001 From: Arun Thangamani Date: Thu, 10 Jul 2025 00:40:53 -0700 Subject: [PATCH 1/5] MLIR support for VPDPBSSD instruction through llvm instrincs --- .../mlir/Dialect/X86Vector/X86Vector.td | 53 +++++++++++++++++++ .../Dialect/X86Vector/legalize-for-llvm.mlir | 16 ++++++ mlir/test/Dialect/X86Vector/roundtrip.mlir | 16 ++++++ mlir/test/Target/LLVMIR/x86vector.mlir | 16 ++++++ 4 files changed, 101 insertions(+) diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td index 3bf0be0a716aa..c3f7904fa1249 100644 --- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td +++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td @@ -302,6 +302,8 @@ def DotBF16Op : AVX512_Op<"dot", [Pure, }]; } + + //----------------------------------------------------------------------------// // Convert packed F32 to packed BF16 //----------------------------------------------------------------------------// @@ -420,6 +422,57 @@ def DotOp : AVX_LowOp<"dot", [Pure, }]; } +//----------------------------------------------------------------------------// +// AVX Int8 Dot +//----------------------------------------------------------------------------// + +def DotInt8Op : AVX_Op<"dot.i32", [Pure, + X86IntrinsicOpInterface, + AllTypesMatch<["a", "b"]>, + AllTypesMatch<["src", "dst"]>, + TypesMatchWith<"`a` has same elements as `src`", + "src", "a", + "VectorType::get({::llvm::cast($_self).getShape()[0]}, " + "IntegerType::get($_self.getContext(), 32))"> + ]> { + let summary = "Dot Int8 op"; + let description = [{ + The `dot` op is an AVX2-I32/I8 specific op that can lower to the proper + LLVMAVX2-INT8/32 operation `llvm.vpdpbssd` depending on the width of MLIR + vectors it is applied to. + + #### From the Intel Intrinsics Guide: + + Multiply groups of 4 adjacent pairs of signed 8-bit integers in `a` with + corresponding signed 8-bit integers in `b`, producing 4 intermediate signed 16-bit + results. Sum these 4 results with the corresponding 32-bit integer in `src`, and + store the packed 32-bit results in `dst`. + + Example: + ```mlir + %dst = x86vector.avx.dot %src, %a, %b : vector<8xi32> -> vector<8xi32> + ``` + }]; + let arguments = (ins VectorOfLengthAndType<[4, 8], [I32]>:$src, + VectorOfLengthAndType<[4, 8], [I32]>:$a, + VectorOfLengthAndType<[4, 8], [I32]>:$b + ); + let results = (outs VectorOfLengthAndType<[4, 8], [I32]>:$dst); + let assemblyFormat = + "$src `,` $a `,` $b attr-dict `:` type($a) `->` type($src)"; + + let extraClassDeclaration = [{ + std::string getIntrinsicName() { + std::string intr = "llvm.x86.avx2.vpdpbssd"; + VectorType vecType = getSrc().getType(); + unsigned elemBitWidth = vecType.getElementTypeBitWidth(); + unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth; + intr += "." + std::to_string(opBitWidth); + return intr; + } + }]; +} + //----------------------------------------------------------------------------// // AVX: Convert BF16/F16 to F32 and broadcast into packed F32 //----------------------------------------------------------------------------// diff --git a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir index 63f06624ef897..a70410497acbd 100644 --- a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir +++ b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir @@ -219,3 +219,19 @@ func.func @avx_dot(%a: vector<8xf32>, %b: vector<8xf32>) -> (vector<8xf32>) %0 = x86vector.avx.intr.dot %a, %b : vector<8xf32> return %0 : vector<8xf32> } + +// CHECK-LABEL: func @avx_dot_i32_128 +func.func @avx_dot_i32_128(%src: vector<4xi32>, %a: vector<4xi32>, + %b: vector<4xi32>) -> vector<4xi32> { + // CHECK: llvm.call_intrinsic "llvm.x86.avx2.vpdpbssd.128" + %0 = x86vector.avx.dot.i32 %src, %a, %b : vector<4xi32> -> vector<4xi32> + return %0 : vector<4xi32> +} + +// CHECK-LABEL: func @avx_dot_i32_256 +func.func @avx_dot_i32_256(%src: vector<8xi32>, %a: vector<8xi32>, + %b: vector<8xi32>) -> vector<8xi32> { + // CHECK: llvm.call_intrinsic "llvm.x86.avx2.vpdpbssd.256" + %0 = x86vector.avx.dot.i32 %src, %a, %b : vector<8xi32> -> vector<8xi32> + return %0 : vector<8xi32> +} diff --git a/mlir/test/Dialect/X86Vector/roundtrip.mlir b/mlir/test/Dialect/X86Vector/roundtrip.mlir index 7dcab3eb4dcb8..bd3509eb07b2b 100644 --- a/mlir/test/Dialect/X86Vector/roundtrip.mlir +++ b/mlir/test/Dialect/X86Vector/roundtrip.mlir @@ -229,3 +229,19 @@ func.func @avx_dot(%a: vector<8xf32>, %b: vector<8xf32>) -> (vector<8xf32>) %0 = x86vector.avx.intr.dot %a, %b : vector<8xf32> return %0 : vector<8xf32> } + +// CHECK-LABEL: func @avx_dot_i32_128 +func.func @avx_dot_i32_128(%src: vector<4xi32>, %a: vector<4xi32>, + %b: vector<4xi32>) -> vector<4xi32> { + // CHECK: x86vector.avx.dot.i32 {{.*}} : vector<4xi32> -> vector<4xi32> + %0 = x86vector.avx.dot.i32 %src, %a, %b : vector<4xi32> -> vector<4xi32> + return %0 : vector<4xi32> +} + +// CHECK-LABEL: func @avx_dot_i32_256 +func.func @avx_dot_i32_256(%src: vector<8xi32>, %a: vector<8xi32>, + %b: vector<8xi32>) -> vector<8xi32> { + // CHECK: x86vector.avx.dot.i32 {{.*}} : vector<8xi32> -> vector<8xi32> + %0 = x86vector.avx.dot.i32 %src, %a, %b : vector<8xi32> -> vector<8xi32> + return %0 : vector<8xi32> +} diff --git a/mlir/test/Target/LLVMIR/x86vector.mlir b/mlir/test/Target/LLVMIR/x86vector.mlir index d11dc89bdc7c9..ac2f3e1277df3 100644 --- a/mlir/test/Target/LLVMIR/x86vector.mlir +++ b/mlir/test/Target/LLVMIR/x86vector.mlir @@ -234,3 +234,19 @@ func.func @LLVM_x86_avx_dp_ps_256( %0 = x86vector.avx.intr.dot %a, %b : vector<8xf32> return %0 : vector<8xf32> } + +// CHECK-LABEL: define <4 x i32> @LLVM_x86_avx2_vpdpbssd_128 +func.func @LLVM_x86_avx2_vpdpbssd_128(%src: vector<4xi32>, %a: vector<4xi32>, + %b: vector<4xi32>) -> vector<4xi32> { + // CHECK: call <4 x i32> @llvm.x86.avx2.vpdpbssd.128( + %0 = x86vector.avx.dot.i32 %src, %a, %b : vector<4xi32> -> vector<4xi32> + return %0 : vector<4xi32> +} + +// CHECK-LABEL: define <8 x i32> @LLVM_x86_avx2_vpdpbssd_256 +func.func @LLVM_x86_avx2_vpdpbssd_256(%src: vector<8xi32>, %a: vector<8xi32>, + %b: vector<8xi32>) -> vector<8xi32> { + // CHECK: call <8 x i32> @llvm.x86.avx2.vpdpbssd.256( + %0 = x86vector.avx.dot.i32 %src, %a, %b : vector<8xi32> -> vector<8xi32> + return %0 : vector<8xi32> +} From df153cfc5d8e9c60a5fb38b058a3bfffcce37516 Mon Sep 17 00:00:00 2001 From: Arun Thangamani Date: Thu, 10 Jul 2025 00:49:48 -0700 Subject: [PATCH 2/5] removing extra space --- mlir/include/mlir/Dialect/X86Vector/X86Vector.td | 2 -- 1 file changed, 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td index c3f7904fa1249..688f26211e48d 100644 --- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td +++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td @@ -302,8 +302,6 @@ def DotBF16Op : AVX512_Op<"dot", [Pure, }]; } - - //----------------------------------------------------------------------------// // Convert packed F32 to packed BF16 //----------------------------------------------------------------------------// From 72933a035fe3dbe25b33d2ae64ea2bbf7191ac6f Mon Sep 17 00:00:00 2001 From: Arun Thangamani Date: Fri, 11 Jul 2025 00:32:13 -0700 Subject: [PATCH 3/5] changed the logic to i8*i8=+i32 with llvm.bitcast --- .../mlir/Dialect/X86Vector/X86Vector.td | 35 +++++++++++-------- .../Dialect/X86Vector/IR/X86VectorDialect.cpp | 19 ++++++++++ .../Dialect/X86Vector/legalize-for-llvm.mlir | 16 ++++----- mlir/test/Dialect/X86Vector/roundtrip.mlir | 20 +++++------ mlir/test/Target/LLVMIR/x86vector.mlir | 16 ++++----- 5 files changed, 65 insertions(+), 41 deletions(-) diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td index 688f26211e48d..73f6877c12fab 100644 --- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td +++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td @@ -424,50 +424,55 @@ def DotOp : AVX_LowOp<"dot", [Pure, // AVX Int8 Dot //----------------------------------------------------------------------------// -def DotInt8Op : AVX_Op<"dot.i32", [Pure, +def DotInt8Op : AVX_Op<"dot.i8", [Pure, X86IntrinsicOpInterface, AllTypesMatch<["a", "b"]>, - AllTypesMatch<["src", "dst"]>, - TypesMatchWith<"`a` has same elements as `src`", - "src", "a", - "VectorType::get({::llvm::cast($_self).getShape()[0]}, " - "IntegerType::get($_self.getContext(), 32))"> + AllTypesMatch<["w", "dst"]>, + TypesMatchWith<"`a` has four times elements as `w`", + "w", "a", + "VectorType::get({::llvm::cast($_self).getShape()[0] * 4}, " + "IntegerType::get($_self.getContext(), 8))"> ]> { let summary = "Dot Int8 op"; let description = [{ - The `dot` op is an AVX2-I32/I8 specific op that can lower to the proper - LLVMAVX2-INT8/32 operation `llvm.vpdpbssd` depending on the width of MLIR + The `dot` op is an AVX2-Int8 specific op that can lower to the proper + LLVMAVX2-INT8 operation `llvm.vpdpbssd` depending on the width of MLIR vectors it is applied to. #### From the Intel Intrinsics Guide: Multiply groups of 4 adjacent pairs of signed 8-bit integers in `a` with corresponding signed 8-bit integers in `b`, producing 4 intermediate signed 16-bit - results. Sum these 4 results with the corresponding 32-bit integer in `src`, and + results. Sum these 4 results with the corresponding 32-bit integer in `w`, and store the packed 32-bit results in `dst`. Example: ```mlir - %dst = x86vector.avx.dot %src, %a, %b : vector<8xi32> -> vector<8xi32> + %dst = x86vector.avx.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32> ``` }]; - let arguments = (ins VectorOfLengthAndType<[4, 8], [I32]>:$src, - VectorOfLengthAndType<[4, 8], [I32]>:$a, - VectorOfLengthAndType<[4, 8], [I32]>:$b + let arguments = (ins VectorOfLengthAndType<[4, 8], [I32]>:$w, + VectorOfLengthAndType<[16, 32], [I8]>:$a, + VectorOfLengthAndType<[16, 32], [I8]>:$b ); let results = (outs VectorOfLengthAndType<[4, 8], [I32]>:$dst); let assemblyFormat = - "$src `,` $a `,` $b attr-dict `:` type($a) `->` type($src)"; + "$w `,` $a `,` $b attr-dict `:` type($a) `->` type($w)"; let extraClassDeclaration = [{ std::string getIntrinsicName() { std::string intr = "llvm.x86.avx2.vpdpbssd"; - VectorType vecType = getSrc().getType(); + VectorType vecType = getW().getType(); unsigned elemBitWidth = vecType.getElementTypeBitWidth(); unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth; intr += "." + std::to_string(opBitWidth); return intr; } + + SmallVector getIntrinsicOperands( + ::mlir::ArrayRef operands, + const ::mlir::LLVMTypeConverter &typeConverter, + ::mlir::RewriterBase &rewriter); }]; } diff --git a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp index cc7ab7f3f3895..7dddf55010800 100644 --- a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp +++ b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp @@ -86,6 +86,25 @@ x86vector::DotOp::getIntrinsicOperands(ArrayRef operands, return intrinsicOperands; } +SmallVector +x86vector::DotInt8Op::getIntrinsicOperands(ArrayRef operands, + const LLVMTypeConverter &typeConverter, + RewriterBase &rewriter) { + SmallVector intrinsicOprnds; + intrinsicOprnds.push_back(operands[0]); + + //Bit-cast `a` and `b` to i32 + Value bitcast_a = rewriter.create( + getLoc(), VectorType::get((getA().getType().getShape()[0]/4), rewriter.getIntegerType(32)), + operands[1]); + intrinsicOprnds.push_back(bitcast_a); + Value bitcast_b = rewriter.create( + getLoc(), VectorType::get((getA().getType().getShape()[0]/4), rewriter.getIntegerType(32)), + operands[2]); + intrinsicOprnds.push_back(bitcast_b); + return intrinsicOprnds; +} + SmallVector x86vector::BcstToPackedF32Op::getIntrinsicOperands( ArrayRef operands, const LLVMTypeConverter &typeConverter, RewriterBase &rewriter) { diff --git a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir index a70410497acbd..72dc899f4f0a6 100644 --- a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir +++ b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir @@ -220,18 +220,18 @@ func.func @avx_dot(%a: vector<8xf32>, %b: vector<8xf32>) -> (vector<8xf32>) return %0 : vector<8xf32> } -// CHECK-LABEL: func @avx_dot_i32_128 -func.func @avx_dot_i32_128(%src: vector<4xi32>, %a: vector<4xi32>, - %b: vector<4xi32>) -> vector<4xi32> { +// CHECK-LABEL: func @avx_dot_i8_128 +func.func @avx_dot_i8_128(%w: vector<4xi32>, %a: vector<16xi8>, + %b: vector<16xi8>) -> vector<4xi32> { // CHECK: llvm.call_intrinsic "llvm.x86.avx2.vpdpbssd.128" - %0 = x86vector.avx.dot.i32 %src, %a, %b : vector<4xi32> -> vector<4xi32> + %0 = x86vector.avx.dot.i8 %w, %a, %b : vector<16xi8> -> vector<4xi32> return %0 : vector<4xi32> } -// CHECK-LABEL: func @avx_dot_i32_256 -func.func @avx_dot_i32_256(%src: vector<8xi32>, %a: vector<8xi32>, - %b: vector<8xi32>) -> vector<8xi32> { +// CHECK-LABEL: func @avx_dot_i8_256 +func.func @avx_dot_i8_256(%w: vector<8xi32>, %a: vector<32xi8>, + %b: vector<32xi8>) -> vector<8xi32> { // CHECK: llvm.call_intrinsic "llvm.x86.avx2.vpdpbssd.256" - %0 = x86vector.avx.dot.i32 %src, %a, %b : vector<8xi32> -> vector<8xi32> + %0 = x86vector.avx.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32> return %0 : vector<8xi32> } diff --git a/mlir/test/Dialect/X86Vector/roundtrip.mlir b/mlir/test/Dialect/X86Vector/roundtrip.mlir index bd3509eb07b2b..959177b27c7ea 100644 --- a/mlir/test/Dialect/X86Vector/roundtrip.mlir +++ b/mlir/test/Dialect/X86Vector/roundtrip.mlir @@ -230,18 +230,18 @@ func.func @avx_dot(%a: vector<8xf32>, %b: vector<8xf32>) -> (vector<8xf32>) return %0 : vector<8xf32> } -// CHECK-LABEL: func @avx_dot_i32_128 -func.func @avx_dot_i32_128(%src: vector<4xi32>, %a: vector<4xi32>, - %b: vector<4xi32>) -> vector<4xi32> { - // CHECK: x86vector.avx.dot.i32 {{.*}} : vector<4xi32> -> vector<4xi32> - %0 = x86vector.avx.dot.i32 %src, %a, %b : vector<4xi32> -> vector<4xi32> +// CHECK-LABEL: func @avx_dot_i8_128 +func.func @avx_dot_i8_128(%w: vector<4xi32>, %a: vector<16xi8>, + %b: vector<16xi8>) -> vector<4xi32> { + // CHECK: x86vector.avx.dot.i8 {{.*}} : vector<16xi8> -> vector<4xi32> + %0 = x86vector.avx.dot.i8 %w, %a, %b : vector<16xi8> -> vector<4xi32> return %0 : vector<4xi32> } -// CHECK-LABEL: func @avx_dot_i32_256 -func.func @avx_dot_i32_256(%src: vector<8xi32>, %a: vector<8xi32>, - %b: vector<8xi32>) -> vector<8xi32> { - // CHECK: x86vector.avx.dot.i32 {{.*}} : vector<8xi32> -> vector<8xi32> - %0 = x86vector.avx.dot.i32 %src, %a, %b : vector<8xi32> -> vector<8xi32> +// CHECK-LABEL: func @avx_dot_i8_256 +func.func @avx_dot_i8_256(%w: vector<8xi32>, %a: vector<32xi8>, + %b: vector<32xi8>) -> vector<8xi32> { + // CHECK: x86vector.avx.dot.i8 {{.*}} : vector<32xi8> -> vector<8xi32> + %0 = x86vector.avx.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32> return %0 : vector<8xi32> } diff --git a/mlir/test/Target/LLVMIR/x86vector.mlir b/mlir/test/Target/LLVMIR/x86vector.mlir index ac2f3e1277df3..74ae2424964b1 100644 --- a/mlir/test/Target/LLVMIR/x86vector.mlir +++ b/mlir/test/Target/LLVMIR/x86vector.mlir @@ -236,17 +236,17 @@ func.func @LLVM_x86_avx_dp_ps_256( } // CHECK-LABEL: define <4 x i32> @LLVM_x86_avx2_vpdpbssd_128 -func.func @LLVM_x86_avx2_vpdpbssd_128(%src: vector<4xi32>, %a: vector<4xi32>, - %b: vector<4xi32>) -> vector<4xi32> { - // CHECK: call <4 x i32> @llvm.x86.avx2.vpdpbssd.128( - %0 = x86vector.avx.dot.i32 %src, %a, %b : vector<4xi32> -> vector<4xi32> +func.func @LLVM_x86_avx2_vpdpbssd_128(%w: vector<4xi32>, %a: vector<16xi8>, + %b: vector<16xi8>) -> vector<4xi32> { + // CHECK: call <4 x i32> @llvm.x86.avx2.vpdpbssd.128( + %0 = x86vector.avx.dot.i8 %w, %a, %b : vector<16xi8> -> vector<4xi32> return %0 : vector<4xi32> } // CHECK-LABEL: define <8 x i32> @LLVM_x86_avx2_vpdpbssd_256 -func.func @LLVM_x86_avx2_vpdpbssd_256(%src: vector<8xi32>, %a: vector<8xi32>, - %b: vector<8xi32>) -> vector<8xi32> { - // CHECK: call <8 x i32> @llvm.x86.avx2.vpdpbssd.256( - %0 = x86vector.avx.dot.i32 %src, %a, %b : vector<8xi32> -> vector<8xi32> +func.func @LLVM_x86_avx2_vpdpbssd_256(%w: vector<8xi32>, %a: vector<32xi8>, + %b: vector<32xi8>) -> vector<8xi32> { + // CHECK: call <8 x i32> @llvm.x86.avx2.vpdpbssd.256( + %0 = x86vector.avx.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32> return %0 : vector<8xi32> } From c902a46f59414d20a824b905f741398698378400 Mon Sep 17 00:00:00 2001 From: Arun Thangamani Date: Fri, 11 Jul 2025 00:35:02 -0700 Subject: [PATCH 4/5] clang-format on c++ file --- .../Dialect/X86Vector/IR/X86VectorDialect.cpp | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp index 7dddf55010800..64d6790306941 100644 --- a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp +++ b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp @@ -86,22 +86,25 @@ x86vector::DotOp::getIntrinsicOperands(ArrayRef operands, return intrinsicOperands; } -SmallVector -x86vector::DotInt8Op::getIntrinsicOperands(ArrayRef operands, - const LLVMTypeConverter &typeConverter, - RewriterBase &rewriter) { +SmallVector x86vector::DotInt8Op::getIntrinsicOperands( + ArrayRef operands, const LLVMTypeConverter &typeConverter, + RewriterBase &rewriter) { SmallVector intrinsicOprnds; intrinsicOprnds.push_back(operands[0]); - - //Bit-cast `a` and `b` to i32 + // Bitcast `a` and `b` to i32 Value bitcast_a = rewriter.create( - getLoc(), VectorType::get((getA().getType().getShape()[0]/4), rewriter.getIntegerType(32)), - operands[1]); + getLoc(), + VectorType::get((getA().getType().getShape()[0] / 4), + rewriter.getIntegerType(32)), + operands[1]); intrinsicOprnds.push_back(bitcast_a); Value bitcast_b = rewriter.create( - getLoc(), VectorType::get((getA().getType().getShape()[0]/4), rewriter.getIntegerType(32)), - operands[2]); + getLoc(), + VectorType::get((getB().getType().getShape()[0] / 4), + rewriter.getIntegerType(32)), + operands[2]); intrinsicOprnds.push_back(bitcast_b); + return intrinsicOprnds; } From e7ae0717a4d152127de9d139a7fa4bab0d2d4b26 Mon Sep 17 00:00:00 2001 From: Arun Thangamani Date: Fri, 11 Jul 2025 04:05:20 -0700 Subject: [PATCH 5/5] cleanup on operands in cpp --- mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp index 64d6790306941..68aea48561283 100644 --- a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp +++ b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp @@ -90,19 +90,20 @@ SmallVector x86vector::DotInt8Op::getIntrinsicOperands( ArrayRef operands, const LLVMTypeConverter &typeConverter, RewriterBase &rewriter) { SmallVector intrinsicOprnds; - intrinsicOprnds.push_back(operands[0]); + Adaptor adaptor(operands, *this); + intrinsicOprnds.push_back(adaptor.getW()); // Bitcast `a` and `b` to i32 Value bitcast_a = rewriter.create( getLoc(), VectorType::get((getA().getType().getShape()[0] / 4), rewriter.getIntegerType(32)), - operands[1]); + adaptor.getA()); intrinsicOprnds.push_back(bitcast_a); Value bitcast_b = rewriter.create( getLoc(), VectorType::get((getB().getType().getShape()[0] / 4), rewriter.getIntegerType(32)), - operands[2]); + adaptor.getB()); intrinsicOprnds.push_back(bitcast_b); return intrinsicOprnds;