Skip to content

Commit 587ba75

Browse files
authored
[mlir][x86vector] AVX2 I8 Dot Op (#147908)
Adds AVX2 i8 dot-product operation and defines lowering to LLVM intrinsics. Target assembly instruction: `vpdpbssd.128/256`
1 parent 6630cde commit 587ba75

File tree

5 files changed

+127
-0
lines changed

5 files changed

+127
-0
lines changed

mlir/include/mlir/Dialect/X86Vector/X86Vector.td

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,62 @@ def DotOp : AVX_LowOp<"dot", [Pure,
420420
}];
421421
}
422422

423+
//----------------------------------------------------------------------------//
424+
// AVX Int8 Dot
425+
//----------------------------------------------------------------------------//
426+
427+
def DotInt8Op : AVX_Op<"dot.i8", [Pure,
428+
X86IntrinsicOpInterface,
429+
AllTypesMatch<["a", "b"]>,
430+
AllTypesMatch<["w", "dst"]>,
431+
TypesMatchWith<"`a` has four times elements as `w`",
432+
"w", "a",
433+
"VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0] * 4}, "
434+
"IntegerType::get($_self.getContext(), 8))">
435+
]> {
436+
let summary = "Dot Int8 op";
437+
let description = [{
438+
The `dot` op is an AVX2-Int8 specific op that can lower to the proper
439+
LLVMAVX2-INT8 operation `llvm.vpdpbssd` depending on the width of MLIR
440+
vectors it is applied to.
441+
442+
#### From the Intel Intrinsics Guide:
443+
444+
Multiply groups of 4 adjacent pairs of signed 8-bit integers in `a` with
445+
corresponding signed 8-bit integers in `b`, producing 4 intermediate signed 16-bit
446+
results. Sum these 4 results with the corresponding 32-bit integer in `w`, and
447+
store the packed 32-bit results in `dst`.
448+
449+
Example:
450+
```mlir
451+
%dst = x86vector.avx.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32>
452+
```
453+
}];
454+
let arguments = (ins VectorOfLengthAndType<[4, 8], [I32]>:$w,
455+
VectorOfLengthAndType<[16, 32], [I8]>:$a,
456+
VectorOfLengthAndType<[16, 32], [I8]>:$b
457+
);
458+
let results = (outs VectorOfLengthAndType<[4, 8], [I32]>:$dst);
459+
let assemblyFormat =
460+
"$w `,` $a `,` $b attr-dict `:` type($a) `->` type($w)";
461+
462+
let extraClassDeclaration = [{
463+
std::string getIntrinsicName() {
464+
std::string intr = "llvm.x86.avx2.vpdpbssd";
465+
VectorType vecType = getW().getType();
466+
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
467+
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
468+
intr += "." + std::to_string(opBitWidth);
469+
return intr;
470+
}
471+
472+
SmallVector<Value> getIntrinsicOperands(
473+
::mlir::ArrayRef<Value> operands,
474+
const ::mlir::LLVMTypeConverter &typeConverter,
475+
::mlir::RewriterBase &rewriter);
476+
}];
477+
}
478+
423479
//----------------------------------------------------------------------------//
424480
// AVX: Convert BF16/F16 to F32 and broadcast into packed F32
425481
//----------------------------------------------------------------------------//

mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,29 @@ x86vector::DotOp::getIntrinsicOperands(ArrayRef<Value> operands,
8686
return intrinsicOperands;
8787
}
8888

