Skip to content

Commit b6c58ec

Browse files
committed
[mlir] add producer fusion to structured transform ops
This relies on the existing TileAndFuse pattern for tensor-based structured ops. It complements pure tiling, from which some utilities are generalized. Depends On D127300 Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D127319
1 parent 942f4e3 commit b6c58ec

File tree

4 files changed

+223
-98
lines changed

4 files changed

+223
-98
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,25 @@ def DecomposeOp : Op<Transform_Dialect, "structured.decompose",
3737
}];
3838
}
3939

40+
def FuseOp : Op<Transform_Dialect, "structured.fuse",
41+
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
42+
DeclareOpInterfaceMethods<TransformOpInterface>]> {
43+
let description = [{
44+
Tiles the operations pointed to by the target handle and fuses their
45+
producers greedily using the options provided as attributes.
46+
}];
47+
48+
let arguments =
49+
(ins PDL_Operation:$target,
50+
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes,
51+
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_interchange);
52+
let results = (outs PDL_Operation:$transformed,
53+
Variadic<PDL_Operation>:$loops);
54+
55+
let hasCustomAssemblyFormat = 1;
56+
let hasVerifier = 1;
57+
}
58+
4059
def GeneralizeOp : Op<Transform_Dialect, "structured.generalize",
4160
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
4261
TransformOpInterface, TransformEachOpTrait]> {
@@ -136,7 +155,7 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
136155

137156
def TileOp : Op<Transform_Dialect, "structured.tile",
138157
[DeclareOpInterfaceMethods<TransformOpInterface>,
139-
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
158+
FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface]> {
140159
let description = [{
141160
Indicates that the given `target` op should be tiled with the options
142161
provided as attributes. This transform generates a loop nest with a smaller

mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -451,29 +451,26 @@ struct PayloadIRResource
451451
StringRef getName() override { return "transform.payload_ir"; }
452452
};
453453

454-
/// Trait implementing the MemoryEffectOpInterface for single-operand zero- or
455-
/// single-result operations that "consume" their operand and produce a new
456-
/// result.
454+
/// Trait implementing the MemoryEffectOpInterface for single-operand operations
455+
/// that "consume" their operand and produce a new result.
457456
template <typename OpTy>
458457
class FunctionalStyleTransformOpTrait
459458
: public OpTrait::TraitBase<OpTy, FunctionalStyleTransformOpTrait> {
460459
public:
461460
/// This op "consumes" the operand by reading and freeing it, "produces" the
462-
/// result by allocating and writing it and reads/writes the payload IR in the
463-
/// process.
461+
/// results by allocating and writing it and reads/writes the payload IR in
462+
/// the process.
464463
void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
465464
effects.emplace_back(MemoryEffects::Read::get(),
466465
this->getOperation()->getOperand(0),
467466
TransformMappingResource::get());
468467
effects.emplace_back(MemoryEffects::Free::get(),
469468
this->getOperation()->getOperand(0),
470469
TransformMappingResource::get());
471-
if (this->getOperation()->getNumResults() == 1) {
472-
effects.emplace_back(MemoryEffects::Allocate::get(),
473-
this->getOperation()->getResult(0),
470+
for (Value result : this->getOperation()->getResults()) {
471+
effects.emplace_back(MemoryEffects::Allocate::get(), result,
474472
TransformMappingResource::get());
475-
effects.emplace_back(MemoryEffects::Write::get(),
476-
this->getOperation()->getResult(0),
473+
effects.emplace_back(MemoryEffects::Write::get(), result,
477474
TransformMappingResource::get());
478475
}
479476
effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
@@ -484,9 +481,6 @@ class FunctionalStyleTransformOpTrait
484481
static LogicalResult verifyTrait(Operation *op) {
485482
static_assert(OpTy::template hasTrait<OpTrait::OneOperand>(),
486483
"expected single-operand op");
487-
static_assert(OpTy::template hasTrait<OpTrait::ZeroResults>() ||
488-
OpTy::template hasTrait<OpTrait::OneResult>(),
489-
"expected zero- or single-result op");
490484
if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
491485
op->emitError()
492486
<< "FunctionalStyleTransformOpTrait should only be attached to ops "

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 126 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,130 @@ FailureOr<LinalgOp> transform::DecomposeOp::applyToOne(LinalgOp target) {
9292
return reportUnknownTransformError(target);
9393
}
9494

95+
//===----------------------------------------------------------------------===//
96+
// FuseOp
97+
//===----------------------------------------------------------------------===//
98+
99+
/// Apply a tiling transformation to all payload ops and store both the
100+
/// tiled operation as well as the created tile loops.
101+
static LogicalResult
102+
applyTilingToAll(Operation *transformOp, Value target,
103+
ArrayRef<int64_t> tileSizes,
104+
transform::TransformResults &transformResults,
105+
transform::TransformState &state,
106+
function_ref<FailureOr<TiledLinalgOp>(LinalgOp)> applyFn) {
107+
// Number of loops: Number of tiles sizes that are not zero.
108+
size_t numLoops = tileSizes.size() - llvm::count(tileSizes, 0);
109+
// All payload ops. These should all be LinalgOps for now.
110+
ArrayRef<Operation *> payloadOps = state.getPayloadOps(target);
111+
112+
SmallVector<Operation *> tiledLinalgOps;
113+
SmallVector<SmallVector<Operation *>> loopOps(numLoops);
114+
for (unsigned int i = 0; i < numLoops; ++i)
115+
loopOps[i].reserve(payloadOps.size());
116+
117+
for (Operation *target : payloadOps) {
118+
auto linalgOp = dyn_cast<linalg::LinalgOp>(target);
119+
if (!linalgOp)
120+
return transformOp->emitError("only LinalgOps are supported");
121+
122+
FailureOr<TiledLinalgOp> tiled = applyFn(linalgOp);
123+
if (failed(tiled))
124+
return failure();
125+
126+
tiledLinalgOps.push_back(tiled->op);
127+
if (tiled->loops.size() != numLoops)
128+
// Not enough loops were generated. This usually means that the input size
129+
// was smaller than the tiling size.
130+
// TODO: LinalgTilingPattern should return failure().
131+
return failure();
132+
for (unsigned int i = 0; i < numLoops; ++i)
133+
loopOps[i].push_back(tiled->loops[i]);
134+
}
135+
136+
transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
137+
for (unsigned int i = 0; i < numLoops; ++i)
138+
transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
139+
return success();
140+
}
141+
142+
/// Parse a tiling-like operation that returns the tiled op as well as the
143+
/// created tile loops. The function counts the non-zero tile sizes to compute
144+
/// the number of results.
145+
static ParseResult parseTileLikeOp(OpAsmParser &parser, OperationState &result,
146+
StringRef sizesAttrName) {
147+
OpAsmParser::UnresolvedOperand targetOperand;
148+
SMLoc opLoc = parser.getCurrentLocation();
149+
if (parser.parseOperand(targetOperand) ||
150+
parser.parseOptionalAttrDict(result.attributes))
151+
return failure();
152+
Attribute sizesAttr = result.attributes.get(sizesAttrName);
153+
if (!sizesAttr)
154+
return parser.emitError(opLoc)
155+
<< "expected '" << sizesAttrName << "' attribute";
156+
auto sizesArrayAttr = sizesAttr.dyn_cast<ArrayAttr>();
157+
if (!sizesArrayAttr)
158+
return parser.emitError(opLoc)
159+
<< "'" << sizesAttrName << "' attribute must be an array";
160+
Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>();
161+
size_t numExpectedLoops =
162+
sizesArrayAttr.size() - llvm::count(extractI64Array(sizesArrayAttr), 0);
163+
result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOpType));
164+
if (parser.resolveOperand(targetOperand, pdlOpType, result.operands))
165+
return failure();
166+
return success();
167+
}
168+
169+
LogicalResult
170+
transform::FuseOp::apply(mlir::transform::TransformResults &transformResults,
171+
mlir::transform::TransformState &state) {
172+
LinalgTilingAndFusionOptions fusionOptions;
173+
fusionOptions.tileSizes = extractI64Array(getTileSizes());
174+
fusionOptions.tileInterchange = extractI64Array(getTileInterchange());
175+
176+
return applyTilingToAll(
177+
getOperation(), getTarget(), fusionOptions.tileSizes, transformResults,
178+
state, [&](LinalgOp linalgOp) -> FailureOr<TiledLinalgOp> {
179+
LinalgTileAndFuseTensorOpsPattern pattern(getContext(), fusionOptions);
180+
SimpleRewriter rewriter(getContext());
181+
rewriter.setInsertionPoint(linalgOp);
182+
FailureOr<TileLoopNest> tileLoopNest =
183+
pattern.returningMatchAndRewrite(linalgOp, rewriter);
184+
if (failed(tileLoopNest))
185+
return failure();
186+
187+
TiledLinalgOp tiledLinalgOp;
188+
tiledLinalgOp.op = tileLoopNest->getRootOp();
189+
tiledLinalgOp.loops = {tileLoopNest->getLoopOps().begin(),
190+
tileLoopNest->getLoopOps().end()};
191+
return tiledLinalgOp;
192+
});
193+
}
194+
195+
ParseResult transform::FuseOp::parse(OpAsmParser &parser,
196+
OperationState &result) {
197+
return parseTileLikeOp(
198+
parser, result,
199+
transform::FuseOp::getTileSizesAttrName(result.name).getValue());
200+
}
201+
202+
void transform::FuseOp::print(OpAsmPrinter &p) {
203+
p << ' ';
204+
p << getTarget();
205+
p.printOptionalAttrDict((*this)->getAttrs());
206+
}
207+
208+
LogicalResult transform::FuseOp::verify() {
209+
SmallVector<int64_t> permutation = extractI64Array(getTileInterchange());
210+
auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
211+
if (!std::is_permutation(sequence.begin(), sequence.end(),
212+
permutation.begin(), permutation.end())) {
213+
return emitOpError() << "expects interchange to be a permutation, found "
214+
<< getTileInterchange();
215+
}
216+
return success();
217+
}
218+
95219
//===----------------------------------------------------------------------===//
96220
// GeneralizeOp
97221
//===----------------------------------------------------------------------===//
@@ -274,49 +398,6 @@ FailureOr<LinalgOp> transform::ScalarizeOp::applyToOne(LinalgOp target) {
274398
// TileOp
275399
//===----------------------------------------------------------------------===//
276400

277-
/// Apply a tiling transformation to all payload ops and store both the
278-
/// tiled operation as well as the created tile loops.
279-
static LogicalResult
280-
applyTilingToAll(Operation *transformOp, Value target,
281-
ArrayRef<int64_t> tileSizes,
282-
transform::TransformResults &transformResults,
283-
transform::TransformState &state,
284-
function_ref<FailureOr<TiledLinalgOp>(LinalgOp)> applyFn) {
285-
// Number of loops: Number of tiles sizes that are not zero.
286-
size_t numLoops = tileSizes.size() - llvm::count(tileSizes, 0);
287-
// All payload ops. These should all be LinalgOps for now.
288-
ArrayRef<Operation *> payloadOps = state.getPayloadOps(target);
289-
290-
SmallVector<Operation *> tiledLinalgOps;
291-
SmallVector<SmallVector<Operation *>> loopOps(numLoops);
292-
for (unsigned int i = 0; i < numLoops; ++i)
293-
loopOps[i].reserve(payloadOps.size());
294-
295-
for (Operation *target : payloadOps) {
296-
auto linalgOp = dyn_cast<linalg::LinalgOp>(target);
297-
if (!linalgOp)
298-
return transformOp->emitError("only LinalgOps are supported");
299-
300-
FailureOr<TiledLinalgOp> tiled = applyFn(linalgOp);
301-
if (failed(tiled))
302-
return failure();
303-
304-
tiledLinalgOps.push_back(tiled->op);
305-
if (tiled->loops.size() != numLoops)
306-
// Not enough loops were generated. This usually means that the input size
307-
// was smaller than the tiling size.
308-
// TODO: LinalgTilingPattern should return failure().
309-
return failure();
310-
for (unsigned int i = 0; i < numLoops; ++i)
311-
loopOps[i].push_back(tiled->loops[i]);
312-
}
313-
314-
transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
315-
for (unsigned int i = 0; i < numLoops; ++i)
316-
transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
317-
return success();
318-
}
319-
320401
LogicalResult transform::TileOp::apply(TransformResults &transformResults,
321402
TransformState &state) {
322403
LinalgTilingOptions tilingOptions;
@@ -337,27 +418,8 @@ LogicalResult transform::TileOp::apply(TransformResults &transformResults,
337418

338419
ParseResult transform::TileOp::parse(OpAsmParser &parser,
339420
OperationState &result) {
340-
StringRef sizesAttrName = TileOp::getSizesAttrName(result.name).getValue();
341-
OpAsmParser::UnresolvedOperand targetOperand;
342-
SMLoc opLoc = parser.getCurrentLocation();
343-
if (parser.parseOperand(targetOperand) ||
344-
parser.parseOptionalAttrDict(result.attributes))
345-
return failure();
346-
Attribute sizesAttr = result.attributes.get(sizesAttrName);
347-
if (!sizesAttr)
348-
return parser.emitError(opLoc)
349-
<< "expected '" << sizesAttrName << "' attribute";
350-
auto sizesArrayAttr = sizesAttr.dyn_cast<ArrayAttr>();
351-
if (!sizesArrayAttr)
352-
return parser.emitError(opLoc)
353-
<< "'" << sizesAttrName << "' attribute must be an array";
354-
Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>();
355-
size_t numExpectedLoops =
356-
sizesArrayAttr.size() - llvm::count(extractI64Array(sizesArrayAttr), 0);
357-
result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOpType));
358-
if (parser.resolveOperand(targetOperand, pdlOpType, result.operands))
359-
return failure();
360-
return success();
421+
return parseTileLikeOp(parser, result,
422+
TileOp::getSizesAttrName(result.name).getValue());
361423
}
362424

363425
void TileOp::print(OpAsmPrinter &p) {
@@ -366,26 +428,6 @@ void TileOp::print(OpAsmPrinter &p) {
366428
p.printOptionalAttrDict((*this)->getAttrs());
367429
}
368430

369-
void TileOp::getEffects(
370-
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
371-
&effects) {
372-
// `target` arg is consumed and can no longer be used.
373-
effects.emplace_back(MemoryEffects::Read::get(), getTarget(),
374-
TransformMappingResource::get());
375-
effects.emplace_back(MemoryEffects::Free::get(), getTarget(),
376-
TransformMappingResource::get());
377-
378-
for (Value r : getResults()) {
379-
effects.emplace_back(MemoryEffects::Write::get(), r,
380-
TransformMappingResource::get());
381-
effects.emplace_back(MemoryEffects::Allocate::get(), r,
382-
TransformMappingResource::get());
383-
}
384-
385-
effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
386-
effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get());
387-
}
388-
389431
//===----------------------------------------------------------------------===//
390432
// VectorizeOp
391433
//===----------------------------------------------------------------------===//
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s
2+
3+
// CHECK-LABEL: func.func @fuse_unary
4+
func.func @fuse_unary(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
5+
6+
// CHECK: scf.for
7+
// CHECK: scf.for
8+
// CHECK: linalg.elemwise_unary
9+
// CHECK: linalg.elemwise_binary
10+
%0 = linalg.elemwise_unary ins(%arg0 : tensor<?x?xf32>)
11+
outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
12+
%1 = linalg.elemwise_binary ins(%0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>)
13+
outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
14+
return %1 : tensor<?x?xf32>
15+
}
16+
17+
transform.with_pdl_patterns {
18+
^bb0(%arg0: !pdl.operation):
19+
pdl.pattern @pdl_target : benefit(1) {
20+
%args = operands
21+
%results = types
22+
%0 = pdl.operation "linalg.elemwise_binary"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
23+
// TODO: we don't want this, but it is the required terminator for pdl.pattern
24+
rewrite %0 with "transform.dialect"
25+
}
26+
27+
transform.sequence %arg0 {
28+
^bb1(%arg1: !pdl.operation):
29+
%0 = pdl_match @pdl_target in %arg1
30+
%1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1]}
31+
}
32+
}
33+
34+
// -----
35+
36+
// CHECK-LABEL: func.func @fuse_unary
37+
func.func @fuse_unary(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
38+
39+
// CHECK: scf.for
40+
// CHECK: scf.for
41+
// CHECK: linalg.elemwise_unary
42+
// CHECK: linalg.elemwise_binary
43+
// CHECK: scf.for
44+
// CHECK: scf.for
45+
// CHECK: linalg.elemwise_unary
46+
// CHECK: linalg.elemwise_binary
47+
%0 = linalg.elemwise_unary ins(%arg0 : tensor<?x?xf32>)
48+
outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
49+
%1 = linalg.elemwise_binary ins(%0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>)
50+
outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
51+
return %1 : tensor<?x?xf32>
52+
}
53+
54+
transform.with_pdl_patterns {
55+
^bb0(%arg0: !pdl.operation):
56+
pdl.pattern @pdl_target : benefit(1) {
57+
%args = operands
58+
%results = types
59+
%0 = pdl.operation "linalg.elemwise_binary"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
60+
// TODO: we don't want this, but it is the required terminator for pdl.pattern
61+
rewrite %0 with "transform.dialect"
62+
}
63+
64+
transform.sequence %arg0 {
65+
^bb1(%arg1: !pdl.operation):
66+
%0 = pdl_match @pdl_target in %arg1
67+
%1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1]}
68+
transform.loop.peel %loops#0
69+
}
70+
}

0 commit comments

Comments
 (0)