diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 6895e946b8a45..9ebaac6f0fd80 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -3009,10 +3009,12 @@ def NVVM_GriddepcontrolLaunchDependentsOp // NVVM Mapa Op //===----------------------------------------------------------------------===// -def NVVM_MapaOp: NVVM_Op<"mapa", - [TypesMatchWith<"`res` and `a` should have the same type", - "a", "res", "$_self">, NVVMRequiresSM<90>]> { - let results = (outs AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$res); +def NVVM_MapaASCheck : PredOpTrait<"Valid address-space check(or mapping) for mapa Op", + Or<[InputMatchesTypes<["a", "res"], [LLVM_PointerShared, LLVM_PointerSharedCluster]>.predicate, + InputMatchesTypes<["a", "res"], [LLVM_PointerGeneric, LLVM_PointerGeneric]>.predicate]>>; + +def NVVM_MapaOp: NVVM_Op<"mapa", [NVVM_MapaASCheck, NVVMRequiresSM<90>]> { + let results = (outs AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerSharedCluster]>:$res); let arguments = (ins AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$a, I32:$b); string llvmBuilder = [{ diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 43ef28624fb19..b21603a410c0c 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -603,6 +603,22 @@ class RangedTypesMatchWith : TypesMatchWith; +// Checks that each inputArg has the same type as the corresponding entry +// in allowedTypes +class InputMatchesTypes inputArgs, list allowedTypes> : + PredOpTrait<"operands {" # !interleave(inputArgs, ", ") # "} match expected types", + !foldl(TruePred, !range(!size(inputArgs)), acc, i, + And<[acc, + SubstLeaves<"$_self", "$" # inputArgs[i] # ".getType()", + allowedTypes[i].predicate> + ]>)> { + assert !eq(!size(inputArgs), !size(allowedTypes)), + "inputArgs and allowedTypes lists must have the same length"; + + list inputArgList = inputArgs; + list allowedTypeList = allowedTypes; +} + // Type Constraint operand `idx`'s Element type is `type`. class TCopVTEtIs : And<[ CPred<"$_op.getNumOperands() > " # idx>, diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index 251ca716c7a7a..2c1c3071c456e 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -1201,8 +1201,8 @@ func.func @cp_async(%arg0: !llvm.ptr<3>, %arg1: !llvm.ptr<1>) { // ----- func.func @mapa(%a: !llvm.ptr, %b : i32) { - // expected-error @below {{`res` and `a` should have the same type}} - %0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr<3> + // expected-error @below {{'nvvm.mapa' op failed to verify that Valid address-space check(or mapping) for mapa Op}} + %0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr<7> return } diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir index c7fa41c98ac92..4349193aa1a45 100644 --- a/mlir/test/Dialect/LLVMIR/nvvm.mlir +++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir @@ -552,7 +552,7 @@ func.func @mapa(%a: !llvm.ptr, %a_shared: !llvm.ptr<3>, %b : i32) { // CHECK: nvvm.mapa %{{.*}} %0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr // CHECK: nvvm.mapa %{{.*}} - %1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<3> + %1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<7> return } diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir index f86a04186f512..c119c1a0fd21f 100644 --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -760,8 +760,8 @@ llvm.func @nvvm_griddepcontrol_launch_dependents() { llvm.func @nvvm_mapa(%a: !llvm.ptr, %a_shared: !llvm.ptr<3>, %b : i32) { // CHECK-LLVM: call ptr @llvm.nvvm.mapa(ptr %{{.*}}, i32 %{{.*}}) %0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr - // CHECK-LLVM: call ptr addrspace(3) @llvm.nvvm.mapa.shared.cluster(ptr addrspace(3) %{{.*}}, i32 %{{.*}}) - %1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<3> + // CHECK-LLVM: call ptr addrspace(7) @llvm.nvvm.mapa.shared.cluster(ptr addrspace(3) %{{.*}}, i32 %{{.*}}) + %1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<7> llvm.return }