Skip to content

Commit 8544523

Browse files
[mlir][tensor] Promote extract(from_elements(...)) to folding pattern
Differential Revision: https://reviews.llvm.org/D123617
1 parent ff087d7 commit 8544523

File tree

2 files changed

+33
-59
lines changed

2 files changed

+33
-59
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 32 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -361,27 +361,48 @@ LogicalResult ExtractOp::verify() {
361361
}
362362

363363
OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
364-
// The tensor operand must be a known constant.
365-
Attribute tensor = operands.front();
366-
if (!tensor)
367-
return {};
368364
// If this is a splat elements attribute, simply return the value. All of the
369365
// elements of a splat attribute are the same.
370-
if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>())
371-
return splatTensor.getSplatValue<Attribute>();
366+
if (Attribute tensor = operands.front())
367+
if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>())
368+
return splatTensor.getSplatValue<Attribute>();
372369

373-
// Otherwise, collect the constant indices into the tensor.
370+
// Collect the constant indices into the tensor.
374371
SmallVector<uint64_t, 8> indices;
375372
for (Attribute indice : llvm::drop_begin(operands, 1)) {
376373
if (!indice || !indice.isa<IntegerAttr>())
377374
return {};
378375
indices.push_back(indice.cast<IntegerAttr>().getInt());
379376
}
380377

378+
// Fold extract(from_elements(...)).
379+
if (auto fromElementsOp = tensor().getDefiningOp<FromElementsOp>()) {
380+
auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
381+
auto rank = tensorType.getRank();
382+
assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
383+
"rank mismatch");
384+
int flatIndex = 0;
385+
int stride = 1;
386+
for (int i = rank - 1; i >= 0; --i) {
387+
if (i < rank - 1)
388+
stride *= tensorType.getDimSize(i);
389+
flatIndex += indices[i] * stride;
390+
}
391+
// Prevent out of bounds accesses. This can happen in invalid code that will
392+
// never execute.
393+
if (static_cast<int>(fromElementsOp.elements().size()) <= flatIndex ||
394+
flatIndex < 0)
395+
return {};
396+
return fromElementsOp.elements()[flatIndex];
397+
}
398+
381399
// If this is an elements attribute, query the value at the given indices.
382-
auto elementsAttr = tensor.dyn_cast<ElementsAttr>();
383-
if (elementsAttr && elementsAttr.isValidIndex(indices))
384-
return elementsAttr.getValues<Attribute>()[indices];
400+
if (Attribute tensor = operands.front()) {
401+
auto elementsAttr = tensor.dyn_cast<ElementsAttr>();
402+
if (elementsAttr && elementsAttr.isValidIndex(indices))
403+
return elementsAttr.getValues<Attribute>()[indices];
404+
}
405+
385406
return {};
386407
}
387408

@@ -411,47 +432,6 @@ OpFoldResult FromElementsOp::fold(ArrayRef<Attribute> operands) {
411432

412433
namespace {
413434

414-
// Canonicalizes the pattern of the form
415-
//
416-
// %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32>
417-
// %extracted_element = tensor.extract %tensor[%c0] : tensor<1xi32>
418-
//
419-
// to just %element.
420-
struct ExtractElementFromTensorFromElements
421-
: public OpRewritePattern<tensor::ExtractOp> {
422-
using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
423-
424-
LogicalResult matchAndRewrite(tensor::ExtractOp extract,
425-
PatternRewriter &rewriter) const final {
426-
auto tensorFromElements = extract.tensor().getDefiningOp<FromElementsOp>();
427-
if (!tensorFromElements)
428-
return failure();
429-
auto tensorType = tensorFromElements.getType().cast<RankedTensorType>();
430-
auto rank = tensorType.getRank();
431-
if (rank == 0) {
432-
rewriter.replaceOp(extract, tensorFromElements.getOperand(0));
433-
return success();
434-
}
435-
SmallVector<APInt, 3> indices(rank);
436-
int64_t flatIndex = 0;
437-
int64_t stride = 1;
438-
for (int i = rank - 1; i >= 0; --i) {
439-
APInt index;
440-
if (!matchPattern(extract.indices()[i], m_ConstantInt(&index)))
441-
return failure();
442-
if (i < rank - 1)
443-
stride *= tensorType.getDimSize(i);
444-
flatIndex += index.getSExtValue() * stride;
445-
}
446-
// Prevent out of bounds accesses. This can happen in invalid code that will
447-
// never execute.
448-
if (tensorFromElements->getNumOperands() <= flatIndex || flatIndex < 0)
449-
return failure();
450-
rewriter.replaceOp(extract, tensorFromElements.getOperand(flatIndex));
451-
return success();
452-
}
453-
};
454-
455435
// Pushes the index_casts that occur before extractions to after the extract.
456436
// This minimizes type conversion in some cases and enables the extract
457437
// canonicalizer. This changes:
@@ -494,9 +474,7 @@ struct ExtractElementFromIndexCast
494474

495475
void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
496476
MLIRContext *context) {
497-
results
498-
.add<ExtractElementFromIndexCast, ExtractElementFromTensorFromElements>(
499-
context);
477+
results.add<ExtractElementFromIndexCast>(context);
500478
}
501479

502480
//===----------------------------------------------------------------------===//

mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,8 @@ func @result_shape(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
2222
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
2323
// CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index
2424
// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG_1]], %[[C0]]
25-
// CHECK-DAG: %[[S0:.+]] = tensor.from_elements %[[D0]], %[[C5]]
26-
// CHECK-DAG: %[[D0_OUT:.+]] = tensor.extract %[[S0]][%[[C0]]]
2725
// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG_0]], %[[C2]]
28-
// CHECK-DAG: %[[S1:.+]] = tensor.from_elements %[[C2]], %[[C3]], %[[D1]]
29-
// CHECK-DAG: %[[D1_OUT:.+]] = tensor.extract %[[S1]][%[[C2]]]
30-
// CHECK: return %[[D0_OUT]], %[[C5]], %[[C2]], %[[C3]], %[[D1_OUT]]
26+
// CHECK: return %[[D0]], %[[C5]], %[[C2]], %[[C3]], %[[D1]]
3127

3228
// -----
3329

0 commit comments

Comments
 (0)