Skip to content

Commit 30f7a6c

Browse files
[flang] Correctly prepare allocatable runtime call arguments (llvm#138727)
When lowering allocatables, the generated calls to runtime functions were not using the runtime::createArguments utility which handles the required conversions. createArguments is where I added the implicit volatile casts to handle converting volatile variables to the appropriate type based on their volatility in the callee. Because the calls to allocatable runtime functions were not using this function, their arguments were not casted to have the appropriate volatility. Add a test to demonstrate that volatile and allocatable class/box/reference types are appropriately casted before calling into the runtime library. Instead of using a recursive variadic template to perform the conversions in createArguments, map over the arguments directly so that createArguments can be called with an ArrayRef of arguments. Some cases in Allocatable.cpp already had a vector of values at the point where createArguments needed to be called - the new overload allows calling with a vector of args or the variadic version with each argument spelled out at the callsite. This change resulted in the allocatable runtime calls having their arguments converted left-to-right, which changed some of the test results. I used CHECK-DAG to ignore the order. Add some missing handling of volatile class entities, which I previously missed because I had not yet enabled volatile class entities in Lower.
1 parent be6c6e2 commit 30f7a6c

File tree

9 files changed

+334
-165
lines changed

9 files changed

+334
-165
lines changed

flang/include/flang/Optimizer/Builder/Runtime/RTBuilder.h

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "flang/Support/Fortran.h"
2727
#include "mlir/IR/BuiltinTypes.h"
2828
#include "mlir/IR/MLIRContext.h"
29+
#include "llvm/ADT/STLExtras.h"
2930
#include "llvm/ADT/SmallVector.h"
3031
#include <cstdint>
3132
#include <functional>
@@ -824,33 +825,23 @@ static mlir::func::FuncOp getIORuntimeFunc(mlir::Location loc,
824825
return getRuntimeFunc<E>(loc, builder, /*isIO=*/true);
825826
}
826827

827-
namespace helper {
828-
template <int N, typename A>
829-
void createArguments(llvm::SmallVectorImpl<mlir::Value> &result,
830-
fir::FirOpBuilder &builder, mlir::Location loc,
831-
mlir::FunctionType fTy, A arg) {
832-
result.emplace_back(
833-
builder.createConvertWithVolatileCast(loc, fTy.getInput(N), arg));
834-
}
835-
836-
template <int N, typename A, typename... As>
837-
void createArguments(llvm::SmallVectorImpl<mlir::Value> &result,
838-
fir::FirOpBuilder &builder, mlir::Location loc,
839-
mlir::FunctionType fTy, A arg, As... args) {
840-
result.emplace_back(
841-
builder.createConvertWithVolatileCast(loc, fTy.getInput(N), arg));
842-
createArguments<N + 1>(result, builder, loc, fTy, args...);
828+
inline llvm::SmallVector<mlir::Value>
829+
createArguments(fir::FirOpBuilder &builder, mlir::Location loc,
830+
mlir::FunctionType fTy, llvm::ArrayRef<mlir::Value> args) {
831+
return llvm::map_to_vector(llvm::zip_equal(fTy.getInputs(), args),
832+
[&](const auto &pair) -> mlir::Value {
833+
auto [type, argument] = pair;
834+
return builder.createConvertWithVolatileCast(
835+
loc, type, argument);
836+
});
843837
}
844-
} // namespace helper
845838

846839
/// Create a SmallVector of arguments for a runtime call.
847840
template <typename... As>
848841
llvm::SmallVector<mlir::Value>
849842
createArguments(fir::FirOpBuilder &builder, mlir::Location loc,
850843
mlir::FunctionType fTy, As... args) {
851-
llvm::SmallVector<mlir::Value> result;
852-
helper::createArguments<0>(result, builder, loc, fTy, args...);
853-
return result;
844+
return createArguments(builder, loc, fTy, {args...});
854845
}
855846

856847
} // namespace fir::runtime

flang/lib/Lower/Allocatable.cpp

Lines changed: 26 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -138,12 +138,10 @@ static void genRuntimeSetBounds(fir::FirOpBuilder &builder, mlir::Location loc,
138138
builder)
139139
: fir::runtime::getRuntimeFunc<mkRTKey(AllocatableSetBounds)>(
140140
loc, builder);
141-
llvm::SmallVector<mlir::Value> args{box.getAddr(), dimIndex, lowerBound,
142-
upperBound};
143-
llvm::SmallVector<mlir::Value> operands;
144-
for (auto [fst, snd] : llvm::zip(args, callee.getFunctionType().getInputs()))
145-
operands.emplace_back(builder.createConvert(loc, snd, fst));
146-
builder.create<fir::CallOp>(loc, callee, operands);
141+
const auto args = fir::runtime::createArguments(
142+
builder, loc, callee.getFunctionType(), box.getAddr(), dimIndex,
143+
lowerBound, upperBound);
144+
builder.create<fir::CallOp>(loc, callee, args);
147145
}
148146

