Skip to content

Commit 327d627

Browse files
authored
[mlir] share argument attributes interface between calls and callables (#123176)
This patch shares core interface methods dealing with argument and result attributes from CallableOpInterface with the CallOpInterface and makes them mandatory to gives more consistent guarantees about concrete operations using these interfaces. This allows adding argument attributes on call like operations, which is sometimes required to get proper ABI, like with llvm.call (and llvm.invoke). The patch adds optional `arg_attrs` and `res_attrs` attributes to operations using these interfaces that did not have that already. They can then re-use the common "rich function signature" printing/parsing helpers if they want (for the LLVM dialect, this is done in the next patch). Part of RFC: https://discourse.llvm.org/t/mlir-rfc-adding-argument-and-result-attributes-to-llvm-call/84107
1 parent 8f025f2 commit 327d627

File tree

32 files changed

+452
-256
lines changed

32 files changed

+452
-256
lines changed

flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,9 @@ def cuf_KernelLaunchOp : cuf_Op<"kernel_launch", [CallOpInterface,
207207
I32:$block_z,
208208
Optional<I32>:$bytes,
209209
Optional<I32>:$stream,
210-
Variadic<AnyType>:$args
210+
Variadic<AnyType>:$args,
211+
OptionalAttr<DictArrayAttr>:$arg_attrs,
212+
OptionalAttr<DictArrayAttr>:$res_attrs
211213
);
212214

213215
let assemblyFormat = [{

flang/include/flang/Optimizer/Dialect/FIROps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2432,6 +2432,8 @@ def fir_CallOp : fir_Op<"call",
24322432
let arguments = (ins
24332433
OptionalAttr<SymbolRefAttr>:$callee,
24342434
Variadic<AnyType>:$args,
2435+
OptionalAttr<DictArrayAttr>:$arg_attrs,
2436+
OptionalAttr<DictArrayAttr>:$res_attrs,
24352437
OptionalAttr<fir_FortranProcedureFlagsAttr>:$procedure_attrs,
24362438
DefaultValuedAttr<Arith_FastMathAttr,
24372439
"::mlir::arith::FastMathFlags::none">:$fastmath
@@ -2518,6 +2520,8 @@ def fir_DispatchOp : fir_Op<"dispatch", []> {
25182520
fir_ClassType:$object,
25192521
Variadic<AnyType>:$args,
25202522
OptionalAttr<I32Attr>:$pass_arg_pos,
2523+
OptionalAttr<DictArrayAttr>:$arg_attrs,
2524+
OptionalAttr<DictArrayAttr>:$res_attrs,
25212525
OptionalAttr<fir_FortranProcedureFlagsAttr>:$procedure_attrs
25222526
);
25232527

flang/lib/Lower/ConvertCall.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,8 @@ Fortran::lower::genCallOpAndResult(
594594

595595
builder.create<cuf::KernelLaunchOp>(
596596
loc, funcType.getResults(), funcSymbolAttr, grid_x, grid_y, grid_z,
597-
block_x, block_y, block_z, bytes, stream, operands);
597+
block_x, block_y, block_z, bytes, stream, operands,
598+
/*arg_attrs=*/nullptr, /*res_attrs=*/nullptr);
598599
callNumResults = 0;
599600
} else if (caller.requireDispatchCall()) {
600601
// Procedure call requiring a dynamic dispatch. Call is created with
@@ -621,7 +622,8 @@ Fortran::lower::genCallOpAndResult(
621622
dispatch = builder.create<fir::DispatchOp>(
622623
loc, funcType.getResults(), builder.getStringAttr(procName),
623624
caller.getInputs()[*passArg], operands,
624-
builder.getI32IntegerAttr(*passArg), procAttrs);
625+
builder.getI32IntegerAttr(*passArg), /*arg_attrs=*/nullptr,
626+
/*res_attrs=*/nullptr, procAttrs);
625627
} else {
626628
// NOPASS
627629
const Fortran::evaluate::Component *component =
@@ -636,15 +638,17 @@ Fortran::lower::genCallOpAndResult(
636638
passObject = builder.create<fir::LoadOp>(loc, passObject);
637639
dispatch = builder.create<fir::DispatchOp>(
638640
loc, funcType.getResults(), builder.getStringAttr(procName),
639-
passObject, operands, nullptr, procAttrs);
641+
passObject, operands, nullptr, /*arg_attrs=*/nullptr,
642+
/*res_attrs=*/nullptr, procAttrs);
640643
}
641644
callNumResults = dispatch.getNumResults();
642645
if (callNumResults != 0)
643646
callResult = dispatch.getResult(0);
644647
} else {
645648
// Standard procedure call with fir.call.
646649
auto call = builder.create<fir::CallOp>(
647-
loc, funcType.getResults(), funcSymbolAttr, operands, procAttrs);
650+
loc, funcType.getResults(), funcSymbolAttr, operands,
651+
/*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, procAttrs);
648652

649653
callNumResults = call.getNumResults();
650654
if (callNumResults != 0)

flang/lib/Optimizer/CodeGen/TargetRewrite.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
518518
newOpers.insert(newOpers.end(), trailingOpers.begin(), trailingOpers.end());
519519

520520
llvm::SmallVector<mlir::Value, 1> newCallResults;
521+
// TODO propagate/update call argument and result attributes.
521522
if constexpr (std::is_same_v<std::decay_t<A>, mlir::gpu::LaunchFuncOp>) {
522523
auto newCall = rewriter->create<A>(
523524
loc, callOp.getKernel(), callOp.getGridSizeOperandValues(),
@@ -557,6 +558,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
557558
loc, newResTys, rewriter->getStringAttr(callOp.getMethod()),
558559
callOp.getOperands()[0], newOpers,
559560
rewriter->getI32IntegerAttr(*callOp.getPassArgPos() + passArgShift),
561+
/*arg_attrs=*/nullptr, /*res_attrs=*/nullptr,
560562
callOp.getProcedureAttrsAttr());
561563
if (wrap)
562564
newCallResults.push_back((*wrap)(dispatchOp.getOperation()));

flang/lib/Optimizer/Transforms/AbstractResult.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ class CallConversion : public mlir::OpRewritePattern<Op> {
147147
newResultTypes.emplace_back(getVoidPtrType(result.getContext()));
148148

149149
Op newOp;
150+
// TODO: propagate argument and result attributes (need to be shifted).
150151
// fir::CallOp specific handling.
151152
if constexpr (std::is_same_v<Op, fir::CallOp>) {
152153
if (op.getCallee()) {
@@ -189,9 +190,11 @@ class CallConversion : public mlir::OpRewritePattern<Op> {
189190
if (op.getPassArgPos())
190191
passArgPos =
191192
rewriter.getI32IntegerAttr(*op.getPassArgPos() + passArgShift);
193+
// TODO: propagate argument and result attributes (need to be shifted).
192194
newOp = rewriter.create<fir::DispatchOp>(
193195
loc, newResultTypes, rewriter.getStringAttr(op.getMethod()),
194196
op.getOperands()[0], newOperands, passArgPos,
197+
/*arg_attrs=*/nullptr, /*res_attrs=*/nullptr,
195198
op.getProcedureAttrsAttr());
196199
}
197200

flang/lib/Optimizer/Transforms/PolymorphicOpConversion.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,8 +205,9 @@ struct DispatchOpConv : public OpConversionPattern<fir::DispatchOp> {
205205
// Make the call.
206206
llvm::SmallVector<mlir::Value> args{funcPtr};
207207
args.append(dispatch.getArgs().begin(), dispatch.getArgs().end());
208-
rewriter.replaceOpWithNewOp<fir::CallOp>(dispatch, resTypes, nullptr, args,
209-
dispatch.getProcedureAttrsAttr());
208+
rewriter.replaceOpWithNewOp<fir::CallOp>(
209+
dispatch, resTypes, nullptr, args, dispatch.getArgAttrsAttr(),
210+
dispatch.getResAttrsAttr(), dispatch.getProcedureAttrsAttr());
210211
return mlir::success();
211212
}
212213

mlir/docs/Interfaces.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -753,10 +753,15 @@ interface section goes as follows:
753753
- (`C++ class` -- `ODS class`(if applicable))
754754

755755
##### CallInterfaces
756-
757756
* `CallOpInterface` - Used to represent operations like 'call'
758757
- `CallInterfaceCallable getCallableForCallee()`
759758
- `void setCalleeFromCallable(CallInterfaceCallable)`
759+
- `ArrayAttr getArgAttrsAttr()`
760+
- `ArrayAttr getResAttrsAttr()`
761+
- `void setArgAttrsAttr(ArrayAttr)`
762+
- `void setResAttrsAttr(ArrayAttr)`
763+
- `Attribute removeArgAttrsAttr()`
764+
- `Attribute removeResAttrsAttr()`
760765
* `CallableOpInterface` - Used to represent the target callee of call.
761766
- `Region * getCallableRegion()`
762767
- `ArrayRef<Type> getArgumentTypes()`

mlir/examples/toy/Ch4/include/toy/Ops.td

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,12 @@ def GenericCallOp : Toy_Op<"generic_call",
215215

216216
// The generic call operation takes a symbol reference attribute as the
217217
// callee, and inputs for the call.
218-
let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<F64Tensor>:$inputs);
218+
let arguments = (ins
219+
FlatSymbolRefAttr:$callee,
220+
Variadic<F64Tensor>:$inputs,
221+
OptionalAttr<DictArrayAttr>:$arg_attrs,
222+
OptionalAttr<DictArrayAttr>:$res_attrs
223+
);
219224

220225
// The generic call operation returns a single value of TensorType.
221226
let results = (outs F64Tensor);

mlir/examples/toy/Ch5/include/toy/Ops.td

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,12 @@ def GenericCallOp : Toy_Op<"generic_call",
214214

215215
// The generic call operation takes a symbol reference attribute as the
216216
// callee, and inputs for the call.
217-
let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<F64Tensor>:$inputs);
217+
let arguments = (ins
218+
FlatSymbolRefAttr:$callee,
219+
Variadic<F64Tensor>:$inputs,
220+
OptionalAttr<DictArrayAttr>:$arg_attrs,
221+
OptionalAttr<DictArrayAttr>:$res_attrs
222+
);
218223

219224
// The generic call operation returns a single value of TensorType.
220225
let results = (outs F64Tensor);

mlir/examples/toy/Ch6/include/toy/Ops.td

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,12 @@ def GenericCallOp : Toy_Op<"generic_call",
214214

215215
// The generic call operation takes a symbol reference attribute as the
216216
// callee, and inputs for the call.
217-
let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<F64Tensor>:$inputs);
217+
let arguments = (ins
218+
FlatSymbolRefAttr:$callee,
219+
Variadic<F64Tensor>:$inputs,
220+
OptionalAttr<DictArrayAttr>:$arg_attrs,
221+
OptionalAttr<DictArrayAttr>:$res_attrs
222+
);
218223

219224
// The generic call operation returns a single value of TensorType.
220225
let results = (outs F64Tensor);

0 commit comments

Comments
 (0)