Skip to content

Commit 3ba7a87

Browse files
authored
[mlir][func]: Introduce ReplaceFuncSignature tranform operation (#143381)
This transform takes a module and a function name, and replaces the signature of the function by reordering the arguments and results according to the interchange arrays. The function is expected to be defined in the module, and the interchange arrays must match the number of arguments and results of the function.
1 parent 37eb465 commit 3ba7a87

File tree

9 files changed

+545
-2
lines changed

9 files changed

+545
-2
lines changed

mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===- FuncTransformOps.h - CF transformation ops --------*- C++ -*-===//
1+
//===- FuncTransformOps.h - Function transformation ops --------*- C++ -*--===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.

mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,4 +98,40 @@ def CastAndCallOp : Op<Transform_Dialect,
9898
let hasVerifier = 1;
9999
}
100100

101+
def ReplaceFuncSignatureOp
102+
: Op<Transform_Dialect, "func.replace_func_signature",
103+
[DeclareOpInterfaceMethods<TransformOpInterface>,
104+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
105+
let description = [{
106+
This transform takes a module and a function name, and replaces the
107+
signature of the function by reordering the arguments and results
108+
according to the interchange arrays. The function is expected to be
109+
defined in the module, and the interchange arrays must match the number
110+
of arguments and results of the function.
111+
112+
The `adjust_func_calls` attribute indicates whether the function calls
113+
should be adjusted to match the new signature. If set to `true`, the
114+
function calls will be adjusted to match the new signature, otherwise
115+
they will not be adjusted.
116+
117+
This transform will emit a silenceable failure if:
118+
- The function with the given name does not exist in the module.
119+
- The interchange arrays do not match the number of arguments/results.
120+
- The interchange arrays contain out of bound indices.
121+
}];
122+
123+
let arguments = (ins TransformHandleTypeInterface:$module,
124+
SymbolRefAttr:$function_name, DenseI32ArrayAttr:$args_interchange,
125+
DenseI32ArrayAttr:$results_interchange, UnitAttr:$adjust_func_calls);
126+
let results = (outs TransformHandleTypeInterface:$transformed_module,
127+
TransformHandleTypeInterface:$transformed_function);
128+
129+
let assemblyFormat = [{
130+
$function_name
131+
`args_interchange` `=` $args_interchange
132+
`results_interchange` `=` $results_interchange
133+
`at` $module attr-dict `:` functional-type(operands, results)
134+
}];
135+
}
136+
101137
#endif // FUNC_TRANSFORM_OPS
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
//===- Utils.h - General Func transformation utilities ----*- C++ -*-------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This header file defines prototypes for various transformation utilities for
10+
// the Func dialect. These are not passes by themselves but are used
11+
// either by passes, optimization sequences, or in turn by other transformation
12+
// utilities.
13+
//
14+
//===----------------------------------------------------------------------===//
15+
16+
#ifndef MLIR_DIALECT_FUNC_UTILS_H
17+
#define MLIR_DIALECT_FUNC_UTILS_H
18+
19+
#include "mlir/IR/PatternMatch.h"
20+
#include "llvm/ADT/ArrayRef.h"
21+
22+
namespace mlir {
23+
24+
namespace func {
25+
26+
class FuncOp;
27+
class CallOp;
28+
29+
/// Creates a new function operation with the same name as the original
30+
/// function operation, but with the arguments reordered according to
31+
/// the `newArgsOrder` and `newResultsOrder`.
32+
/// The `funcOp` operation must have exactly one block.
33+
/// Returns the new function operation or failure if `funcOp` doesn't
34+
/// have exactly one block.
35+
FailureOr<FuncOp>
36+
replaceFuncWithNewOrder(RewriterBase &rewriter, FuncOp funcOp,
37+
llvm::ArrayRef<unsigned> newArgsOrder,
38+
llvm::ArrayRef<unsigned> newResultsOrder);
39+
/// Creates a new call operation with the values as the original
40+
/// call operation, but with the arguments reordered according to
41+
/// the `newArgsOrder` and `newResultsOrder`.
42+
CallOp replaceCallOpWithNewOrder(RewriterBase &rewriter, CallOp callOp,
43+
llvm::ArrayRef<unsigned> newArgsOrder,
44+
llvm::ArrayRef<unsigned> newResultsOrder);
45+
46+
} // namespace func
47+
} // namespace mlir
48+
49+
#endif // MLIR_DIALECT_FUNC_UTILS_H

