Skip to content

[mlir][linalg] Extend FuseElementwiseOps pattern to work with named ops #144922

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -529,8 +529,8 @@ fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand);
/// * There is a chance that the implementation of the transformation does not
/// agree with the result of this method. This function gives a prediction based
/// on an optimized fusion.
llvm::SmallDenseSet<int> getPreservedProducerResults(GenericOp producer,
GenericOp consumer,
llvm::SmallDenseSet<int> getPreservedProducerResults(LinalgOp producer,
LinalgOp consumer,
OpOperand *fusedOperand);

/// Try to peel and canonicalize loop `op` and return the new result.
Expand Down
63 changes: 42 additions & 21 deletions mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,11 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
// of the fused producer & consumer after the fusion can still compute the
// bounds of the op.
static bool isOpOperandCanBeDroppedAfterFusedLinalgs(
GenericOp producer, GenericOp consumer,
LinalgOp producer, LinalgOp consumer,
ArrayRef<OpOperand *> opOperandsToIgnore) {
SmallVector<AffineMap> indexingMaps;

SmallVector<GenericOp> ops = {producer, consumer};
SmallVector<LinalgOp> ops = {producer, consumer};
for (auto &op : ops) {
for (auto &opOperand : op->getOpOperands()) {
if (llvm::is_contained(opOperandsToIgnore, &opOperand)) {
Expand Down Expand Up @@ -109,8 +109,9 @@ static bool isOpOperandCanBeDroppedAfterFusedLinalgs(
/// * There is a chance that the implementation of the transformation does not
/// agree with the result of this method. This function gives a prediction based
/// on an optimized fusion.
llvm::SmallDenseSet<int> mlir::linalg::getPreservedProducerResults(
GenericOp producer, GenericOp consumer, OpOperand *fusedOperand) {
llvm::SmallDenseSet<int>
mlir::linalg::getPreservedProducerResults(LinalgOp producer, LinalgOp consumer,
OpOperand *fusedOperand) {
llvm::SmallDenseSet<int> preservedProducerResults;
llvm::SmallVector<OpOperand *> opOperandsToIgnore;

Expand Down Expand Up @@ -140,8 +141,8 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
if (!fusedOperand)
return false;

auto producer = fusedOperand->get().getDefiningOp<GenericOp>();
auto consumer = dyn_cast<GenericOp>(fusedOperand->getOwner());
auto producer = fusedOperand->get().getDefiningOp<LinalgOp>();
auto consumer = dyn_cast<LinalgOp>(fusedOperand->getOwner());

// Check producer and consumer are generic ops.
if (!producer || !consumer)
Expand Down Expand Up @@ -215,16 +216,39 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
/// Generate the region of the fused tensor operation. The region of the fused
/// op must be empty.
static void generateFusedElementwiseOpRegion(
RewriterBase &rewriter, GenericOp fusedOp,
RewriterBase &rewriter, LinalgOp fusedOp,
AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand,
unsigned nloops, llvm::SmallDenseSet<int> &preservedProducerResults) {
auto producer = cast<GenericOp>(fusedOperand->get().getDefiningOp());
auto consumer = cast<GenericOp>(fusedOperand->getOwner());
auto producer = cast<LinalgOp>(fusedOperand->get().getDefiningOp());
auto consumer = cast<LinalgOp>(fusedOperand->getOwner());
// Build the region of the fused op.

// Since some ops, like `linalg.map`, do not have block arguments for init operands
// then we first "generalize" the block by adding arguments for init operands when
// they aren't present. We detect this case by checking if
// `getOpOperandsMatchingBBargs() == getDpsInputOperands();
Block &producerBlock = producer->getRegion(0).front();
if (producer.getOpOperandsMatchingBBargs() ==
producer.getDpsInputOperands()) {
for (auto init : producer.getDpsInits()) {
Type bbType = isa<ShapedType>(init.getType())
? cast<ShapedType>(init.getType()).getElementType()
: init.getType();
producerBlock.addArgument(bbType, producer.getLoc());
}
}
Block &consumerBlock = consumer->getRegion(0).front();
if (consumer.getOpOperandsMatchingBBargs() ==
consumer.getDpsInputOperands()) {
for (auto init : consumer.getDpsInits()) {
Type bbType = isa<ShapedType>(init.getType())
? cast<ShapedType>(init.getType()).getElementType()
: init.getType();
consumerBlock.addArgument(bbType, consumer.getLoc());
}
}
OpBuilder::InsertionGuard guard(rewriter);
Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion());
Block *fusedBlock = rewriter.createBlock(&fusedOp->getRegion(0));
IRMapping mapper;

// 2. Add an index operation for every fused loop dimension and use the
Expand Down Expand Up @@ -330,7 +354,7 @@ static void generateFusedElementwiseOpRegion(
rewriter.create<YieldOp>(fusedOp.getLoc(), fusedYieldValues);

// Sanity checks.
assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
assert(fusedBlock->getNumArguments() == fusedOp->getNumOperands() &&
"Ill-formed GenericOp region");
}

Expand All @@ -340,8 +364,8 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
assert(areElementwiseOpsFusable(fusedOperand) &&
"expected elementwise operation pre-conditions to pass");
auto producerResult = cast<OpResult>(fusedOperand->get());
auto producer = cast<GenericOp>(producerResult.getOwner());
auto consumer = cast<GenericOp>(fusedOperand->getOwner());
auto producer = cast<LinalgOp>(producerResult.getOwner());
auto consumer = cast<LinalgOp>(fusedOperand->getOwner());
// TODO: allow fusing the producer of an output operand.
assert(consumer.isDpsInput(fusedOperand) &&
"expected producer of input operand");
Expand Down Expand Up @@ -418,10 +442,7 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
// Generate the fused op.
auto fusedOp = rewriter.create<GenericOp>(
consumer.getLoc(), fusedResultTypes, fusedInputOperands,
fusedOutputOperands, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
consumer.getIteratorTypes(),
/*doc=*/nullptr,
/*library_call=*/nullptr);
fusedOutputOperands, fusedIndexMaps, consumer.getIteratorTypesArray());
if (!fusedOp.getShapesToLoopsMap()) {
// Fused op has invalid indexing maps. Typically this means something is off
// in the input, but going ahead here would result in verification errors.
Expand Down Expand Up @@ -460,14 +481,14 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,

namespace {
/// Patterns to fuse a generic op, with the producer of its operands.
class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
class FuseElementwiseOps : public OpInterfaceRewritePattern<LinalgOp> {
public:
FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun,
PatternBenefit benefit = 1)
: OpRewritePattern<GenericOp>(context, benefit),
: OpInterfaceRewritePattern<LinalgOp>(context, benefit),
controlFn(std::move(fun)) {}

LogicalResult matchAndRewrite(GenericOp genericOp,
LogicalResult matchAndRewrite(LinalgOp genericOp,
PatternRewriter &rewriter) const override {
// Find the first operand that is defined by another generic op on tensors.
for (OpOperand &opOperand : genericOp->getOpOperands()) {
Expand All @@ -494,7 +515,7 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
rewriter.eraseOp(genericOp);
return success();
}
return failure();
return rewriter.notifyMatchFailure(genericOp, "no fusable operands");
}

private:
Expand Down
21 changes: 21 additions & 0 deletions mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1014,3 +1014,24 @@ module {
// CHECK-DAG: %[[T3:.+]] = arith.addf %[[T2]], %[[B1]]
// CHECK: linalg.yield %[[T3]] : f32
// CHECK: return %[[GENERIC]]

// -----

func.func @map_ops(%in1: tensor<8xf32>, %in2: tensor<8xf32>) -> tensor<8xf32> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also add a test where the map operations each have more than one region ops? The fuser should be able to cope with it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure. I initially didn't go that far with tests because at the time my reasoning was that the logic is unchanged so should generalize. but then I ran across that one issue with the bb args, so I'm less convinced of that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

%fill = tensor.empty() : tensor<8xf32>
%add = linalg.map {arith.addf} ins(%in1, %in2: tensor<8xf32>, tensor<8xf32>) outs(%fill: tensor<8xf32>)
%mapped_65 = linalg.map { math.sqrt } ins(%add : tensor<8xf32>) outs(%fill : tensor<8xf32>)
return %mapped_65 : tensor<8xf32>
}

// CHECK-LABEL: func @map_ops
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32>
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32>
// CHECK: %[[FUSED_OP:.+]] = linalg.generic
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] :
// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]]
// CHECK-NEXT: %[[SQRT:.*]] = math.sqrt %[[ADD]]
// CHECK-NEXT: linalg.yield %[[SQRT]]
// CHECK-NOT: linalg.generic
54 changes: 54 additions & 0 deletions mlir/test/Dialect/Linalg/fusion-elementwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,57 @@ func.func @handle_unused_operands(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) ->
// CHECK: %[[FUSED_OP:.+]] = linalg.generic
// CHECK-SAME: outs(%[[EMPTY]] :
// CHECK-NOT: linalg.generic

// -----

func.func @map_ops(%in1: tensor<8xf32>, %in2: tensor<8xf32>) -> tensor<8xf32> {
%fill = tensor.empty() : tensor<8xf32>
%add = linalg.map {arith.addf} ins(%in1, %in2: tensor<8xf32>, tensor<8xf32>) outs(%fill: tensor<8xf32>)
%mapped_65 = linalg.map { math.sqrt } ins(%add : tensor<8xf32>) outs(%fill : tensor<8xf32>)
return %mapped_65 : tensor<8xf32>
}

// CHECK-LABEL: func @map_ops
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32>
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32>
// CHECK: %[[FUSED_OP:.+]] = linalg.generic
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe not for this PR, but we have discussed keeping the named ops for as long as possible. Here, since they're both maps, we could fuse into a map still. Technically, they're the same (as discussed in the forum), but if I have a chain of matches and fusers, I'd have to match against all possible representations.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yah I mentioned in a comment that I wanted to try to do that, but don't know of a clean way to do that yet. I've done something like that before by just using clone and modifying as a way to generalizing a transform from a named op to that same named op. It seems more complicated here, but maybe not

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also that wouldn't help with elementwise+elementwise -> map anyway

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also that wouldn't help with elementwise+elementwise -> map anyway

Right, the idea is that ew + ew -> map is still better than to generic. So we only walk up the tree when needed (assuming generic -> map -> ew is the branch we're walking).

Copy link
Contributor Author

@srcarroll srcarroll Jun 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one thing to note. ew + ew -> map only works if maps for both ew are same rank identity on all operands since indexing maps for map are limited

// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] :
// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]]
// CHECK-NEXT: %[[SQRT:.*]] = math.sqrt %[[ADD]]
// CHECK-NEXT: linalg.yield %[[SQRT]]
// CHECK-NOT: linalg.map

// -----

func.func @map_ops_mixed_types(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> tensor<8xf32> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to fuse map and generic together, if their affine maps and iterator types are compatible? If yes, we should have a quick test on it. If not, this should eventually be supported (separate PR), so a FIXME in the code would help.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my thinking is that if it works for the generic form it should work for map form. but again, that block arg oddity kinda ruins that. nevertheless I expect that to work, but will add a test

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

%init = tensor.empty() : tensor<8xi1>
%initf = tensor.empty() : tensor<8xf32>
%0 = linalg.map {math.sqrt} ins(%arg0 : tensor<8xf32>) outs(%initf : tensor<8xf32>)
%1 = linalg.map {math.exp} ins(%arg1 : tensor<8xf32>) outs(%initf : tensor<8xf32>)
%2 = linalg.map ins(%0, %1 : tensor<8xf32>, tensor<8xf32>) outs (%init : tensor<8xi1>)
(%in0 : f32, %in1 : f32) {
%cmp = arith.cmpf olt, %in0, %in1 : f32
linalg.yield %cmp : i1
}
%3 = linalg.map { arith.select } ins(%2, %0, %1 : tensor<8xi1>, tensor<8xf32>, tensor<8xf32>) outs(%initf : tensor<8xf32>)
return %3 : tensor<8xf32>
}

// CHECK-LABEL: func @map_ops_mixed_types
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32>
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32>
// CHECK: %[[FUSED_OP:.+]] = linalg.generic
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] :
// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
// CHECK-NEXT: %[[EXP0:.*]] = math.exp %[[IN1]]
// CHECK-NEXT: %[[SQRT0:.*]] = math.sqrt %[[IN0]]
// CHECK-NEXT: %[[EXP1:.*]] = math.exp %[[IN1]]
// CHECK-NEXT: %[[SQRT1:.*]] = math.sqrt %[[IN0]]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think the duplicate computations are an old artifact. these do go away with cse but let me know if this is something that should be looked at

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is indeed odd. Looks like a bug in the fuser. Could be related to the map vs generic issue you've seen above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did make a generic version of this and ran the old version of the pass and got same results to confirm it's a pre-existing thing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here's the generic version

#map = affine_map<(d0)->(d0)>
func.func @map_ops_mixed_types(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>, %arg3: tensor<8xf32>) -> tensor<8xf32> {
  %init = tensor.empty() : tensor<8xi1>
  %initf = tensor.empty() : tensor<8xf32>
  %0 = linalg.generic {
      indexing_maps = [#map, #map],
      iterator_types = ["parallel"]} ins(%arg0 : tensor<8xf32>) outs(%initf : tensor<8xf32>) {
    ^bb0(%in0 : f32, %out : f32):
        %sqrt = math.sqrt %in0 : f32
        linalg.yield %sqrt : f32
    } -> tensor<8xf32>
  %1 = linalg.generic {
      indexing_maps = [#map, #map],
      iterator_types = ["parallel"]} ins(%arg1 : tensor<8xf32>) outs(%initf : tensor<8xf32>) {
    ^bb0(%in0 : f32, %out : f32):
        %sqrt = math.exp %in0 : f32
        linalg.yield %sqrt : f32
    } -> tensor<8xf32>
  %2 = linalg.generic {
      indexing_maps = [#map, #map, #map],
      iterator_types = ["parallel"]} ins(%0, %1 : tensor<8xf32>, tensor<8xf32>) outs(%init : tensor<8xi1>) {
    ^bb0(%in0 : f32, %in1 : f32, %out: i1):
      %cmp = arith.cmpf olt, %in0, %in1 : f32
      linalg.yield %cmp : i1
  } -> tensor<8xi1>
  %3 = linalg.generic {
      indexing_maps = [#map, #map, #map, #map],
      iterator_types = ["parallel"]} ins(%2, %0, %1 : tensor<8xi1>, tensor<8xf32>, tensor<8xf32>) outs(%initf : tensor<8xf32>) { 
    ^bb0(%in0 : i1, %in1 : f32, %in2 : f32, %out: f32):
      %select = arith.select %in0, %in1, %in2 : f32
      linalg.yield %select : f32
  } -> tensor<8xf32>
  return %3 : tensor<8xf32>
}

// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf olt, %[[SQRT1]], %[[EXP1]]
// CHECK-NEXT: %[[RES:.*]] = arith.select %[[CMP]], %[[SQRT0]], %[[EXP0]]
// CHECK-NEXT: linalg.yield %[[RES]]
// CHECK-NOT: linalg.map

Loading