Skip to content

[flang] Introduce omp.target_allocmem and omp.target_freemem omp dialect ops. #145464

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 167 additions & 0 deletions flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,177 @@ struct PrivateClauseOpConversion
return mlir::success();
}
};

static mlir::LLVM::LLVMFuncOp getOmpTargetAlloc(mlir::Operation *op) {
auto module = op->getParentOfType<mlir::ModuleOp>();
if (mlir::LLVM::LLVMFuncOp mallocFunc =
module.lookupSymbol<mlir::LLVM::LLVMFuncOp>("omp_target_alloc"))
return mallocFunc;
mlir::OpBuilder moduleBuilder(module.getBodyRegion());
auto i64Ty = mlir::IntegerType::get(module->getContext(), 64);
auto i32Ty = mlir::IntegerType::get(module->getContext(), 32);
return moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
moduleBuilder.getUnknownLoc(), "omp_target_alloc",
mlir::LLVM::LLVMFunctionType::get(
mlir::LLVM::LLVMPointerType::get(module->getContext()),
{i64Ty, i32Ty},
/*isVarArg=*/false));
}

static mlir::Type convertObjectType(const fir::LLVMTypeConverter &converter,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems this function and some of the ones below are shared between CodeGen.cpp and CodeGenOpenMP.cpp. Can we move them to a shared location, e.g. flang/Optimizer/Support/Utils.h/.cpp?

mlir::Type firType) {
if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(firType))
return converter.convertBoxTypeAsStruct(boxTy);
return converter.convertType(firType);
}

static llvm::SmallVector<mlir::NamedAttribute>
addLLVMOpBundleAttrs(mlir::ConversionPatternRewriter &rewriter,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems this function does more than what the name suggests (i.e. copies old attributes, drops the old operandSegmentSizes and recreates it, and adds and empty op_bundle_sizes).

Should we do the attribute changes in-place (i.e. at the call site) instead? Since this function does not seem reusable in other scenarios.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see this is adapted from CodeGen.cpp though, so see my comment below.

llvm::ArrayRef<mlir::NamedAttribute> attrs,
int32_t numCallOperands) {
llvm::SmallVector<mlir::NamedAttribute> newAttrs;
newAttrs.reserve(attrs.size() + 2);

for (mlir::NamedAttribute attr : attrs) {
if (attr.getName() != "operandSegmentSizes")
newAttrs.push_back(attr);
}

newAttrs.push_back(rewriter.getNamedAttr(
"operandSegmentSizes",
rewriter.getDenseI32ArrayAttr({numCallOperands, 0})));
newAttrs.push_back(rewriter.getNamedAttr("op_bundle_sizes",
rewriter.getDenseI32ArrayAttr({})));
return newAttrs;
}

static mlir::LLVM::ConstantOp
genConstantIndex(mlir::Location loc, mlir::Type ity,
mlir::ConversionPatternRewriter &rewriter,
std::int64_t offset) {
auto cattr = rewriter.getI64IntegerAttr(offset);
return rewriter.create<mlir::LLVM::ConstantOp>(loc, ity, cattr);
}

