-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[mlir][x86vector] AVX2 i8/i32 Dot Op #147908
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir-llvm @llvm/pr-subscribers-mlir Author: None (arun-thmn) ChangesAdds AVX2 i8/i32 dot-product operation and defines lowering to LLVM intrinsics. Target assembly instruction: Full diff: https://github.com/llvm/llvm-project/pull/147908.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index 3bf0be0a716aa..688f26211e48d 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -420,6 +420,57 @@ def DotOp : AVX_LowOp<"dot", [Pure,
}];
}
+//----------------------------------------------------------------------------//
+// AVX Int8 Dot
+//----------------------------------------------------------------------------//
+
+def DotInt8Op : AVX_Op<"dot.i32", [Pure,
+ X86IntrinsicOpInterface,
+ AllTypesMatch<["a", "b"]>,
+ AllTypesMatch<["src", "dst"]>,
+ TypesMatchWith<"`a` has same elements as `src`",
+ "src", "a",
+ "VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0]}, "
+ "IntegerType::get($_self.getContext(), 32))">
+ ]> {
+ let summary = "Dot Int8 op";
+ let description = [{
+ The `dot` op is an AVX2-I32/I8 specific op that can lower to the proper
+ LLVMAVX2-INT8/32 operation `llvm.vpdpbssd` depending on the width of MLIR
+ vectors it is applied to.
+
+ #### From the Intel Intrinsics Guide:
+
+ Multiply groups of 4 adjacent pairs of signed 8-bit integers in `a` with
+ corresponding signed 8-bit integers in `b`, producing 4 intermediate signed 16-bit
+ results. Sum these 4 results with the corresponding 32-bit integer in `src`, and
+ store the packed 32-bit results in `dst`.
+
+ Example:
+ ```mlir
+ %dst = x86vector.avx.dot %src, %a, %b : vector<8xi32> -> vector<8xi32>
+ ```
+ }];
+ let arguments = (ins VectorOfLengthAndType<[4, 8], [I32]>:$src,
+ VectorOfLengthAndType<[4, 8], [I32]>:$a,
+ VectorOfLengthAndType<[4, 8], [I32]>:$b
+ );
+ let results = (outs VectorOfLengthAndType<[4, 8], [I32]>:$dst);
+ let assemblyFormat =
+ "$src `,` $a `,` $b attr-dict `:` type($a) `->` type($src)";
+
+ let extraClassDeclaration = [{
+ std::string getIntrinsicName() {
+ std::string intr = "llvm.x86.avx2.vpdpbssd";
+ VectorType vecType = getSrc().getType();
+ unsigned elemBitWidth = vecType.getElementTypeBitWidth();
+ unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
+ intr += "." + std::to_string(opBitWidth);
+ return intr;
+ }
+ }];
+}
+
//----------------------------------------------------------------------------//
// AVX: Convert BF16/F16 to F32 and broadcast into packed F32
//----------------------------------------------------------------------------//
diff --git a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
index 63f06624ef897..a70410497acbd 100644
--- a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
@@ -219,3 +219,19 @@ func.func @avx_dot(%a: vector<8xf32>, %b: vector<8xf32>) -> (vector<8xf32>)
%0 = x86vector.avx.intr.dot %a, %b : vector<8xf32>
return %0 : vector<8xf32>
}
+
+// CHECK-LABEL: func @avx_dot_i32_128
+func.func @avx_dot_i32_128(%src: vector<4xi32>, %a: vector<4xi32>,
+ %b: vector<4xi32>) -> vector<4xi32> {
+ // CHECK: llvm.call_intrinsic "llvm.x86.avx2.vpdpbssd.128"
+ %0 = x86vector.avx.dot.i32 %src, %a, %b : vector<4xi32> -> vector<4xi32>
+ return %0 : vector<4xi32>
+}
+
+// CHECK-LABEL: func @avx_dot_i32_256
+func.func @avx_dot_i32_256(%src: vector<8xi32>, %a: vector<8xi32>,
+ %b: vector<8xi32>) -> vector<8xi32> {
+ // CHECK: llvm.call_intrinsic "llvm.x86.avx2.vpdpbssd.256"
+ %0 = x86vector.avx.dot.i32 %src, %a, %b : vector<8xi32> -> vector<8xi32>
+ return %0 : vector<8xi32>
+}
diff --git a/mlir/test/Dialect/X86Vector/roundtrip.mlir b/mlir/test/Dialect/X86Vector/roundtrip.mlir
index 7dcab3eb4dcb8..bd3509eb07b2b 100644
--- a/mlir/test/Dialect/X86Vector/roundtrip.mlir
+++ b/mlir/test/Dialect/X86Vector/roundtrip.mlir
@@ -229,3 +229,19 @@ func.func @avx_dot(%a: vector<8xf32>, %b: vector<8xf32>) -> (vector<8xf32>)
%0 = x86vector.avx.intr.dot %a, %b : vector<8xf32>
return %0 : vector<8xf32>
}
+
+// CHECK-LABEL: func @avx_dot_i32_128
+func.func @avx_dot_i32_128(%src: vector<4xi32>, %a: vector<4xi32>,
+ %b: vector<4xi32>) -> vector<4xi32> {
+ // CHECK: x86vector.avx.dot.i32 {{.*}} : vector<4xi32> -> vector<4xi32>
+ %0 = x86vector.avx.dot.i32 %src, %a, %b : vector<4xi32> -> vector<4xi32>
+ return %0 : vector<4xi32>
+}
+
+// CHECK-LABEL: func @avx_dot_i32_256
+func.func @avx_dot_i32_256(%src: vector<8xi32>, %a: vector<8xi32>,
+ %b: vector<8xi32>) -> vector<8xi32> {
+ // CHECK: x86vector.avx.dot.i32 {{.*}} : vector<8xi32> -> vector<8xi32>
+ %0 = x86vector.avx.dot.i32 %src, %a, %b : vector<8xi32> -> vector<8xi32>
+ return %0 : vector<8xi32>
+}
diff --git a/mlir/test/Target/LLVMIR/x86vector.mlir b/mlir/test/Target/LLVMIR/x86vector.mlir
index d11dc89bdc7c9..ac2f3e1277df3 100644
--- a/mlir/test/Target/LLVMIR/x86vector.mlir
+++ b/mlir/test/Target/LLVMIR/x86vector.mlir
@@ -234,3 +234,19 @@ func.func @LLVM_x86_avx_dp_ps_256(
%0 = x86vector.avx.intr.dot %a, %b : vector<8xf32>
return %0 : vector<8xf32>
}
+
+// CHECK-LABEL: define <4 x i32> @LLVM_x86_avx2_vpdpbssd_128
+func.func @LLVM_x86_avx2_vpdpbssd_128(%src: vector<4xi32>, %a: vector<4xi32>,
+ %b: vector<4xi32>) -> vector<4xi32> {
+ // CHECK: call <4 x i32> @llvm.x86.avx2.vpdpbssd.128(
+ %0 = x86vector.avx.dot.i32 %src, %a, %b : vector<4xi32> -> vector<4xi32>
+ return %0 : vector<4xi32>
+}
+
+// CHECK-LABEL: define <8 x i32> @LLVM_x86_avx2_vpdpbssd_256
+func.func @LLVM_x86_avx2_vpdpbssd_256(%src: vector<8xi32>, %a: vector<8xi32>,
+ %b: vector<8xi32>) -> vector<8xi32> {
+ // CHECK: call <8 x i32> @llvm.x86.avx2.vpdpbssd.256(
+ %0 = x86vector.avx.dot.i32 %src, %a, %b : vector<8xi32> -> vector<8xi32>
+ return %0 : vector<8xi32>
+}
|
@@ -420,6 +420,57 @@ def DotOp : AVX_LowOp<"dot", [Pure, | |||
}]; | |||
} | |||
|
|||
//----------------------------------------------------------------------------// | |||
// AVX Int8 Dot |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's lean more into its intended i8 * i8 -> i32
semantics and align closer to the dot bf16.
I'd suggest renaming the op to dot.i8
and make the op inputs take appropriately sized vectors of i8
.
The inputs can be packed into i32
as a part of getIntrinsicOperands
through bitcasts to align with the underlying intrinsic API.
|
||
Example: | ||
```mlir | ||
%dst = x86vector.avx.dot %src, %a, %b : vector<8xi32> -> vector<8xi32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: incomplete op name in the example
%dst = x86vector.avx.dot %src, %a, %b : vector<8xi32> -> vector<8xi32> | ||
``` | ||
}]; | ||
let arguments = (ins VectorOfLengthAndType<[4, 8], [I32]>:$src, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's call it w
to be in line with the intrinsic docs
Adds AVX2 i8/i32 dot-product operation and defines lowering to LLVM intrinsics.
Target assembly instruction:
vpdpbssd.128/256