Skip to content

[mlir][x86vector] AVX2 i8/i32 Dot Op #147908

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions mlir/include/mlir/Dialect/X86Vector/X86Vector.td
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,57 @@ def DotOp : AVX_LowOp<"dot", [Pure,
}];
}

//----------------------------------------------------------------------------//
// AVX Int8 Dot
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's lean more into its intended i8 * i8 -> i32 semantics and align closer to the dot bf16.
I'd suggest renaming the op to dot.i8 and make the op inputs take appropriately sized vectors of i8.

The inputs can be packed into i32 as a part of getIntrinsicOperands through bitcasts to align with the underlying intrinsic API.

//----------------------------------------------------------------------------//

def DotInt8Op : AVX_Op<"dot.i32", [Pure,
X86IntrinsicOpInterface,
AllTypesMatch<["a", "b"]>,
AllTypesMatch<["src", "dst"]>,
TypesMatchWith<"`a` has same elements as `src`",
"src", "a",
"VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0]}, "
"IntegerType::get($_self.getContext(), 32))">
]> {
let summary = "Dot Int8 op";
let description = [{
The `dot` op is an AVX2-I32/I8 specific op that can lower to the proper
LLVMAVX2-INT8/32 operation `llvm.vpdpbssd` depending on the width of MLIR
vectors it is applied to.

#### From the Intel Intrinsics Guide:

Multiply groups of 4 adjacent pairs of signed 8-bit integers in `a` with
corresponding signed 8-bit integers in `b`, producing 4 intermediate signed 16-bit
results. Sum these 4 results with the corresponding 32-bit integer in `src`, and
store the packed 32-bit results in `dst`.

Example:
```mlir
%dst = x86vector.avx.dot %src, %a, %b : vector<8xi32> -> vector<8xi32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: incomplete op name in the example

```
}];
let arguments = (ins VectorOfLengthAndType<[4, 8], [I32]>:$src,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's call it w to be in line with the intrinsic docs

VectorOfLengthAndType<[4, 8], [I32]>:$a,
VectorOfLengthAndType<[4, 8], [I32]>:$b
);
let results = (outs VectorOfLengthAndType<[4, 8], [I32]>:$dst);
let assemblyFormat =
"$src `,` $a `,` $b attr-dict `:` type($a) `->` type($src)";

let extraClassDeclaration = [{
std::string getIntrinsicName() {
std::string intr = "llvm.x86.avx2.vpdpbssd";
VectorType vecType = getSrc().getType();
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
intr += "." + std::to_string(opBitWidth);
return intr;
}
}];
}

//----------------------------------------------------------------------------//
// AVX: Convert BF16/F16 to F32 and broadcast into packed F32
//----------------------------------------------------------------------------//
Expand Down
16 changes: 16 additions & 0 deletions mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,19 @@ func.func @avx_dot(%a: vector<8xf32>, %b: vector<8xf32>) -> (vector<8xf32>)
%0 = x86vector.avx.intr.dot %a, %b : vector<8xf32>
return %0 : vector<8xf32>
}

// CHECK-LABEL: func @avx_dot_i32_128
func.func @avx_dot_i32_128(%src: vector<4xi32>, %a: vector<4xi32>,
%b: vector<4xi32>) -> vector<4xi32> {
// CHECK: llvm.call_intrinsic "llvm.x86.avx2.vpdpbssd.128"
%0 = x86vector.avx.dot.i32 %src, %a, %b : vector<4xi32> -> vector<4xi32>
return %0 : vector<4xi32>
}

// CHECK-LABEL: func @avx_dot_i32_256
func.func @avx_dot_i32_256(%src: vector<8xi32>, %a: vector<8xi32>,
%b: vector<8xi32>) -> vector<8xi32> {
// CHECK: llvm.call_intrinsic "llvm.x86.avx2.vpdpbssd.256"
%0 = x86vector.avx.dot.i32 %src, %a, %b : vector<8xi32> -> vector<8xi32>
return %0 : vector<8xi32>
}
16 changes: 16 additions & 0 deletions mlir/test/Dialect/X86Vector/roundtrip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -229,3 +229,19 @@ func.func @avx_dot(%a: vector<8xf32>, %b: vector<8xf32>) -> (vector<8xf32>)
%0 = x86vector.avx.intr.dot %a, %b : vector<8xf32>
return %0 : vector<8xf32>
}

// CHECK-LABEL: func @avx_dot_i32_128
func.func @avx_dot_i32_128(%src: vector<4xi32>, %a: vector<4xi32>,
%b: vector<4xi32>) -> vector<4xi32> {
// CHECK: x86vector.avx.dot.i32 {{.*}} : vector<4xi32> -> vector<4xi32>
%0 = x86vector.avx.dot.i32 %src, %a, %b : vector<4xi32> -> vector<4xi32>
return %0 : vector<4xi32>
}

// CHECK-LABEL: func @avx_dot_i32_256
func.func @avx_dot_i32_256(%src: vector<8xi32>, %a: vector<8xi32>,
%b: vector<8xi32>) -> vector<8xi32> {
// CHECK: x86vector.avx.dot.i32 {{.*}} : vector<8xi32> -> vector<8xi32>
%0 = x86vector.avx.dot.i32 %src, %a, %b : vector<8xi32> -> vector<8xi32>
return %0 : vector<8xi32>
}
16 changes: 16 additions & 0 deletions mlir/test/Target/LLVMIR/x86vector.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,19 @@ func.func @LLVM_x86_avx_dp_ps_256(
%0 = x86vector.avx.intr.dot %a, %b : vector<8xf32>
return %0 : vector<8xf32>
}

// CHECK-LABEL: define <4 x i32> @LLVM_x86_avx2_vpdpbssd_128
func.func @LLVM_x86_avx2_vpdpbssd_128(%src: vector<4xi32>, %a: vector<4xi32>,
%b: vector<4xi32>) -> vector<4xi32> {
// CHECK: call <4 x i32> @llvm.x86.avx2.vpdpbssd.128(
%0 = x86vector.avx.dot.i32 %src, %a, %b : vector<4xi32> -> vector<4xi32>
return %0 : vector<4xi32>
}

// CHECK-LABEL: define <8 x i32> @LLVM_x86_avx2_vpdpbssd_256
func.func @LLVM_x86_avx2_vpdpbssd_256(%src: vector<8xi32>, %a: vector<8xi32>,
%b: vector<8xi32>) -> vector<8xi32> {
// CHECK: call <8 x i32> @llvm.x86.avx2.vpdpbssd.256(
%0 = x86vector.avx.dot.i32 %src, %a, %b : vector<8xi32> -> vector<8xi32>
return %0 : vector<8xi32>
}
Loading