diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 147a2907f52e4..f0c8f0de06637 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -529,8 +529,8 @@ fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand); /// * There is a chance that the implementation of the transformation does not /// agree with the result of this method. This function gives a prediction based /// on an optimized fusion. -llvm::SmallDenseSet getPreservedProducerResults(GenericOp producer, - GenericOp consumer, +llvm::SmallDenseSet getPreservedProducerResults(LinalgOp producer, + LinalgOp consumer, OpOperand *fusedOperand); /// Try to peel and canonicalize loop `op` and return the new result. diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index f97ed3d6d5111..c3b5765a5c4ad 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -77,11 +77,11 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( // of the fused producer & consumer after the fusion can still compute the // bounds of the op. static bool isOpOperandCanBeDroppedAfterFusedLinalgs( - GenericOp producer, GenericOp consumer, + LinalgOp producer, LinalgOp consumer, ArrayRef opOperandsToIgnore) { SmallVector indexingMaps; - SmallVector ops = {producer, consumer}; + SmallVector ops = {producer, consumer}; for (auto &op : ops) { for (auto &opOperand : op->getOpOperands()) { if (llvm::is_contained(opOperandsToIgnore, &opOperand)) { @@ -109,8 +109,9 @@ static bool isOpOperandCanBeDroppedAfterFusedLinalgs( /// * There is a chance that the implementation of the transformation does not /// agree with the result of this method. This function gives a prediction based /// on an optimized fusion. -llvm::SmallDenseSet mlir::linalg::getPreservedProducerResults( - GenericOp producer, GenericOp consumer, OpOperand *fusedOperand) { +llvm::SmallDenseSet +mlir::linalg::getPreservedProducerResults(LinalgOp producer, LinalgOp consumer, + OpOperand *fusedOperand) { llvm::SmallDenseSet preservedProducerResults; llvm::SmallVector opOperandsToIgnore; @@ -140,8 +141,8 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) { if (!fusedOperand) return false; - auto producer = fusedOperand->get().getDefiningOp(); - auto consumer = dyn_cast(fusedOperand->getOwner()); + auto producer = fusedOperand->get().getDefiningOp(); + auto consumer = dyn_cast(fusedOperand->getOwner()); // Check producer and consumer are generic ops. if (!producer || !consumer) @@ -215,16 +216,37 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) { /// Generate the region of the fused tensor operation. The region of the fused /// op must be empty. static void generateFusedElementwiseOpRegion( - RewriterBase &rewriter, GenericOp fusedOp, + RewriterBase &rewriter, LinalgOp fusedOp, AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand, unsigned nloops, llvm::SmallDenseSet &preservedProducerResults) { - auto producer = cast(fusedOperand->get().getDefiningOp()); - auto consumer = cast(fusedOperand->getOwner()); + auto producer = cast(fusedOperand->get().getDefiningOp()); + auto consumer = cast(fusedOperand->getOwner()); // Build the region of the fused op. + + // Since some ops, like `linalg.map`, do not have block arguments for init + // operands then we first "generalize" the block by adding arguments for init + // operands when they aren't present. We detect this case by checking if + // `getOpOperandsMatchingBBargs() == getDpsInputOperands()`. + // TODO: This is hacky and should not be merged. Keeping for now for testing + // purposes in the meantime, but need a better way Block &producerBlock = producer->getRegion(0).front(); + bool addOutputArgsProducer = + producer.getOpOperandsMatchingBBargs() == producer.getDpsInputOperands(); + if (addOutputArgsProducer) { + for (auto init : producer.getDpsInits()) + producerBlock.addArgument(getElementTypeOrSelf(init.getType()), + producer.getLoc()); + } Block &consumerBlock = consumer->getRegion(0).front(); + bool addOutputArgsConsumer = + consumer.getOpOperandsMatchingBBargs() == consumer.getDpsInputOperands(); + if (addOutputArgsConsumer) { + for (auto init : consumer.getDpsInits()) + consumerBlock.addArgument(getElementTypeOrSelf(init.getType()), + consumer.getLoc()); + } OpBuilder::InsertionGuard guard(rewriter); - Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion()); + Block *fusedBlock = rewriter.createBlock(&fusedOp->getRegion(0)); IRMapping mapper; // 2. Add an index operation for every fused loop dimension and use the @@ -330,8 +352,16 @@ static void generateFusedElementwiseOpRegion( rewriter.create(fusedOp.getLoc(), fusedYieldValues); // Sanity checks. - assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() && + assert(fusedBlock->getNumArguments() == fusedOp->getNumOperands() && "Ill-formed GenericOp region"); + // Erase added args in case that the ops are still live after fusion. + // TODO: Remove along with hacky code above. + if (addOutputArgsProducer) + producerBlock.eraseArguments(producer.getNumDpsInputs(), + producer.getNumDpsInits()); + if (addOutputArgsConsumer) + consumerBlock.eraseArguments(consumer.getNumDpsInputs(), + consumer.getNumDpsInits()); } FailureOr @@ -340,8 +370,8 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter, assert(areElementwiseOpsFusable(fusedOperand) && "expected elementwise operation pre-conditions to pass"); auto producerResult = cast(fusedOperand->get()); - auto producer = cast(producerResult.getOwner()); - auto consumer = cast(fusedOperand->getOwner()); + auto producer = cast(producerResult.getOwner()); + auto consumer = cast(fusedOperand->getOwner()); // TODO: allow fusing the producer of an output operand. assert(consumer.isDpsInput(fusedOperand) && "expected producer of input operand"); @@ -418,10 +448,7 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter, // Generate the fused op. auto fusedOp = rewriter.create( consumer.getLoc(), fusedResultTypes, fusedInputOperands, - fusedOutputOperands, rewriter.getAffineMapArrayAttr(fusedIndexMaps), - consumer.getIteratorTypes(), - /*doc=*/nullptr, - /*library_call=*/nullptr); + fusedOutputOperands, fusedIndexMaps, consumer.getIteratorTypesArray()); if (!fusedOp.getShapesToLoopsMap()) { // Fused op has invalid indexing maps. Typically this means something is off // in the input, but going ahead here would result in verification errors. @@ -460,14 +487,14 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter, namespace { /// Patterns to fuse a generic op, with the producer of its operands. -class FuseElementwiseOps : public OpRewritePattern { +class FuseElementwiseOps : public OpInterfaceRewritePattern { public: FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun, PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), + : OpInterfaceRewritePattern(context, benefit), controlFn(std::move(fun)) {} - LogicalResult matchAndRewrite(GenericOp genericOp, + LogicalResult matchAndRewrite(LinalgOp genericOp, PatternRewriter &rewriter) const override { // Find the first operand that is defined by another generic op on tensors. for (OpOperand &opOperand : genericOp->getOpOperands()) { @@ -494,7 +521,7 @@ class FuseElementwiseOps : public OpRewritePattern { rewriter.eraseOp(genericOp); return success(); } - return failure(); + return rewriter.notifyMatchFailure(genericOp, "no fusable operands"); } private: diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir index 66fc55fadf8fa..b581567cf57a7 100644 --- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir @@ -1014,3 +1014,24 @@ module { // CHECK-DAG: %[[T3:.+]] = arith.addf %[[T2]], %[[B1]] // CHECK: linalg.yield %[[T3]] : f32 // CHECK: return %[[GENERIC]] + +// ----- + +func.func @map_ops(%in1: tensor<8xf32>, %in2: tensor<8xf32>) -> tensor<8xf32> { + %fill = tensor.empty() : tensor<8xf32> + %add = linalg.map {arith.addf} ins(%in1, %in2: tensor<8xf32>, tensor<8xf32>) outs(%fill: tensor<8xf32>) + %mapped_65 = linalg.map { math.sqrt } ins(%add : tensor<8xf32>) outs(%fill : tensor<8xf32>) + return %mapped_65 : tensor<8xf32> +} + +// CHECK-LABEL: func @map_ops +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32> +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32> +// CHECK: %[[FUSED_OP:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] : +// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]] +// CHECK-NEXT: %[[SQRT:.*]] = math.sqrt %[[ADD]] +// CHECK-NEXT: linalg.yield %[[SQRT]] +// CHECK-NOT: linalg.generic diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir index bd9977f1410b9..d4b25eb4be691 100644 --- a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir @@ -59,3 +59,154 @@ func.func @handle_unused_operands(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> // CHECK: %[[FUSED_OP:.+]] = linalg.generic // CHECK-SAME: outs(%[[EMPTY]] : // CHECK-NOT: linalg.generic + +// ----- + +func.func @map_ops(%in1: tensor<8xf32>, %in2: tensor<8xf32>) -> tensor<8xf32> { + %fill = tensor.empty() : tensor<8xf32> + %add = linalg.map {arith.addf} ins(%in1, %in2: tensor<8xf32>, tensor<8xf32>) outs(%fill: tensor<8xf32>) + %sqrt = linalg.map { math.sqrt } ins(%add : tensor<8xf32>) outs(%fill : tensor<8xf32>) + return %sqrt : tensor<8xf32> +} + +// CHECK-LABEL: func @map_ops +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32> +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32> +// CHECK: %[[FUSED_OP:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] : +// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]] +// CHECK-NEXT: %[[SQRT:.*]] = math.sqrt %[[ADD]] +// CHECK-NEXT: linalg.yield %[[SQRT]] +// CHECK-NOT: linalg.map + +// ----- + +func.func @map_ops_mixed_types(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> tensor<8xf32> { + %init = tensor.empty() : tensor<8xi1> + %initf = tensor.empty() : tensor<8xf32> + %0 = linalg.map {math.sqrt} ins(%arg0 : tensor<8xf32>) outs(%initf : tensor<8xf32>) + %1 = linalg.map {math.exp} ins(%arg1 : tensor<8xf32>) outs(%initf : tensor<8xf32>) + %2 = linalg.map ins(%0, %1 : tensor<8xf32>, tensor<8xf32>) outs (%init : tensor<8xi1>) + (%in0 : f32, %in1 : f32) { + %cmp = arith.cmpf olt, %in0, %in1 : f32 + linalg.yield %cmp : i1 + } + %3 = linalg.map { arith.select } ins(%2, %0, %1 : tensor<8xi1>, tensor<8xf32>, tensor<8xf32>) outs(%initf : tensor<8xf32>) + return %3 : tensor<8xf32> +} + +// CHECK-LABEL: func @map_ops_mixed_types +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32> +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32> +// CHECK: %[[FUSED_OP:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] : +// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: %[[EXP0:.*]] = math.exp %[[IN1]] +// CHECK-NEXT: %[[SQRT0:.*]] = math.sqrt %[[IN0]] +// CHECK-NEXT: %[[EXP1:.*]] = math.exp %[[IN1]] +// CHECK-NEXT: %[[SQRT1:.*]] = math.sqrt %[[IN0]] +// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf olt, %[[SQRT1]], %[[EXP1]] +// CHECK-NEXT: %[[RES:.*]] = arith.select %[[CMP]], %[[SQRT0]], %[[EXP0]] +// CHECK-NEXT: linalg.yield %[[RES]] +// CHECK-NOT: linalg.map + +// ----- + +#identity = affine_map<(d0, d1) -> (d0, d1)> +#bcast = affine_map<(d0, d1) -> (d0)> +func.func @elementwise_ops(%in1: tensor<8xf32>, %in2: tensor<8x10xf32>) -> tensor<8x10xf32> { + %fill = tensor.empty() : tensor<8x10xf32> + %add = linalg.elementwise + kind=#linalg.elementwise_kind + indexing_maps = [#bcast, #identity, #identity] + ins(%in1, %in2: tensor<8xf32>, tensor<8x10xf32>) outs(%fill: tensor<8x10xf32>) -> tensor<8x10xf32> + %sqrt = linalg.elementwise + kind=#linalg.elementwise_kind + indexing_maps = [#identity, #identity] + ins(%add : tensor<8x10xf32>) outs(%fill : tensor<8x10xf32>) -> tensor<8x10xf32> + return %sqrt : tensor<8x10xf32> +} + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0)> +// CHECK-LABEL: func @elementwise_ops +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8x10xf32> +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x10xf32> +// CHECK: %[[FUSED_OP:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP0]], #[[MAP0]]] +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] : +// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]] +// CHECK-NEXT: %[[SQRT:.*]] = math.sqrt %[[ADD]] +// CHECK-NEXT: linalg.yield %[[SQRT]] +// CHECK-NOT: linalg.map + +// ----- + +func.func @map_multi_ops(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>) -> tensor<8xf32> { + %fill = tensor.empty() : tensor<8xf32> + %add_exp = linalg.map ins(%arg0, %arg1: tensor<8xf32>, tensor<8xf32>) outs(%fill: tensor<8xf32>) + (%in0 : f32, %in1 : f32) { + %add = arith.addf %in0, %in1 : f32 + %exp = math.exp %add : f32 + linalg.yield %exp : f32 + } + %sqrt_mul = linalg.map ins(%add_exp, %arg2 : tensor<8xf32>, tensor<8xf32>) outs(%fill : tensor<8xf32>) + (%in0 : f32, %in1 : f32) { + %sqrt = math.sqrt %in0 : f32 + %mul = arith.mulf %sqrt, %in1 : f32 + linalg.yield %mul : f32 + } + return %sqrt_mul : tensor<8xf32> +} + +// CHECK-LABEL: func @map_multi_ops +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<8xf32> +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32> +// CHECK: %[[FUSED_OP:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]] : {{.*}}) outs(%[[EMPTY]] : +// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[IN2:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]] +// CHECK-NEXT: %[[EXP:.*]] = math.exp %[[ADD]] +// CHECK-NEXT: %[[SQRT:.*]] = math.sqrt %[[EXP]] +// CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[SQRT]], %[[IN2]] +// CHECK-NEXT: linalg.yield %[[MUL]] +// CHECK-NOT: linalg.map + +// ----- + +#identity = affine_map<(d0, d1) -> (d0, d1)> +#bcast = affine_map<(d0, d1) -> (d0)> +func.func @map_genric_ops(%arg0: tensor<8xf32>, %arg1: tensor<8x10xf32>) -> tensor<8x10xf32> { + %fill = tensor.empty() : tensor<8x10xf32> + %add = linalg.generic + {indexing_maps = [#bcast, #identity, #identity], iterator_types = ["parallel", "parallel"]} + ins(%arg0, %arg1: tensor<8xf32>, tensor<8x10xf32>) outs(%fill: tensor<8x10xf32>) { + ^bb0(%in0: f32, %in1: f32, %out: f32): + %add = arith.addf %in0, %in1 : f32 + linalg.yield %add : f32 + } -> tensor<8x10xf32> + %sqrt = linalg.map { math.sqrt } ins(%add : tensor<8x10xf32>) outs(%fill : tensor<8x10xf32>) + return %sqrt : tensor<8x10xf32> +} + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0)> +// CHECK-LABEL: func @map_genric_ops +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8x10xf32> +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x10xf32> +// CHECK: %[[FUSED_OP:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP0]], #[[MAP0]]] +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] : +// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]] +// CHECK-NEXT: %[[SQRT:.*]] = math.sqrt %[[ADD]] +// CHECK-NEXT: linalg.yield %[[SQRT]] +// CHECK-NOT: linalg.map