mlir/lib/Dialect/Func/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ add_subdirectory(Extensions)
22
add_subdirectory(IR)
33
add_subdirectory(Transforms)
44
add_subdirectory(TransformOps)
5+
add_subdirectory(Utils)

mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===- FuncTransformOps.cpp - Implementation of CF transform ops ---===//
1+
//===- FuncTransformOps.cpp - Implementation of CF transform ops ----------===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -11,10 +11,12 @@
1111
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
1212
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
1313
#include "mlir/Dialect/Func/IR/FuncOps.h"
14+
#include "mlir/Dialect/Func/Utils/Utils.h"
1415
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1516
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
1617
#include "mlir/Dialect/Transform/IR/TransformOps.h"
1718
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
19+
#include "mlir/IR/PatternMatch.h"
1820
#include "mlir/Transforms/DialectConversion.h"
1921

2022
using namespace mlir;
@@ -226,6 +228,109 @@ void transform::CastAndCallOp::getEffects(
226228
transform::modifiesPayload(effects);
227229
}
228230

231+
//===----------------------------------------------------------------------===//
232+
// ReplaceFuncSignatureOp
233+
//===----------------------------------------------------------------------===//
234+
235+
DiagnosedSilenceableFailure
236+
transform::ReplaceFuncSignatureOp::apply(transform::TransformRewriter &rewriter,
237+
transform::TransformResults &results,
238+
transform::TransformState &state) {
239+
auto payloadOps = state.getPayloadOps(getModule());
240+
if (!llvm::hasSingleElement(payloadOps))
241+
return emitDefiniteFailure() << "requires a single module to operate on";
242+
243+
auto targetModuleOp = dyn_cast<ModuleOp>(*payloadOps.begin());
244+
if (!targetModuleOp)
245+
return emitSilenceableFailure(getLoc())
246+
<< "target is expected to be module operation";
247+
248+
func::FuncOp funcOp =
249+
targetModuleOp.lookupSymbol<func::FuncOp>(getFunctionName());
250+
if (!funcOp)
251+
return emitSilenceableFailure(getLoc())
252+
<< "function with name '" << getFunctionName() << "' not found";
253+
254+
unsigned numArgs = funcOp.getNumArguments();
255+
unsigned numResults = funcOp.getNumResults();
256+
// Check that the number of arguments and results matches the
257+
// interchange sizes.
258+
if (numArgs != getArgsInterchange().size())
259+
return emitSilenceableFailure(getLoc())
260+
<< "function with name '" << getFunctionName() << "' has " << numArgs
261+
<< " arguments, but " << getArgsInterchange().size()
262+
<< " args interchange were given";
263+
264+
if (numResults != getResultsInterchange().size())
265+
return emitSilenceableFailure(getLoc())
266+
<< "function with name '" << getFunctionName() << "' has "
267+
<< numResults << " results, but " << getResultsInterchange().size()
268+
<< " results interchange were given";
269+
270+
// Check that the args and results interchanges are unique.
271+
SetVector<unsigned> argsInterchange, resultsInterchange;
272+
argsInterchange.insert_range(getArgsInterchange());
273+
resultsInterchange.insert_range(getResultsInterchange());
274+
if (argsInterchange.size() != getArgsInterchange().size())
275+
return emitSilenceableFailure(getLoc())
276+
<< "args interchange must be unique";
277+
278+
if (resultsInterchange.size() != getResultsInterchange().size())
279+
return emitSilenceableFailure(getLoc())
280+
<< "results interchange must be unique";
281+
282+
// Check that the args and results interchange indices are in bounds.
283+
for (unsigned index : argsInterchange) {
284+
if (index >= numArgs) {
285+
return emitSilenceableFailure(getLoc())
286+
<< "args interchange index " << index
287+
<< " is out of bounds for function with name '"
288+
<< getFunctionName() << "' with " << numArgs << " arguments";
289+
}
290+
}
291+
for (unsigned index : resultsInterchange) {
292+
if (index >= numResults) {
293+
return emitSilenceableFailure(getLoc())
294+
<< "results interchange index " << index
295+
<< " is out of bounds for function with name '"
296+
<< getFunctionName() << "' with " << numResults << " results";
297+
}
298+
}
299+
300+
FailureOr<func::FuncOp> newFuncOpOrFailure = func::replaceFuncWithNewOrder(
301+
rewriter, funcOp, argsInterchange.getArrayRef(),
302+
resultsInterchange.getArrayRef());
303+
if (failed(newFuncOpOrFailure))
304+
return emitSilenceableFailure(getLoc())
305+
<< "failed to replace function signature '" << getFunctionName()
306+
<< "' with new order";
307+
308+
if (getAdjustFuncCalls()) {
309+
SmallVector<func::CallOp> callOps;
310+
targetModuleOp.walk([&](func::CallOp callOp) {
311+
if (callOp.getCallee() == getFunctionName().getRootReference().getValue())
312+
callOps.push_back(callOp);
313+
});
314+
315+
for (func::CallOp callOp : callOps)
316+
func::replaceCallOpWithNewOrder(rewriter, callOp,
317+
argsInterchange.getArrayRef(),
318+
resultsInterchange.getArrayRef());
319+
}
320+
321+
results.set(cast<OpResult>(getTransformedModule()), {targetModuleOp});
322+
results.set(cast<OpResult>(getTransformedFunction()), {*newFuncOpOrFailure});
323+
324+
return DiagnosedSilenceableFailure::success();
325+
}
326+
327+
void transform::ReplaceFuncSignatureOp::getEffects(
328+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
329+
transform::consumesHandle(getModuleMutable(), effects);
330+
transform::producesHandle(getOperation()->getOpResults(), effects);
331+
transform::modifiesPayload(effects);
332+
}
333+
229334
//===----------------------------------------------------------------------===//
230335
// Transform op registration
231336
//===----------------------------------------------------------------------===//
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
add_mlir_dialect_library(MLIRFuncUtils
2+
Utils.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Func/Utils
6+
7+
LINK_LIBS PUBLIC
8+
MLIRFuncDialect
9+
MLIRDialect
10+
MLIRDialectUtils
11+
MLIRIR
12+
)

