Skip to content

Commit 4e4273c

Browse files
authored
[MLIR][NVVM] Add dot.accumulate.2way Op (#140518)
This change adds the `dot.accumulate.2way` Op to the NVVM dialect for 16-bit to 8-bit dot-product accumulate operation. PTX Spec Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-dp2a
1 parent 11a9dad commit 4e4273c

File tree

4 files changed

+123
-0
lines changed

4 files changed

+123
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3703,6 +3703,60 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
37033703
}];
37043704
}
37053705

3706+
def NVVM_DotAccumulate2WayOp : NVVM_Op<"dot.accumulate.2way"> {
3707+
let summary = "Two-way 16-bit to 8-bit dot product-accumulate instruction";
3708+
let description = [{
3709+
Performs a two-way 16-bit to 8-bit dot-product which is accumulated in a
3710+
32-bit result.
3711+
Operand `a` is a vector of two 16-bit elements and operand `b` a vector
3712+
of four 8-bit elements between which the dot product is computed.
3713+
3714+
The `a_type` and `b_type` attributes specify the type of the elements in `a`
3715+
and `b` respectively.
3716+
If `a_type` or `b_type` is `s`, then the elements in the corresponding
3717+
vector are sign-extended to 32-bit before the dot product is computed.
3718+
If `a_type` or `b_type` is `u`, then the elements in the corresponding
3719+
vector are zero-extended to 32-bit instead.
3720+
3721+
The `b_hi` boolean attribute specifies which two bytes of `b` are used for
3722+
the dot product. If `b_hi` is true, then the dot product is computed
3723+
between `a` and elements at indices 2 and 3 of `b`. If `b_hi` is false,
3724+
then the dot product is computed between `a` and elements at indices 0 and
3725+
1 of `b`.
3726+
3727+
Operand `c` is a 32-bit integer to which the result is accumulated. It is
3728+
treated as holding a signed integer if any of `a_type` or `b_type` is
3729+
signed.
3730+
3731+
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-dp2a)
3732+
}];
3733+
3734+
let arguments = (ins
3735+
VectorOfLengthAndType<[2], [I16]>:$a,
3736+
DotAccumulateTypeAttr:$a_type,
3737+
VectorOfLengthAndType<[4], [I8]>:$b,
3738+
DotAccumulateTypeAttr:$b_type,
3739+
I32:$c,
3740+
BoolAttr:$b_hi
3741+
);
3742+
3743+
let results = (outs I32:$res);
3744+
3745+
let assemblyFormat = "$a $a_type `,` $b $b_type `,` $c attr-dict `:` type($a) `,` type($b)";
3746+
3747+
let extraClassDeclaration = [{
3748+
static mlir::NVVM::IDArgPair
3749+
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3750+
llvm::IRBuilderBase &builder);
3751+
}];
3752+
3753+
string llvmBuilder = [{
3754+
auto [id, args] = NVVM::DotAccumulate2WayOp::getIntrinsicIDAndArgs(
3755+
*op, moduleTranslation, builder);
3756+
$res = createIntrinsicCall(builder, id, args);
3757+
}];
3758+
}
3759+
37063760
//===----------------------------------------------------------------------===//
37073761
// NVVM target attribute.
37083762
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1712,6 +1712,28 @@ NVVM::IDArgPair DotAccumulate4WayOp::getIntrinsicIDAndArgs(
17121712
return {ids[type], args};
17131713
}
17141714

1715+
NVVM::IDArgPair DotAccumulate2WayOp::getIntrinsicIDAndArgs(
1716+
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1717+
auto curOp = cast<NVVM::DotAccumulate2WayOp>(op);
1718+
1719+
llvm::SmallVector<llvm::Value *> args;
1720+
args.push_back(getAsPackedI32(mt.lookupValue(curOp.getA()), builder));
1721+
args.push_back(getAsPackedI32(mt.lookupValue(curOp.getB()), builder));
1722+
args.push_back(builder.getInt1(curOp.getBHi()));
1723+
args.push_back(mt.lookupValue(curOp.getC()));
1724+
1725+
bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
1726+
bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
1727+
unsigned type = (isASigned << 1) | isBSigned;
1728+
const llvm::Intrinsic::ID ids[] = {
1729+
llvm::Intrinsic::nvvm_idp2a_u_u,
1730+
llvm::Intrinsic::nvvm_idp2a_u_s,
1731+
llvm::Intrinsic::nvvm_idp2a_s_u,
1732+
llvm::Intrinsic::nvvm_idp2a_s_s,
1733+
};
1734+
return {ids[type], args};
1735+
}
1736+
17151737
//===----------------------------------------------------------------------===//
17161738
// NVVMDialect initialization, type parsing, and registration.
17171739
//===----------------------------------------------------------------------===//

mlir/test/Dialect/LLVMIR/nvvm.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,15 @@ func.func @dot_accumulate_4way(%a_vec: vector<4xi8>, %b_vec: vector<4xi8>, %c: i
587587
return
588588
}
589589