static mlir::Value
computeElementDistance(mlir::Location loc, mlir::Type llvmObjectType,
mlir::Type idxTy,
mlir::ConversionPatternRewriter &rewriter,
const mlir::DataLayout &dataLayout) {
llvm::TypeSize size = dataLayout.getTypeSize(llvmObjectType);
unsigned short alignment = dataLayout.getTypeABIAlignment(llvmObjectType);
std::int64_t distance = llvm::alignTo(size, alignment);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we are using a deprected conversion here (see: https://github.com/llvm/llvm-project/blob/main/llvm/include/llvm/Support/TypeSize.h#L360).

return genConstantIndex(loc, idxTy, rewriter, distance);
}

static mlir::Value genTypeSizeInBytes(mlir::Location loc, mlir::Type idxTy,
mlir::ConversionPatternRewriter &rewriter,
mlir::Type llTy,
const mlir::DataLayout &dataLayout) {
return computeElementDistance(loc, llTy, idxTy, rewriter, dataLayout);
}

template <typename OP>
static mlir::Value
genAllocationScaleSize(OP op, mlir::Type ity,
mlir::ConversionPatternRewriter &rewriter) {
mlir::Location loc = op.getLoc();
mlir::Type dataTy = op.getInType();
auto seqTy = mlir::dyn_cast<fir::SequenceType>(dataTy);
fir::SequenceType::Extent constSize = 1;
if (seqTy) {
int constRows = seqTy.getConstantRows();
const fir::SequenceType::ShapeRef &shape = seqTy.getShape();
if (constRows != static_cast<int>(shape.size())) {
for (auto extent : shape) {
if (constRows-- > 0)
continue;
if (extent != fir::SequenceType::getUnknownExtent())
constSize *= extent;
}
}
}

if (constSize != 1) {
mlir::Value constVal{
genConstantIndex(loc, ity, rewriter, constSize).getResult()};
return constVal;
}
return nullptr;
}

static mlir::Value integerCast(const fir::LLVMTypeConverter &converter,
mlir::Location loc,
mlir::ConversionPatternRewriter &rewriter,
mlir::Type ty, mlir::Value val,
bool fold = false) {
auto valTy = val.getType();
// If the value was not yet lowered, lower its type so that it can
// be used in getPrimitiveTypeSizeInBits.
if (!mlir::isa<mlir::IntegerType>(valTy))
valTy = converter.convertType(valTy);
auto toSize = mlir::LLVM::getPrimitiveTypeSizeInBits(ty);
auto fromSize = mlir::LLVM::getPrimitiveTypeSizeInBits(valTy);
if (fold) {
if (toSize < fromSize)
return rewriter.createOrFold<mlir::LLVM::TruncOp>(loc, ty, val);
if (toSize > fromSize)
return rewriter.createOrFold<mlir::LLVM::SExtOp>(loc, ty, val);
} else {
if (toSize < fromSize)
return rewriter.create<mlir::LLVM::TruncOp>(loc, ty, val);
if (toSize > fromSize)
return rewriter.create<mlir::LLVM::SExtOp>(loc, ty, val);
}
return val;
}

// FIR Op specific conversion for TargetAllocMemOp
struct TargetAllocMemOpConversion
Copy link
Member

@ergawy ergawy Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this conversion pattern should not replace omp.target_allocmem. Instead it should only convert the fir-specific allocation type to i64 and leave the lowering of omp.target_allocmem to call ptr @omp_target_alloc(...) to convertTargetAllocMemOp in OpenMPToLLVMIRTranslation.cpp. What we have now means that we do the same conversion in 2 different places. The preferred place would be OpenMPToLLVMIRTranslation.cpp since this is where all OpenMP to LLVM conversions are done.

: public OpenMPFIROpConversion<mlir::omp::TargetAllocMemOp> {
using OpenMPFIROpConversion::OpenMPFIROpConversion;

llvm::LogicalResult
matchAndRewrite(mlir::omp::TargetAllocMemOp allocmemOp, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::Type heapTy = allocmemOp.getAllocatedType();
mlir::LLVM::LLVMFuncOp mallocFunc = getOmpTargetAlloc(allocmemOp);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: move before use location below.

mlir::Location loc = allocmemOp.getLoc();
auto ity = lowerTy().indexType();
mlir::Type dataTy = fir::unwrapRefType(heapTy);
mlir::Type llvmObjectTy = convertObjectType(lowerTy(), dataTy);
mlir::Type llvmPtrTy =
mlir::LLVM::LLVMPointerType::get(allocmemOp.getContext(), 0);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
mlir::LLVM::LLVMPointerType::get(allocmemOp.getContext(), 0);
mlir::LLVM::LLVMPointerType::get(allocmemOp.getContext());

0 is the default value for the address space.

if (fir::isRecordWithTypeParameters(fir::unwrapSequenceType(dataTy)))
TODO(loc, "omp.target_allocmem codegen of derived type with length "
"parameters");
mlir::Value size = genTypeSizeInBytes(loc, ity, rewriter, llvmObjectTy,
lowerTy().getDataLayout());
if (auto scaleSize = genAllocationScaleSize(allocmemOp, ity, rewriter))
size = rewriter.create<mlir::LLVM::MulOp>(loc, ity, size, scaleSize);
for (mlir::Value opnd : adaptor.getOperands().drop_front())
size = rewriter.create<mlir::LLVM::MulOp>(
loc, ity, size, integerCast(lowerTy(), loc, rewriter, ity, opnd));
auto mallocTyWidth = lowerTy().getIndexTypeBitwidth();
auto mallocTy =
mlir::IntegerType::get(rewriter.getContext(), mallocTyWidth);
if (mallocTyWidth != ity.getIntOrFloatBitWidth())
size = integerCast(lowerTy(), loc, rewriter, mallocTy, size);
Comment on lines +282 to +283
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this if condition is not covered by the introduced tests.

allocmemOp->setAttr("callee", mlir::SymbolRefAttr::get(mallocFunc));
auto callOp = rewriter.create<mlir::LLVM::CallOp>(
loc, llvmPtrTy,
mlir::SmallVector<mlir::Value, 2>({size, allocmemOp.getDevice()}),
addLLVMOpBundleAttrs(rewriter, allocmemOp->getAttrs(), 2));
rewriter.replaceOpWithNewOp<mlir::LLVM::PtrToIntOp>(
allocmemOp, rewriter.getIntegerType(64), callOp.getResult());
return mlir::success();
}
};
} // namespace