mlir/lib/Dialect/Func/Utils/Utils.cpp

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
//===- Utils.cpp - Utilities to support the Func dialect ------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements utilities for the Func dialect.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Dialect/Func/Utils/Utils.h"
14+
#include "mlir/Dialect/Func/IR/FuncOps.h"
15+
#include "mlir/IR/IRMapping.h"
16+
#include "mlir/IR/PatternMatch.h"
17+
#include "llvm/ADT/SmallVector.h"
18+
19+
using namespace mlir;
20+
21+
FailureOr<func::FuncOp>
22+
func::replaceFuncWithNewOrder(RewriterBase &rewriter, func::FuncOp funcOp,
23+
ArrayRef<unsigned> newArgsOrder,
24+
ArrayRef<unsigned> newResultsOrder) {
25+
// Generate an empty new function operation with the same name as the
26+
// original.
27+
assert(funcOp.getNumArguments() == newArgsOrder.size() &&
28+
"newArgsOrder must match the number of arguments in the function");
29+
assert(funcOp.getNumResults() == newResultsOrder.size() &&
30+
"newResultsOrder must match the number of results in the function");
31+
32+
if (!funcOp.getBody().hasOneBlock())
33+
return rewriter.notifyMatchFailure(
34+
funcOp, "expected function to have exactly one block");
35+
36+
ArrayRef<Type> origInputTypes = funcOp.getFunctionType().getInputs();
37+
ArrayRef<Type> origOutputTypes = funcOp.getFunctionType().getResults();
38+
SmallVector<Type> newInputTypes, newOutputTypes;
39+
SmallVector<Location> locs;
40+
for (unsigned int idx : newArgsOrder) {
41+
newInputTypes.push_back(origInputTypes[idx]);
42+
locs.push_back(funcOp.getArgument(newArgsOrder[idx]).getLoc());
43+
}
44+
for (unsigned int idx : newResultsOrder)
45+
newOutputTypes.push_back(origOutputTypes[idx]);
46+
rewriter.setInsertionPoint(funcOp);
47+
auto newFuncOp = rewriter.create<func::FuncOp>(
48+
funcOp.getLoc(), funcOp.getName(),
49+
rewriter.getFunctionType(newInputTypes, newOutputTypes));
50+
51+
Region &newRegion = newFuncOp.getBody();
52+
rewriter.createBlock(&newRegion, newRegion.begin(), newInputTypes, locs);
53+
newFuncOp.setVisibility(funcOp.getVisibility());
54+
newFuncOp->setDiscardableAttrs(funcOp->getDiscardableAttrDictionary());
55+
56+
// Map the arguments of the original function to the new function in
57+
// the new order and adjust the attributes accordingly.
58+
IRMapping operandMapper;
59+
SmallVector<DictionaryAttr> argAttrs, resultAttrs;
60+
funcOp.getAllArgAttrs(argAttrs);
61+
for (unsigned int i = 0; i < newArgsOrder.size(); ++i) {
62+
operandMapper.map(funcOp.getArgument(newArgsOrder[i]),
63+
newFuncOp.getArgument(i));
64+
newFuncOp.setArgAttrs(i, argAttrs[newArgsOrder[i]]);
65+
}
66+
funcOp.getAllResultAttrs(resultAttrs);
67+
for (unsigned int i = 0; i < newResultsOrder.size(); ++i)
68+
newFuncOp.setResultAttrs(i, resultAttrs[newResultsOrder[i]]);
69+
70+
// Clone the operations from the original function to the new function.
71+
rewriter.setInsertionPointToStart(&newFuncOp.getBody().front());
72+
for (Operation &op : funcOp.getOps())
73+
rewriter.clone(op, operandMapper);
74+
75+
// Handle the return operation.
76+
auto returnOp = cast<func::ReturnOp>(
77+
newFuncOp.getFunctionBody().begin()->getTerminator());
78+
SmallVector<Value> newReturnValues;
79+
for (unsigned int idx : newResultsOrder)
80+
newReturnValues.push_back(returnOp.getOperand(idx));
81+
rewriter.setInsertionPoint(returnOp);
82+
auto newReturnOp =
83+
rewriter.create<func::ReturnOp>(newFuncOp.getLoc(), newReturnValues);
84+
newReturnOp->setDiscardableAttrs(returnOp->getDiscardableAttrDictionary());
85+
rewriter.eraseOp(returnOp);
86+
87+
rewriter.eraseOp(funcOp);
88+
89+
return newFuncOp;
90+
}
91+
92+
func::CallOp
93+
func::replaceCallOpWithNewOrder(RewriterBase &rewriter, func::CallOp callOp,
94+
ArrayRef<unsigned> newArgsOrder,
95+
ArrayRef<unsigned> newResultsOrder) {
96+
assert(
97+
callOp.getNumOperands() == newArgsOrder.size() &&
98+
"newArgsOrder must match the number of operands in the call operation");
99+
assert(
100+
callOp.getNumResults() == newResultsOrder.size() &&
101+
"newResultsOrder must match the number of results in the call operation");
102+
SmallVector<Value> newArgsOrderValues;
103+
for (unsigned int argIdx : newArgsOrder)
104+
newArgsOrderValues.push_back(callOp.getOperand(argIdx));
105+
SmallVector<Type> newResultTypes;
106+
for (unsigned int resIdx : newResultsOrder)
107+
newResultTypes.push_back(callOp.getResult(resIdx).getType());
108+
109+
// Replace the kernel call operation with a new one that has the
110+
// reordered arguments.
111+
rewriter.setInsertionPoint(callOp);
112+
auto newCallOp = rewriter.create<func::CallOp>(
113+
callOp.getLoc(), callOp.getCallee(), newResultTypes, newArgsOrderValues);
114+
newCallOp.setNoInlineAttr(callOp.getNoInlineAttr());
115+
for (auto &&[newIndex, origIndex] : llvm::enumerate(newResultsOrder))
116+
rewriter.replaceAllUsesWith(callOp.getResult(origIndex),
117+
newCallOp.getResult(newIndex));
118+
rewriter.eraseOp(callOp);
119+
120+
return newCallOp;
121+
}

0 commit comments

Comments
 (0)