590+
// CHECK-LABEL: @dot_accumulate_2way
591+
func.func @dot_accumulate_2way(%a_vec: vector<2xi16>, %b_vec: vector<4xi8>, %c: i32) {
592+
// CHECK: nvvm.dot.accumulate.2way %{{.*}}, %{{.*}}, %{{.*}} {b_hi = false} : vector<2xi16>, vector<4xi8>
593+
%1 = nvvm.dot.accumulate.2way %a_vec <unsigned>, %b_vec <unsigned>, %c {b_hi = false}: vector<2xi16>, vector<4xi8>
594+
// CHECK: nvvm.dot.accumulate.2way %{{.*}}, %{{.*}}, %{{.*}} {b_hi = true} : vector<2xi16>, vector<4xi8>
595+
%3 = nvvm.dot.accumulate.2way %a_vec <signed>, %b_vec <signed>, %c {b_hi = true}: vector<2xi16>, vector<4xi8>
596+
return
597+
}
598+
590599
// -----
591600

592601
// Just check these don't emit errors.

mlir/test/Target/LLVMIR/nvvmir.mlir

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -866,3 +866,41 @@ llvm.func @nvvm_dot_accumulate_4way(%a: vector<4xi8>, %b: vector<4xi8>, %c: i32)
866866
%3 = nvvm.dot.accumulate.4way %a <signed>, %b <signed>, %c: vector<4xi8>, vector<4xi8>
867867
llvm.return
868868
}
869+
870+
// -----
871+
// CHECK-LABEL: @nvvm_dot_accumulate_2way
872+
llvm.func @nvvm_dot_accumulate_2way(%a: vector<2xi16>, %b: vector<4xi8>, %c: i32) {
873+
// CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
874+
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
875+
// CHECK: call i32 @llvm.nvvm.idp2a.u.u(i32 %[[a_cast]], i32 %[[b_cast]], i1 false, i32 %{{.*}})
876+
%0 = nvvm.dot.accumulate.2way %a <unsigned>, %b <unsigned>, %c {b_hi = false} : vector<2xi16>, vector<4xi8>
877+
// CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
878+
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
879+
// CHECK: call i32 @llvm.nvvm.idp2a.u.u(i32 %[[a_cast]], i32 %[[b_cast]], i1 true, i32 %{{.*}})
880+
%1 = nvvm.dot.accumulate.2way %a <unsigned>, %b <unsigned>, %c {b_hi = true}: vector<2xi16>, vector<4xi8>
881+
// CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
882+
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
883+
// CHECK: call i32 @llvm.nvvm.idp2a.s.u(i32 %[[a_cast]], i32 %[[b_cast]], i1 false, i32 %{{.*}})
884+
%2 = nvvm.dot.accumulate.2way %a <signed>, %b <unsigned>, %c {b_hi = false}: vector<2xi16>, vector<4xi8>
885+
// CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
886+
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
887+
// CHECK: call i32 @llvm.nvvm.idp2a.s.u(i32 %[[a_cast]], i32 %[[b_cast]], i1 true, i32 %{{.*}})
888+
%3 = nvvm.dot.accumulate.2way %a <signed>, %b <unsigned>, %c {b_hi = true}: vector<2xi16>, vector<4xi8>
889+
// CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
890+
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
891+
// CHECK: call i32 @llvm.nvvm.idp2a.u.s(i32 %[[a_cast]], i32 %[[b_cast]], i1 false, i32 %{{.*}})
892+
%4 = nvvm.dot.accumulate.2way %a <unsigned>, %b <signed>, %c {b_hi = false}: vector<2xi16>, vector<4xi8>
893+
// CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
894+
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
895+
// CHECK: call i32 @llvm.nvvm.idp2a.u.s(i32 %[[a_cast]], i32 %[[b_cast]], i1 true, i32 %{{.*}})
896+
%5 = nvvm.dot.accumulate.2way %a <unsigned>, %b <signed>, %c {b_hi = true}: vector<2xi16>, vector<4xi8>
897+
// CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
898+
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
899+
// CHECK: call i32 @llvm.nvvm.idp2a.s.s(i32 %[[a_cast]], i32 %[[b_cast]], i1 false, i32 %{{.*}})
900+
%6 = nvvm.dot.accumulate.2way %a <signed>, %b <signed>, %c {b_hi = false}: vector<2xi16>, vector<4xi8>
901+
// CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
902+
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
903+
// CHECK: call i32 @llvm.nvvm.idp2a.s.s(i32 %[[a_cast]], i32 %[[b_cast]], i1 true, i32 %{{.*}})
904+
%7 = nvvm.dot.accumulate.2way %a <signed>, %b <signed>, %c {b_hi = true}: vector<2xi16>, vector<4xi8>
905+
llvm.return
906+
}

0 commit comments

Comments
 (0)