Skip to content

Commit 9e3ca79

Browse files
lhutton1rsuderman
authored andcommitted
[mlir][tosa] Canonicalize concatenate->slice sequence
Adds a canonicalizer for the concatenate->slice sequence where an output of slice can be replaced with an input of concatenate. This is useful in the context of operations with complex inputs and outputs that are legalized from a framework such as TFL. For example, a TFL graph (FFT->FFT) will be legalized to the following TOSA graph: <complex input> / \ slice slice \ / FFT / \ -+ concatenate | / \ | Redundant slice slice | \ / -+ FFT / \ concatenate | <complex output> Concatenate and slice operations at the boundaries of the graph are useful as they maintain the correct correspondance of input/output tensors to the original TFL graph. However, consecutive complex operations will result in redundant concatenate->slice sequences which should be removed from the final TOSA graph. The canonicalization does not currently handle dynamic types. Signed-off-by: Luke Hutton <luke.hutton@arm.com> Reviewed By: rsuderman Differential Revision: https://reviews.llvm.org/D144545
1 parent 83e420c commit 9e3ca79

File tree

3 files changed

+113
-0
lines changed

3 files changed

+113
-0
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1556,6 +1556,7 @@ def Tosa_SliceOp: Tosa_Op<"slice", [
15561556
Tosa_Tensor1Dto6D:$output
15571557
);
15581558

1559+
let hasCanonicalizer = 1;
15591560
let hasFolder = 1;
15601561
}
15611562

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,65 @@ void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results,
519519
results.add<ClampClampOptimization>(context);
520520
}
521521

