Skip to content

Commit 8a3481b

Browse files
author
Jeff Niu
committed
[mlir] Add AllOfType and ConfinedType constraints
`AllOfType` is a type constraint that satisfies all given type constraints and `ConfinedType` is a type that satisfies additional predicates. These shorthands simplify type constraint definition mostly by removing the need to deal with `myType.predicate` manipulation. Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D131788
1 parent f62e60f commit 8a3481b

File tree

1 file changed

+40
-30
lines changed

1 file changed

+40
-30
lines changed

mlir/include/mlir/IR/OpBase.td

Lines changed: 40 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -340,13 +340,29 @@ def NoneType : Type<CPred<"$_self.isa<::mlir::NoneType>()">, "none type",
340340
// Any type from the given list
341341
class AnyTypeOf<list<Type> allowedTypes, string summary = "",
342342
string cppClassName = "::mlir::Type"> : Type<
343-
// Satisfy any of the allowed type's condition
343+
// Satisfy any of the allowed types' conditions.
344344
Or<!foreach(allowedtype, allowedTypes, allowedtype.predicate)>,
345345
!if(!eq(summary, ""),
346346
!interleave(!foreach(t, allowedTypes, t.summary), " or "),
347347
summary),
348348
cppClassName>;
349349

350+
// A type that satisfies the constraints of all given types.
351+
class AllOfType<list<Type> allowedTypes, string summary = "",
352+
string cppClassName = "::mlir::Type"> : Type<
353+
// Satisfy all of the allowedf types' conditions.
354+
And<!foreach(allowedType, allowedTypes, allowedType.predicate)>,
355+
!if(!eq(summary, ""),
356+
!interleave(!foreach(t, allowedTypes, t.summary), " and "),
357+
summary),
358+
cppClassName>;
359+
360+
// A type that satisfies additional predicates.
361+
class ConfinedType<Type type, list<Pred> predicates, string summary = "",
362+
string cppClassName = "::mlir::Type"> : Type<
363+
And<!listconcat([type.predicate], !foreach(pred, predicates, pred))>,
364+
summary, cppClassName>;
365+
350366
// Integer types.
351367

352368
// Any integer type irrespective of its width and signedness semantics.
@@ -475,22 +491,21 @@ def F128 : F<128>;
475491
def BF16 : Type<CPred<"$_self.isBF16()">, "bfloat16 type">,
476492
BuildableType<"$_builder.getBF16Type()">;
477493

494+
def AnyComplex : Type<CPred<"$_self.isa<::mlir::ComplexType>()">,
495+
"complex-type", "::mlir::ComplexType">;
496+
478497
class Complex<Type type>
479-
: Type<And<[
480-
CPred<"$_self.isa<::mlir::ComplexType>()">,
498+
: ConfinedType<AnyComplex, [
481499
SubstLeaves<"$_self",
482500
"$_self.cast<::mlir::ComplexType>().getElementType()",
483-
type.predicate>]>,
501+
type.predicate>],
484502
"complex type with " # type.summary # " elements",
485503
"::mlir::ComplexType">,
486504
SameBuildabilityAs<type, "::mlir::ComplexType::get($_builder.get" # type #
487505
"Type())"> {
488506
Type elementType = type;
489507
}
490508

491-
def AnyComplex : Type<CPred<"$_self.isa<::mlir::ComplexType>()">,
492-
"complex-type", "::mlir::ComplexType">;
493-
494509
class OpaqueType<string dialect, string name, string summary>
495510
: Type<CPred<"isOpaqueTypeWithName($_self, \""#dialect#"\", \""#name#"\")">,
496511
summary, "::mlir::OpaqueType">,
@@ -572,9 +587,8 @@ class VectorOfRank<list<int> allowedRanks> : Type<
572587
// Any vector where the rank is from the given `allowedRanks` list and the type
573588
// is from the given `allowedTypes` list
574589
class VectorOfRankAndType<list<int> allowedRanks,
575-
list<Type> allowedTypes> : Type<
576-
And<[VectorOf<allowedTypes>.predicate,
577-
VectorOfRank<allowedRanks>.predicate]>,
590+
list<Type> allowedTypes> : AllOfType<
591+
[VectorOf<allowedTypes>, VectorOfRank<allowedRanks>],
578592
VectorOf<allowedTypes>.summary # VectorOfRank<allowedRanks>.summary,
579593
"::mlir::VectorType">;
580594

@@ -630,28 +644,25 @@ class ScalableVectorOfLength<list<int> allowedLengths> : Type<
630644
// `allowedLengths` list and the type is from the given `allowedTypes`
631645
// list
632646
class VectorOfLengthAndType<list<int> allowedLengths,
633-
list<Type> allowedTypes> : Type<
634-
And<[VectorOf<allowedTypes>.predicate,
635-
VectorOfLength<allowedLengths>.predicate]>,
647+
list<Type> allowedTypes> : AllOfType<
648+
[VectorOf<allowedTypes>, VectorOfLength<allowedLengths>],
636649
VectorOf<allowedTypes>.summary # VectorOfLength<allowedLengths>.summary,
637650
"::mlir::VectorType">;
638651