void fir::populateOpenMPFIRToLLVMConversionPatterns(
const LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns) {
patterns.add<MapInfoOpConversion>(converter);
patterns.add<PrivateClauseOpConversion>(converter);
patterns.add<TargetAllocMemOpConversion>(converter);
}
1 change: 0 additions & 1 deletion flang/lib/Optimizer/Dialect/FIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ static bool verifyTypeParamCount(mlir::Type inType, unsigned numParams) {
}

/// Parser shared by Alloca and Allocmem
///
/// operation ::= %res = (`fir.alloca` | `fir.allocmem`) $in_type
/// ( `(` $typeparams `)` )? ( `,` $shape )?
/// attr-dict-without-keyword
Expand Down
33 changes: 33 additions & 0 deletions flang/test/Fir/omp_target_allocmem_freemem.fir
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// RUN: %flang_fc1 -emit-llvm %s -o - | FileCheck %s

// CHECK-LABEL: define void @omp_target_allocmem_array_of_nonchar(
// CHECK: call ptr @omp_target_alloc(i64 36, i32 0)
// CHECK: call void @omp_target_free(ptr {{.*}}, i32 0)
func.func @omp_target_allocmem_array_of_nonchar() -> () {
%device = arith.constant 0 : i32
%1 = omp.target_allocmem %device : i32, !fir.array<3x3xi32>
omp.target_freemem %device, %1 : i32, i64
return
}

// CHECK-LABEL: define void @omp_target_allocmem_array_of_char(
// CHECK: call ptr @omp_target_alloc(i64 90, i32 0)
// CHECK: call void @omp_target_free(ptr {{.*}}, i32 0)
func.func @omp_target_allocmem_array_of_char() -> () {
%device = arith.constant 0 : i32
%1 = omp.target_allocmem %device : i32, !fir.array<3x3x!fir.char<1,10>>
omp.target_freemem %device, %1 : i32, i64
return
}

// CHECK-LABEL: define void @omp_target_allocmem_array_of_dynchar(
// CHECK-SAME: i32 %[[len:.*]])
// CHECK: %[[mul1:.*]] = sext i32 %[[len]] to i64
// CHECK: %[[mul2:.*]] = mul i64 9, %[[mul1]]
// CHECK: call ptr @omp_target_alloc(i64 %[[mul2]], i32 0)
func.func @omp_target_allocmem_array_of_dynchar(%l: i32) -> () {
%device = arith.constant 0 : i32
%1 = omp.target_allocmem %device : i32, !fir.array<3x3x!fir.char<1,?>>(%l : i32)
omp.target_freemem %device, %1 : i32, i64
return
}
64 changes: 64 additions & 0 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1887,4 +1887,68 @@ def MaskedOp : OpenMP_Op<"masked", clauses = [
];
}

//===----------------------------------------------------------------------===//
// TargetAllocMemOp
//===----------------------------------------------------------------------===//

def TargetAllocMemOp : OpenMP_Op<"target_allocmem",
[MemoryEffects<[MemAlloc<DefaultResource>]>, AttrSizedOperandSegments]> {
let summary = "allocate storage on an openmp device for an object of a given type";

let description = [{
Allocates memory on the specified OpenMP device for an object of the given type.
Returns an integer value representing the device pointer to the allocated memory.
The memory is uninitialized after allocation. Operations must be paired with
`omp.target_freemem` to avoid memory leaks.

```mlir
%device = arith.constant 0 : i32
%ptr = omp.target_allocmem %device : i32, vector<3x3xi32>
```
}];

let arguments = (ins
Arg<AnyInteger>:$device,
TypeAttr:$in_type,
OptionalAttr<StrAttr>:$uniq_name,
OptionalAttr<StrAttr>:$bindc_name,
Variadic<AnyInteger>:$typeparams,
Variadic<AnyInteger>:$shape
);
let results = (outs I64);

let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;

let extraClassDeclaration = [{
mlir::Type getAllocatedType();
}];
}

//===----------------------------------------------------------------------===//
// TargetFreeMemOp
//===----------------------------------------------------------------------===//

