diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td index 3bf0be0a716aa..688f26211e48d 100644 --- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td +++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td @@ -420,6 +420,57 @@ def DotOp : AVX_LowOp<"dot", [Pure, }]; } +//----------------------------------------------------------------------------// +// AVX Int8 Dot +//----------------------------------------------------------------------------// + +def DotInt8Op : AVX_Op<"dot.i32", [Pure, + X86IntrinsicOpInterface, + AllTypesMatch<["a", "b"]>, + AllTypesMatch<["src", "dst"]>, + TypesMatchWith<"`a` has same elements as `src`", + "src", "a", + "VectorType::get({::llvm::cast($_self).getShape()[0]}, " + "IntegerType::get($_self.getContext(), 32))"> + ]> { + let summary = "Dot Int8 op"; + let description = [{ + The `dot` op is an AVX2-I32/I8 specific op that can lower to the proper + LLVMAVX2-INT8/32 operation `llvm.vpdpbssd` depending on the width of MLIR + vectors it is applied to. + + #### From the Intel Intrinsics Guide: + + Multiply groups of 4 adjacent pairs of signed 8-bit integers in `a` with + corresponding signed 8-bit integers in `b`, producing 4 intermediate signed 16-bit + results. Sum these 4 results with the corresponding 32-bit integer in `src`, and + store the packed 32-bit results in `dst`. + + Example: + ```mlir + %dst = x86vector.avx.dot %src, %a, %b : vector<8xi32> -> vector<8xi32> + ``` + }]; + let arguments = (ins VectorOfLengthAndType<[4, 8], [I32]>:$src, + VectorOfLengthAndType<[4, 8], [I32]>:$a, + VectorOfLengthAndType<[4, 8], [I32]>:$b + ); + let results = (outs VectorOfLengthAndType<[4, 8], [I32]>:$dst); + let assemblyFormat = + "$src `,` $a `,` $b attr-dict `:` type($a) `->` type($src)"; + + let extraClassDeclaration = [{ + std::string getIntrinsicName() { + std::string intr = "llvm.x86.avx2.vpdpbssd"; + VectorType vecType = getSrc().getType(); + unsigned elemBitWidth = vecType.getElementTypeBitWidth(); + unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth; + intr += "." + std::to_string(opBitWidth); + return intr; + } + }]; +} + //----------------------------------------------------------------------------// // AVX: Convert BF16/F16 to F32 and broadcast into packed F32 //----------------------------------------------------------------------------// diff --git a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir index 63f06624ef897..a70410497acbd 100644 --- a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir +++ b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir @@ -219,3 +219,19 @@ func.func @avx_dot(%a: vector<8xf32>, %b: vector<8xf32>) -> (vector<8xf32>) %0 = x86vector.avx.intr.dot %a, %b : vector<8xf32> return %0 : vector<8xf32> } + +// CHECK-LABEL: func @avx_dot_i32_128 +func.func @avx_dot_i32_128(%src: vector<4xi32>, %a: vector<4xi32>, + %b: vector<4xi32>) -> vector<4xi32> { + // CHECK: llvm.call_intrinsic "llvm.x86.avx2.vpdpbssd.128" + %0 = x86vector.avx.dot.i32 %src, %a, %b : vector<4xi32> -> vector<4xi32> + return %0 : vector<4xi32> +} + +// CHECK-LABEL: func @avx_dot_i32_256 +func.func @avx_dot_i32_256(%src: vector<8xi32>, %a: vector<8xi32>, + %b: vector<8xi32>) -> vector<8xi32> { + // CHECK: llvm.call_intrinsic "llvm.x86.avx2.vpdpbssd.256" + %0 = x86vector.avx.dot.i32 %src, %a, %b : vector<8xi32> -> vector<8xi32> + return %0 : vector<8xi32> +} diff --git a/mlir/test/Dialect/X86Vector/roundtrip.mlir b/mlir/test/Dialect/X86Vector/roundtrip.mlir index 7dcab3eb4dcb8..bd3509eb07b2b 100644 --- a/mlir/test/Dialect/X86Vector/roundtrip.mlir +++ b/mlir/test/Dialect/X86Vector/roundtrip.mlir @@ -229,3 +229,19 @@ func.func @avx_dot(%a: vector<8xf32>, %b: vector<8xf32>) -> (vector<8xf32>) %0 = x86vector.avx.intr.dot %a, %b : vector<8xf32> return %0 : vector<8xf32> } + +// CHECK-LABEL: func @avx_dot_i32_128 +func.func @avx_dot_i32_128(%src: vector<4xi32>, %a: vector<4xi32>, + %b: vector<4xi32>) -> vector<4xi32> { + // CHECK: x86vector.avx.dot.i32 {{.*}} : vector<4xi32> -> vector<4xi32> + %0 = x86vector.avx.dot.i32 %src, %a, %b : vector<4xi32> -> vector<4xi32> + return %0 : vector<4xi32> +} + +// CHECK-LABEL: func @avx_dot_i32_256 +func.func @avx_dot_i32_256(%src: vector<8xi32>, %a: vector<8xi32>, + %b: vector<8xi32>) -> vector<8xi32> { + // CHECK: x86vector.avx.dot.i32 {{.*}} : vector<8xi32> -> vector<8xi32> + %0 = x86vector.avx.dot.i32 %src, %a, %b : vector<8xi32> -> vector<8xi32> + return %0 : vector<8xi32> +} diff --git a/mlir/test/Target/LLVMIR/x86vector.mlir b/mlir/test/Target/LLVMIR/x86vector.mlir index d11dc89bdc7c9..ac2f3e1277df3 100644 --- a/mlir/test/Target/LLVMIR/x86vector.mlir +++ b/mlir/test/Target/LLVMIR/x86vector.mlir @@ -234,3 +234,19 @@ func.func @LLVM_x86_avx_dp_ps_256( %0 = x86vector.avx.intr.dot %a, %b : vector<8xf32> return %0 : vector<8xf32> } + +// CHECK-LABEL: define <4 x i32> @LLVM_x86_avx2_vpdpbssd_128 +func.func @LLVM_x86_avx2_vpdpbssd_128(%src: vector<4xi32>, %a: vector<4xi32>, + %b: vector<4xi32>) -> vector<4xi32> { + // CHECK: call <4 x i32> @llvm.x86.avx2.vpdpbssd.128( + %0 = x86vector.avx.dot.i32 %src, %a, %b : vector<4xi32> -> vector<4xi32> + return %0 : vector<4xi32> +} + +// CHECK-LABEL: define <8 x i32> @LLVM_x86_avx2_vpdpbssd_256 +func.func @LLVM_x86_avx2_vpdpbssd_256(%src: vector<8xi32>, %a: vector<8xi32>, + %b: vector<8xi32>) -> vector<8xi32> { + // CHECK: call <8 x i32> @llvm.x86.avx2.vpdpbssd.256( + %0 = x86vector.avx.dot.i32 %src, %a, %b : vector<8xi32> -> vector<8xi32> + return %0 : vector<8xi32> +}