639652
// Any fixed-length vector where the number of elements is from the given
640653
// `allowedLengths` list and the type is from the given `allowedTypes` list
641654
class FixedVectorOfLengthAndType<list<int> allowedLengths,
642-
list<Type> allowedTypes> : Type<
643-
And<[FixedVectorOf<allowedTypes>.predicate,
644-
FixedVectorOfLength<allowedLengths>.predicate]>,
655+
list<Type> allowedTypes> : AllOfType<
656+
[FixedVectorOf<allowedTypes>, FixedVectorOfLength<allowedLengths>],
645657
FixedVectorOf<allowedTypes>.summary #
646658
FixedVectorOfLength<allowedLengths>.summary,
647659
"::mlir::VectorType">;
648660

649661
// Any scalable vector where the number of elements is from the given
650662
// `allowedLengths` list and the type is from the given `allowedTypes` list
651663
class ScalableVectorOfLengthAndType<list<int> allowedLengths,
652-
list<Type> allowedTypes> : Type<
653-
And<[ScalableVectorOf<allowedTypes>.predicate,
654-
ScalableVectorOfLength<allowedLengths>.predicate]>,
664+
list<Type> allowedTypes> : AllOfType<
665+
[ScalableVectorOf<allowedTypes>, ScalableVectorOfLength<allowedLengths>],
655666
ScalableVectorOf<allowedTypes>.summary #
656667
ScalableVectorOfLength<allowedLengths>.summary,
657668
"::mlir::VectorType">;
@@ -768,34 +779,33 @@ def F64MemRef : MemRefOf<[F64]>;
768779

769780
// TODO: Have an easy way to add another constraint to a type.
770781
class MemRefRankOf<list<Type> allowedTypes, list<int> ranks> :
771-
Type<And<[MemRefOf<allowedTypes>.predicate, HasAnyRankOfPred<ranks>]>,
782+
ConfinedType<MemRefOf<allowedTypes>, [HasAnyRankOfPred<ranks>],
772783
!interleave(!foreach(rank, ranks, rank # "D"), "/") # " " #
773784
MemRefOf<allowedTypes>.summary,
774785
"::mlir::MemRefType">;
775786

776-
class StaticShapeMemRefOf<list<Type> allowedTypes>
777-
: Type<And<[MemRefOf<allowedTypes>.predicate, HasStaticShapePred]>,
778-
"statically shaped " # MemRefOf<allowedTypes>.summary,
779-
"::mlir::MemRefType">;
787+
class StaticShapeMemRefOf<list<Type> allowedTypes> :
788+
ConfinedType<MemRefOf<allowedTypes>, [HasStaticShapePred],
789+
"statically shaped " # MemRefOf<allowedTypes>.summary,
790+
"::mlir::MemRefType">;
780791

781792
def AnyStaticShapeMemRef : StaticShapeMemRefOf<[AnyType]>;
782793

783794
// For a MemRefType, verify that it has strides.
784795
def HasStridesPred : CPred<[{ isStrided($_self.cast<::mlir::MemRefType>()) }]>;
785796

786-
class StridedMemRefOf<list<Type> allowedTypes>
787-
: Type<And<[MemRefOf<allowedTypes>.predicate, HasStridesPred]>,
788-
"strided " # MemRefOf<allowedTypes>.summary>;
797+
class StridedMemRefOf<list<Type> allowedTypes> :
798+
ConfinedType<MemRefOf<allowedTypes>, [HasStridesPred],
799+
"strided " # MemRefOf<allowedTypes>.summary>;
789800

790801
def AnyStridedMemRef : StridedMemRefOf<[AnyType]>;
791802

792803
class AnyStridedMemRefOfRank<int rank> :
793-
Type<And<[AnyStridedMemRef.predicate,
794-
MemRefRankOf<[AnyType], [rank]>.predicate]>,
804+
AllOfType<[AnyStridedMemRef, MemRefRankOf<[AnyType], [rank]>],
795805
AnyStridedMemRef.summary # " of rank " # rank>;
796806

797807
class StridedMemRefRankOf<list<Type> allowedTypes, list<int> ranks> :
798-
Type<And<[MemRefOf<allowedTypes>.predicate, HasAnyRankOfPred<ranks>]>,
808+
ConfinedType<MemRefOf<allowedTypes>, [HasAnyRankOfPred<ranks>],
799809
!interleave(!foreach(rank, ranks, rank # "D"), "/") # " " #
800810
MemRefOf<allowedTypes>.summary>;
801811

0 commit comments

Comments
 (0)