From c0a51c0b178764bf10cb611c97820c529e11b2b9 Mon Sep 17 00:00:00 2001 From: Fabian Mora Date: Wed, 27 Mar 2024 19:26:30 +0000 Subject: [PATCH 1/5] [mlir][Ptr] Init the Ptr dialect with the `!ptr.ptr` type. This patch initializes the `ptr` dialect directories and some base files. It also add the `!ptr.ptr` type, together with the `DataLayoutTypeInterface` interface. The implementation of the `DataLayoutTypeInterface` interface clones the implementation from `LLVM::LLVMPointerType`. --- mlir/include/mlir/Dialect/CMakeLists.txt | 1 + mlir/include/mlir/Dialect/Ptr/CMakeLists.txt | 1 + .../mlir/Dialect/Ptr/IR/CMakeLists.txt | 2 + mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.h | 20 ++ .../include/mlir/Dialect/Ptr/IR/PtrDialect.td | 83 ++++++++ mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h | 24 +++ mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td | 15 ++ mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h | 37 ++++ mlir/include/mlir/InitAllDialects.h | 2 + mlir/lib/Dialect/CMakeLists.txt | 1 + mlir/lib/Dialect/Ptr/CMakeLists.txt | 1 + mlir/lib/Dialect/Ptr/IR/CMakeLists.txt | 14 ++ mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp | 48 +++++ mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp | 184 ++++++++++++++++++ mlir/test/Dialect/Ptr/types.mlir | 17 ++ 15 files changed, 450 insertions(+) create mode 100644 mlir/include/mlir/Dialect/Ptr/CMakeLists.txt create mode 100644 mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt create mode 100644 mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.h create mode 100644 mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td create mode 100644 mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h create mode 100644 mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td create mode 100644 mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h create mode 100644 mlir/lib/Dialect/Ptr/CMakeLists.txt create mode 100644 mlir/lib/Dialect/Ptr/IR/CMakeLists.txt create mode 100644 mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp create mode 100644 mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp create mode 100644 mlir/test/Dialect/Ptr/types.mlir diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt index 4bd7f12fabf7b..f710235197334 100644 --- a/mlir/include/mlir/Dialect/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/CMakeLists.txt @@ -29,6 +29,7 @@ add_subdirectory(OpenMP) add_subdirectory(PDL) add_subdirectory(PDLInterp) add_subdirectory(Polynomial) +add_subdirectory(Ptr) add_subdirectory(Quant) add_subdirectory(SCF) add_subdirectory(Shape) diff --git a/mlir/include/mlir/Dialect/Ptr/CMakeLists.txt b/mlir/include/mlir/Dialect/Ptr/CMakeLists.txt new file mode 100644 index 0000000000000..f33061b2d87cf --- /dev/null +++ b/mlir/include/mlir/Dialect/Ptr/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt new file mode 100644 index 0000000000000..c6ffa892e4ecb --- /dev/null +++ b/mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt @@ -0,0 +1,2 @@ +add_mlir_dialect(PtrOps ptr) +add_mlir_doc(PtrOps PtrOps Dialects/ -gen-op-doc) diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.h b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.h new file mode 100644 index 0000000000000..92f877c20dbf0 --- /dev/null +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.h @@ -0,0 +1,20 @@ +//===- PtrDialect.h - Pointer dialect ---------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the Ptr dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_PTR_IR_PTRDIALECT_H +#define MLIR_DIALECT_PTR_IR_PTRDIALECT_H + +#include "mlir/IR/Dialect.h" + +#include "mlir/Dialect/Ptr/IR/PtrOpsDialect.h.inc" + +#endif // MLIR_DIALECT_PTR_IR_PTRDIALECT_H diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td new file mode 100644 index 0000000000000..bffae6b1ad71b --- /dev/null +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td @@ -0,0 +1,83 @@ +//===- PtrDialect.td - Pointer dialect ---------------------*- tablegen -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PTR_DIALECT +#define PTR_DIALECT + +include "mlir/Interfaces/DataLayoutInterfaces.td" +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinTypeInterfaces.td" +include "mlir/IR/OpBase.td" + +//===----------------------------------------------------------------------===// +// Pointer dialect definition. +//===----------------------------------------------------------------------===// + +def Ptr_Dialect : Dialect { + let name = "ptr"; + let summary = "Pointer dialect"; + let cppNamespace = "::mlir::ptr"; + let useDefaultTypePrinterParser = 1; + let useDefaultAttributePrinterParser = 0; +} + +//===----------------------------------------------------------------------===// +// Pointer type definitions +//===----------------------------------------------------------------------===// + +class Ptr_Type traits = []> + : TypeDef { + let mnemonic = typeMnemonic; +} + +def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [ + MemRefElementTypeInterface, + DeclareTypeInterfaceMethods + ]> { + let summary = "pointer type"; + let description = [{ + The `ptr` type is an opaque pointer type. This type typically represents + a reference to an object in memory. Pointers are optionally parameterized + by a memory space. + Syntax: + + ```mlir + pointer ::= `ptr` (`<` memory-space `>`)? + memory-space ::= attribute-value + ``` + }]; + let parameters = (ins OptionalParameter<"Attribute">:$memorySpace); + let assemblyFormat = "(`<` $memorySpace^ `>`)?"; + let builders = [ + TypeBuilder<(ins CArg<"Attribute", "nullptr">:$addressSpace), [{ + return $_get($_ctxt, addressSpace); + }]>, + TypeBuilder<(ins CArg<"unsigned">:$addressSpace), [{ + return $_get($_ctxt, IntegerAttr::get(IntegerType::get($_ctxt, 32), + addressSpace)); + }]> + ]; + let skipDefaultBuilders = 1; + let extraClassDeclaration = [{ + /// Returns the default memory space. + Attribute getDefaultMemorySpace() const; + + /// Returns the memory space as an unsigned number. + int64_t getAddressSpace() const; + }]; +} + +//===----------------------------------------------------------------------===// +// Base address operation definition. +//===----------------------------------------------------------------------===// + +class Pointer_Op traits = []> : + Op; + +#endif // PTR_DIALECT diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h new file mode 100644 index 0000000000000..ad8a2bbcbdd8d --- /dev/null +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h @@ -0,0 +1,24 @@ +//===- PtrDialect.h - Pointer dialect ---------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the Ptr dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_PTR_IR_PTROPS_H +#define MLIR_DIALECT_PTR_IR_PTROPS_H + +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/Dialect/Ptr/IR/PtrDialect.h" +#include "mlir/Dialect/Ptr/IR/PtrTypes.h" +#include "mlir/IR/OpDefinition.h" + +#define GET_OP_CLASSES +#include "mlir/Dialect/Ptr/IR/PtrOps.h.inc" + +#endif // MLIR_DIALECT_PTR_IR_PTROPS_H diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td new file mode 100644 index 0000000000000..690941337bdfb --- /dev/null +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td @@ -0,0 +1,15 @@ +//===- PtrOps.td - Pointer dialect ops ---------------------*- tablegen -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://ptr.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PTR_OPS +#define PTR_OPS + +include "mlir/Dialect/Ptr/IR/PtrDialect.td" +include "mlir/IR/OpAsmInterface.td" + +#endif // PTR_OPS diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h b/mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h new file mode 100644 index 0000000000000..9984aedcbf6ce --- /dev/null +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h @@ -0,0 +1,37 @@ +//===- PtrTypes.h - Pointer types -------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the Pointer dialect types. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_PTR_IR_PTRTYPES_H +#define MLIR_DIALECT_PTR_IR_PTRTYPES_H + +#include "mlir/IR/Types.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" + +namespace mlir { +namespace ptr { +/// The positions of different values in the data layout entry for pointers. +enum class PtrDLEntryPos { Size = 0, Abi = 1, Preferred = 2, Index = 3 }; + +/// Returns the value that corresponds to named position `pos` from the +/// data layout entry `attr` assuming it's a dense integer elements attribute. +/// Returns `std::nullopt` if `pos` is not present in the entry. +/// Currently only `PtrDLEntryPos::Index` is optional, and all other positions +/// may be assumed to be present. +std::optional extractPointerSpecValue(Attribute attr, + PtrDLEntryPos pos); +} // namespace ptr +} // namespace mlir + +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/Ptr/IR/PtrOpsTypes.h.inc" + +#endif // MLIR_DIALECT_PTR_IR_PTRTYPES_H diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h index d9db21073e15c..549c26c72d8a1 100644 --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -63,6 +63,7 @@ #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" #include "mlir/Dialect/Polynomial/IR/PolynomialDialect.h" +#include "mlir/Dialect/Ptr/IR/PtrDialect.h" #include "mlir/Dialect/Quant/QuantOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.h" @@ -134,6 +135,7 @@ inline void registerAllDialects(DialectRegistry ®istry) { pdl::PDLDialect, pdl_interp::PDLInterpDialect, polynomial::PolynomialDialect, + ptr::PtrDialect, quant::QuantizationDialect, ROCDL::ROCDLDialect, scf::SCFDialect, diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt index a324ce7f9b19f..80b0ef068d96d 100644 --- a/mlir/lib/Dialect/CMakeLists.txt +++ b/mlir/lib/Dialect/CMakeLists.txt @@ -29,6 +29,7 @@ add_subdirectory(OpenMP) add_subdirectory(PDL) add_subdirectory(PDLInterp) add_subdirectory(Polynomial) +add_subdirectory(Ptr) add_subdirectory(Quant) add_subdirectory(SCF) add_subdirectory(Shape) diff --git a/mlir/lib/Dialect/Ptr/CMakeLists.txt b/mlir/lib/Dialect/Ptr/CMakeLists.txt new file mode 100644 index 0000000000000..f33061b2d87cf --- /dev/null +++ b/mlir/lib/Dialect/Ptr/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt new file mode 100644 index 0000000000000..359b9f02a0626 --- /dev/null +++ b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt @@ -0,0 +1,14 @@ +add_mlir_dialect_library( + MLIRPtrDialect + PtrTypes.cpp + PtrDialect.cpp + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/mlir/Dialect/Pointer + DEPENDS + MLIRPtrOpsIncGen + LINK_LIBS + PUBLIC + MLIRIR + MLIRDataLayoutInterfaces + MLIRMemorySlotInterfaces +) diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp new file mode 100644 index 0000000000000..59c97b22f332c --- /dev/null +++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp @@ -0,0 +1,48 @@ +//===- PtrDialect.cpp - Pointer dialect ---------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the Pointer dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Ptr/IR/PtrOps.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/InliningUtils.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace mlir::ptr; + +//===----------------------------------------------------------------------===// +// Pointer dialect +//===----------------------------------------------------------------------===// + +void PtrDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/Ptr/IR/PtrOps.cpp.inc" + >(); + addTypes< +#define GET_TYPEDEF_LIST +#include "mlir/Dialect/Ptr/IR/PtrOpsTypes.cpp.inc" + >(); +} + +//===----------------------------------------------------------------------===// +// Pointer API. +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Ptr/IR/PtrOpsDialect.cpp.inc" + +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/Ptr/IR/PtrOpsTypes.cpp.inc" + +#define GET_OP_CLASSES +#include "mlir/Dialect/Ptr/IR/PtrOps.cpp.inc" diff --git a/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp new file mode 100644 index 0000000000000..51d0a45051b85 --- /dev/null +++ b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp @@ -0,0 +1,184 @@ +//===- PtrTypes.cpp - Pointer dialect types ---------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the Ptr dialect types. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Ptr/IR/PtrTypes.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace mlir::ptr; + +//===----------------------------------------------------------------------===// +// Pointer type +//===----------------------------------------------------------------------===// + +constexpr const static unsigned kDefaultPointerSizeBits = 64; +constexpr const static unsigned kBitsInByte = 8; +constexpr const static unsigned kDefaultPointerAlignment = 8; + +/// Returns the part of the data layout entry that corresponds to `pos` for the +/// given `type` by interpreting the list of entries `params`. For the pointer +/// type in the default address space, returns the default value if the entries +/// do not provide a custom one, for other address spaces returns std::nullopt. +static std::optional +getPointerDataLayoutEntry(DataLayoutEntryListRef params, PtrType type, + PtrDLEntryPos pos) { + // First, look for the entry for the pointer in the current address space. + Attribute currentEntry; + for (DataLayoutEntryInterface entry : params) { + if (!entry.isTypeEntry()) + continue; + if (cast(entry.getKey().get()).getAddressSpace() == + type.getAddressSpace()) { + currentEntry = entry.getValue(); + break; + } + } + if (currentEntry) { + std::optional value = extractPointerSpecValue(currentEntry, pos); + // If the optional `PtrDLEntryPos::Index` entry is not available, use the + // pointer size as the index bitwidth. + if (!value && pos == PtrDLEntryPos::Index) + value = extractPointerSpecValue(currentEntry, PtrDLEntryPos::Size); + bool isSizeOrIndex = + pos == PtrDLEntryPos::Size || pos == PtrDLEntryPos::Index; + return *value / (isSizeOrIndex ? 1 : kBitsInByte); + } + + // If not found, and this is the pointer to the default memory space, assume + // 64-bit pointers. + if (type.getAddressSpace() == 0) { + bool isSizeOrIndex = + pos == PtrDLEntryPos::Size || pos == PtrDLEntryPos::Index; + return isSizeOrIndex ? kDefaultPointerSizeBits : kDefaultPointerAlignment; + } + + return std::nullopt; +} + +int64_t PtrType::getAddressSpace() const { return 0; } + +Attribute PtrType::getDefaultMemorySpace() const { return nullptr; } + +bool PtrType::areCompatible(DataLayoutEntryListRef oldLayout, + DataLayoutEntryListRef newLayout) const { + for (DataLayoutEntryInterface newEntry : newLayout) { + if (!newEntry.isTypeEntry()) + continue; + unsigned size = kDefaultPointerSizeBits; + unsigned abi = kDefaultPointerAlignment; + auto newType = llvm::cast(newEntry.getKey().get()); + const auto *it = + llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) { + if (auto type = llvm::dyn_cast_if_present(entry.getKey())) { + return llvm::cast(type).getMemorySpace() == + newType.getMemorySpace(); + } + return false; + }); + if (it == oldLayout.end()) { + llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) { + if (auto type = llvm::dyn_cast_if_present(entry.getKey())) { + return llvm::cast(type).getAddressSpace() == 0; + } + return false; + }); + } + if (it != oldLayout.end()) { + size = *extractPointerSpecValue(*it, PtrDLEntryPos::Size); + abi = *extractPointerSpecValue(*it, PtrDLEntryPos::Abi); + } + + Attribute newSpec = llvm::cast(newEntry.getValue()); + unsigned newSize = *extractPointerSpecValue(newSpec, PtrDLEntryPos::Size); + unsigned newAbi = *extractPointerSpecValue(newSpec, PtrDLEntryPos::Abi); + if (size != newSize || abi < newAbi || abi % newAbi != 0) + return false; + } + return true; +} + +uint64_t PtrType::getABIAlignment(const DataLayout &dataLayout, + DataLayoutEntryListRef params) const { + if (std::optional alignment = + getPointerDataLayoutEntry(params, *this, PtrDLEntryPos::Abi)) + return *alignment; + + return dataLayout.getTypeABIAlignment( + get(getContext(), getDefaultMemorySpace())); +} + +std::optional +PtrType::getIndexBitwidth(const DataLayout &dataLayout, + DataLayoutEntryListRef params) const { + if (std::optional indexBitwidth = + getPointerDataLayoutEntry(params, *this, PtrDLEntryPos::Index)) + return *indexBitwidth; + + return dataLayout.getTypeIndexBitwidth( + get(getContext(), getDefaultMemorySpace())); +} + +llvm::TypeSize PtrType::getTypeSizeInBits(const DataLayout &dataLayout, + DataLayoutEntryListRef params) const { + if (std::optional size = + getPointerDataLayoutEntry(params, *this, PtrDLEntryPos::Size)) + return llvm::TypeSize::getFixed(*size); + + // For other memory spaces, use the size of the pointer to the default memory + // space. + return dataLayout.getTypeSizeInBits( + get(getContext(), getDefaultMemorySpace())); +} + +uint64_t PtrType::getPreferredAlignment(const DataLayout &dataLayout, + DataLayoutEntryListRef params) const { + if (std::optional alignment = + getPointerDataLayoutEntry(params, *this, PtrDLEntryPos::Preferred)) + return *alignment; + + return dataLayout.getTypePreferredAlignment( + get(getContext(), getDefaultMemorySpace())); +} + +std::optional mlir::ptr::extractPointerSpecValue(Attribute attr, + PtrDLEntryPos pos) { + auto spec = cast(attr); + auto idx = static_cast(pos); + if (idx >= spec.size()) + return std::nullopt; + return spec.getValues()[idx]; +} + +LogicalResult PtrType::verifyEntries(DataLayoutEntryListRef entries, + Location loc) const { + for (DataLayoutEntryInterface entry : entries) { + if (!entry.isTypeEntry()) + continue; + auto key = entry.getKey().get(); + auto values = llvm::dyn_cast(entry.getValue()); + if (!values || (values.size() != 3 && values.size() != 4)) { + return emitError(loc) + << "expected layout attribute for " << key + << " to be a dense integer elements attribute with 3 or 4 " + "elements"; + } + if (!values.getElementType().isInteger(64)) + return emitError(loc) << "expected i64 parameters for " << key; + + if (extractPointerSpecValue(values, PtrDLEntryPos::Abi) > + extractPointerSpecValue(values, PtrDLEntryPos::Preferred)) { + return emitError(loc) << "preferred alignment is expected to be at least " + "as large as ABI alignment"; + } + } + return success(); +} diff --git a/mlir/test/Dialect/Ptr/types.mlir b/mlir/test/Dialect/Ptr/types.mlir new file mode 100644 index 0000000000000..279213bd6fc3e --- /dev/null +++ b/mlir/test/Dialect/Ptr/types.mlir @@ -0,0 +1,17 @@ +// RUN: mlir-opt %s -split-input-file | mlir-opt | FileCheck %s + +// CHECK-LABEL: func @ptr_test +// CHECK: (%[[ARG0:.*]]: !ptr.ptr, %[[ARG1:.*]]: !ptr.ptr<1 : i32>) +// CHECK: -> (!ptr.ptr<1 : i32>, !ptr.ptr) +func.func @ptr_test(%arg0: !ptr.ptr, %arg1: !ptr.ptr<1 : i32>) -> (!ptr.ptr<1 : i32>, !ptr.ptr) { + // CHECK: return %[[ARG1]], %[[ARG0]] : !ptr.ptr<1 : i32>, !ptr.ptr + return %arg1, %arg0 : !ptr.ptr<1 : i32>, !ptr.ptr +} + +// ----- + +// CHECK-LABEL: func @ptr_test +// CHECK: %[[ARG:.*]]: memref +func.func @ptr_test(%arg0: memref) { + return +} From f2ffe702650d66c3729242ef008552651746368f Mon Sep 17 00:00:00 2001 From: Fabian Mora Date: Tue, 2 Apr 2024 15:12:48 +0000 Subject: [PATCH 2/5] add layout test, address reviewer comments and fix a bug in layout methods --- .../include/mlir/Dialect/Ptr/IR/PtrDialect.td | 6 +- mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp | 20 ++--- mlir/test/Dialect/Ptr/layout.mlir | 87 +++++++++++++++++++ 3 files changed, 99 insertions(+), 14 deletions(-) create mode 100644 mlir/test/Dialect/Ptr/layout.mlir diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td index bffae6b1ad71b..315ecdfb6609e 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td @@ -55,8 +55,8 @@ def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [ let parameters = (ins OptionalParameter<"Attribute">:$memorySpace); let assemblyFormat = "(`<` $memorySpace^ `>`)?"; let builders = [ - TypeBuilder<(ins CArg<"Attribute", "nullptr">:$addressSpace), [{ - return $_get($_ctxt, addressSpace); + TypeBuilder<(ins CArg<"Attribute", "nullptr">:$memorySpace), [{ + return $_get($_ctxt, memorySpace); }]>, TypeBuilder<(ins CArg<"unsigned">:$addressSpace), [{ return $_get($_ctxt, IntegerAttr::get(IntegerType::get($_ctxt, 32), @@ -69,7 +69,7 @@ def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [ Attribute getDefaultMemorySpace() const; /// Returns the memory space as an unsigned number. - int64_t getAddressSpace() const; + uint64_t getAddressSpace() const; }]; } diff --git a/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp index 51d0a45051b85..95c72f5743afc 100644 --- a/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp +++ b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp @@ -36,35 +36,32 @@ getPointerDataLayoutEntry(DataLayoutEntryListRef params, PtrType type, for (DataLayoutEntryInterface entry : params) { if (!entry.isTypeEntry()) continue; - if (cast(entry.getKey().get()).getAddressSpace() == - type.getAddressSpace()) { + if (cast(entry.getKey().get()).getMemorySpace() == + type.getMemorySpace()) { currentEntry = entry.getValue(); break; } } + bool isSizeOrIndex = + pos == PtrDLEntryPos::Size || pos == PtrDLEntryPos::Index; if (currentEntry) { std::optional value = extractPointerSpecValue(currentEntry, pos); // If the optional `PtrDLEntryPos::Index` entry is not available, use the // pointer size as the index bitwidth. if (!value && pos == PtrDLEntryPos::Index) value = extractPointerSpecValue(currentEntry, PtrDLEntryPos::Size); - bool isSizeOrIndex = - pos == PtrDLEntryPos::Size || pos == PtrDLEntryPos::Index; return *value / (isSizeOrIndex ? 1 : kBitsInByte); } // If not found, and this is the pointer to the default memory space, assume // 64-bit pointers. - if (type.getAddressSpace() == 0) { - bool isSizeOrIndex = - pos == PtrDLEntryPos::Size || pos == PtrDLEntryPos::Index; + if (type.getMemorySpace() == type.getDefaultMemorySpace()) return isSizeOrIndex ? kDefaultPointerSizeBits : kDefaultPointerAlignment; - } return std::nullopt; } -int64_t PtrType::getAddressSpace() const { return 0; } +uint64_t PtrType::getAddressSpace() const { return 0; } Attribute PtrType::getDefaultMemorySpace() const { return nullptr; } @@ -85,9 +82,10 @@ bool PtrType::areCompatible(DataLayoutEntryListRef oldLayout, return false; }); if (it == oldLayout.end()) { - llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) { + it = llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) { if (auto type = llvm::dyn_cast_if_present(entry.getKey())) { - return llvm::cast(type).getAddressSpace() == 0; + auto ptrTy = llvm::cast(type); + return ptrTy.getMemorySpace() == ptrTy.getDefaultMemorySpace(); } return false; }); diff --git a/mlir/test/Dialect/Ptr/layout.mlir b/mlir/test/Dialect/Ptr/layout.mlir new file mode 100644 index 0000000000000..b345fbd6f6fbb --- /dev/null +++ b/mlir/test/Dialect/Ptr/layout.mlir @@ -0,0 +1,87 @@ +// RUN: mlir-opt --test-data-layout-query --split-input-file --verify-diagnostics %s | FileCheck %s + +module attributes { dlti.dl_spec = #dlti.dl_spec< + #dlti.dl_entry : vector<3xi64>>, + #dlti.dl_entry, dense<[64, 64, 64]> : vector<3xi64>>, + #dlti.dl_entry, dense<[32, 64, 64, 24]> : vector<4xi64>>, + #dlti.dl_entry<"dlti.alloca_memory_space", 5 : ui64>, + #dlti.dl_entry<"dlti.global_memory_space", 2 : ui64>, + #dlti.dl_entry<"dlti.program_memory_space", 3 : ui64>, + #dlti.dl_entry<"dlti.stack_alignment", 128 : i64> +>} { + // CHECK: @spec + func.func @spec() { + // CHECK: alignment = 4 + // CHECK: alloca_memory_space = 5 + // CHECK: bitsize = 32 + // CHECK: global_memory_space = 2 + // CHECK: index = 32 + // CHECK: preferred = 8 + // CHECK: program_memory_space = 3 + // CHECK: size = 4 + // CHECK: stack_alignment = 128 + "test.data_layout_query"() : () -> !ptr.ptr + // CHECK: alignment = 4 + // CHECK: alloca_memory_space = 5 + // CHECK: bitsize = 32 + // CHECK: global_memory_space = 2 + // CHECK: index = 32 + // CHECK: preferred = 8 + // CHECK: program_memory_space = 3 + // CHECK: size = 4 + // CHECK: stack_alignment = 128 + "test.data_layout_query"() : () -> !ptr.ptr<3> + // CHECK: alignment = 8 + // CHECK: alloca_memory_space = 5 + // CHECK: bitsize = 64 + // CHECK: global_memory_space = 2 + // CHECK: index = 64 + // CHECK: preferred = 8 + // CHECK: program_memory_space = 3 + // CHECK: size = 8 + // CHECK: stack_alignment = 128 + "test.data_layout_query"() : () -> !ptr.ptr<5> + // CHECK: alignment = 8 + // CHECK: alloca_memory_space = 5 + // CHECK: bitsize = 32 + // CHECK: global_memory_space = 2 + // CHECK: index = 24 + // CHECK: preferred = 8 + // CHECK: program_memory_space = 3 + // CHECK: size = 4 + // CHECK: stack_alignment = 128 + "test.data_layout_query"() : () -> !ptr.ptr<4> + return + } +} + +// ----- + +// expected-error@below {{expected layout attribute for '!ptr.ptr' to be a dense integer elements attribute with 3 or 4 elements}} +module attributes { dlti.dl_spec = #dlti.dl_spec< + #dlti.dl_entry : vector<3xf32>> +>} { + func.func @pointer() { + return + } +} + +// ----- + +// expected-error@below {{preferred alignment is expected to be at least as large as ABI alignment}} +module attributes { dlti.dl_spec = #dlti.dl_spec< + #dlti.dl_entry : vector<3xi64>> +>} { + func.func @pointer() { + return + } +} + +// ----- + +// expected-error @below {{expected i64 parameters for '!ptr.ptr'}} +module attributes { dlti.dl_spec = #dlti.dl_spec< + #dlti.dl_entry : vector<3xi32>> +>} { +} + From 108501ea897d6ac00aa8f95d4c61d51024931187 Mon Sep 17 00:00:00 2001 From: Fabian Mora Date: Sat, 8 Jun 2024 18:19:53 +0000 Subject: [PATCH 3/5] Addressed reviewer comments --- .../mlir/Dialect/Ptr/IR/CMakeLists.txt | 5 + .../mlir/Dialect/Ptr/IR/PtrAttrDefs.td | 70 +++++++++++ mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h | 21 ++++ .../include/mlir/Dialect/Ptr/IR/PtrDialect.td | 9 +- mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h | 1 + mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td | 1 + mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h | 15 --- mlir/lib/Dialect/Ptr/IR/CMakeLists.txt | 2 + mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp | 40 ++++++ mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp | 7 ++ mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp | 116 ++++++------------ mlir/test/Dialect/Ptr/layout.mlir | 45 +++++-- 12 files changed, 223 insertions(+), 109 deletions(-) create mode 100644 mlir/include/mlir/Dialect/Ptr/IR/PtrAttrDefs.td create mode 100644 mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h create mode 100644 mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp diff --git a/mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt index c6ffa892e4ecb..df07b8d5a63d9 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt @@ -1,2 +1,7 @@ add_mlir_dialect(PtrOps ptr) add_mlir_doc(PtrOps PtrOps Dialects/ -gen-op-doc) + +set(LLVM_TARGET_DEFINITIONS PtrOps.td) +mlir_tablegen(PtrOpsAttrs.h.inc -gen-attrdef-decls -attrdefs-dialect=ptr) +mlir_tablegen(PtrOpsAttrs.cpp.inc -gen-attrdef-defs -attrdefs-dialect=ptr) +add_public_tablegen_target(MLIRPtrOpsAttributesIncGen) diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrDefs.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrDefs.td new file mode 100644 index 0000000000000..e75038f300f1a --- /dev/null +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrDefs.td @@ -0,0 +1,70 @@ +//===-- PtrAttrDefs.td - Ptr Attributes definition file ----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PTR_ATTRDEFS +#define PTR_ATTRDEFS + +include "mlir/Dialect/Ptr/IR/PtrDialect.td" +include "mlir/IR/AttrTypeBase.td" + +// All of the attributes will extend this class. +class Ptr_Attr traits = [], + string baseCppClass = "::mlir::Attribute"> + : AttrDef { + let mnemonic = attrMnemonic; +} + +//===----------------------------------------------------------------------===// +// SpecAttr +//===----------------------------------------------------------------------===// + +def Ptr_SpecAttr : Ptr_Attr<"Spec", "spec"> { + let summary = "ptr data layout spec"; + let description = [{ + Defines the data layout spec for a pointer type. This attribute has 4 + fields: + - [Required] size: size of the pointer in bits. + - [Required] abi: ABI-required alignment for the pointer in bits. + - [Required] preferred: preferred alignment for the pointer in bits. + - [Optional] index: bitwidth that should be used when performing index + computations for the type. Setting the field to `kOptionalSpecValue`, means + the field is optional. + + Furthermore, the attribute will verify that all present values are divisible + by 8 (number of bits in a byte), and that `preferred` > `abi`. + + Example: + ```mlir + // Spec for a 64 bit ptr, with a required alignment of 64 bits, but with + // a preferred alignment of 128 bits and an index bitwidth of 64 bits. + #ptr.spec + ``` + }]; + let parameters = (ins + "uint32_t":$size, + "uint32_t":$abi, + "uint32_t":$preferred, + DefaultValuedParameter<"uint32_t", "kOptionalSpecValue">:$index + ); + let skipDefaultBuilders = 1; + let builders = [ + AttrBuilder<(ins "uint32_t":$size, "uint32_t":$abi, "uint32_t":$preferred, + CArg<"uint32_t", "kOptionalSpecValue">:$index), [{ + return $_get($_ctxt, size, abi, preferred, index); + }]> + ]; + let assemblyFormat = "`<` struct(params) `>`"; + let extraClassDeclaration = [{ + /// Constant for specifying a spec entry is optional. + static constexpr uint32_t kOptionalSpecValue = std::numeric_limits::max(); + }]; + let genVerifyDecl = 1; +} + +#endif // PTR_ATTRDEFS diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h b/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h new file mode 100644 index 0000000000000..72e767764d98b --- /dev/null +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h @@ -0,0 +1,21 @@ +//===- PtrAttrs.h - Pointer dialect attributes ------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the Ptr dialect attributes. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_PTR_IR_PTRATTRS_H +#define MLIR_DIALECT_PTR_IR_PTRATTRS_H + +#include "mlir/IR/OpImplementation.h" + +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/Ptr/IR/PtrOpsAttrs.h.inc" + +#endif // MLIR_DIALECT_PTR_IR_PTRATTRS_H diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td index 315ecdfb6609e..2e5e0a14ae991 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td @@ -23,7 +23,7 @@ def Ptr_Dialect : Dialect { let summary = "Pointer dialect"; let cppNamespace = "::mlir::ptr"; let useDefaultTypePrinterParser = 1; - let useDefaultAttributePrinterParser = 0; + let useDefaultAttributePrinterParser = 1; } //===----------------------------------------------------------------------===// @@ -64,13 +64,6 @@ def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [ }]> ]; let skipDefaultBuilders = 1; - let extraClassDeclaration = [{ - /// Returns the default memory space. - Attribute getDefaultMemorySpace() const; - - /// Returns the memory space as an unsigned number. - uint64_t getAddressSpace() const; - }]; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h index ad8a2bbcbdd8d..6a0c1429c6be9 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h @@ -14,6 +14,7 @@ #define MLIR_DIALECT_PTR_IR_PTROPS_H #include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/Dialect/Ptr/IR/PtrAttrs.h" #include "mlir/Dialect/Ptr/IR/PtrDialect.h" #include "mlir/Dialect/Ptr/IR/PtrTypes.h" #include "mlir/IR/OpDefinition.h" diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td index 690941337bdfb..c63a0b220e501 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td @@ -10,6 +10,7 @@ #define PTR_OPS include "mlir/Dialect/Ptr/IR/PtrDialect.td" +include "mlir/Dialect/Ptr/IR/PtrAttrDefs.td" include "mlir/IR/OpAsmInterface.td" #endif // PTR_OPS diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h b/mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h index 9984aedcbf6ce..264a97c80722a 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h @@ -16,21 +16,6 @@ #include "mlir/IR/Types.h" #include "mlir/Interfaces/DataLayoutInterfaces.h" -namespace mlir { -namespace ptr { -/// The positions of different values in the data layout entry for pointers. -enum class PtrDLEntryPos { Size = 0, Abi = 1, Preferred = 2, Index = 3 }; - -/// Returns the value that corresponds to named position `pos` from the -/// data layout entry `attr` assuming it's a dense integer elements attribute. -/// Returns `std::nullopt` if `pos` is not present in the entry. -/// Currently only `PtrDLEntryPos::Index` is optional, and all other positions -/// may be assumed to be present. -std::optional extractPointerSpecValue(Attribute attr, - PtrDLEntryPos pos); -} // namespace ptr -} // namespace mlir - #define GET_TYPEDEF_CLASSES #include "mlir/Dialect/Ptr/IR/PtrOpsTypes.h.inc" diff --git a/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt index 359b9f02a0626..fbdd08375f15c 100644 --- a/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt @@ -1,10 +1,12 @@ add_mlir_dialect_library( MLIRPtrDialect + PtrAttrs.cpp PtrTypes.cpp PtrDialect.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/mlir/Dialect/Pointer DEPENDS + MLIRPtrOpsAttributesIncGen MLIRPtrOpsIncGen LINK_LIBS PUBLIC diff --git a/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp b/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp new file mode 100644 index 0000000000000..2e5902f653cb8 --- /dev/null +++ b/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp @@ -0,0 +1,40 @@ +//===- PtrAttrs.cpp - Pointer dialect attributes ----------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the Ptr dialect attributes. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Ptr/IR/PtrAttrs.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace mlir::ptr; + +constexpr const static unsigned kBitsInByte = 8; + +//===----------------------------------------------------------------------===// +// SpecAttr +//===----------------------------------------------------------------------===// + +LogicalResult SpecAttr::verify(function_ref emitError, + uint32_t size, uint32_t abi, uint32_t preferred, + uint32_t index) { + if (size % kBitsInByte != 0) + return emitError() << "size entry must be divisible by 8"; + else if (abi % kBitsInByte != 0) + return emitError() << "abi entry must be divisible by 8"; + else if (preferred % kBitsInByte != 0) + return emitError() << "preferred entry must be divisible by 8"; + else if (index != kOptionalSpecValue && index % kBitsInByte != 0) + return emitError() << "index entry must be divisible by 8"; + if (abi > preferred) + return emitError() << "preferred alignment is expected to be at least " + "as large as ABI alignment"; + return success(); +} diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp index 59c97b22f332c..7830ffe893dfd 100644 --- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp +++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp @@ -29,6 +29,10 @@ void PtrDialect::initialize() { #define GET_OP_LIST #include "mlir/Dialect/Ptr/IR/PtrOps.cpp.inc" >(); + addAttributes< +#define GET_ATTRDEF_LIST +#include "mlir/Dialect/Ptr/IR/PtrOpsAttrs.cpp.inc" + >(); addTypes< #define GET_TYPEDEF_LIST #include "mlir/Dialect/Ptr/IR/PtrOpsTypes.cpp.inc" @@ -41,6 +45,9 @@ void PtrDialect::initialize() { #include "mlir/Dialect/Ptr/IR/PtrOpsDialect.cpp.inc" +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/Ptr/IR/PtrOpsAttrs.cpp.inc" + #define GET_TYPEDEF_CLASSES #include "mlir/Dialect/Ptr/IR/PtrOpsTypes.cpp.inc" diff --git a/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp index 95c72f5743afc..2866d4eb10feb 100644 --- a/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp +++ b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Ptr/IR/PtrTypes.h" +#include "mlir/Dialect/Ptr/IR/PtrAttrs.h" #include "llvm/ADT/TypeSwitch.h" using namespace mlir; @@ -24,54 +25,36 @@ constexpr const static unsigned kDefaultPointerSizeBits = 64; constexpr const static unsigned kBitsInByte = 8; constexpr const static unsigned kDefaultPointerAlignment = 8; -/// Returns the part of the data layout entry that corresponds to `pos` for the -/// given `type` by interpreting the list of entries `params`. For the pointer -/// type in the default address space, returns the default value if the entries -/// do not provide a custom one, for other address spaces returns std::nullopt. -static std::optional -getPointerDataLayoutEntry(DataLayoutEntryListRef params, PtrType type, - PtrDLEntryPos pos) { - // First, look for the entry for the pointer in the current address space. - Attribute currentEntry; +static Attribute getDefaultMemorySpace(PtrType ptr) { return nullptr; } + +/// Searches the data layout for the pointer spec, returns nullptr if it is not +/// found. +static SpecAttr getPointerSpec(DataLayoutEntryListRef params, PtrType type) { for (DataLayoutEntryInterface entry : params) { if (!entry.isTypeEntry()) continue; if (cast(entry.getKey().get()).getMemorySpace() == type.getMemorySpace()) { - currentEntry = entry.getValue(); - break; + if (auto spec = dyn_cast(entry.getValue())) + return spec; } } - bool isSizeOrIndex = - pos == PtrDLEntryPos::Size || pos == PtrDLEntryPos::Index; - if (currentEntry) { - std::optional value = extractPointerSpecValue(currentEntry, pos); - // If the optional `PtrDLEntryPos::Index` entry is not available, use the - // pointer size as the index bitwidth. - if (!value && pos == PtrDLEntryPos::Index) - value = extractPointerSpecValue(currentEntry, PtrDLEntryPos::Size); - return *value / (isSizeOrIndex ? 1 : kBitsInByte); - } - // If not found, and this is the pointer to the default memory space, assume // 64-bit pointers. - if (type.getMemorySpace() == type.getDefaultMemorySpace()) - return isSizeOrIndex ? kDefaultPointerSizeBits : kDefaultPointerAlignment; - - return std::nullopt; + if (type.getMemorySpace() == getDefaultMemorySpace(type)) + return SpecAttr::get(type.getContext(), kDefaultPointerSizeBits, + kDefaultPointerAlignment, kDefaultPointerAlignment, + kDefaultPointerSizeBits); + return nullptr; } -uint64_t PtrType::getAddressSpace() const { return 0; } - -Attribute PtrType::getDefaultMemorySpace() const { return nullptr; } - bool PtrType::areCompatible(DataLayoutEntryListRef oldLayout, DataLayoutEntryListRef newLayout) const { for (DataLayoutEntryInterface newEntry : newLayout) { if (!newEntry.isTypeEntry()) continue; - unsigned size = kDefaultPointerSizeBits; - unsigned abi = kDefaultPointerAlignment; + uint32_t size = kDefaultPointerSizeBits; + uint32_t abi = kDefaultPointerAlignment; auto newType = llvm::cast(newEntry.getKey().get()); const auto *it = llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) { @@ -85,19 +68,20 @@ bool PtrType::areCompatible(DataLayoutEntryListRef oldLayout, it = llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) { if (auto type = llvm::dyn_cast_if_present(entry.getKey())) { auto ptrTy = llvm::cast(type); - return ptrTy.getMemorySpace() == ptrTy.getDefaultMemorySpace(); + return ptrTy.getMemorySpace() == getDefaultMemorySpace(ptrTy); } return false; }); } if (it != oldLayout.end()) { - size = *extractPointerSpecValue(*it, PtrDLEntryPos::Size); - abi = *extractPointerSpecValue(*it, PtrDLEntryPos::Abi); + auto spec = llvm::cast(*it); + size = spec.getSize(); + abi = spec.getAbi(); } - Attribute newSpec = llvm::cast(newEntry.getValue()); - unsigned newSize = *extractPointerSpecValue(newSpec, PtrDLEntryPos::Size); - unsigned newAbi = *extractPointerSpecValue(newSpec, PtrDLEntryPos::Abi); + auto newSpec = llvm::cast(newEntry.getValue()); + uint32_t newSize = newSpec.getSize(); + uint32_t newAbi = newSpec.getAbi(); if (size != newSize || abi < newAbi || abi % newAbi != 0) return false; } @@ -106,54 +90,43 @@ bool PtrType::areCompatible(DataLayoutEntryListRef oldLayout, uint64_t PtrType::getABIAlignment(const DataLayout &dataLayout, DataLayoutEntryListRef params) const { - if (std::optional alignment = - getPointerDataLayoutEntry(params, *this, PtrDLEntryPos::Abi)) - return *alignment; + if (SpecAttr spec = getPointerSpec(params, *this)) + return spec.getAbi() / kBitsInByte; return dataLayout.getTypeABIAlignment( - get(getContext(), getDefaultMemorySpace())); + get(getContext(), getDefaultMemorySpace(*this))); } std::optional PtrType::getIndexBitwidth(const DataLayout &dataLayout, DataLayoutEntryListRef params) const { - if (std::optional indexBitwidth = - getPointerDataLayoutEntry(params, *this, PtrDLEntryPos::Index)) - return *indexBitwidth; + if (SpecAttr spec = getPointerSpec(params, *this)) { + return spec.getIndex() == SpecAttr::kOptionalSpecValue ? spec.getSize() + : spec.getIndex(); + } return dataLayout.getTypeIndexBitwidth( - get(getContext(), getDefaultMemorySpace())); + get(getContext(), getDefaultMemorySpace(*this))); } llvm::TypeSize PtrType::getTypeSizeInBits(const DataLayout &dataLayout, DataLayoutEntryListRef params) const { - if (std::optional size = - getPointerDataLayoutEntry(params, *this, PtrDLEntryPos::Size)) - return llvm::TypeSize::getFixed(*size); + if (SpecAttr spec = getPointerSpec(params, *this)) + return llvm::TypeSize::getFixed(spec.getSize()); // For other memory spaces, use the size of the pointer to the default memory // space. return dataLayout.getTypeSizeInBits( - get(getContext(), getDefaultMemorySpace())); + get(getContext(), getDefaultMemorySpace(*this))); } uint64_t PtrType::getPreferredAlignment(const DataLayout &dataLayout, DataLayoutEntryListRef params) const { - if (std::optional alignment = - getPointerDataLayoutEntry(params, *this, PtrDLEntryPos::Preferred)) - return *alignment; + if (SpecAttr spec = getPointerSpec(params, *this)) + return spec.getPreferred() / kBitsInByte; return dataLayout.getTypePreferredAlignment( - get(getContext(), getDefaultMemorySpace())); -} - -std::optional mlir::ptr::extractPointerSpecValue(Attribute attr, - PtrDLEntryPos pos) { - auto spec = cast(attr); - auto idx = static_cast(pos); - if (idx >= spec.size()) - return std::nullopt; - return spec.getValues()[idx]; + get(getContext(), getDefaultMemorySpace(*this))); } LogicalResult PtrType::verifyEntries(DataLayoutEntryListRef entries, @@ -162,20 +135,9 @@ LogicalResult PtrType::verifyEntries(DataLayoutEntryListRef entries, if (!entry.isTypeEntry()) continue; auto key = entry.getKey().get(); - auto values = llvm::dyn_cast(entry.getValue()); - if (!values || (values.size() != 3 && values.size() != 4)) { - return emitError(loc) - << "expected layout attribute for " << key - << " to be a dense integer elements attribute with 3 or 4 " - "elements"; - } - if (!values.getElementType().isInteger(64)) - return emitError(loc) << "expected i64 parameters for " << key; - - if (extractPointerSpecValue(values, PtrDLEntryPos::Abi) > - extractPointerSpecValue(values, PtrDLEntryPos::Preferred)) { - return emitError(loc) << "preferred alignment is expected to be at least " - "as large as ABI alignment"; + if (!llvm::isa(entry.getValue())) { + return emitError(loc) << "expected layout attribute for " << key + << " to be a #ptr.spec attribute"; } } return success(); diff --git a/mlir/test/Dialect/Ptr/layout.mlir b/mlir/test/Dialect/Ptr/layout.mlir index b345fbd6f6fbb..73189a388942a 100644 --- a/mlir/test/Dialect/Ptr/layout.mlir +++ b/mlir/test/Dialect/Ptr/layout.mlir @@ -1,9 +1,9 @@ // RUN: mlir-opt --test-data-layout-query --split-input-file --verify-diagnostics %s | FileCheck %s module attributes { dlti.dl_spec = #dlti.dl_spec< - #dlti.dl_entry : vector<3xi64>>, - #dlti.dl_entry, dense<[64, 64, 64]> : vector<3xi64>>, - #dlti.dl_entry, dense<[32, 64, 64, 24]> : vector<4xi64>>, + #dlti.dl_entry>, + #dlti.dl_entry,#ptr.spec>, + #dlti.dl_entry, #ptr.spec>, #dlti.dl_entry<"dlti.alloca_memory_space", 5 : ui64>, #dlti.dl_entry<"dlti.global_memory_space", 2 : ui64>, #dlti.dl_entry<"dlti.program_memory_space", 3 : ui64>, @@ -57,9 +57,9 @@ module attributes { dlti.dl_spec = #dlti.dl_spec< // ----- -// expected-error@below {{expected layout attribute for '!ptr.ptr' to be a dense integer elements attribute with 3 or 4 elements}} +// expected-error@+2 {{preferred alignment is expected to be at least as large as ABI alignment}} module attributes { dlti.dl_spec = #dlti.dl_spec< - #dlti.dl_entry : vector<3xf32>> + #dlti.dl_entry> >} { func.func @pointer() { return @@ -68,20 +68,47 @@ module attributes { dlti.dl_spec = #dlti.dl_spec< // ----- -// expected-error@below {{preferred alignment is expected to be at least as large as ABI alignment}} +// expected-error@+2 {{size entry must be divisible by 8}} module attributes { dlti.dl_spec = #dlti.dl_spec< - #dlti.dl_entry : vector<3xi64>> + #dlti.dl_entry> >} { func.func @pointer() { return } } + // ----- -// expected-error @below {{expected i64 parameters for '!ptr.ptr'}} +// expected-error@+2 {{abi entry must be divisible by 8}} module attributes { dlti.dl_spec = #dlti.dl_spec< - #dlti.dl_entry : vector<3xi32>> + #dlti.dl_entry> >} { + func.func @pointer() { + return + } } + +// ----- + +// expected-error@+2 {{preferred entry must be divisible by 8}} +module attributes { dlti.dl_spec = #dlti.dl_spec< + #dlti.dl_entry> +>} { + func.func @pointer() { + return + } +} + + +// ----- + +// expected-error@+2 {{index entry must be divisible by 8}} +module attributes { dlti.dl_spec = #dlti.dl_spec< + #dlti.dl_entry> +>} { + func.func @pointer() { + return + } +} From 4a42c6817e5620f6e4e418c3ff2f23dec24ea3f3 Mon Sep 17 00:00:00 2001 From: Fabian Mora Date: Tue, 25 Jun 2024 00:48:37 +0000 Subject: [PATCH 4/5] address reviewer comments --- mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td | 6 +----- mlir/lib/Dialect/Ptr/IR/CMakeLists.txt | 4 ++-- mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp | 6 +++--- 3 files changed, 6 insertions(+), 10 deletions(-) diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td index 2e5e0a14ae991..b8f6baa929641 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td @@ -43,7 +43,7 @@ def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [ let summary = "pointer type"; let description = [{ The `ptr` type is an opaque pointer type. This type typically represents - a reference to an object in memory. Pointers are optionally parameterized + a handle to an object in memory. Pointers are optionally parameterized by a memory space. Syntax: @@ -57,10 +57,6 @@ def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [ let builders = [ TypeBuilder<(ins CArg<"Attribute", "nullptr">:$memorySpace), [{ return $_get($_ctxt, memorySpace); - }]>, - TypeBuilder<(ins CArg<"unsigned">:$addressSpace), [{ - return $_get($_ctxt, IntegerAttr::get(IntegerType::get($_ctxt, 32), - addressSpace)); }]> ]; let skipDefaultBuilders = 1; diff --git a/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt index fbdd08375f15c..9cf3643c73d3e 100644 --- a/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt @@ -3,11 +3,11 @@ add_mlir_dialect_library( PtrAttrs.cpp PtrTypes.cpp PtrDialect.cpp - ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/mlir/Dialect/Pointer + DEPENDS MLIRPtrOpsAttributesIncGen MLIRPtrOpsIncGen + LINK_LIBS PUBLIC MLIRIR diff --git a/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp b/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp index 2e5902f653cb8..f8ce820d0bcbd 100644 --- a/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp +++ b/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp @@ -27,11 +27,11 @@ LogicalResult SpecAttr::verify(function_ref emitError, uint32_t index) { if (size % kBitsInByte != 0) return emitError() << "size entry must be divisible by 8"; - else if (abi % kBitsInByte != 0) + if (abi % kBitsInByte != 0) return emitError() << "abi entry must be divisible by 8"; - else if (preferred % kBitsInByte != 0) + if (preferred % kBitsInByte != 0) return emitError() << "preferred entry must be divisible by 8"; - else if (index != kOptionalSpecValue && index % kBitsInByte != 0) + if (index != kOptionalSpecValue && index % kBitsInByte != 0) return emitError() << "index entry must be divisible by 8"; if (abi > preferred) return emitError() << "preferred alignment is expected to be at least " From 57f4945cd44240b3ad25734bcab9fff4e7c161cb Mon Sep 17 00:00:00 2001 From: Fabian Mora Date: Wed, 26 Jun 2024 23:58:54 +0000 Subject: [PATCH 5/5] update PtrType description with mention of nullptr --- mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td index b8f6baa929641..14d72c3001d91 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td @@ -42,9 +42,10 @@ def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [ ]> { let summary = "pointer type"; let description = [{ - The `ptr` type is an opaque pointer type. This type typically represents - a handle to an object in memory. Pointers are optionally parameterized - by a memory space. + The `ptr` type is an opaque pointer type. This type typically represents a + handle to an object in memory or target-dependent values like `nullptr`. + Pointers are optionally parameterized by a memory space. + Syntax: ```mlir