89+
SmallVector<Value> x86vector::DotInt8Op::getIntrinsicOperands(
90+
ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
91+
RewriterBase &rewriter) {
92+
SmallVector<Value, 3> intrinsicOprnds;
93+
Adaptor adaptor(operands, *this);
94+
intrinsicOprnds.push_back(adaptor.getW());
95+
// Bitcast `a` and `b` to i32
96+
Value bitcast_a = rewriter.create<LLVM::BitcastOp>(
97+
getLoc(),
98+
VectorType::get((getA().getType().getShape()[0] / 4),
99+
rewriter.getIntegerType(32)),
100+
adaptor.getA());
101+
intrinsicOprnds.push_back(bitcast_a);
102+
Value bitcast_b = rewriter.create<LLVM::BitcastOp>(
103+
getLoc(),
104+
VectorType::get((getB().getType().getShape()[0] / 4),
105+
rewriter.getIntegerType(32)),
106+
adaptor.getB());
107+
intrinsicOprnds.push_back(bitcast_b);
108+
109+
return intrinsicOprnds;
110+
}
111+
89112
SmallVector<Value> x86vector::BcstToPackedF32Op::getIntrinsicOperands(
90113
ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
91114
RewriterBase &rewriter) {

mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,3 +219,19 @@ func.func @avx_dot(%a: vector<8xf32>, %b: vector<8xf32>) -> (vector<8xf32>)
219219
%0 = x86vector.avx.intr.dot %a, %b : vector<8xf32>
220220
return %0 : vector<8xf32>
221221
}
222+
223+
// CHECK-LABEL: func @avx_dot_i8_128
224+
func.func @avx_dot_i8_128(%w: vector<4xi32>, %a: vector<16xi8>,
225+
%b: vector<16xi8>) -> vector<4xi32> {
226+
// CHECK: llvm.call_intrinsic "llvm.x86.avx2.vpdpbssd.128"
227+
%0 = x86vector.avx.dot.i8 %w, %a, %b : vector<16xi8> -> vector<4xi32>
228+
return %0 : vector<4xi32>
229+
}
230+
231+
// CHECK-LABEL: func @avx_dot_i8_256
232+
func.func @avx_dot_i8_256(%w: vector<8xi32>, %a: vector<32xi8>,
233+
%b: vector<32xi8>) -> vector<8xi32> {
234+
// CHECK: llvm.call_intrinsic "llvm.x86.avx2.vpdpbssd.256"
235+
%0 = x86vector.avx.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32>
236+
return %0 : vector<8xi32>
237+
}

mlir/test/Dialect/X86Vector/roundtrip.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,3 +229,19 @@ func.func @avx_dot(%a: vector<8xf32>, %b: vector<8xf32>) -> (vector<8xf32>)
229229
%0 = x86vector.avx.intr.dot %a, %b : vector<8xf32>
230230
return %0 : vector<8xf32>
231231
}
232+
233+
// CHECK-LABEL: func @avx_dot_i8_128
234+
func.func @avx_dot_i8_128(%w: vector<4xi32>, %a: vector<16xi8>,
235+
%b: vector<16xi8>) -> vector<4xi32> {
236+
// CHECK: x86vector.avx.dot.i8 {{.*}} : vector<16xi8> -> vector<4xi32>
237+
%0 = x86vector.avx.dot.i8 %w, %a, %b : vector<16xi8> -> vector<4xi32>
238+
return %0 : vector<4xi32>
239+
}
240+
241+
// CHECK-LABEL: func @avx_dot_i8_256
242+
func.func @avx_dot_i8_256(%w: vector<8xi32>, %a: vector<32xi8>,
243+
%b: vector<32xi8>) -> vector<8xi32> {
244+
// CHECK: x86vector.avx.dot.i8 {{.*}} : vector<32xi8> -> vector<8xi32>
245+
%0 = x86vector.avx.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32>
246+
return %0 : vector<8xi32>
247+
}

mlir/test/Target/LLVMIR/x86vector.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,3 +234,19 @@ func.func @LLVM_x86_avx_dp_ps_256(
234234
%0 = x86vector.avx.intr.dot %a, %b : vector<8xf32>
235235
return %0 : vector<8xf32>
236236
}
237+
238+
// CHECK-LABEL: define <4 x i32> @LLVM_x86_avx2_vpdpbssd_128
239+
func.func @LLVM_x86_avx2_vpdpbssd_128(%w: vector<4xi32>, %a: vector<16xi8>,
240+
%b: vector<16xi8>) -> vector<4xi32> {
241+
// CHECK: call <4 x i32> @llvm.x86.avx2.vpdpbssd.128(
242+
%0 = x86vector.avx.dot.i8 %w, %a, %b : vector<16xi8> -> vector<4xi32>
243+
return %0 : vector<4xi32>
244+
}
245+
246+
// CHECK-LABEL: define <8 x i32> @LLVM_x86_avx2_vpdpbssd_256
247+
func.func @LLVM_x86_avx2_vpdpbssd_256(%w: vector<8xi32>, %a: vector<32xi8>,
248+
%b: vector<32xi8>) -> vector<8xi32> {
249+
// CHECK: call <8 x i32> @llvm.x86.avx2.vpdpbssd.256(
250+
%0 = x86vector.avx.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32>
251+
return %0 : vector<8xi32>
252+
}

0 commit comments

Comments
 (0)