Skip to content

[mlir][linalg] Extend FuseElementwiseOps pattern to work with named ops #144922

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -529,8 +529,8 @@ fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand);
/// * There is a chance that the implementation of the transformation does not
/// agree with the result of this method. This function gives a prediction based
/// on an optimized fusion.
llvm::SmallDenseSet<int> getPreservedProducerResults(GenericOp producer,
GenericOp consumer,
llvm::SmallDenseSet<int> getPreservedProducerResults(LinalgOp producer,
LinalgOp consumer,
OpOperand *fusedOperand);

/// Try to peel and canonicalize loop `op` and return the new result.
Expand Down
69 changes: 48 additions & 21 deletions mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,11 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
// of the fused producer & consumer after the fusion can still compute the
// bounds of the op.
static bool isOpOperandCanBeDroppedAfterFusedLinalgs(
GenericOp producer, GenericOp consumer,
LinalgOp producer, LinalgOp consumer,
ArrayRef<OpOperand *> opOperandsToIgnore) {
SmallVector<AffineMap> indexingMaps;

SmallVector<GenericOp> ops = {producer, consumer};
SmallVector<LinalgOp> ops = {producer, consumer};
for (auto &op : ops) {
for (auto &opOperand : op->getOpOperands()) {
if (llvm::is_contained(opOperandsToIgnore, &opOperand)) {
Expand Down Expand Up @@ -109,8 +109,9 @@ static bool isOpOperandCanBeDroppedAfterFusedLinalgs(
/// * There is a chance that the implementation of the transformation does not
/// agree with the result of this method. This function gives a prediction based
/// on an optimized fusion.
llvm::SmallDenseSet<int> mlir::linalg::getPreservedProducerResults(
GenericOp producer, GenericOp consumer, OpOperand *fusedOperand) {
llvm::SmallDenseSet<int>
mlir::linalg::getPreservedProducerResults(LinalgOp producer, LinalgOp consumer,
OpOperand *fusedOperand) {
llvm::SmallDenseSet<int> preservedProducerResults;
llvm::SmallVector<OpOperand *> opOperandsToIgnore;

Expand Down Expand Up @@ -140,8 +141,8 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
if (!fusedOperand)
return false;

auto producer = fusedOperand->get().getDefiningOp<GenericOp>();
auto consumer = dyn_cast<GenericOp>(fusedOperand->getOwner());
auto producer = fusedOperand->get().getDefiningOp<LinalgOp>();
auto consumer = dyn_cast<LinalgOp>(fusedOperand->getOwner());

// Check producer and consumer are generic ops.
if (!producer || !consumer)
Expand Down Expand Up @@ -215,16 +216,37 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
/// Generate the region of the fused tensor operation. The region of the fused
/// op must be empty.
static void generateFusedElementwiseOpRegion(
RewriterBase &rewriter, GenericOp fusedOp,
RewriterBase &rewriter, LinalgOp fusedOp,
AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand,
unsigned nloops, llvm::SmallDenseSet<int> &preservedProducerResults) {
auto producer = cast<GenericOp>(fusedOperand->get().getDefiningOp());
auto consumer = cast<GenericOp>(fusedOperand->getOwner());
auto producer = cast<LinalgOp>(fusedOperand->get().getDefiningOp());
auto consumer = cast<LinalgOp>(fusedOperand->getOwner());
// Build the region of the fused op.

// Since some ops, like `linalg.map`, do not have block arguments for init
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a missed opportunity. IIRC, neither contract and elementwise have that problem. Perhaps we can update map to behave like the new ones first? @javedabsar1 @shahidact

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that it needs to be updated, or even better if we deprecate it in favour of linalg.elementwise as linalg.map is same semantically with linalg.elementwise, in fact linalg.elementwise seems more general, IIRC.

Copy link
Contributor Author

@srcarroll srcarroll Jun 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rengolin and I had a discussion of that here. seems like we want to keep map and it does support multi-op regions whereas elementwise does not. So elementwise isn't quite more general. They each have something the other doesn't. But yes i agree that this is an odd feature of map, so if we do keep it would be nice to canonicalize this kind of thing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so i actually found another bug and i think it's because of what i'm doing here. i'm modifying the blocks for producers and consumers. as long as they don't stick around, this works fine. but if the producer has more than one user so that it has to stick around, then this logic here converts it to an invalid op. i'm in the process of confirming this. In any event, it's probably not a good idea to modify the blocks in place like this

Copy link
Contributor Author

@srcarroll srcarroll Jun 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we can update map to behave like the new ones first?

i'm fine with waiting for that change so we don't have to figure out special logic here

Copy link
Contributor Author

@srcarroll srcarroll Jun 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i fixed the bug here. the problem was what i expected. i made note here that this is a hack and should be changed before merging. So I'll either come up with something better or wait for the map op change (if we decide to go with that).

I can't actually reproduce the bug with builtin passes. I only discovered it when I applied my own control function for populateElementwiseOpsFusionPatterns

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't actually reproduce the bug with builtin passes. I only discovered it when I applied my own control function for populateElementwiseOpsFusionPatterns

If there's a bug present is the current fuser, it'd be worth adding such test case to the TestLinalgElementwiseFusion. You could extend it with another option flag that adds necessary custom control logic.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure. however, since the bug was an artifact of my own changes that I noted are hacky, and those changes will go away before this merges anyway, it would be irrelevant by that point.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rengolin do you know if anyone will make the change to map, or should I do it?

// operands then we first "generalize" the block by adding arguments for init
// operands when they aren't present. We detect this case by checking if
// `getOpOperandsMatchingBBargs() == getDpsInputOperands()`.
// TODO: This is hacky and should not be merged. Keeping for now for testing
// purposes in the meantime, but need a better way
Block &producerBlock = producer->getRegion(0).front();
bool addOutputArgsProducer =
producer.getOpOperandsMatchingBBargs() == producer.getDpsInputOperands();
if (addOutputArgsProducer) {
for (auto init : producer.getDpsInits())
producerBlock.addArgument(getElementTypeOrSelf(init.getType()),
producer.getLoc());
}
Block &consumerBlock = consumer->getRegion(0).front();
bool addOutputArgsConsumer =
consumer.getOpOperandsMatchingBBargs() == consumer.getDpsInputOperands();
if (addOutputArgsConsumer) {
for (auto init : consumer.getDpsInits())
consumerBlock.addArgument(getElementTypeOrSelf(init.getType()),
consumer.getLoc());
}
OpBuilder::InsertionGuard guard(rewriter);
Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion());
Block *fusedBlock = rewriter.createBlock(&fusedOp->getRegion(0));
IRMapping mapper;

// 2. Add an index operation for every fused loop dimension and use the
Expand Down Expand Up @@ -330,8 +352,16 @@ static void generateFusedElementwiseOpRegion(
rewriter.create<YieldOp>(fusedOp.getLoc(), fusedYieldValues);

// Sanity checks.
assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
assert(fusedBlock->getNumArguments() == fusedOp->getNumOperands() &&
"Ill-formed GenericOp region");
// Erase added args in case that the ops are still live after fusion.
// TODO: Remove along with hacky code above.
if (addOutputArgsProducer)
producerBlock.eraseArguments(producer.getNumDpsInputs(),
producer.getNumDpsInits());
if (addOutputArgsConsumer)
consumerBlock.eraseArguments(consumer.getNumDpsInputs(),
consumer.getNumDpsInits());
}

FailureOr<mlir::linalg::ElementwiseOpFusionResult>
Expand All @@ -340,8 +370,8 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
assert(areElementwiseOpsFusable(fusedOperand) &&
"expected elementwise operation pre-conditions to pass");
auto producerResult = cast<OpResult>(fusedOperand->get());
auto producer = cast<GenericOp>(producerResult.getOwner());
auto consumer = cast<GenericOp>(fusedOperand->getOwner());
auto producer = cast<LinalgOp>(producerResult.getOwner());
auto consumer = cast<LinalgOp>(fusedOperand->getOwner());
// TODO: allow fusing the producer of an output operand.
assert(consumer.isDpsInput(fusedOperand) &&
"expected producer of input operand");
Expand Down Expand Up @@ -418,10 +448,7 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
// Generate the fused op.
auto fusedOp = rewriter.create<GenericOp>(
consumer.getLoc(), fusedResultTypes, fusedInputOperands,
fusedOutputOperands, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
consumer.getIteratorTypes(),
/*doc=*/nullptr,
/*library_call=*/nullptr);
fusedOutputOperands, fusedIndexMaps, consumer.getIteratorTypesArray());
if (!fusedOp.getShapesToLoopsMap()) {
// Fused op has invalid indexing maps. Typically this means something is off
// in the input, but going ahead here would result in verification errors.
Expand Down Expand Up @@ -460,14 +487,14 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,

namespace {
/// Patterns to fuse a generic op, with the producer of its operands.
class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
class FuseElementwiseOps : public OpInterfaceRewritePattern<LinalgOp> {
public:
FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun,
PatternBenefit benefit = 1)
: OpRewritePattern<GenericOp>(context, benefit),
: OpInterfaceRewritePattern<LinalgOp>(context, benefit),
controlFn(std::move(fun)) {}

LogicalResult matchAndRewrite(GenericOp genericOp,
LogicalResult matchAndRewrite(LinalgOp genericOp,
PatternRewriter &rewriter) const override {
// Find the first operand that is defined by another generic op on tensors.
for (OpOperand &opOperand : genericOp->getOpOperands()) {
Expand All @@ -494,7 +521,7 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
rewriter.eraseOp(genericOp);
return success();
}
return failure();
return rewriter.notifyMatchFailure(genericOp, "no fusable operands");
}

private:
Expand Down
21 changes: 21 additions & 0 deletions mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1014,3 +1014,24 @@ module {
// CHECK-DAG: %[[T3:.+]] = arith.addf %[[T2]], %[[B1]]
// CHECK: linalg.yield %[[T3]] : f32
// CHECK: return %[[GENERIC]]

// -----

func.func @map_ops(%in1: tensor<8xf32>, %in2: tensor<8xf32>) -> tensor<8xf32> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also add a test where the map operations each have more than one region ops? The fuser should be able to cope with it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure. I initially didn't go that far with tests because at the time my reasoning was that the logic is unchanged so should generalize. but then I ran across that one issue with the bb args, so I'm less convinced of that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

%fill = tensor.empty() : tensor<8xf32>
%add = linalg.map {arith.addf} ins(%in1, %in2: tensor<8xf32>, tensor<8xf32>) outs(%fill: tensor<8xf32>)
%mapped_65 = linalg.map { math.sqrt } ins(%add : tensor<8xf32>) outs(%fill : tensor<8xf32>)
return %mapped_65 : tensor<8xf32>
}

// CHECK-LABEL: func @map_ops
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32>
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32>
// CHECK: %[[FUSED_OP:.+]] = linalg.generic
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] :
// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]]
// CHECK-NEXT: %[[SQRT:.*]] = math.sqrt %[[ADD]]
// CHECK-NEXT: linalg.yield %[[SQRT]]
// CHECK-NOT: linalg.generic
151 changes: 151 additions & 0 deletions mlir/test/Dialect/Linalg/fusion-elementwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,154 @@ func.func @handle_unused_operands(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) ->
// CHECK: %[[FUSED_OP:.+]] = linalg.generic
// CHECK-SAME: outs(%[[EMPTY]] :
// CHECK-NOT: linalg.generic

// -----

func.func @map_ops(%in1: tensor<8xf32>, %in2: tensor<8xf32>) -> tensor<8xf32> {
%fill = tensor.empty() : tensor<8xf32>
%add = linalg.map {arith.addf} ins(%in1, %in2: tensor<8xf32>, tensor<8xf32>) outs(%fill: tensor<8xf32>)
%sqrt = linalg.map { math.sqrt } ins(%add : tensor<8xf32>) outs(%fill : tensor<8xf32>)
return %sqrt : tensor<8xf32>
}

// CHECK-LABEL: func @map_ops
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32>
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32>
// CHECK: %[[FUSED_OP:.+]] = linalg.generic
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe not for this PR, but we have discussed keeping the named ops for as long as possible. Here, since they're both maps, we could fuse into a map still. Technically, they're the same (as discussed in the forum), but if I have a chain of matches and fusers, I'd have to match against all possible representations.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yah I mentioned in a comment that I wanted to try to do that, but don't know of a clean way to do that yet. I've done something like that before by just using clone and modifying as a way to generalizing a transform from a named op to that same named op. It seems more complicated here, but maybe not

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also that wouldn't help with elementwise+elementwise -> map anyway

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also that wouldn't help with elementwise+elementwise -> map anyway

Right, the idea is that ew + ew -> map is still better than to generic. So we only walk up the tree when needed (assuming generic -> map -> ew is the branch we're walking).

Copy link
Contributor Author

@srcarroll srcarroll Jun 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one thing to note. ew + ew -> map only works if maps for both ew are same rank identity on all operands since indexing maps for map are limited

// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] :
// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]]
// CHECK-NEXT: %[[SQRT:.*]] = math.sqrt %[[ADD]]
// CHECK-NEXT: linalg.yield %[[SQRT]]
// CHECK-NOT: linalg.map

// -----

func.func @map_ops_mixed_types(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> tensor<8xf32> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to fuse map and generic together, if their affine maps and iterator types are compatible? If yes, we should have a quick test on it. If not, this should eventually be supported (separate PR), so a FIXME in the code would help.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my thinking is that if it works for the generic form it should work for map form. but again, that block arg oddity kinda ruins that. nevertheless I expect that to work, but will add a test

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

%init = tensor.empty() : tensor<8xi1>
%initf = tensor.empty() : tensor<8xf32>
%0 = linalg.map {math.sqrt} ins(%arg0 : tensor<8xf32>) outs(%initf : tensor<8xf32>)
%1 = linalg.map {math.exp} ins(%arg1 : tensor<8xf32>) outs(%initf : tensor<8xf32>)
%2 = linalg.map ins(%0, %1 : tensor<8xf32>, tensor<8xf32>) outs (%init : tensor<8xi1>)
(%in0 : f32, %in1 : f32) {
%cmp = arith.cmpf olt, %in0, %in1 : f32
linalg.yield %cmp : i1
}
%3 = linalg.map { arith.select } ins(%2, %0, %1 : tensor<8xi1>, tensor<8xf32>, tensor<8xf32>) outs(%initf : tensor<8xf32>)
return %3 : tensor<8xf32>
}

// CHECK-LABEL: func @map_ops_mixed_types
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32>
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32>
// CHECK: %[[FUSED_OP:.+]] = linalg.generic
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] :
// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
// CHECK-NEXT: %[[EXP0:.*]] = math.exp %[[IN1]]
// CHECK-NEXT: %[[SQRT0:.*]] = math.sqrt %[[IN0]]
// CHECK-NEXT: %[[EXP1:.*]] = math.exp %[[IN1]]
// CHECK-NEXT: %[[SQRT1:.*]] = math.sqrt %[[IN0]]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think the duplicate computations are an old artifact. these do go away with cse but let me know if this is something that should be looked at

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is indeed odd. Looks like a bug in the fuser. Could be related to the map vs generic issue you've seen above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did make a generic version of this and ran the old version of the pass and got same results to confirm it's a pre-existing thing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here's the generic version

#map = affine_map<(d0)->(d0)>
func.func @map_ops_mixed_types(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>, %arg3: tensor<8xf32>) -> tensor<8xf32> {
  %init = tensor.empty() : tensor<8xi1>
  %initf = tensor.empty() : tensor<8xf32>
  %0 = linalg.generic {
      indexing_maps = [#map, #map],
      iterator_types = ["parallel"]} ins(%arg0 : tensor<8xf32>) outs(%initf : tensor<8xf32>) {
    ^bb0(%in0 : f32, %out : f32):
        %sqrt = math.sqrt %in0 : f32
        linalg.yield %sqrt : f32
    } -> tensor<8xf32>
  %1 = linalg.generic {
      indexing_maps = [#map, #map],
      iterator_types = ["parallel"]} ins(%arg1 : tensor<8xf32>) outs(%initf : tensor<8xf32>) {
    ^bb0(%in0 : f32, %out : f32):
        %sqrt = math.exp %in0 : f32
        linalg.yield %sqrt : f32
    } -> tensor<8xf32>
  %2 = linalg.generic {
      indexing_maps = [#map, #map, #map],
      iterator_types = ["parallel"]} ins(%0, %1 : tensor<8xf32>, tensor<8xf32>) outs(%init : tensor<8xi1>) {
    ^bb0(%in0 : f32, %in1 : f32, %out: i1):
      %cmp = arith.cmpf olt, %in0, %in1 : f32
      linalg.yield %cmp : i1
  } -> tensor<8xi1>
  %3 = linalg.generic {
      indexing_maps = [#map, #map, #map, #map],
      iterator_types = ["parallel"]} ins(%2, %0, %1 : tensor<8xi1>, tensor<8xf32>, tensor<8xf32>) outs(%initf : tensor<8xf32>) { 
    ^bb0(%in0 : i1, %in1 : f32, %in2 : f32, %out: f32):
      %select = arith.select %in0, %in1, %in2 : f32
      linalg.yield %select : f32
  } -> tensor<8xf32>
  return %3 : tensor<8xf32>
}

// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf olt, %[[SQRT1]], %[[EXP1]]
// CHECK-NEXT: %[[RES:.*]] = arith.select %[[CMP]], %[[SQRT0]], %[[EXP0]]
// CHECK-NEXT: linalg.yield %[[RES]]
// CHECK-NOT: linalg.map

// -----

#identity = affine_map<(d0, d1) -> (d0, d1)>
#bcast = affine_map<(d0, d1) -> (d0)>
func.func @elementwise_ops(%in1: tensor<8xf32>, %in2: tensor<8x10xf32>) -> tensor<8x10xf32> {
%fill = tensor.empty() : tensor<8x10xf32>
%add = linalg.elementwise
kind=#linalg.elementwise_kind<add>
indexing_maps = [#bcast, #identity, #identity]
ins(%in1, %in2: tensor<8xf32>, tensor<8x10xf32>) outs(%fill: tensor<8x10xf32>) -> tensor<8x10xf32>
%sqrt = linalg.elementwise
kind=#linalg.elementwise_kind<sqrt>
indexing_maps = [#identity, #identity]
ins(%add : tensor<8x10xf32>) outs(%fill : tensor<8x10xf32>) -> tensor<8x10xf32>
return %sqrt : tensor<8x10xf32>
}

// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
// CHECK-LABEL: func @elementwise_ops
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8x10xf32>
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x10xf32>
// CHECK: %[[FUSED_OP:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP0]], #[[MAP0]]]
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] :
// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]]
// CHECK-NEXT: %[[SQRT:.*]] = math.sqrt %[[ADD]]
// CHECK-NEXT: linalg.yield %[[SQRT]]
// CHECK-NOT: linalg.map

// -----

func.func @map_multi_ops(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>) -> tensor<8xf32> {
%fill = tensor.empty() : tensor<8xf32>
%add_exp = linalg.map ins(%arg0, %arg1: tensor<8xf32>, tensor<8xf32>) outs(%fill: tensor<8xf32>)
(%in0 : f32, %in1 : f32) {
%add = arith.addf %in0, %in1 : f32
%exp = math.exp %add : f32
linalg.yield %exp : f32
}
%sqrt_mul = linalg.map ins(%add_exp, %arg2 : tensor<8xf32>, tensor<8xf32>) outs(%fill : tensor<8xf32>)
(%in0 : f32, %in1 : f32) {
%sqrt = math.sqrt %in0 : f32
%mul = arith.mulf %sqrt, %in1 : f32
linalg.yield %mul : f32
}
return %sqrt_mul : tensor<8xf32>
}

// CHECK-LABEL: func @map_multi_ops
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<8xf32>
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32>
// CHECK: %[[FUSED_OP:.+]] = linalg.generic
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]] : {{.*}}) outs(%[[EMPTY]] :
// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[IN2:.*]]: f32, %[[OUT:.*]]: f32):
// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]]
// CHECK-NEXT: %[[EXP:.*]] = math.exp %[[ADD]]
// CHECK-NEXT: %[[SQRT:.*]] = math.sqrt %[[EXP]]
// CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[SQRT]], %[[IN2]]
// CHECK-NEXT: linalg.yield %[[MUL]]
// CHECK-NOT: linalg.map

// -----

#identity = affine_map<(d0, d1) -> (d0, d1)>
#bcast = affine_map<(d0, d1) -> (d0)>
func.func @map_genric_ops(%arg0: tensor<8xf32>, %arg1: tensor<8x10xf32>) -> tensor<8x10xf32> {
%fill = tensor.empty() : tensor<8x10xf32>
%add = linalg.generic
{indexing_maps = [#bcast, #identity, #identity], iterator_types = ["parallel", "parallel"]}
ins(%arg0, %arg1: tensor<8xf32>, tensor<8x10xf32>) outs(%fill: tensor<8x10xf32>) {
^bb0(%in0: f32, %in1: f32, %out: f32):
%add = arith.addf %in0, %in1 : f32
linalg.yield %add : f32
} -> tensor<8x10xf32>
%sqrt = linalg.map { math.sqrt } ins(%add : tensor<8x10xf32>) outs(%fill : tensor<8x10xf32>)
return %sqrt : tensor<8x10xf32>
}

// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
// CHECK-LABEL: func @map_genric_ops
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8x10xf32>
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x10xf32>
// CHECK: %[[FUSED_OP:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP0]], #[[MAP0]]]
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] :
// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]]
// CHECK-NEXT: %[[SQRT:.*]] = math.sqrt %[[ADD]]
// CHECK-NEXT: linalg.yield %[[SQRT]]
// CHECK-NOT: linalg.map
Loading