diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td index 3bf0be0a716aa..73f6877c12fab 100644 --- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td +++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td @@ -420,6 +420,62 @@ def DotOp : AVX_LowOp<"dot", [Pure, }]; } +//----------------------------------------------------------------------------// +// AVX Int8 Dot +//----------------------------------------------------------------------------// + +def DotInt8Op : AVX_Op<"dot.i8", [Pure, + X86IntrinsicOpInterface, + AllTypesMatch<["a", "b"]>, + AllTypesMatch<["w", "dst"]>, + TypesMatchWith<"`a` has four times elements as `w`", + "w", "a", + "VectorType::get({::llvm::cast($_self).getShape()[0] * 4}, " + "IntegerType::get($_self.getContext(), 8))"> + ]> { + let summary = "Dot Int8 op"; + let description = [{ + The `dot` op is an AVX2-Int8 specific op that can lower to the proper + LLVMAVX2-INT8 operation `llvm.vpdpbssd` depending on the width of MLIR + vectors it is applied to. + + #### From the Intel Intrinsics Guide: + + Multiply groups of 4 adjacent pairs of signed 8-bit integers in `a` with + corresponding signed 8-bit integers in `b`, producing 4 intermediate signed 16-bit + results. Sum these 4 results with the corresponding 32-bit integer in `w`, and + store the packed 32-bit results in `dst`. + + Example: + ```mlir + %dst = x86vector.avx.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32> + ``` + }]; + let arguments = (ins VectorOfLengthAndType<[4, 8], [I32]>:$w, + VectorOfLengthAndType<[16, 32], [I8]>:$a, + VectorOfLengthAndType<[16, 32], [I8]>:$b + ); + let results = (outs VectorOfLengthAndType<[4, 8], [I32]>:$dst); + let assemblyFormat = + "$w `,` $a `,` $b attr-dict `:` type($a) `->` type($w)"; + + let extraClassDeclaration = [{ + std::string getIntrinsicName() { + std::string intr = "llvm.x86.avx2.vpdpbssd"; + VectorType vecType = getW().getType(); + unsigned elemBitWidth = vecType.getElementTypeBitWidth(); + unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth; + intr += "." + std::to_string(opBitWidth); + return intr; + } + + SmallVector getIntrinsicOperands( + ::mlir::ArrayRef operands, + const ::mlir::LLVMTypeConverter &typeConverter, + ::mlir::RewriterBase &rewriter); + }]; +} + //----------------------------------------------------------------------------// // AVX: Convert BF16/F16 to F32 and broadcast into packed F32 //----------------------------------------------------------------------------// diff --git a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp index cc7ab7f3f3895..68aea48561283 100644 --- a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp +++ b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp @@ -86,6 +86,29 @@ x86vector::DotOp::getIntrinsicOperands(ArrayRef operands, return intrinsicOperands; } +SmallVector x86vector::DotInt8Op::getIntrinsicOperands( + ArrayRef operands, const LLVMTypeConverter &typeConverter, + RewriterBase &rewriter) { + SmallVector intrinsicOprnds; + Adaptor adaptor(operands, *this); + intrinsicOprnds.push_back(adaptor.getW()); + // Bitcast `a` and `b` to i32 + Value bitcast_a = rewriter.create( + getLoc(), + VectorType::get((getA().getType().getShape()[0] / 4), + rewriter.getIntegerType(32)), + adaptor.getA()); + intrinsicOprnds.push_back(bitcast_a); + Value bitcast_b = rewriter.create( + getLoc(), + VectorType::get((getB().getType().getShape()[0] / 4), + rewriter.getIntegerType(32)), + adaptor.getB()); + intrinsicOprnds.push_back(bitcast_b); + + return intrinsicOprnds; +} + SmallVector x86vector::BcstToPackedF32Op::getIntrinsicOperands( ArrayRef operands, const LLVMTypeConverter &typeConverter, RewriterBase &rewriter) { diff --git a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir index 63f06624ef897..72dc899f4f0a6 100644 --- a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir +++ b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir @@ -219,3 +219,19 @@ func.func @avx_dot(%a: vector<8xf32>, %b: vector<8xf32>) -> (vector<8xf32>) %0 = x86vector.avx.intr.dot %a, %b : vector<8xf32> return %0 : vector<8xf32> } + +// CHECK-LABEL: func @avx_dot_i8_128 +func.func @avx_dot_i8_128(%w: vector<4xi32>, %a: vector<16xi8>, + %b: vector<16xi8>) -> vector<4xi32> { + // CHECK: llvm.call_intrinsic "llvm.x86.avx2.vpdpbssd.128" + %0 = x86vector.avx.dot.i8 %w, %a, %b : vector<16xi8> -> vector<4xi32> + return %0 : vector<4xi32> +} + +// CHECK-LABEL: func @avx_dot_i8_256 +func.func @avx_dot_i8_256(%w: vector<8xi32>, %a: vector<32xi8>, + %b: vector<32xi8>) -> vector<8xi32> { + // CHECK: llvm.call_intrinsic "llvm.x86.avx2.vpdpbssd.256" + %0 = x86vector.avx.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32> + return %0 : vector<8xi32> +} diff --git a/mlir/test/Dialect/X86Vector/roundtrip.mlir b/mlir/test/Dialect/X86Vector/roundtrip.mlir index 7dcab3eb4dcb8..959177b27c7ea 100644 --- a/mlir/test/Dialect/X86Vector/roundtrip.mlir +++ b/mlir/test/Dialect/X86Vector/roundtrip.mlir @@ -229,3 +229,19 @@ func.func @avx_dot(%a: vector<8xf32>, %b: vector<8xf32>) -> (vector<8xf32>) %0 = x86vector.avx.intr.dot %a, %b : vector<8xf32> return %0 : vector<8xf32> } + +// CHECK-LABEL: func @avx_dot_i8_128 +func.func @avx_dot_i8_128(%w: vector<4xi32>, %a: vector<16xi8>, + %b: vector<16xi8>) -> vector<4xi32> { + // CHECK: x86vector.avx.dot.i8 {{.*}} : vector<16xi8> -> vector<4xi32> + %0 = x86vector.avx.dot.i8 %w, %a, %b : vector<16xi8> -> vector<4xi32> + return %0 : vector<4xi32> +} + +// CHECK-LABEL: func @avx_dot_i8_256 +func.func @avx_dot_i8_256(%w: vector<8xi32>, %a: vector<32xi8>, + %b: vector<32xi8>) -> vector<8xi32> { + // CHECK: x86vector.avx.dot.i8 {{.*}} : vector<32xi8> -> vector<8xi32> + %0 = x86vector.avx.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32> + return %0 : vector<8xi32> +} diff --git a/mlir/test/Target/LLVMIR/x86vector.mlir b/mlir/test/Target/LLVMIR/x86vector.mlir index d11dc89bdc7c9..74ae2424964b1 100644 --- a/mlir/test/Target/LLVMIR/x86vector.mlir +++ b/mlir/test/Target/LLVMIR/x86vector.mlir @@ -234,3 +234,19 @@ func.func @LLVM_x86_avx_dp_ps_256( %0 = x86vector.avx.intr.dot %a, %b : vector<8xf32> return %0 : vector<8xf32> } + +// CHECK-LABEL: define <4 x i32> @LLVM_x86_avx2_vpdpbssd_128 +func.func @LLVM_x86_avx2_vpdpbssd_128(%w: vector<4xi32>, %a: vector<16xi8>, + %b: vector<16xi8>) -> vector<4xi32> { + // CHECK: call <4 x i32> @llvm.x86.avx2.vpdpbssd.128( + %0 = x86vector.avx.dot.i8 %w, %a, %b : vector<16xi8> -> vector<4xi32> + return %0 : vector<4xi32> +} + +// CHECK-LABEL: define <8 x i32> @LLVM_x86_avx2_vpdpbssd_256 +func.func @LLVM_x86_avx2_vpdpbssd_256(%w: vector<8xi32>, %a: vector<32xi8>, + %b: vector<32xi8>) -> vector<8xi32> { + // CHECK: call <8 x i32> @llvm.x86.avx2.vpdpbssd.256( + %0 = x86vector.avx.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32> + return %0 : vector<8xi32> +}