Skip to content

Commit 088fddf

Browse files
committed
make generic helper and update usage to it
1 parent 65552af commit 088fddf

File tree

2 files changed

+21
-22
lines changed

2 files changed

+21
-22
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3009,28 +3009,11 @@ def NVVM_GriddepcontrolLaunchDependentsOp
30093009
// NVVM Mapa Op
30103010
//===----------------------------------------------------------------------===//
30113011

3012-
// Helper predicates for address space checking
3013-
def IsGenericAddressSpace : CPred<"llvm::cast<LLVM::LLVMPointerType>($_self).getAddressSpace() == 0">;
3014-
def IsSharedAddressSpace : CPred<"llvm::cast<LLVM::LLVMPointerType>($_self).getAddressSpace() == 3">;
3015-
def IsSharedClusterAddressSpace : CPred<"llvm::cast<LLVM::LLVMPointerType>($_self).getAddressSpace() == 7">;
3016-
3017-
class NVVM_MapaASCheck<string inputArg, string resultArg> :
3018-
PredOpTrait<"Valid address-space check(or mapping) for mapa Op",
3019-
Or<[
3020-
// Generic -> Generic
3021-
And<[
3022-
SubstLeaves<"$_self", "$" # inputArg # ".getType()", IsGenericAddressSpace>,
3023-
SubstLeaves<"$_self", "$" # resultArg # ".getType()", IsGenericAddressSpace>
3024-
]>,
3025-
// Shared -> SharedCluster
3026-
And<[
3027-
SubstLeaves<"$_self", "$" # inputArg # ".getType()", IsSharedAddressSpace>,
3028-
SubstLeaves<"$_self", "$" # resultArg # ".getType()", IsSharedClusterAddressSpace>
3029-
]>
3030-
]>>;
3031-
3032-
def NVVM_MapaOp: NVVM_Op<"mapa",
3033-
[NVVM_MapaASCheck<"a", "res">, NVVMRequiresSM<90>]> {
3012+
def NVVM_MapaASCheck : PredOpTrait<"Valid address-space check(or mapping) for mapa Op",
3013+
Or<[InputMatchesTypes<["a", "res"], [LLVM_PointerShared, LLVM_PointerSharedCluster]>.predicate,
3014+
InputMatchesTypes<["a", "res"], [LLVM_PointerGeneric, LLVM_PointerGeneric]>.predicate]>>;
3015+
3016+
def NVVM_MapaOp: NVVM_Op<"mapa", [NVVM_MapaASCheck, NVVMRequiresSM<90>]> {
30343017
let results = (outs AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerSharedCluster]>:$res);
30353018
let arguments = (ins AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$a, I32:$b);
30363019

mlir/include/mlir/IR/OpBase.td

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -603,6 +603,22 @@ class RangedTypesMatchWith<string summary, string lhsArg, string rhsArg,
603603
string transform>
604604
: TypesMatchWith<summary, lhsArg, rhsArg, transform, "llvm::equal">;
605605

606+
// Checks that each inputArg has the same type as the corresponding entry
607+
// in allowedTypes
608+
class InputMatchesTypes<list<string> inputArgs, list<Type> allowedTypes> :
609+
PredOpTrait<"operands {" # !interleave(inputArgs, ", ") # "} match expected types",
610+
!foldl(TruePred, !range(!size(inputArgs)), acc, i,
611+
And<[acc,
612+
SubstLeaves<"$_self", "$" # inputArgs[i] # ".getType()",
613+
allowedTypes[i].predicate>
614+
]>)> {
615+
assert !eq(!size(inputArgs), !size(allowedTypes)),
616+
"inputArgs and allowedTypes lists must have the same length";
617+
618+
list<string> inputArgList = inputArgs;
619+
list<Type> allowedTypeList = allowedTypes;
620+
}
621+
606622
// Type Constraint operand `idx`'s Element type is `type`.
607623
class TCopVTEtIs<int idx, Type type> : And<[
608624
CPred<"$_op.getNumOperands() > " # idx>,

0 commit comments

Comments
 (0)