From 8f40426eb97dad07474500d6e37706a186642331 Mon Sep 17 00:00:00 2001 From: Andrey Turetskiy Date: Thu, 6 Mar 2025 15:17:08 -0800 Subject: [PATCH] [mlir] Introduction of LocalEffectsOpInterface. --- .../Dialect/Linalg/IR/LinalgStructuredOps.td | 1 + .../mlir/Dialect/MemRef/IR/MemRefOps.td | 5 +++- .../mlir/Interfaces/ControlFlowInterfaces.h | 1 + .../mlir/Interfaces/ControlFlowInterfaces.td | 4 ++- .../mlir/Interfaces/SideEffectInterfaces.h | 9 ++++++ .../mlir/Interfaces/SideEffectInterfaces.td | 25 ++++++++++++++++ mlir/lib/Dialect/Affine/Analysis/Utils.cpp | 29 ++++++++---------- mlir/lib/Interfaces/SideEffectInterfaces.cpp | 21 +++++++++++++ mlir/test/Dialect/Affine/loop-fusion-4.mlir | 30 +++++++++++++++++++ 9 files changed, 107 insertions(+), 18 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index e4dd458eaff84..66911c578c1c1 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -31,6 +31,7 @@ class LinalgStructuredBase_Op props> DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, RecursiveMemoryEffects, + LocalEffectsOpInterface, DestinationStyleOpInterface, LinalgStructuredInterface, ReifyRankedShapedTypeOpInterface], props)> { diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index 4c8a214049ea9..f0f93b3ca4a89 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -547,7 +547,8 @@ def CopyOp : MemRef_Op<"copy", [CopyOpInterface, SameOperandsElementType, // DeallocOp //===----------------------------------------------------------------------===// -def MemRef_DeallocOp : MemRef_Op<"dealloc", [MemRefsNormalizable]> { +def MemRef_DeallocOp + : MemRef_Op<"dealloc", [MemRefsNormalizable, LocalEffectsOpInterface]> { let summary = "memory deallocation operation"; let description = [{ The `dealloc` operation frees the region of memory referenced by a memref @@ -1180,6 +1181,7 @@ def LoadOp : MemRef_Op<"load", "memref", "result", "::llvm::cast($_self).getElementType()">, MemRefsNormalizable, + LocalEffectsOpInterface, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "load operation"; @@ -1813,6 +1815,7 @@ def MemRef_StoreOp : MemRef_Op<"store", "memref", "value", "::llvm::cast($_self).getElementType()">, MemRefsNormalizable, + LocalEffectsOpInterface, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "store operation"; diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h index 7f6967f11444f..8a20d4ce52e21 100644 --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h @@ -15,6 +15,7 @@ #define MLIR_INTERFACES_CONTROLFLOWINTERFACES_H #include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" namespace mlir { class BranchOpInterface; diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td index 69bce78e946c8..2e7c46e3dd184 100644 --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td @@ -14,6 +14,7 @@ #ifndef MLIR_INTERFACES_CONTROLFLOWINTERFACES #define MLIR_INTERFACES_CONTROLFLOWINTERFACES +include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/OpBase.td" //===----------------------------------------------------------------------===// @@ -115,7 +116,8 @@ def BranchOpInterface : OpInterface<"BranchOpInterface"> { // RegionBranchOpInterface //===----------------------------------------------------------------------===// -def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> { +def RegionBranchOpInterface + : OpInterface<"RegionBranchOpInterface", [LocalEffectsOpInterface]> { let description = [{ This interface provides information for region operations that exhibit branching behavior between held regions. I.e., this interface allows for diff --git a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h index aef7ec622fe4f..fde49166b017c 100644 --- a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h +++ b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h @@ -457,6 +457,15 @@ bool isSpeculatable(Operation *op); /// This function is the C++ equivalent of the `Pure` trait. bool isPure(Operation *op); +//===----------------------------------------------------------------------===// +// LocalEffects Utilities +//===----------------------------------------------------------------------===// + +namespace detail { +/// Default implementation of `hasLocalEffects` method. +bool hasLocalEffectsDefaultImpl(Operation *op); +} // namespace detail + } // namespace mlir //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Interfaces/SideEffectInterfaces.td b/mlir/include/mlir/Interfaces/SideEffectInterfaces.td index b2ab4fee9d29c..fc676ce2ee9cb 100644 --- a/mlir/include/mlir/Interfaces/SideEffectInterfaces.td +++ b/mlir/include/mlir/Interfaces/SideEffectInterfaces.td @@ -145,4 +145,29 @@ def RecursivelySpeculatable : TraitList<[ // are always legal to hoist or sink. def Pure : TraitList<[AlwaysSpeculatable, NoMemoryEffect]>; +//===----------------------------------------------------------------------===// +// LocalEffects +//===----------------------------------------------------------------------===// + +// Interface which could be implemented by imperative operators that have no +// effects on state outside of what’s directly available through their operands +// (for example, they can’t access a `memref.global`, can’t make a call to +// another function that can potentially do so, can’t perform a +// synchronization/wait on other pending memory operations, etc.), including +// through operators in their regions. +def LocalEffectsOpInterface : OpInterface<"LocalEffectsOpInterface"> { + let description = [{An interface for operators which have no effects on state + outside of what's directly available through their own + operands or operands of the operators inside their regions. + }]; + let cppNamespace = "::mlir"; + + let methods = + [InterfaceMethod<[{ Returns true if operator has only local effects. }], + "bool", "hasLocalEffects", (ins), [{}], [{ + return mlir::detail::hasLocalEffectsDefaultImpl( + $_op.getOperation()); + }]>]; +} + #endif // MLIR_INTERFACES_SIDEEFFECTS diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp index ba6f045cff408..e4462dd310ccf 100644 --- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp @@ -252,6 +252,7 @@ bool MemRefDependenceGraph::init() { // Create graph nodes. DenseMap forToNodeMap; for (Operation &op : block) { + auto localEffectsOp = dyn_cast(op); if (auto forOp = dyn_cast(op)) { Node *node = addNodeToMDG(&op, *this, memrefAccesses); if (!node) @@ -277,27 +278,23 @@ bool MemRefDependenceGraph::init() { Node *node = addNodeToMDG(&op, *this, memrefAccesses); if (!node) return false; - } else if (!isMemoryEffectFree(&op) && - (op.getNumRegions() == 0 || isa(op))) { - // Create graph node for top-level op unless it is known to be - // memory-effect free. This covers all unknown/unregistered ops, - // non-affine ops with memory effects, and region-holding ops with a - // well-defined control flow. During the fusion validity checks, edges - // to/from these ops get looked at. + } else if (isMemoryEffectFree(&op)) { + // Do not create nodes for memory-effect free ops w/o uses. + ; + } else if (localEffectsOp && localEffectsOp.hasLocalEffects()) { + // Create graph node for top-level op which are known to have only local + // effects. Node *node = addNodeToMDG(&op, *this, memrefAccesses); if (!node) return false; - } else if (op.getNumRegions() != 0 && !isa(op)) { - // Return false if non-handled/unknown region-holding ops are found. We - // won't know what such ops do or what its regions mean; for e.g., it may - // not be an imperative op. - LLVM_DEBUG(llvm::dbgs() - << "MDG init failed; unknown region-holding op found!\n"); + } else { + // Return false if non-handled/unknown ops are found. We won't know what + // such ops do or what its regions mean; for e.g., it may not be an + // imperative op. + LLVM_DEBUG(llvm::dbgs() << "MDG init failed; unknown operator found:\n" + << op << "\n"); return false; } - // We aren't creating nodes for memory-effect free ops either with no - // regions (unless it has results being used) or those with branch op - // interface. } LLVM_DEBUG(llvm::dbgs() << "Created " << nodes.size() << " nodes\n"); diff --git a/mlir/lib/Interfaces/SideEffectInterfaces.cpp b/mlir/lib/Interfaces/SideEffectInterfaces.cpp index 59fd19310cea5..50d62a03e8536 100644 --- a/mlir/lib/Interfaces/SideEffectInterfaces.cpp +++ b/mlir/lib/Interfaces/SideEffectInterfaces.cpp @@ -399,3 +399,24 @@ bool mlir::isSpeculatable(Operation *op) { bool mlir::isPure(Operation *op) { return isSpeculatable(op) && isMemoryEffectFree(op); } + +//===----------------------------------------------------------------------===// +// LocalEffects Utilities +//===----------------------------------------------------------------------===// + +bool mlir::detail::hasLocalEffectsDefaultImpl(Operation *op) { + assert(isa(op) && + "Operator does not implement LocalEffectsOpInterface"); + + // Recurse into the regions and ensure that all nested ops have local effects. + for (auto ®ion : op->getRegions()) { + for (auto &nestedOp : region.getOps()) { + auto localEffectsOp = dyn_cast(nestedOp); + auto hasLocalEffects = localEffectsOp && localEffectsOp.hasLocalEffects(); + if (!isPure(&nestedOp) && !hasLocalEffects) { + return false; + } + } + } + return true; +} diff --git a/mlir/test/Dialect/Affine/loop-fusion-4.mlir b/mlir/test/Dialect/Affine/loop-fusion-4.mlir index b5b951ad5eb0e..69c7c250404e1 100644 --- a/mlir/test/Dialect/Affine/loop-fusion-4.mlir +++ b/mlir/test/Dialect/Affine/loop-fusion-4.mlir @@ -548,6 +548,36 @@ func.func @sibling_reduction(%input : memref<10xf32>, %output : memref<10xf32>, // ----- +// Check that presence of a Linalg operator in a block does not prevent +// fusion from happening in this block. + +// ALL-LABEL: func @fusion_in_block_containing_linalg +func.func @fusion_in_block_containing_linalg(%arg0: memref<5xi8>, %arg1: memref<5xi8>) { + %c15_i8 = arith.constant 15 : i8 + %alloc = memref.alloc() : memref<5xi8> + affine.for %arg3 = 0 to 5 { + affine.store %c15_i8, %alloc[%arg3] : memref<5xi8> + } + affine.for %arg3 = 0 to 5 { + %0 = affine.load %alloc[%arg3] : memref<5xi8> + %1 = affine.load %arg0[%arg3] : memref<5xi8> + %2 = arith.muli %0, %1 : i8 + affine.store %2, %alloc[%arg3] : memref<5xi8> + } + // ALL: affine.for + // ALL-NEXT: affine.store + // ALL-NEXT: affine.load + // ALL-NEXT: affine.load + // ALL-NEXT: arith.muli + // ALL-NEXT: affine.store + // ALL-NEXT: } + linalg.elemwise_binary ins(%alloc, %alloc: memref<5xi8>, memref<5xi8>) outs(%arg1: memref<5xi8>) + // ALL-NEXT: linalg.elemwise_binary + return +} + +// ----- + // From https://github.com/llvm/llvm-project/issues/54541 #map = affine_map<(d0) -> (d0 mod 65536)>