Skip to content

[mlir][x86vector] AVX2 i8/i32 Dot Op #147908

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

Conversation

arun-thmn
Copy link
Contributor

Adds AVX2 i8/i32 dot-product operation and defines lowering to LLVM intrinsics.

Target assembly instruction: vpdpbssd.128/256

@llvmbot
Copy link
Member

llvmbot commented Jul 10, 2025

@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: None (arun-thmn)

Changes

Adds AVX2 i8/i32 dot-product operation and defines lowering to LLVM intrinsics.

Target assembly instruction: vpdpbssd.128/256


Full diff: https://github.com/llvm/llvm-project/pull/147908.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/X86Vector/X86Vector.td (+51)
  • (modified) mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir (+16)
  • (modified) mlir/test/Dialect/X86Vector/roundtrip.mlir (+16)
  • (modified) mlir/test/Target/LLVMIR/x86vector.mlir (+16)
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index 3bf0be0a716aa..688f26211e48d 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -420,6 +420,57 @@ def DotOp : AVX_LowOp<"dot", [Pure,
   }];
 }
 
+//----------------------------------------------------------------------------//
+// AVX Int8 Dot
+//----------------------------------------------------------------------------//
+
+def DotInt8Op : AVX_Op<"dot.i32", [Pure,
+    X86IntrinsicOpInterface,
+    AllTypesMatch<["a", "b"]>,
+    AllTypesMatch<["src", "dst"]>,
+    TypesMatchWith<"`a` has same elements as `src`",
+                   "src", "a",
+                   "VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0]}, "
+                   "IntegerType::get($_self.getContext(), 32))">
+  ]> {
+  let summary = "Dot Int8 op";
+  let description = [{
+    The `dot` op is an AVX2-I32/I8 specific op that can lower to the proper
+    LLVMAVX2-INT8/32 operation `llvm.vpdpbssd` depending on the width of MLIR
+    vectors it is applied to.
+
+    #### From the Intel Intrinsics Guide:
+
+    Multiply groups of 4 adjacent pairs of signed 8-bit integers in `a` with 
+    corresponding signed 8-bit integers in `b`, producing 4 intermediate signed 16-bit 
+    results. Sum these 4 results with the corresponding 32-bit integer in `src`, and 
+    store the packed 32-bit results in `dst`.
+
+    Example:
+    ```mlir
+    %dst = x86vector.avx.dot %src, %a, %b : vector<8xi32> -> vector<8xi32>
+    ```
+  }];
+  let arguments = (ins VectorOfLengthAndType<[4, 8], [I32]>:$src,
+                   VectorOfLengthAndType<[4, 8], [I32]>:$a,
+                   VectorOfLengthAndType<[4, 8], [I32]>:$b
+                   );
+  let results = (outs VectorOfLengthAndType<[4, 8], [I32]>:$dst);
+  let assemblyFormat =
+    "$src `,` $a `,` $b attr-dict `:` type($a) `->` type($src)";
+
+  let extraClassDeclaration = [{
+    std::string getIntrinsicName() {
+      std::string intr = "llvm.x86.avx2.vpdpbssd";
+      VectorType vecType = getSrc().getType();
+      unsigned elemBitWidth = vecType.getElementTypeBitWidth();
+      unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
+      intr += "." + std::to_string(opBitWidth);
+      return intr;
+    }
+  }];
+}
+
 //----------------------------------------------------------------------------//
 // AVX: Convert BF16/F16 to F32 and broadcast into packed F32
 //----------------------------------------------------------------------------//
diff --git a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
index 63f06624ef897..a70410497acbd 100644
--- a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
@@ -219,3 +219,19 @@ func.func @avx_dot(%a: vector<8xf32>, %b: vector<8xf32>) -> (vector<8xf32>)
   %0 = x86vector.avx.intr.dot %a, %b : vector<8xf32>
   return %0 : vector<8xf32>
 }
