Skip to content

Commit c0261eb

Browse files
[mlir][IR] Improve clone function return type of shaped types
There are `clone` overloads that take a shape as a parameter. These overloads are guaranteed to return a ranked shaped type. `TensorType::clone`/`BaseMemRefType::clone` used to always return a `TensorType`/`BaseMemRefType`. The variants that take a shape parameter now return a `RankedTensorType`/`MemRefType`. Better static type information can make extra casts at the call site obsolete. E.g.: ``` {TensorType/RankedTensorType} t; t.clone({1, 2}) // now returns RankedTensorType instead of TensorType ``` Also improve documentation for `clone`. Differential Revision: https://reviews.llvm.org/D150865
1 parent 92723d5 commit c0261eb

File tree

4 files changed

+78
-13
lines changed

4 files changed

+78
-13
lines changed

mlir/include/mlir/IR/BuiltinTypeInterfaces.td

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,13 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
5959
}];
6060
let methods = [
6161
InterfaceMethod<[{
62-
Returns a clone of this type with the given shape and element
63-
type. If a shape is not provided, the current shape of the type is used.
62+
Returns a clone of this type with the given shape and element type.
63+
64+
If no shape is provided, the shape of this type is used. In that case, if
65+
this type is unranked, so is the resulting type.
66+
67+
If a shape is provided, the resulting type is always ranked, even if this
68+
type is unranked.
6469
}],
6570
"::mlir::ShapedType", "cloneWith", (ins
6671
"::std::optional<::llvm::ArrayRef<int64_t>>":$shape,
@@ -89,7 +94,7 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
8994

9095
/// Whether the given dimension size indicates a dynamic dimension.
9196
static constexpr bool isDynamic(int64_t dValue) {
92-
return dValue == kDynamic;
97+
return dValue == kDynamic;
9398
}
9499

95100
/// Whether the given shape has any size that indicates a dynamic dimension.
@@ -99,18 +104,24 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
99104

100105
/// Return the number of elements present in the given shape.
101106
static int64_t getNumElements(ArrayRef<int64_t> shape);
102-
}];
103107

