Skip to content

Commit 3e49b01

Browse files
authored
Merge pull request #1649 from frasercrmck/hip-multimap
[HIP] Add support for multiple active mappings
2 parents 396fb20 + 07ddcbf commit 3e49b01

File tree

3 files changed

+122
-166
lines changed

3 files changed

+122
-166
lines changed

source/adapters/hip/enqueue.cpp

Lines changed: 50 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "kernel.hpp"
1616
#include "memory.hpp"
1717
#include "queue.hpp"
18+
#include "ur_api.h"
1819

1920
#include <ur/ur.hpp>
2021

@@ -1239,49 +1240,42 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferMap(
12391240
UR_ASSERT(offset + size <= BufferImpl.getSize(),
12401241
UR_RESULT_ERROR_INVALID_SIZE);
12411242

1242-
ur_result_t Result = UR_RESULT_ERROR_INVALID_OPERATION;
1243-
const bool IsPinned =
1244-
BufferImpl.MemAllocMode == BufferMem::AllocMode::AllocHostPtr;
1245-
1246-
// Currently no support for overlapping regions
1247-
if (BufferImpl.getMapPtr() != nullptr) {
1248-
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
1243+
auto MapPtr = BufferImpl.mapToPtr(size, offset, mapFlags);
1244+
if (!MapPtr) {
1245+
return UR_RESULT_ERROR_INVALID_MEM_OBJECT;
12491246
}
12501247

1251-
// Allocate a pointer in the host to store the mapped information
1252-
auto HostPtr = BufferImpl.mapToPtr(size, offset, mapFlags);
1253-
*ppRetMap = std::get<BufferMem>(hBuffer->Mem).getMapPtr();
1254-
if (HostPtr) {
1255-
Result = UR_RESULT_SUCCESS;
1256-
}
1248+
const bool IsPinned =
1249+
BufferImpl.MemAllocMode == BufferMem::AllocMode::AllocHostPtr;
12571250

1258-
if (!IsPinned &&
1259-
((mapFlags & UR_MAP_FLAG_READ) || (mapFlags & UR_MAP_FLAG_WRITE))) {
1260-
// Pinned host memory is already on host so it doesn't need to be read.
1261-
Result = urEnqueueMemBufferRead(hQueue, hBuffer, blockingMap, offset, size,
1262-
HostPtr, numEventsInWaitList,
1263-
phEventWaitList, phEvent);
1264-
} else {
1265-
ScopedContext Active(hQueue->getDevice());
1251+
try {
1252+
if (!IsPinned && (mapFlags & (UR_MAP_FLAG_READ | UR_MAP_FLAG_WRITE))) {
1253+
// Pinned host memory is already on host so it doesn't need to be read.
1254+
UR_CHECK_ERROR(urEnqueueMemBufferRead(
1255+
hQueue, hBuffer, blockingMap, offset, size, MapPtr,
1256+
numEventsInWaitList, phEventWaitList, phEvent));
1257+
} else {
1258+
ScopedContext Active(hQueue->getDevice());
12661259

1267-
if (IsPinned) {
1268-
Result = urEnqueueEventsWait(hQueue, numEventsInWaitList, phEventWaitList,
1269-
nullptr);
1270-
}
1260+
if (IsPinned) {
1261+
UR_CHECK_ERROR(urEnqueueEventsWait(hQueue, numEventsInWaitList,
1262+
phEventWaitList, nullptr));
1263+
}
12711264

1272-
if (phEvent) {
1273-
try {
1265+
if (phEvent) {
12741266
*phEvent = ur_event_handle_t_::makeNative(
12751267
UR_COMMAND_MEM_BUFFER_MAP, hQueue, hQueue->getNextTransferStream());
12761268
UR_CHECK_ERROR((*phEvent)->start());
12771269
UR_CHECK_ERROR((*phEvent)->record());
1278-
} catch (ur_result_t Error) {
1279-
Result = Error;
12801270
}
12811271
}
1272+
} catch (ur_result_t Error) {
1273+
return Error;
12821274
}
12831275

1284-
return Result;
1276+
*ppRetMap = MapPtr;
1277+
1278+
return UR_RESULT_SUCCESS;
12851279
}
12861280

12871281
/// Implements the unmap from the host, using a BufferWrite operation.
@@ -1292,47 +1286,44 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemUnmap(
12921286
ur_queue_handle_t hQueue, ur_mem_handle_t hMem, void *pMappedPtr,
12931287
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
12941288
ur_event_handle_t *phEvent) {
1295-
ur_result_t Result = UR_RESULT_SUCCESS;
12961289
UR_ASSERT(hMem->isBuffer(), UR_RESULT_ERROR_INVALID_MEM_OBJECT);
1297-
UR_ASSERT(std::get<BufferMem>(hMem->Mem).getMapPtr() != nullptr,
1298-
UR_RESULT_ERROR_INVALID_MEM_OBJECT);
1299-
UR_ASSERT(std::get<BufferMem>(hMem->Mem).getMapPtr() == pMappedPtr,
1300-
UR_RESULT_ERROR_INVALID_MEM_OBJECT);
1290+
auto &BufferImpl = std::get<BufferMem>(hMem->Mem);
13011291

1302-
const bool IsPinned = std::get<BufferMem>(hMem->Mem).MemAllocMode ==
1303-
BufferMem::AllocMode::AllocHostPtr;
1304-
1305-
if (!IsPinned &&
1306-
((std::get<BufferMem>(hMem->Mem).getMapFlags() & UR_MAP_FLAG_WRITE) ||
1307-
(std::get<BufferMem>(hMem->Mem).getMapFlags() &
1308-
UR_MAP_FLAG_WRITE_INVALIDATE_REGION))) {
1309-
// Pinned host memory is only on host so it doesn't need to be written to.
1310-
Result = urEnqueueMemBufferWrite(
1311-
hQueue, hMem, true, std::get<BufferMem>(hMem->Mem).getMapOffset(),
1312-
std::get<BufferMem>(hMem->Mem).getMapSize(), pMappedPtr,
1313-
numEventsInWaitList, phEventWaitList, phEvent);
1314-
} else {
1315-
ScopedContext Active(hQueue->getDevice());
1292+
auto *Map = BufferImpl.getMapDetails(pMappedPtr);
1293+
UR_ASSERT(Map != nullptr, UR_RESULT_ERROR_INVALID_MEM_OBJECT);
13161294

1317-
if (IsPinned) {
1318-
Result = urEnqueueEventsWait(hQueue, numEventsInWaitList, phEventWaitList,
1319-
nullptr);
1320-
}
1295+
const bool IsPinned =
1296+
BufferImpl.MemAllocMode == BufferMem::AllocMode::AllocHostPtr;
13211297

1322-
if (phEvent) {
1323-
try {
1298+
try {
1299+
if (!IsPinned &&
1300+
(Map->getMapFlags() &
1301+
(UR_MAP_FLAG_WRITE | UR_MAP_FLAG_WRITE_INVALIDATE_REGION))) {
1302+
// Pinned host memory is only on host so it doesn't need to be written to.
1303+
UR_CHECK_ERROR(urEnqueueMemBufferWrite(
1304+
hQueue, hMem, true, Map->getMapOffset(), Map->getMapSize(),
1305+
pMappedPtr, numEventsInWaitList, phEventWaitList, phEvent));
1306+
} else {
1307+
ScopedContext Active(hQueue->getDevice());
1308+
1309+
if (IsPinned) {
1310+
UR_CHECK_ERROR(urEnqueueEventsWait(hQueue, numEventsInWaitList,
1311+
phEventWaitList, nullptr));
1312+
}
1313+
1314+
if (phEvent) {
13241315
*phEvent = ur_event_handle_t_::makeNative(
13251316
UR_COMMAND_MEM_UNMAP, hQueue, hQueue->getNextTransferStream());
13261317
UR_CHECK_ERROR((*phEvent)->start());
13271318
UR_CHECK_ERROR((*phEvent)->record());
1328-
} catch (ur_result_t Error) {
1329-
Result = Error;
13301319
}
13311320
}
1321+
} catch (ur_result_t Error) {
1322+
return Error;
13321323
}
13331324

1334-
std::get<BufferMem>(hMem->Mem).unmap(pMappedPtr);
1335-
return Result;
1325+
BufferImpl.unmap(pMappedPtr);
1326+
return UR_RESULT_SUCCESS;
13361327
}
13371328

13381329
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(

source/adapters/hip/memory.hpp

Lines changed: 69 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -9,53 +9,72 @@
99
//===----------------------------------------------------------------------===//
1010
#pragma once
1111

12-
#include "common.hpp"
1312
#include "context.hpp"
1413
#include "event.hpp"
1514
#include <cassert>
15+
#include <memory>
16+
#include <unordered_map>
1617
#include <variant>
1718

19+
#include "common.hpp"
20+
1821
ur_result_t allocateMemObjOnDeviceIfNeeded(ur_mem_handle_t,
1922
const ur_device_handle_t);
2023
ur_result_t migrateMemoryToDeviceIfNeeded(ur_mem_handle_t,
2124
const ur_device_handle_t);
2225

2326
// Handler for plain, pointer-based HIP allocations
2427
struct BufferMem {
28+
struct BufferMap {
29+
/// Size of the active mapped region.
30+
size_t MapSize;
31+
/// Offset of the active mapped region.
32+
size_t MapOffset;
33+
/// Original flags for the mapped region
34+
ur_map_flags_t MapFlags;
35+
/// Allocated host memory used exclusively for this map.
36+
std::shared_ptr<unsigned char[]> MapMem;
37+
38+
BufferMap(size_t MapSize, size_t MapOffset, ur_map_flags_t MapFlags)
39+
: MapSize(MapSize), MapOffset(MapOffset), MapFlags(MapFlags),
40+
MapMem(nullptr) {}
41+
42+
BufferMap(size_t MapSize, size_t MapOffset, ur_map_flags_t MapFlags,
43+
std::unique_ptr<unsigned char[]> &&MapMem)
44+
: MapSize(MapSize), MapOffset(MapOffset), MapFlags(MapFlags),
45+
MapMem(std::move(MapMem)) {}
46+
47+
size_t getMapSize() const noexcept { return MapSize; }
48+
49+
size_t getMapOffset() const noexcept { return MapOffset; }
50+
51+
ur_map_flags_t getMapFlags() const noexcept { return MapFlags; }
52+
};
53+
54+
/** AllocMode
55+
* Classic: Just a normal buffer allocated on the device via hip malloc
56+
* UseHostPtr: Use an address on the host for the device
57+
* CopyIn: The data for the device comes from the host but the host
58+
pointer is not available later for re-use
59+
* AllocHostPtr: Uses pinned-memory allocation
60+
*/
61+
enum class AllocMode { Classic, UseHostPtr, CopyIn, AllocHostPtr };
62+
2563
using native_type = hipDeviceptr_t;
2664

2765
// If this allocation is a sub-buffer (i.e., a view on an existing
2866
// allocation), this is the pointer to the parent handler structure
2967
ur_mem_handle_t Parent = nullptr;
3068
// Outer mem holding this struct in variant
3169
ur_mem_handle_t OuterMemStruct;
32-
3370
/// Pointer associated with this device on the host
3471
void *HostPtr;
3572
/// Size of the allocation in bytes
3673
size_t Size;
37-
/// Size of the active mapped region.
38-
size_t MapSize;
39-
/// Offset of the active mapped region.
40-
size_t MapOffset;
41-
/// Pointer to the active mapped region, if any
42-
void *MapPtr;
43-
/// Original flags for the mapped region
44-
ur_map_flags_t MapFlags;
74+
/// A map that contains all the active mappings for this buffer.
75+
std::unordered_map<void *, BufferMap> PtrToBufferMap;
4576

46-
/** AllocMode
47-
* Classic: Just a normal buffer allocated on the device via hip malloc
48-
* UseHostPtr: Use an address on the host for the device
49-
* CopyIn: The data for the device comes from the host but the host
50-
pointer is not available later for re-use
51-
* AllocHostPtr: Uses pinned-memory allocation
52-
*/
53-
enum class AllocMode {
54-
Classic,
55-
UseHostPtr,
56-
CopyIn,
57-
AllocHostPtr
58-
} MemAllocMode;
77+
AllocMode MemAllocMode;
5978

6079
private:
6180
// Vector of HIP pointers
@@ -65,10 +84,8 @@ struct BufferMem {
6584
BufferMem(ur_context_handle_t Context, ur_mem_handle_t OuterMemStruct,
6685
AllocMode Mode, void *HostPtr, size_t Size)
6786
: OuterMemStruct{OuterMemStruct}, HostPtr{HostPtr}, Size{Size},
68-
MapSize{0}, MapOffset{0}, MapPtr{nullptr}, MapFlags{UR_MAP_FLAG_WRITE},
69-
MemAllocMode{Mode}, Ptrs(Context->Devices.size(), native_type{0}){};
70-
71-
BufferMem(const BufferMem &Buffer) = default;
87+
PtrToBufferMap{}, MemAllocMode{Mode},
88+
Ptrs(Context->Devices.size(), native_type{0}){};
7289

7390
// This will allocate memory on device if there isn't already an active
7491
// allocation on the device
@@ -98,45 +115,41 @@ struct BufferMem {
98115

99116
size_t getSize() const noexcept { return Size; }
100117

101-
void *getMapPtr() const noexcept { return MapPtr; }
102-
103-
size_t getMapSize() const noexcept { return MapSize; }
104-
105-
size_t getMapOffset() const noexcept { return MapOffset; }
118+
BufferMap *getMapDetails(void *Map) {
119+
auto details = PtrToBufferMap.find(Map);
120+
if (details != PtrToBufferMap.end()) {
121+
return &details->second;
122+
}
123+
return nullptr;
124+
}
106125

107126
/// Returns a pointer to data visible on the host that contains
108127
/// the data on the device associated with this allocation.
109128
/// The offset is used to index into the HIP allocation.
110129
///
111-
void *mapToPtr(size_t Size, size_t Offset, ur_map_flags_t Flags) noexcept {
112-
assert(MapPtr == nullptr);
113-
MapSize = Size;
114-
MapOffset = Offset;
115-
MapFlags = Flags;
116-
if (HostPtr) {
117-
MapPtr = static_cast<char *>(HostPtr) + Offset;
130+
void *mapToPtr(size_t MapSize, size_t MapOffset,
131+
ur_map_flags_t MapFlags) noexcept {
132+
void *MapPtr = nullptr;
133+
if (HostPtr == nullptr) {
134+
/// If HostPtr is invalid, we need to create a Mapping that owns its own
135+
/// memory on the host.
136+
auto MapMem = std::make_unique<unsigned char[]>(MapSize);
137+
MapPtr = MapMem.get();
138+
PtrToBufferMap.insert(
139+
{MapPtr, BufferMap(MapSize, MapOffset, MapFlags, std::move(MapMem))});
118140
} else {
119-
// TODO: Allocate only what is needed based on the offset
120-
MapPtr = static_cast<void *>(malloc(this->getSize()));
141+
/// However, if HostPtr already has valid memory (e.g. pinned allocation),
142+
/// we can just use that memory for the mapping.
143+
MapPtr = static_cast<char *>(HostPtr) + MapOffset;
144+
PtrToBufferMap.insert({MapPtr, BufferMap(MapSize, MapOffset, MapFlags)});
121145
}
122146
return MapPtr;
123147
}
124148

125149
/// Detach the allocation from the host memory.
126-
void unmap(void *) noexcept {
150+
void unmap(void *MapPtr) noexcept {
127151
assert(MapPtr != nullptr);
128-
129-
if (MapPtr != HostPtr) {
130-
free(MapPtr);
131-
}
132-
MapPtr = nullptr;
133-
MapSize = 0;
134-
MapOffset = 0;
135-
}
136-
137-
ur_map_flags_t getMapFlags() const noexcept {
138-
assert(MapPtr != nullptr);
139-
return MapFlags;
152+
PtrToBufferMap.erase(MapPtr);
140153
}
141154

142155
ur_result_t clear() {
@@ -414,7 +427,7 @@ struct ur_mem_handle_t_ {
414427
HaveMigratedToDeviceSinceLastWrite(Context->Devices.size(), false),
415428
Mem{std::in_place_type<BufferMem>, Ctxt, this, Mode, HostPtr, Size} {
416429
urContextRetain(Context);
417-
};
430+
}
418431

419432
// Subbuffer constructor
420433
ur_mem_handle_t_(ur_mem Parent, size_t SubBufferOffset)
@@ -435,7 +448,7 @@ struct ur_mem_handle_t_ {
435448
}
436449
}
437450
urMemRetain(Parent);
438-
};
451+
}
439452

440453
/// Constructs the UR mem handler for an Image object
441454
ur_mem_handle_t_(ur_context Ctxt, ur_mem_flags_t MemFlags,

0 commit comments

Comments
 (0)