diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td index 16c14ef085d6d..311c57fb4446c 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td @@ -22,6 +22,31 @@ include "mlir/Dialect/OpenMP/OpenMPOpBase.td" include "mlir/IR/SymbolInterfaces.td" +//===----------------------------------------------------------------------===// +// V5.2: [6.3] `align` clause +//===----------------------------------------------------------------------===// + +class OpenMP_AlignClauseSkip< + bit traits = false, bit arguments = false, bit assemblyFormat = false, + bit description = false, bit extraClassDeclaration = false + > : OpenMP_Clause { + let arguments = (ins + ConfinedAttr, [IntPositive]>:$align + ); + + let optAssemblyFormat = [{ + `align` `(` $align `)` + }]; + + let description = [{ + The `align` clause is used to specify the byte alignment to use for + allocations associated with the construct on which the clause appears. + }]; +} + +def OpenMP_AlignClause : OpenMP_AlignClauseSkip<>; + //===----------------------------------------------------------------------===// // V5.2: [5.11] `aligned` clause //===----------------------------------------------------------------------===// @@ -84,6 +109,32 @@ class OpenMP_AllocateClauseSkip< def OpenMP_AllocateClause : OpenMP_AllocateClauseSkip<>; +//===----------------------------------------------------------------------===// +// V5.2: [6.4] `allocator` clause +//===----------------------------------------------------------------------===// + +class OpenMP_AllocatorClauseSkip< + bit traits = false, bit arguments = false, bit assemblyFormat = false, + bit description = false, bit extraClassDeclaration = false + > : OpenMP_Clause { + + let arguments = (ins + OptionalAttr:$allocator + ); + + let optAssemblyFormat = [{ + `allocator` `(` custom($allocator) `)` + }]; + + let description = [{ + `allocator` specifies the memory allocator to be used for allocations + associated with the construct on which the clause appears. + }]; +} + +def OpenMP_AllocatorClause : OpenMP_AllocatorClauseSkip<>; + //===----------------------------------------------------------------------===// // LLVM OpenMP extension `ompx_bare` clause //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td index 9dbe6897a3304..c080c3fac87d4 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td @@ -263,4 +263,34 @@ def VariableCaptureKindAttr : OpenMP_EnumAttr; +def OpenMP_AllocatorHandleDefaultMemAlloc : I32EnumAttrCase<"omp_default_mem_alloc", 1>; +def OpenMP_AllocatorHandleLargeCapMemAlloc : I32EnumAttrCase<"omp_large_cap_mem_alloc", 2>; +def OpenMP_AllocatorHandleConstMemAlloc : I32EnumAttrCase<"omp_const_mem_alloc", 3>; +def OpenMP_AllocatorHandleHighBwMemAlloc : I32EnumAttrCase<"omp_high_bw_mem_alloc", 4>; +def OpenMP_AllocatorHandleLowLatMemAlloc : I32EnumAttrCase<"omp_low_lat_mem_alloc", 5>; +def OpenMP_AllocatorHandleCgroupMemAlloc : I32EnumAttrCase<"omp_cgroup_mem_alloc", 6>; +def OpenMP_AllocatorHandlePteamMemAlloc : I32EnumAttrCase<"omp_pteam_mem_alloc", 7>; +def OpenMP_AllocatorHandlethreadMemAlloc : I32EnumAttrCase<"omp_thread_mem_alloc", 8>; + +def AllocatorHandle : OpenMP_I32EnumAttr< + "AllocatorHandle", + "OpenMP allocator_handle", [ + OpenMP_AllocatorHandleNullAllocator, + OpenMP_AllocatorHandleDefaultMemAlloc, + OpenMP_AllocatorHandleLargeCapMemAlloc, + OpenMP_AllocatorHandleConstMemAlloc, + OpenMP_AllocatorHandleHighBwMemAlloc, + OpenMP_AllocatorHandleLowLatMemAlloc, + OpenMP_AllocatorHandleCgroupMemAlloc, + OpenMP_AllocatorHandlePteamMemAlloc, + OpenMP_AllocatorHandlethreadMemAlloc + ]>; + +def AllocatorHandleAttr : OpenMP_EnumAttr; #endif // OPENMP_ENUMS diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index e4f52777d8aa2..f9ebb17411533 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -2090,4 +2090,31 @@ def MaskedOp : OpenMP_Op<"masked", clauses = [ ]; } +//===----------------------------------------------------------------------===// +// [Spec 5.2] 6.5 allocate Directive +//===----------------------------------------------------------------------===// +def AllocateDirOp : OpenMP_Op<"allocate_dir", clauses = [ + OpenMP_AlignClause, OpenMP_AllocatorClause + ]> { + let summary = "allocate directive"; + let description = [{ + The storage for each list item that appears in the allocate directive is + provided an allocation through the memory allocator. + }] # clausesDescription; + + let arguments = !con((ins Variadic:$varList), + clausesArgs); + + // Override inherited assembly format to include `varList`. + let assemblyFormat = " `(` $varList `:` type($varList) `)` oilist(" # + clausesOptAssemblyFormat # + ") attr-dict "; + + let builders = [ + OpBuilder<(ins CArg<"const AllocateDirOperands &">:$clauses)> + ]; + + let hasVerifier = 1; +} + #endif // OPENMP_OPS diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 9bea8b7a732a8..486935fe7341c 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -33,6 +33,7 @@ #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/ADT/bit.h" #include "llvm/Frontend/OpenMP/OMPConstants.h" #include "llvm/Frontend/OpenMP/OMPDeviceConstants.h" #include @@ -3863,6 +3864,20 @@ LogicalResult ScanOp::verify() { "reduction modifier"); } +/// Verifies align clause in allocate directive + +LogicalResult AllocateDirOp::verify() { + std::optional align = this->getAlign(); + + if (align.has_value()) { + if ((align.value() > 0) && !llvm::has_single_bit(align.value())) + return emitError() << "ALIGN value : " << align.value() + << " must be power of 2"; + } + + return success(); +} + #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index 7608ad57c7967..5088f2dfa7d7a 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -2993,3 +2993,27 @@ llvm.func @invalid_mapper(%0 : !llvm.ptr) { } llvm.return } + +// ----- +func.func @invalid_allocate_align_1(%arg0 : memref) -> () { + // expected-error @below {{failed to satisfy constraint: 64-bit signless integer attribute whose value is positive}} + omp.allocate_dir (%arg0 : memref) align(-1) + + return +} + +// ----- +func.func @invalid_allocate_align_2(%arg0 : memref) -> () { + // expected-error @below {{must be power of 2}} + omp.allocate_dir (%arg0 : memref) align(3) + + return +} + +// ----- +func.func @invalid_allocate_allocator(%arg0 : memref) -> () { + // expected-error @below {{invalid clause value}} + omp.allocate_dir (%arg0 : memref) allocator(omp_small_cap_mem_alloc) + + return +} diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index 47cfc5278a5d0..4c50ed3230976 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -3197,3 +3197,36 @@ func.func @omp_workshare_loop_wrapper_attrs(%idx : index) { } return } + +// CHECK-LABEL: func.func @omp_allocate_dir( +// CHECK-SAME: %[[ARG0:.*]]: memref, +// CHECK-SAME: %[[ARG1:.*]]: memref) { +func.func @omp_allocate_dir(%arg0 : memref, %arg1 : memref) -> () { + + // Test with one data var + // CHECK: omp.allocate_dir(%[[ARG0]] : memref) + omp.allocate_dir (%arg0 : memref) + + // Test with two data vars + // CHECK: omp.allocate_dir(%[[ARG0]], %[[ARG1]] : memref, memref) + omp.allocate_dir (%arg0, %arg1: memref, memref) + + // Test with one data var and align clause + // CHECK: omp.allocate_dir(%[[ARG0]] : memref) align(2) + omp.allocate_dir (%arg0 : memref) align(2) + + // Test with one data var and allocator clause + // CHECK: omp.allocate_dir(%[[ARG0]] : memref) allocator(omp_pteam_mem_alloc) + omp.allocate_dir (%arg0 : memref) allocator(omp_pteam_mem_alloc) + + // Test with one data var, align clause and allocator clause + // CHECK: omp.allocate_dir(%[[ARG0]] : memref) align(2) allocator(omp_thread_mem_alloc) + omp.allocate_dir (%arg0 : memref) align(2) allocator(omp_thread_mem_alloc) + + // Test with two data vars, align clause and allocator clause + // CHECK: omp.allocate_dir(%[[ARG0]], %[[ARG1]] : memref, memref) align(2) allocator(omp_cgroup_mem_alloc) + omp.allocate_dir (%arg0, %arg1 : memref, memref) align(2) allocator(omp_cgroup_mem_alloc) + + return +} +