149147
/// Generate runtime call to set the lengths of a character allocatable or
@@ -162,9 +160,7 @@ static void genRuntimeInitCharacter(fir::FirOpBuilder &builder,
162160
if (inputTypes.size() != 5)
163161
fir::emitFatalError(
164162
loc, "AllocatableInitCharacter runtime interface not as expected");
165-
llvm::SmallVector<mlir::Value> args;
166-
args.push_back(builder.createConvert(loc, inputTypes[0], box.getAddr()));
167-
args.push_back(builder.createConvert(loc, inputTypes[1], len));
163+
llvm::SmallVector<mlir::Value> args = {box.getAddr(), len};
168164
if (kind == 0)
169165
kind = mlir::cast<fir::CharacterType>(box.getEleTy()).getFKind();
170166
args.push_back(builder.createIntegerConstant(loc, inputTypes[2], kind));
@@ -173,7 +169,9 @@ static void genRuntimeInitCharacter(fir::FirOpBuilder &builder,
173169
// TODO: coarrays
174170
int corank = 0;
175171
args.push_back(builder.createIntegerConstant(loc, inputTypes[4], corank));
176-
builder.create<fir::CallOp>(loc, callee, args);
172+
const auto convertedArgs = fir::runtime::createArguments(
173+
builder, loc, callee.getFunctionType(), args);
174+
builder.create<fir::CallOp>(loc, callee, convertedArgs);
177175
}
178176

179177
/// Generate a sequence of runtime calls to allocate memory.
@@ -194,10 +192,9 @@ static mlir::Value genRuntimeAllocate(fir::FirOpBuilder &builder,
194192
args.push_back(errorManager.errMsgAddr);
195193
args.push_back(errorManager.sourceFile);
196194
args.push_back(errorManager.sourceLine);
197-
llvm::SmallVector<mlir::Value> operands;
198-
for (auto [fst, snd] : llvm::zip(args, callee.getFunctionType().getInputs()))
199-
operands.emplace_back(builder.createConvert(loc, snd, fst));
200-
return builder.create<fir::CallOp>(loc, callee, operands).getResult(0);
195+
const auto convertedArgs = fir::runtime::createArguments(
196+
builder, loc, callee.getFunctionType(), args);
197+
return builder.create<fir::CallOp>(loc, callee, convertedArgs).getResult(0);
201198
}
202199

203200
/// Generate a sequence of runtime calls to allocate memory and assign with the
@@ -213,14 +210,11 @@ static mlir::Value genRuntimeAllocateSource(fir::FirOpBuilder &builder,
213210
loc, builder)
214211
: fir::runtime::getRuntimeFunc<mkRTKey(AllocatableAllocateSource)>(
215212
loc, builder);
216-
llvm::SmallVector<mlir::Value> args{
217-
box.getAddr(), fir::getBase(source),
218-
errorManager.hasStat, errorManager.errMsgAddr,
219-
errorManager.sourceFile, errorManager.sourceLine};
220-
llvm::SmallVector<mlir::Value> operands;
221-
for (auto [fst, snd] : llvm::zip(args, callee.getFunctionType().getInputs()))
222-
operands.emplace_back(builder.createConvert(loc, snd, fst));
223-
return builder.create<fir::CallOp>(loc, callee, operands).getResult(0);
213+
const auto args = fir::runtime::createArguments(
214+
builder, loc, callee.getFunctionType(), box.getAddr(),
215+
fir::getBase(source), errorManager.hasStat, errorManager.errMsgAddr,
216+
errorManager.sourceFile, errorManager.sourceLine);
217+
return builder.create<fir::CallOp>(loc, callee, args).getResult(0);
224218
}
225219

226220
/// Generate runtime call to apply mold to the descriptor.
@@ -234,14 +228,12 @@ static void genRuntimeAllocateApplyMold(fir::FirOpBuilder &builder,
234228
builder)
235229
: fir::runtime::getRuntimeFunc<mkRTKey(AllocatableApplyMold)>(
236230
loc, builder);
237-
llvm::SmallVector<mlir::Value> args{
231+
const auto args = fir::runtime::createArguments(
232+
builder, loc, callee.getFunctionType(),
238233
fir::factory::getMutableIRBox(builder, loc, box), fir::getBase(mold),
239234
builder.createIntegerConstant(
240-
loc, callee.getFunctionType().getInputs()[2], rank)};
241-
llvm::SmallVector<mlir::Value> operands;
242-
for (auto [fst, snd] : llvm::zip(args, callee.getFunctionType().getInputs()))
243-
operands.emplace_back(builder.createConvert(loc, snd, fst));
244-
builder.create<fir::CallOp>(loc, callee, operands);
235+
loc, callee.getFunctionType().getInputs()[2], rank));
236+
builder.create<fir::CallOp>(loc, callee, args);
245237
}
246238

