@@ -340,13 +340,29 @@ def NoneType : Type<CPred<"$_self.isa<::mlir::NoneType>()">, "none type",
340
340
// Any type from the given list
341
341
class AnyTypeOf<list<Type> allowedTypes, string summary = "",
342
342
string cppClassName = "::mlir::Type"> : Type<
343
- // Satisfy any of the allowed type's condition
343
+ // Satisfy any of the allowed types' conditions.
344
344
Or<!foreach(allowedtype, allowedTypes, allowedtype.predicate)>,
345
345
!if(!eq(summary, ""),
346
346
!interleave(!foreach(t, allowedTypes, t.summary), " or "),
347
347
summary),
348
348
cppClassName>;
349
349
350
+ // A type that satisfies the constraints of all given types.
351
+ class AllOfType<list<Type> allowedTypes, string summary = "",
352
+ string cppClassName = "::mlir::Type"> : Type<
353
+ // Satisfy all of the allowedf types' conditions.
354
+ And<!foreach(allowedType, allowedTypes, allowedType.predicate)>,
355
+ !if(!eq(summary, ""),
356
+ !interleave(!foreach(t, allowedTypes, t.summary), " and "),
357
+ summary),
358
+ cppClassName>;
359
+
360
+ // A type that satisfies additional predicates.
361
+ class ConfinedType<Type type, list<Pred> predicates, string summary = "",
362
+ string cppClassName = "::mlir::Type"> : Type<
363
+ And<!listconcat([type.predicate], !foreach(pred, predicates, pred))>,
364
+ summary, cppClassName>;
365
+
350
366
// Integer types.
351
367
352
368
// Any integer type irrespective of its width and signedness semantics.
@@ -475,22 +491,21 @@ def F128 : F<128>;
475
491
def BF16 : Type<CPred<"$_self.isBF16()">, "bfloat16 type">,
476
492
BuildableType<"$_builder.getBF16Type()">;
477
493
494
+ def AnyComplex : Type<CPred<"$_self.isa<::mlir::ComplexType>()">,
495
+ "complex-type", "::mlir::ComplexType">;
496
+
478
497
class Complex<Type type>
479
- : Type<And<[
480
- CPred<"$_self.isa<::mlir::ComplexType>()">,
498
+ : ConfinedType<AnyComplex, [
481
499
SubstLeaves<"$_self",
482
500
"$_self.cast<::mlir::ComplexType>().getElementType()",
483
- type.predicate>]> ,
501
+ type.predicate>],
484
502
"complex type with " # type.summary # " elements",
485
503
"::mlir::ComplexType">,
486
504
SameBuildabilityAs<type, "::mlir::ComplexType::get($_builder.get" # type #
487
505
"Type())"> {
488
506
Type elementType = type;
489
507
}
490
508
491
- def AnyComplex : Type<CPred<"$_self.isa<::mlir::ComplexType>()">,
492
- "complex-type", "::mlir::ComplexType">;
493
-
494
509
class OpaqueType<string dialect, string name, string summary>
495
510
: Type<CPred<"isOpaqueTypeWithName($_self, \""#dialect#"\", \""#name#"\")">,
496
511
summary, "::mlir::OpaqueType">,
@@ -572,9 +587,8 @@ class VectorOfRank<list<int> allowedRanks> : Type<
572
587
// Any vector where the rank is from the given `allowedRanks` list and the type
573
588
// is from the given `allowedTypes` list
574
589
class VectorOfRankAndType<list<int> allowedRanks,
575
- list<Type> allowedTypes> : Type<
576
- And<[VectorOf<allowedTypes>.predicate,
577
- VectorOfRank<allowedRanks>.predicate]>,
590
+ list<Type> allowedTypes> : AllOfType<
591
+ [VectorOf<allowedTypes>, VectorOfRank<allowedRanks>],
578
592
VectorOf<allowedTypes>.summary # VectorOfRank<allowedRanks>.summary,
579
593
"::mlir::VectorType">;
580
594
@@ -630,28 +644,25 @@ class ScalableVectorOfLength<list<int> allowedLengths> : Type<
630
644
// `allowedLengths` list and the type is from the given `allowedTypes`
631
645
// list
632
646
class VectorOfLengthAndType<list<int> allowedLengths,
633
- list<Type> allowedTypes> : Type<
634
- And<[VectorOf<allowedTypes>.predicate,
635
- VectorOfLength<allowedLengths>.predicate]>,
647
+ list<Type> allowedTypes> : AllOfType<
648
+ [VectorOf<allowedTypes>, VectorOfLength<allowedLengths>],
636
649
VectorOf<allowedTypes>.summary # VectorOfLength<allowedLengths>.summary,
637
650
"::mlir::VectorType">;
638
651
639
652
// Any fixed-length vector where the number of elements is from the given
640
653
// `allowedLengths` list and the type is from the given `allowedTypes` list
641
654
class FixedVectorOfLengthAndType<list<int> allowedLengths,
642
- list<Type> allowedTypes> : Type<
643
- And<[FixedVectorOf<allowedTypes>.predicate,
644
- FixedVectorOfLength<allowedLengths>.predicate]>,
655
+ list<Type> allowedTypes> : AllOfType<
656
+ [FixedVectorOf<allowedTypes>, FixedVectorOfLength<allowedLengths>],
645
657
FixedVectorOf<allowedTypes>.summary #
646
658
FixedVectorOfLength<allowedLengths>.summary,
647
659
"::mlir::VectorType">;
648
660
649
661
// Any scalable vector where the number of elements is from the given
650
662
// `allowedLengths` list and the type is from the given `allowedTypes` list
651
663
class ScalableVectorOfLengthAndType<list<int> allowedLengths,
652
- list<Type> allowedTypes> : Type<
653
- And<[ScalableVectorOf<allowedTypes>.predicate,
654
- ScalableVectorOfLength<allowedLengths>.predicate]>,
664
+ list<Type> allowedTypes> : AllOfType<
665
+ [ScalableVectorOf<allowedTypes>, ScalableVectorOfLength<allowedLengths>],
655
666
ScalableVectorOf<allowedTypes>.summary #
656
667
ScalableVectorOfLength<allowedLengths>.summary,
657
668
"::mlir::VectorType">;
@@ -768,34 +779,33 @@ def F64MemRef : MemRefOf<[F64]>;
768
779
769
780
// TODO: Have an easy way to add another constraint to a type.
770
781
class MemRefRankOf<list<Type> allowedTypes, list<int> ranks> :
771
- Type<And<[ MemRefOf<allowedTypes>.predicate, HasAnyRankOfPred<ranks>]> ,
782
+ ConfinedType< MemRefOf<allowedTypes>, [ HasAnyRankOfPred<ranks>],
772
783
!interleave(!foreach(rank, ranks, rank # "D"), "/") # " " #
773
784
MemRefOf<allowedTypes>.summary,
774
785
"::mlir::MemRefType">;
775
786
776
- class StaticShapeMemRefOf<list<Type> allowedTypes>
777
- : Type<And<[ MemRefOf<allowedTypes>.predicate, HasStaticShapePred]> ,
778
- "statically shaped " # MemRefOf<allowedTypes>.summary,
779
- "::mlir::MemRefType">;
787
+ class StaticShapeMemRefOf<list<Type> allowedTypes> :
788
+ ConfinedType< MemRefOf<allowedTypes>, [ HasStaticShapePred],
789
+ "statically shaped " # MemRefOf<allowedTypes>.summary,
790
+ "::mlir::MemRefType">;
780
791
781
792
def AnyStaticShapeMemRef : StaticShapeMemRefOf<[AnyType]>;
782
793
783
794
// For a MemRefType, verify that it has strides.
784
795
def HasStridesPred : CPred<[{ isStrided($_self.cast<::mlir::MemRefType>()) }]>;
785
796
786
- class StridedMemRefOf<list<Type> allowedTypes>
787
- : Type<And<[ MemRefOf<allowedTypes>.predicate, HasStridesPred]> ,
788
- "strided " # MemRefOf<allowedTypes>.summary>;
797
+ class StridedMemRefOf<list<Type> allowedTypes> :
798
+ ConfinedType< MemRefOf<allowedTypes>, [ HasStridesPred],
799
+ "strided " # MemRefOf<allowedTypes>.summary>;
789
800
790
801
def AnyStridedMemRef : StridedMemRefOf<[AnyType]>;
791
802
792
803
class AnyStridedMemRefOfRank<int rank> :
793
- Type<And<[AnyStridedMemRef.predicate,
794
- MemRefRankOf<[AnyType], [rank]>.predicate]>,
804
+ AllOfType<[AnyStridedMemRef, MemRefRankOf<[AnyType], [rank]>],
795
805
AnyStridedMemRef.summary # " of rank " # rank>;
796
806
797
807
class StridedMemRefRankOf<list<Type> allowedTypes, list<int> ranks> :
798
- Type<And<[ MemRefOf<allowedTypes>.predicate, HasAnyRankOfPred<ranks>]> ,
808
+ ConfinedType< MemRefOf<allowedTypes>, [ HasAnyRankOfPred<ranks>],
799
809
!interleave(!foreach(rank, ranks, rank # "D"), "/") # " " #
800
810
MemRefOf<allowedTypes>.summary>;
801
811
0 commit comments