@@ -361,27 +361,48 @@ LogicalResult ExtractOp::verify() {
361
361
}
362
362
363
363
OpFoldResult ExtractOp::fold (ArrayRef<Attribute> operands) {
364
- // The tensor operand must be a known constant.
365
- Attribute tensor = operands.front ();
366
- if (!tensor)
367
- return {};
368
364
// If this is a splat elements attribute, simply return the value. All of the
369
365
// elements of a splat attribute are the same.
370
- if (auto splatTensor = tensor.dyn_cast <SplatElementsAttr>())
371
- return splatTensor.getSplatValue <Attribute>();
366
+ if (Attribute tensor = operands.front ())
367
+ if (auto splatTensor = tensor.dyn_cast <SplatElementsAttr>())
368
+ return splatTensor.getSplatValue <Attribute>();
372
369
373
- // Otherwise, collect the constant indices into the tensor.
370
+ // Collect the constant indices into the tensor.
374
371
SmallVector<uint64_t , 8 > indices;
375
372
for (Attribute indice : llvm::drop_begin (operands, 1 )) {
376
373
if (!indice || !indice.isa <IntegerAttr>())
377
374
return {};
378
375
indices.push_back (indice.cast <IntegerAttr>().getInt ());
379
376
}
380
377
378
+ // Fold extract(from_elements(...)).
379
+ if (auto fromElementsOp = tensor ().getDefiningOp <FromElementsOp>()) {
380
+ auto tensorType = fromElementsOp.getType ().cast <RankedTensorType>();
381
+ auto rank = tensorType.getRank ();
382
+ assert (static_cast <int64_t >(indices.size ()) == tensorType.getRank () &&
383
+ " rank mismatch" );
384
+ int flatIndex = 0 ;
385
+ int stride = 1 ;
386
+ for (int i = rank - 1 ; i >= 0 ; --i) {
387
+ if (i < rank - 1 )
388
+ stride *= tensorType.getDimSize (i);
389
+ flatIndex += indices[i] * stride;
390
+ }
391
+ // Prevent out of bounds accesses. This can happen in invalid code that will
392
+ // never execute.
393
+ if (static_cast <int >(fromElementsOp.elements ().size ()) <= flatIndex ||
394
+ flatIndex < 0 )
395
+ return {};
396
+ return fromElementsOp.elements ()[flatIndex];
397
+ }
398
+
381
399
// If this is an elements attribute, query the value at the given indices.
382
- auto elementsAttr = tensor.dyn_cast <ElementsAttr>();
383
- if (elementsAttr && elementsAttr.isValidIndex (indices))
384
- return elementsAttr.getValues <Attribute>()[indices];
400
+ if (Attribute tensor = operands.front ()) {
401
+ auto elementsAttr = tensor.dyn_cast <ElementsAttr>();
402
+ if (elementsAttr && elementsAttr.isValidIndex (indices))
403
+ return elementsAttr.getValues <Attribute>()[indices];
404
+ }
405
+
385
406
return {};
386
407
}
387
408
@@ -411,47 +432,6 @@ OpFoldResult FromElementsOp::fold(ArrayRef<Attribute> operands) {
411
432
412
433
namespace {
413
434
414
- // Canonicalizes the pattern of the form
415
- //
416
- // %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32>
417
- // %extracted_element = tensor.extract %tensor[%c0] : tensor<1xi32>
418
- //
419
- // to just %element.
420
- struct ExtractElementFromTensorFromElements
421
- : public OpRewritePattern<tensor::ExtractOp> {
422
- using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
423
-
424
- LogicalResult matchAndRewrite (tensor::ExtractOp extract,
425
- PatternRewriter &rewriter) const final {
426
- auto tensorFromElements = extract.tensor ().getDefiningOp <FromElementsOp>();
427
- if (!tensorFromElements)
428
- return failure ();
429
- auto tensorType = tensorFromElements.getType ().cast <RankedTensorType>();
430
- auto rank = tensorType.getRank ();
431
- if (rank == 0 ) {
432
- rewriter.replaceOp (extract, tensorFromElements.getOperand (0 ));
433
- return success ();
434
- }
435
- SmallVector<APInt, 3 > indices (rank);
436
- int64_t flatIndex = 0 ;
437
- int64_t stride = 1 ;
438
- for (int i = rank - 1 ; i >= 0 ; --i) {
439
- APInt index;
440
- if (!matchPattern (extract.indices ()[i], m_ConstantInt (&index)))
441
- return failure ();
442
- if (i < rank - 1 )
443
- stride *= tensorType.getDimSize (i);
444
- flatIndex += index.getSExtValue () * stride;
445
- }
446
- // Prevent out of bounds accesses. This can happen in invalid code that will
447
- // never execute.
448
- if (tensorFromElements->getNumOperands () <= flatIndex || flatIndex < 0 )
449
- return failure ();
450
- rewriter.replaceOp (extract, tensorFromElements.getOperand (flatIndex));
451
- return success ();
452
- }
453
- };
454
-
455
435
// Pushes the index_casts that occur before extractions to after the extract.
456
436
// This minimizes type conversion in some cases and enables the extract
457
437
// canonicalizer. This changes:
@@ -494,9 +474,7 @@ struct ExtractElementFromIndexCast
494
474
495
475
void FromElementsOp::getCanonicalizationPatterns (RewritePatternSet &results,
496
476
MLIRContext *context) {
497
- results
498
- .add <ExtractElementFromIndexCast, ExtractElementFromTensorFromElements>(
499
- context);
477
+ results.add <ExtractElementFromIndexCast>(context);
500
478
}
501
479
502
480
// ===----------------------------------------------------------------------===//
0 commit comments