247239
/// Generate a runtime call to deallocate memory.
@@ -669,15 +661,13 @@ class AllocateStmtHelper {
669661

670662
llvm::ArrayRef<mlir::Type> inputTypes =
671663
callee.getFunctionType().getInputs();
672-
llvm::SmallVector<mlir::Value> args;
673-
args.push_back(builder.createConvert(loc, inputTypes[0], box.getAddr()));
674-
args.push_back(builder.createConvert(loc, inputTypes[1], typeDescAddr));
675664
mlir::Value rankValue =
676665
builder.createIntegerConstant(loc, inputTypes[2], rank);
677666
mlir::Value corankValue =
678667
builder.createIntegerConstant(loc, inputTypes[3], corank);
679-
args.push_back(rankValue);
680-
args.push_back(corankValue);
668+
const auto args = fir::runtime::createArguments(
669+
builder, loc, callee.getFunctionType(), box.getAddr(), typeDescAddr,
670+
rankValue, corankValue);
681671
builder.create<fir::CallOp>(loc, callee, args);
682672
}
683673

@@ -696,8 +686,6 @@ class AllocateStmtHelper {
696686

697687
llvm::ArrayRef<mlir::Type> inputTypes =
698688
callee.getFunctionType().getInputs();
699-
llvm::SmallVector<mlir::Value> args;
700-
args.push_back(builder.createConvert(loc, inputTypes[0], box.getAddr()));
701689
mlir::Value categoryValue = builder.createIntegerConstant(
702690
loc, inputTypes[1], static_cast<int32_t>(category));
703691
mlir::Value kindValue =
@@ -706,10 +694,9 @@ class AllocateStmtHelper {
706694
builder.createIntegerConstant(loc, inputTypes[3], rank);
707695
mlir::Value corankValue =
708696
builder.createIntegerConstant(loc, inputTypes[4], corank);
709-
args.push_back(categoryValue);
710-
args.push_back(kindValue);
711-
args.push_back(rankValue);
712-
args.push_back(corankValue);
697+
const auto args = fir::runtime::createArguments(
698+
builder, loc, callee.getFunctionType(), box.getAddr(), categoryValue,
699+
kindValue, rankValue, corankValue);
713700
builder.create<fir::CallOp>(loc, callee, args);
714701
}
715702

flang/lib/Lower/ConvertExprToHLFIR.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,12 @@ class HlfirDesignatorBuilder {
227227
isVolatile = true;
228228
}
229229

230+
// Check if the base type is volatile
231+
if (partInfo.base.has_value()) {
232+
mlir::Type baseType = partInfo.base.value().getType();
233+
isVolatile = isVolatile || fir::isa_volatile_type(baseType);
234+
}
235+
230236
// Dynamic type of polymorphic base must be kept if the designator is
231237
// polymorphic.
232238
if (isPolymorphic(designatorNode))
@@ -238,12 +244,6 @@ class HlfirDesignatorBuilder {
238244
if (charType && charType.hasDynamicLen())
239245
return fir::BoxCharType::get(charType.getContext(), charType.getFKind());
240246

241-
// Check if the base type is volatile
242-
if (partInfo.base.has_value()) {
243-
mlir::Type baseType = partInfo.base.value().getType();
244-
isVolatile = isVolatile || fir::isa_volatile_type(baseType);
245-
}
246-
247247
// Arrays with non default lower bounds or dynamic length or dynamic extent
248248
// need a fir.box to hold the dynamic or lower bound information.
249249
if (fir::hasDynamicSize(resultValueType) ||

flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -210,15 +210,14 @@ static bool hasExplicitLowerBounds(mlir::Value shape) {
210210
static std::pair<mlir::Type, mlir::Value> updateDeclareInputTypeWithVolatility(
211211
mlir::Type inputType, mlir::Value memref, mlir::OpBuilder &builder,
212212
fir::FortranVariableFlagsAttr fortran_attrs) {
213-
if (mlir::isa<fir::BoxType, fir::ReferenceType>(inputType) && fortran_attrs &&
213+
if (fortran_attrs &&
214214
bitEnumContainsAny(fortran_attrs.getFlags(),
215215
fir::FortranVariableFlagsEnum::fortran_volatile)) {
216216
const bool isPointer = bitEnumContainsAny(
217217
fortran_attrs.getFlags(), fir::FortranVariableFlagsEnum::pointer);
218218
auto updateType = [&](auto t) {
219219
using FIRT = decltype(t);
220-
// If an entity is a pointer, the entity it points to is volatile, as far
221-
// as consumers of the pointer are concerned.
220+
// A volatile pointer's pointee is volatile.
222221
auto elementType = t.getEleTy();
223222
const bool elementTypeIsVolatile =
224223
isPointer || fir::isa_volatile_type(elementType);
@@ -227,8 +226,7 @@ static std::pair<mlir::Type, mlir::Value> updateDeclareInputTypeWithVolatility(
227226
inputType = FIRT::get(newEleTy, true);
228227
};
229228
llvm::TypeSwitch<mlir::Type>(inputType)
230-
.Case<fir::ReferenceType, fir::BoxType>(updateType)
231-
.Default([](mlir::Type t) { return t; });
229+
.Case<fir::ReferenceType, fir::BoxType, fir::ClassType>(updateType);
232230
memref =
233231
builder.create<fir::VolatileCastOp>(memref.getLoc(), inputType, memref);
234232
}

0 commit comments

Comments
 (0)