522+
struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
523+
using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
524+
525+
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
526+
PatternRewriter &rewriter) const override {
527+
Value sliceInput = sliceOp.getInput();
528+
auto concatOp = sliceInput.getDefiningOp<tosa::ConcatOp>();
529+
if (!concatOp)
530+
return rewriter.notifyMatchFailure(
531+
sliceOp, "slice input must be concat operation");
532+
533+
OperandRange inputs = concatOp.getInput1();
534+
auto concatType = dyn_cast<RankedTensorType>(concatOp.getType());
535+
if (!concatType || !concatType.hasStaticShape())
536+
return rewriter.notifyMatchFailure(
537+
sliceOp, "slice input must be a static ranked tensor");
538+
int32_t axis = concatOp.getAxis();
539+
540+
llvm::SmallVector<int64_t> sliceStart(sliceOp.getStart());
541+
llvm::ArrayRef<int64_t> sliceSize = sliceOp.getSize();
542+
543+
// Validate slice on the concatenated axis. Slicing along this
544+
// axis should span only one of the inputs to the concatenate
545+
// operation.
546+
std::optional<Value> replaceWithSlice;
547+
for (auto input : inputs) {
548+
auto inputType = dyn_cast<RankedTensorType>(input.getType());
549+
if (!inputType || !inputType.hasStaticShape())
550+
return rewriter.notifyMatchFailure(
551+
sliceOp, "concat input must be a static ranked tensor");
552+
553+
if (sliceStart[axis] >= 0 &&
554+
(sliceStart[axis] + sliceSize[axis]) <= inputType.getDimSize(axis)) {
555+
replaceWithSlice =
556+
rewriter
557+
.create<tosa::SliceOp>(
558+
sliceOp.getLoc(), sliceOp.getType(), input,
559+
rewriter.getDenseI64ArrayAttr(sliceOp.getStart()),
560+
rewriter.getDenseI64ArrayAttr(sliceSize))
561+
.getResult();
562+
break;
563+
}
564+
sliceStart[axis] -= inputType.getDimSize(axis);
565+
}
566+
567+
if (!replaceWithSlice)
568+
return rewriter.notifyMatchFailure(
569+
sliceOp, "corresponding concat input not found for slice");
570+
571+
rewriter.replaceOp(sliceOp, replaceWithSlice.value());
572+
return success();
573+
}
574+
};
575+
576+
void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
577+
MLIRContext *context) {
578+
results.add<ConcatSliceOptimization>(context);
579+
}
580+
522581
//===----------------------------------------------------------------------===//
523582
// Operator Folders.
524583
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,3 +434,56 @@ func.func @fold_resize_bilinear(%arg0 : tensor<1x15x13x1xi8>) -> tensor<1x15x13x
434434
%resize = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = array<i64: 2, 2, 1, 1>, offset = array<i64: 0, 0>, border = array<i64: 0, 0>} : (tensor<1x15x13x1xi8>) -> tensor<1x15x13x1xi8>
435435
return %resize : tensor<1x15x13x1xi8>
436436
}
437+
438+
// -----
439+
440+
// CHECK-LABEL: @canonicalize_concat_slice_final_axis
441+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x12x12x1xf32>, %[[VAL_1:.*]]: tensor<1x12x12x1xf32>
442+
// CHECK: return %[[VAL_0]], %[[VAL_1]] : tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32>
443+
func.func @canonicalize_concat_slice_final_axis(%arg0 : tensor<1x12x12x1xf32>, %arg1 : tensor<1x12x12x1xf32>) -> (tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32>) {
444+
%0 = "tosa.concat"(%arg0, %arg1) {axis = 3 : i64} : (tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32>) -> tensor<1x12x12x2xf32>
445+
%1 = "tosa.slice"(%0) {size = array<i64: 1, 12, 12, 1>, start = array<i64: 0, 0, 0, 0>} : (tensor<1x12x12x2xf32>) -> tensor<1x12x12x1xf32>
446+
%2 = "tosa.slice"(%0) {size = array<i64: 1, 12, 12, 1>, start = array<i64: 0, 0, 0, 1>} : (tensor<1x12x12x2xf32>) -> tensor<1x12x12x1xf32>
447+
return %1, %2 : tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32>
448+
}
449+
450+
// -----
451+
452+
// CHECK-LABEL: @canonicalize_concat_slice_middle_axis
453+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x12x12xf32>, %[[VAL_1:.*]]: tensor<1x12x12xf32>
454+
// CHECK: return %[[VAL_0]], %[[VAL_1]] : tensor<1x12x12xf32>, tensor<1x12x12xf32>
455+
func.func @canonicalize_concat_slice_middle_axis(%arg0 : tensor<1x12x12xf32>, %arg1 : tensor<1x12x12xf32>) -> (tensor<1x12x12xf32>, tensor<1x12x12xf32>) {
456+
%0 = "tosa.concat"(%arg0, %arg1) {axis = 1 : i64} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x24x12xf32>
457+
%1 = "tosa.slice"(%0) {size = array<i64: 1, 12, 12>, start = array<i64: 0, 0, 0>} : (tensor<1x24x12xf32>) -> tensor<1x12x12xf32>
458+
%2 = "tosa.slice"(%0) {size = array<i64: 1, 12, 12>, start = array<i64: 0, 12, 0>} : (tensor<1x24x12xf32>) -> tensor<1x12x12xf32>
459+
return %1, %2 : tensor<1x12x12xf32>, tensor<1x12x12xf32>
460+
}
461+
462+
// -----
463+
464+
// CHECK-LABEL: @canonicalize_cross_concat_inputs
465+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x12x12xf32>, %[[VAL_1:.*]]: tensor<1x12x12xf32>
466+
// CHECK: %[[VAL_2:.*]] = "tosa.concat"(%[[VAL_0]], %[[VAL_1]]) {axis = 2 : i64} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x12x24xf32>
467+
// CHECK: %[[VAL_3:.*]] = "tosa.slice"(%[[VAL_2]]) {size = array<i64: 1, 12, 15>, start = array<i64: 0, 0, 0>} : (tensor<1x12x24xf32>) -> tensor<1x12x15xf32>
468+
// CHECK: %[[VAL_4:.*]] = "tosa.slice"(%[[VAL_2]]) {size = array<i64: 1, 12, 20>, start = array<i64: 0, 0, 4>} : (tensor<1x12x24xf32>) -> tensor<1x12x20xf32>
469+
// CHECK: return %[[VAL_3]], %[[VAL_4]] : tensor<1x12x15xf32>, tensor<1x12x20xf32>
470+
func.func @canonicalize_cross_concat_inputs(%arg0 : tensor<1x12x12xf32>, %arg1 : tensor<1x12x12xf32>) -> (tensor<1x12x15xf32>, tensor<1x12x20xf32>) {
471+
%0 = "tosa.concat"(%arg0, %arg1) {axis = 2 : i64} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x12x24xf32>
472+
%1 = "tosa.slice"(%0) {size = array<i64: 1, 12, 15>, start = array<i64: 0, 0, 0>} : (tensor<1x12x24xf32>) -> tensor<1x12x15xf32>
473+
%2 = "tosa.slice"(%0) {size = array<i64: 1, 12, 20>, start = array<i64: 0, 0, 4>} : (tensor<1x12x24xf32>) -> tensor<1x12x20xf32>
474+
return %1, %2 : tensor<1x12x15xf32>, tensor<1x12x20xf32>
475+
}
476+
477+
// -----
478+
479+
// CHECK-LABEL: @canonicalize_concat_slice_on_non_concat_axis
480+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x12x12xf32>, %[[VAL_1:.*]]: tensor<1x12x12xf32>
481+
// CHECK: %[[VAL_2:.*]] = "tosa.slice"(%[[VAL_0]]) {size = array<i64: 1, 6, 12>, start = array<i64: 0, 0, 0>} : (tensor<1x12x12xf32>) -> tensor<1x6x12xf32>
482+
// CHECK: %[[VAL_3:.*]] = "tosa.slice"(%[[VAL_1]]) {size = array<i64: 1, 3, 12>, start = array<i64: 1, 3, 12>} : (tensor<1x12x12xf32>) -> tensor<1x3x12xf32>
483+
// CHECK: return %[[VAL_2]], %[[VAL_3]] : tensor<1x6x12xf32>, tensor<1x3x12xf32>
484+
func.func @canonicalize_concat_slice_on_non_concat_axis(%arg0 : tensor<1x12x12xf32>, %arg1 : tensor<1x12x12xf32>) -> (tensor<1x6x12xf32>, tensor<1x3x12xf32>) {
485+
%0 = "tosa.concat"(%arg0, %arg1) {axis = 2 : i64} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x12x24xf32>
486+
%1 = "tosa.slice"(%0) {size = array<i64: 1, 6, 12>, start = array<i64: 0, 0, 0>} : (tensor<1x12x24xf32>) -> tensor<1x6x12xf32>
487+
%2 = "tosa.slice"(%0) {size = array<i64: 1, 3, 12>, start = array<i64: 1, 3, 12>} : (tensor<1x12x24xf32>) -> tensor<1x3x12xf32>
488+
return %1, %2 : tensor<1x6x12xf32>, tensor<1x3x12xf32>
489+
}

0 commit comments

Comments
 (0)