def TargetFreeMemOp : OpenMP_Op<"target_freemem",
[MemoryEffects<[MemFree]>]> {
let summary = "free memory on an openmp device";

let description = [{
Deallocates memory on the specified OpenMP device that was previously
allocated by an `omp.target_allocmem` operation. The memory is placed
in an undefined state after deallocation.
```
%device = arith.constant 0 : i32
%ptr = omp.target_allocmem %device : i32, vector<3x3xi32>
omp.target_freemem %device, %ptr : i32, i64
```
}];

let arguments = (ins
Arg<AnyInteger, "", [MemFree]>:$device,
Arg<I64, "", [MemFree]>:$heapref
);
let assemblyFormat = "$device `,` $heapref attr-dict `:` type($device) `,` qualified(type($heapref))";
}

#endif // OPENMP_OPS
102 changes: 102 additions & 0 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3493,6 +3493,108 @@ LogicalResult ScanOp::verify() {
"reduction modifier");
}

//===----------------------------------------------------------------------===//
// TargetAllocMemOp
//===----------------------------------------------------------------------===//

mlir::Type omp::TargetAllocMemOp::getAllocatedType() {
return getInTypeAttr().getValue();
}

/// operation ::= %res = (`omp.target_alloc_mem`) $device : devicetype,
/// $in_type ( `(` $typeparams `)` )? ( `,` $shape )?
/// attr-dict-without-keyword
static mlir::ParseResult parseTargetAllocMemOp(mlir::OpAsmParser &parser,
mlir::OperationState &result) {
auto &builder = parser.getBuilder();
bool hasOperands = false;
std::int32_t typeparamsSize = 0;

// Parse device number as a new operand
mlir::OpAsmParser::UnresolvedOperand deviceOperand;
mlir::Type deviceType;
if (parser.parseOperand(deviceOperand) || parser.parseColonType(deviceType))
return mlir::failure();
if (parser.resolveOperand(deviceOperand, deviceType, result.operands))
return mlir::failure();
if (parser.parseComma())
return mlir::failure();

mlir::Type intype;
if (parser.parseType(intype))
return mlir::failure();
result.addAttribute("in_type", mlir::TypeAttr::get(intype));
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands;
llvm::SmallVector<mlir::Type> typeVec;
if (!parser.parseOptionalLParen()) {
// parse the LEN params of the derived type. (<params> : <types>)
if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None) ||
parser.parseColonTypeList(typeVec) || parser.parseRParen())
return mlir::failure();
typeparamsSize = operands.size();
hasOperands = true;
}
std::int32_t shapeSize = 0;
if (!parser.parseOptionalComma()) {
// parse size to scale by, vector of n dimensions of type index
if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None))
return mlir::failure();
shapeSize = operands.size() - typeparamsSize;
auto idxTy = builder.getIndexType();
for (std::int32_t i = typeparamsSize, end = operands.size(); i != end; ++i)
typeVec.push_back(idxTy);
hasOperands = true;
}
if (hasOperands &&
parser.resolveOperands(operands, typeVec, parser.getNameLoc(),
result.operands))
return mlir::failure();

mlir::Type restype = builder.getIntegerType(64);
;
if (!restype) {
parser.emitError(parser.getNameLoc(), "invalid allocate type: ") << intype;
return mlir::failure();
}
llvm::SmallVector<std::int32_t> segmentSizes{1, typeparamsSize, shapeSize};
result.addAttribute("operandSegmentSizes",
builder.getDenseI32ArrayAttr(segmentSizes));
if (parser.parseOptionalAttrDict(result.attributes) ||
parser.addTypeToList(restype, result.types))
return mlir::failure();
return mlir::success();
}

mlir::ParseResult omp::TargetAllocMemOp::parse(mlir::OpAsmParser &parser,
mlir::OperationState &result) {
return parseTargetAllocMemOp(parser, result);
}

void omp::TargetAllocMemOp::print(mlir::OpAsmPrinter &p) {
p << " ";
p.printOperand(getDevice());
p << " : ";
p << getDevice().getType();
p << ", ";
p << getInType();
if (!getTypeparams().empty()) {
p << '(' << getTypeparams() << " : " << getTypeparams().getTypes() << ')';
}
for (auto sh : getShape()) {
p << ", ";
p.printOperand(sh);
}
p.printOptionalAttrDict((*this)->getAttrs(),
{"in_type", "operandSegmentSizes"});
}

llvm::LogicalResult omp::TargetAllocMemOp::verify() {
mlir::Type outType = getType();
if (!mlir::dyn_cast<IntegerType>(outType))
return emitOpError("must be a integer type");
return mlir::success();
}

#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"

Expand Down
Loading