+
+// CHECK-LABEL: func @avx_dot_i32_128
+func.func @avx_dot_i32_128(%src: vector<4xi32>, %a: vector<4xi32>,
+    %b: vector<4xi32>) -> vector<4xi32> {
+  // CHECK: llvm.call_intrinsic "llvm.x86.avx2.vpdpbssd.128"
+  %0 = x86vector.avx.dot.i32 %src, %a, %b : vector<4xi32> -> vector<4xi32>
+  return %0 : vector<4xi32>
+}
+
+// CHECK-LABEL: func @avx_dot_i32_256
+func.func @avx_dot_i32_256(%src: vector<8xi32>, %a: vector<8xi32>,
+    %b: vector<8xi32>) -> vector<8xi32> {
+  // CHECK: llvm.call_intrinsic "llvm.x86.avx2.vpdpbssd.256"
+  %0 = x86vector.avx.dot.i32 %src, %a, %b : vector<8xi32> -> vector<8xi32>
+  return %0 : vector<8xi32>
+}
diff --git a/mlir/test/Dialect/X86Vector/roundtrip.mlir b/mlir/test/Dialect/X86Vector/roundtrip.mlir
index 7dcab3eb4dcb8..bd3509eb07b2b 100644
--- a/mlir/test/Dialect/X86Vector/roundtrip.mlir
+++ b/mlir/test/Dialect/X86Vector/roundtrip.mlir
@@ -229,3 +229,19 @@ func.func @avx_dot(%a: vector<8xf32>, %b: vector<8xf32>) -> (vector<8xf32>)
   %0 = x86vector.avx.intr.dot %a, %b : vector<8xf32>
   return %0 : vector<8xf32>
 }
+
+// CHECK-LABEL: func @avx_dot_i32_128
+func.func @avx_dot_i32_128(%src: vector<4xi32>, %a: vector<4xi32>,
+    %b: vector<4xi32>) -> vector<4xi32> {
+  // CHECK: x86vector.avx.dot.i32 {{.*}} : vector<4xi32> -> vector<4xi32>
+  %0 = x86vector.avx.dot.i32 %src, %a, %b : vector<4xi32> -> vector<4xi32>
+  return %0 : vector<4xi32>
+}
+
+// CHECK-LABEL: func @avx_dot_i32_256
+func.func @avx_dot_i32_256(%src: vector<8xi32>, %a: vector<8xi32>,
+    %b: vector<8xi32>) -> vector<8xi32> {
+  // CHECK: x86vector.avx.dot.i32 {{.*}} : vector<8xi32> -> vector<8xi32>
+  %0 = x86vector.avx.dot.i32 %src, %a, %b : vector<8xi32> -> vector<8xi32>
+  return %0 : vector<8xi32>
+}
diff --git a/mlir/test/Target/LLVMIR/x86vector.mlir b/mlir/test/Target/LLVMIR/x86vector.mlir
index d11dc89bdc7c9..ac2f3e1277df3 100644
--- a/mlir/test/Target/LLVMIR/x86vector.mlir
+++ b/mlir/test/Target/LLVMIR/x86vector.mlir
@@ -234,3 +234,19 @@ func.func @LLVM_x86_avx_dp_ps_256(
   %0 = x86vector.avx.intr.dot %a, %b : vector<8xf32>
   return %0 : vector<8xf32>
 }
+
+// CHECK-LABEL: define <4 x i32> @LLVM_x86_avx2_vpdpbssd_128
+func.func @LLVM_x86_avx2_vpdpbssd_128(%src: vector<4xi32>, %a: vector<4xi32>,
+    %b: vector<4xi32>) -> vector<4xi32> {
+    // CHECK: call <4 x i32> @llvm.x86.avx2.vpdpbssd.128(
+  %0 = x86vector.avx.dot.i32 %src, %a, %b : vector<4xi32> -> vector<4xi32>
+  return %0 : vector<4xi32>
+}
+
+// CHECK-LABEL: define <8 x i32> @LLVM_x86_avx2_vpdpbssd_256
+func.func @LLVM_x86_avx2_vpdpbssd_256(%src: vector<8xi32>, %a: vector<8xi32>,
+    %b: vector<8xi32>) -> vector<8xi32> {
+    // CHECK: call <8 x i32> @llvm.x86.avx2.vpdpbssd.256(
+  %0 = x86vector.avx.dot.i32 %src, %a, %b : vector<8xi32> -> vector<8xi32>
+  return %0 : vector<8xi32>
+}

@@ -420,6 +420,57 @@ def DotOp : AVX_LowOp<"dot", [Pure,
}];
}

//----------------------------------------------------------------------------//
// AVX Int8 Dot
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's lean more into its intended i8 * i8 -> i32 semantics and align closer to the dot bf16.
I'd suggest renaming the op to dot.i8 and make the op inputs take appropriately sized vectors of i8.

The inputs can be packed into i32 as a part of getIntrinsicOperands through bitcasts to align with the underlying intrinsic API.


Example:
```mlir
%dst = x86vector.avx.dot %src, %a, %b : vector<8xi32> -> vector<8xi32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: incomplete op name in the example

%dst = x86vector.avx.dot %src, %a, %b : vector<8xi32> -> vector<8xi32>
```
}];
let arguments = (ins VectorOfLengthAndType<[4, 8], [I32]>:$src,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's call it w to be in line with the intrinsic docs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants