Skip to content

[MLIR] [OpenMP] Initial support for OMP ALLOCATE directive op. #147900

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,31 @@
include "mlir/Dialect/OpenMP/OpenMPOpBase.td"
include "mlir/IR/SymbolInterfaces.td"

//===----------------------------------------------------------------------===//
// V5.2: [6.3] `align` clause
//===----------------------------------------------------------------------===//

class OpenMP_AlignClauseSkip<
bit traits = false, bit arguments = false, bit assemblyFormat = false,
bit description = false, bit extraClassDeclaration = false
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {
let arguments = (ins
ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$align
);

let optAssemblyFormat = [{
`align` `(` $align `)`
}];

let description = [{
The `align` clause is used to specify the byte alignment to use for
allocations associated with the construct on which the clause appears.
}];
}

def OpenMP_AlignClause : OpenMP_AlignClauseSkip<>;

//===----------------------------------------------------------------------===//
// V5.2: [5.11] `aligned` clause
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -84,6 +109,32 @@ class OpenMP_AllocateClauseSkip<

def OpenMP_AllocateClause : OpenMP_AllocateClauseSkip<>;

//===----------------------------------------------------------------------===//
// V5.2: [6.4] `allocator` clause
//===----------------------------------------------------------------------===//

class OpenMP_AllocatorClauseSkip<
bit traits = false, bit arguments = false, bit assemblyFormat = false,
bit description = false, bit extraClassDeclaration = false
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {

let arguments = (ins
OptionalAttr<AllocatorHandleAttr>:$allocator
);

let optAssemblyFormat = [{
`allocator` `(` custom<ClauseAttr>($allocator) `)`
}];

let description = [{
`allocator` specifies the memory allocator to be used for allocations
associated with the construct on which the clause appears.
}];
}

def OpenMP_AllocatorClause : OpenMP_AllocatorClauseSkip<>;

//===----------------------------------------------------------------------===//
// LLVM OpenMP extension `ompx_bare` clause
//===----------------------------------------------------------------------===//
Expand Down
30 changes: 30 additions & 0 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -263,4 +263,34 @@ def VariableCaptureKindAttr : OpenMP_EnumAttr<VariableCaptureKind,
let assemblyFormat = "`(` $value `)`";
}


//===----------------------------------------------------------------------===//
// allocator_handle enum.
//===----------------------------------------------------------------------===//

def OpenMP_AllocatorHandleNullAllocator : I32EnumAttrCase<"omp_null_allocator", 0>;
def OpenMP_AllocatorHandleDefaultMemAlloc : I32EnumAttrCase<"omp_default_mem_alloc", 1>;
def OpenMP_AllocatorHandleLargeCapMemAlloc : I32EnumAttrCase<"omp_large_cap_mem_alloc", 2>;
def OpenMP_AllocatorHandleConstMemAlloc : I32EnumAttrCase<"omp_const_mem_alloc", 3>;
def OpenMP_AllocatorHandleHighBwMemAlloc : I32EnumAttrCase<"omp_high_bw_mem_alloc", 4>;
def OpenMP_AllocatorHandleLowLatMemAlloc : I32EnumAttrCase<"omp_low_lat_mem_alloc", 5>;
def OpenMP_AllocatorHandleCgroupMemAlloc : I32EnumAttrCase<"omp_cgroup_mem_alloc", 6>;
def OpenMP_AllocatorHandlePteamMemAlloc : I32EnumAttrCase<"omp_pteam_mem_alloc", 7>;
def OpenMP_AllocatorHandlethreadMemAlloc : I32EnumAttrCase<"omp_thread_mem_alloc", 8>;

def AllocatorHandle : OpenMP_I32EnumAttr<
"AllocatorHandle",
"OpenMP allocator_handle", [
OpenMP_AllocatorHandleNullAllocator,
OpenMP_AllocatorHandleDefaultMemAlloc,
OpenMP_AllocatorHandleLargeCapMemAlloc,
OpenMP_AllocatorHandleConstMemAlloc,
OpenMP_AllocatorHandleHighBwMemAlloc,
OpenMP_AllocatorHandleLowLatMemAlloc,
OpenMP_AllocatorHandleCgroupMemAlloc,
OpenMP_AllocatorHandlePteamMemAlloc,
OpenMP_AllocatorHandlethreadMemAlloc
]>;

def AllocatorHandleAttr : OpenMP_EnumAttr<AllocatorHandle, "allocator_handle">;
#endif // OPENMP_ENUMS
27 changes: 27 additions & 0 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2090,4 +2090,31 @@ def MaskedOp : OpenMP_Op<"masked", clauses = [
];
}

//===----------------------------------------------------------------------===//
// [Spec 5.2] 6.5 allocate Directive
//===----------------------------------------------------------------------===//
def AllocateDirOp : OpenMP_Op<"allocate_dir", clauses = [
OpenMP_AlignClause, OpenMP_AllocatorClause
]> {
let summary = "allocate directive";
let description = [{
The storage for each list item that appears in the allocate directive is
provided an allocation through the memory allocator.
}] # clausesDescription;

let arguments = !con((ins Variadic<AnyType>:$varList),
clausesArgs);

// Override inherited assembly format to include `varList`.
let assemblyFormat = " `(` $varList `:` type($varList) `)` oilist(" #
clausesOptAssemblyFormat #
") attr-dict ";

let builders = [
OpBuilder<(ins CArg<"const AllocateDirOperands &">:$clauses)>
];

let hasVerifier = 1;
}

#endif // OPENMP_OPS
14 changes: 14 additions & 0 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3865,6 +3865,20 @@ LogicalResult ScanOp::verify() {
"reduction modifier");
}

/// Verifies align clause in allocate directive

LogicalResult AllocateDirOp::verify() {
std::optional<u_int64_t> align = this->getAlign();

if (align.has_value()) {
if ((align.value() > 0) && ((align.value() & (align.value() - 1)) != 0))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#include "llvm/ADT/bit.h"

if (llvm::has_single_bit(align.value()))

return emitError() << "ALIGN value : " << align.value()
<< " must be power of 2";
}

return success();
}

#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"

Expand Down
24 changes: 24 additions & 0 deletions mlir/test/Dialect/OpenMP/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2993,3 +2993,27 @@ llvm.func @invalid_mapper(%0 : !llvm.ptr) {
}
llvm.return
}

// -----
func.func @invalid_allocate_align_1(%arg0 : memref<i32>) -> () {
// expected-error @below {{failed to satisfy constraint: 64-bit signless integer attribute whose value is positive}}
omp.allocate_dir (%arg0 : memref<i32>) align(-1)

return
}

// -----
func.func @invalid_allocate_align_2(%arg0 : memref<i32>) -> () {
// expected-error @below {{must be power of 2}}
omp.allocate_dir (%arg0 : memref<i32>) align(3)

return
}

// -----
func.func @invalid_allocate_allocator(%arg0 : memref<i32>) -> () {
// expected-error @below {{invalid clause value}}
omp.allocate_dir (%arg0 : memref<i32>) allocator(omp_small_cap_mem_alloc)

return
}
33 changes: 33 additions & 0 deletions mlir/test/Dialect/OpenMP/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3197,3 +3197,36 @@ func.func @omp_workshare_loop_wrapper_attrs(%idx : index) {
}
return
}

// CHECK-LABEL: func.func @omp_allocate_dir(
// CHECK-SAME: %[[ARG0:.*]]: memref<i32>,
// CHECK-SAME: %[[ARG1:.*]]: memref<i32>) {
func.func @omp_allocate_dir(%arg0 : memref<i32>, %arg1 : memref<i32>) -> () {

// Test with one data var
// CHECK: omp.allocate_dir(%[[ARG0]] : memref<i32>)
omp.allocate_dir (%arg0 : memref<i32>)

// Test with two data vars
// CHECK: omp.allocate_dir(%[[ARG0]], %[[ARG1]] : memref<i32>, memref<i32>)
omp.allocate_dir (%arg0, %arg1: memref<i32>, memref<i32>)

// Test with one data var and align clause
// CHECK: omp.allocate_dir(%[[ARG0]] : memref<i32>) align(2)
omp.allocate_dir (%arg0 : memref<i32>) align(2)

// Test with one data var and allocator clause
// CHECK: omp.allocate_dir(%[[ARG0]] : memref<i32>) allocator(omp_pteam_mem_alloc)
omp.allocate_dir (%arg0 : memref<i32>) allocator(omp_pteam_mem_alloc)

// Test with one data var, align clause and allocator clause
// CHECK: omp.allocate_dir(%[[ARG0]] : memref<i32>) align(2) allocator(omp_thread_mem_alloc)
omp.allocate_dir (%arg0 : memref<i32>) align(2) allocator(omp_thread_mem_alloc)

// Test with two data vars, align clause and allocator clause
// CHECK: omp.allocate_dir(%[[ARG0]], %[[ARG1]] : memref<i32>, memref<i32>) align(2) allocator(omp_cgroup_mem_alloc)
omp.allocate_dir (%arg0, %arg1 : memref<i32>, memref<i32>) align(2) allocator(omp_cgroup_mem_alloc)

return
}

Loading