Skip to content

Commit 21b0eff

Browse files
committed
[mlir][shape] Add shape.from_extents.
Summary: This is a basic op needed for creating shapes from SSA values representing the extents. Differential Revision: https://reviews.llvm.org/D79833
1 parent 47650dc commit 21b0eff

File tree

3 files changed

+66
-0
lines changed

3 files changed

+66
-0
lines changed

mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,30 @@ def Shape_ConstSizeOp : Shape_Op<"const_size",
132132
let assemblyFormat = "attr-dict $value";
133133
}
134134

135+
def Shape_FromExtentsOp : Shape_Op<"from_extents", [
136+
NoSideEffect,
137+
DeclareOpInterfaceMethods<InferTypeOpInterface>
138+
]> {
139+
let summary = "Creates a shape from extents";
140+
let description = [{
141+
Creates a shape from multiple SSA values representing the extents of
142+
the shape.
143+
144+
```mlir
145+
// Rank 2 shape.
146+
%s0 = shape.from_extents %a, %b
147+
// Rank 0 shape.
148+
%s1 = shape.from_extents
149+
```
150+
}];
151+
let arguments = (ins Variadic<Index>:$extents);
152+
let results = (outs Shape_ShapeType:$shape);
153+
154+
let assemblyFormat = "attr-dict $extents";
155+
156+
let hasFolder = 1;
157+
}
158+
135159
def Shape_FromExtentTensorOp : Shape_Op<"from_extent_tensor", []> {
136160
let summary = "Creates a shape from a tensor of extents";
137161
let description = [{

mlir/lib/Dialect/Shape/IR/Shape.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,28 @@ ConstSizeOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
201201
return success();
202202
}
203203

204+
//===----------------------------------------------------------------------===//
205+
// FromExtentsOp
206+
//===----------------------------------------------------------------------===//
207+
208+
LogicalResult FromExtentsOp::inferReturnTypes(
209+
MLIRContext *context, Optional<Location> location, ValueRange operands,
210+
DictionaryAttr attributes, RegionRange regions,
211+
SmallVectorImpl<Type> &inferredReturnTypes) {
212+
inferredReturnTypes.push_back(ShapeType::get(context));
213+
return success();
214+
}
215+
216+
OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
217+
if (llvm::any_of(operands, [](Attribute a) { return !a; }))
218+
return nullptr;
219+
SmallVector<int64_t, 6> extents;
220+
for (auto attr : operands)
221+
extents.push_back(attr.cast<IntegerAttr>().getInt());
222+
Builder builder(getContext());
223+
return builder.getI64TensorAttr(extents);
224+
}
225+
204226
//===----------------------------------------------------------------------===//
205227
// ShapeOfOp
206228
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Shape/canonicalize.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,23 @@ func @f() -> tensor<2xindex> {
8686
%0 = "shape.to_extent_tensor"(%cs) : (!shape.shape) -> tensor<2xindex>
8787
return %0 : tensor<2xindex>
8888
}
89+
90+
// -----
91+
// Basic case.
92+
// CHECK-LABEL: func @f()
93+
func @f() -> !shape.shape {
94+
// CHECK: shape.const_shape [3, 5, 11]
95+
%e0 = constant 3 : index
96+
%e1 = constant 5 : index
97+
%e2 = constant 11 : index
98+
%ret = shape.from_extents %e0, %e1, %e2
99+
return %ret : !shape.shape
100+
}
101+
102+
// CHECK-LABEL: func @no_fold
103+
func @no_fold(%arg0: index) -> !shape.shape {
104+
// CHECK-NOT: shape.const_shape
105+
%e0 = constant 3 : index
106+
%ret = shape.from_extents %e0, %arg0
107+
return %ret : !shape.shape
108+
}

0 commit comments

Comments
 (0)