104-
let extraSharedClassDeclaration = [{
105108
/// Return a clone of this type with the given new shape and element type.
109+
/// The returned type is ranked, even if this type is unranked.
106110
auto clone(::llvm::ArrayRef<int64_t> shape, Type elementType) {
107-
return $_type.cloneWith(shape, elementType);
111+
return cloneWith(shape, elementType);
108112
}
109-
/// Return a clone of this type with the given new shape.
113+
114+
/// Return a clone of this type with the given new shape. The returned type
115+
/// is ranked, even if this type is unranked.
110116
auto clone(::llvm::ArrayRef<int64_t> shape) {
111-
return $_type.cloneWith(shape, $_type.getElementType());
117+
return cloneWith(shape, getElementType());
112118
}
113-
/// Return a clone of this type with the given new element type.
119+
}];
120+
121+
let extraSharedClassDeclaration = [{
122+
/// Return a clone of this type with the given new element type. The
123+
/// returned type is ranked if and only if this type is ranked. In that
124+
/// case, the returned type has the same shape as this type.
114125
auto clone(::mlir::Type elementType) {
115126
return $_type.cloneWith(/*shape=*/std::nullopt, elementType);
116127
}

mlir/include/mlir/IR/BuiltinTypes.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ class AffineMap;
2727
class FloatType;
2828
class IndexType;
2929
class IntegerType;
30+
class MemRefType;
31+
class RankedTensorType;
3032
class StringAttr;
3133
class TypeRange;
3234

@@ -95,6 +97,17 @@ class TensorType : public Type, public ShapedType::Trait<TensorType> {
9597
TensorType cloneWith(std::optional<ArrayRef<int64_t>> shape,
9698
Type elementType) const;
9799

100+
// Make sure that base class overloads are visible.
101+
using ShapedType::Trait<TensorType>::clone;
102+
103+
/// Return a clone of this type with the given new shape and element type.
104+
/// The returned type is ranked, even if this type is unranked.
105+
RankedTensorType clone(ArrayRef<int64_t> shape, Type elementType) const;
106+
107+
/// Return a clone of this type with the given new shape. The returned type
108+
/// is ranked, even if this type is unranked.
109+
RankedTensorType clone(ArrayRef<int64_t> shape) const;
110+
98111
/// Return true if the specified element type is ok in a tensor.
99112
static bool isValidElementType(Type type);
100113

@@ -131,6 +144,17 @@ class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
131144
BaseMemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape,
132145
Type elementType) const;
133146

147+
// Make sure that base class overloads are visible.
148+
using ShapedType::Trait<BaseMemRefType>::clone;
149+
150+
/// Return a clone of this type with the given new shape and element type.
151+
/// The returned type is ranked, even if this type is unranked.
152+
MemRefType clone(ArrayRef<int64_t> shape, Type elementType) const;
153+
154+
/// Return a clone of this type with the given new shape. The returned type
155+
/// is ranked, even if this type is unranked.
156+
MemRefType clone(ArrayRef<int64_t> shape) const;
157+
134158
/// Return true if the specified element type is ok in a memref.
135159
static bool isValidElementType(Type type);
136160

mlir/include/mlir/IR/BuiltinTypes.td

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,7 @@ def Builtin_MemRef : Builtin_Type<"MemRef", [
629629
"unsigned":$memorySpaceInd)>
630630
];
631631
let extraClassDeclaration = [{
632-
using ShapedType::Trait<MemRefType>::clone;
632+
using BaseMemRefType::clone;
633633
using ShapedType::Trait<MemRefType>::getElementTypeBitWidth;
634634
using ShapedType::Trait<MemRefType>::getRank;
635635
using ShapedType::Trait<MemRefType>::getNumElements;
@@ -794,7 +794,7 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", [
794794
}]>
795795
];
796796
let extraClassDeclaration = [{
797-
using ShapedType::Trait<RankedTensorType>::clone;
797+
using TensorType::clone;
798798
using ShapedType::Trait<RankedTensorType>::getElementTypeBitWidth;
799799
using ShapedType::Trait<RankedTensorType>::getRank;
800800
using ShapedType::Trait<RankedTensorType>::getNumElements;
@@ -807,6 +807,12 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", [
807807
/// This is a builder type that keeps local references to arguments.
808808
/// Arguments that are passed into the builder must outlive the builder.
809809
class Builder;
810+
811+
/// Return a clone of this type with the given new element type and the same
812+
/// shape as this type.
813+
RankedTensorType clone(::mlir::Type elementType) {
814+
return ::llvm::cast<RankedTensorType>(cloneWith(getShape(), elementType));
815+
}
810816
}];
811817
let skipDefaultBuilders = 1;
812818
let genVerifyDecl = 1;
@@ -931,7 +937,7 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", [
931937
}]>
932938
];
933939
let extraClassDeclaration = [{
934-
using ShapedType::Trait<UnrankedMemRefType>::clone;
940+
using BaseMemRefType::clone;
935941
using ShapedType::Trait<UnrankedMemRefType>::getElementTypeBitWidth;
936942
using ShapedType::Trait<UnrankedMemRefType>::getRank;
937943
using ShapedType::Trait<UnrankedMemRefType>::getNumElements;
@@ -946,6 +952,12 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", [
946952
/// [deprecated] Returns the memory space in old raw integer representation.
947953
/// New `Attribute getMemorySpace()` method should be used instead.
948954
unsigned getMemorySpaceAsInt() const;
955+
956+
/// Return a clone of this type with the given new element type and the same
957+
/// shape as this type.
958+
MemRefType clone(::mlir::Type elementType) {
959+
return ::llvm::cast<MemRefType>(cloneWith(getShape(), elementType));
960+
}
949961
}];
950962
let skipDefaultBuilders = 1;
951963
let genVerifyDecl = 1;
@@ -984,7 +996,7 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", [
984996
}]>
985997
];
986998
let extraClassDeclaration = [{
987-
using ShapedType::Trait<UnrankedTensorType>::clone;
999+
using TensorType::clone;
9881000
using ShapedType::Trait<UnrankedTensorType>::getElementTypeBitWidth;
9891001
using ShapedType::Trait<UnrankedTensorType>::getRank;
9901002
using ShapedType::Trait<UnrankedTensorType>::getNumElements;

mlir/lib/IR/BuiltinTypes.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,15 @@ TensorType TensorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
291291
rankedTy.getEncoding());
292292
}
293293

294+
RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape,
295+
Type elementType) const {
296+
return ::llvm::cast<RankedTensorType>(cloneWith(shape, elementType));
297+
}
298+
299+
RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape) const {
300+
return ::llvm::cast<RankedTensorType>(cloneWith(shape, getElementType()));
301+
}
302+
294303
// Check if "elementType" can be an element type of a tensor.
295304
static LogicalResult
296305
checkTensorElementType(function_ref<InFlightDiagnostic()> emitError,
@@ -370,6 +379,15 @@ BaseMemRefType BaseMemRefType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
370379
return builder;
371380
}
372381

382+
MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape,
383+
Type elementType) const {
384+
return ::llvm::cast<MemRefType>(cloneWith(shape, elementType));
385+
}
386+
387+
MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape) const {
388+
return ::llvm::cast<MemRefType>(cloneWith(shape, getElementType()));
389+
}
390+
373391
Attribute BaseMemRefType::getMemorySpace() const {
374392
if (auto rankedMemRefTy = dyn_cast<MemRefType>())
375393
return rankedMemRefTy.getMemorySpace();

0 commit comments

Comments
 (0)