Skip to content

Commit c0c607c

Browse files
Merge pull request #1667 from nrspruit/fix_multi_device_event_cache
[UR] Fix Multi Device Event Cache for shared Root Device
2 parents e18c691 + 0f2d1f4 commit c0c607c

File tree

3 files changed

+135
-4
lines changed

3 files changed

+135
-4
lines changed

source/adapters/level_zero/event.cpp

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1500,7 +1500,20 @@ ur_result_t _ur_ze_event_list_t::createAndRetainUrZeEventList(
15001500

15011501
std::shared_lock<ur_shared_mutex> Lock(EventList[I]->Mutex);
15021502

1503-
if (Queue && Queue->Device != CurQueueDevice &&
1503+
ur_device_handle_t QueueRootDevice;
1504+
ur_device_handle_t CurrentQueueRootDevice;
1505+
if (Queue) {
1506+
QueueRootDevice = Queue->Device;
1507+
CurrentQueueRootDevice = CurQueueDevice;
1508+
if (Queue->Device->isSubDevice()) {
1509+
QueueRootDevice = Queue->Device->RootDevice;
1510+
}
1511+
if (CurQueueDevice->isSubDevice()) {
1512+
CurrentQueueRootDevice = CurQueueDevice->RootDevice;
1513+
}
1514+
}
1515+
1516+
if (Queue && QueueRootDevice != CurrentQueueRootDevice &&
15041517
!EventList[I]->IsMultiDevice) {
15051518
ze_event_handle_t MultiDeviceZeEvent = nullptr;
15061519
ur_event_handle_t MultiDeviceEvent;
@@ -1514,10 +1527,10 @@ ur_result_t _ur_ze_event_list_t::createAndRetainUrZeEventList(
15141527
const auto &ZeCommandList = CommandList->first;
15151528
EventList[I]->RefCount.increment();
15161529

1517-
zeCommandListAppendWaitOnEvents(ZeCommandList, 1u,
1518-
&EventList[I]->ZeEvent);
1530+
ZE2UR_CALL(zeCommandListAppendWaitOnEvents,
1531+
(ZeCommandList, 1u, &EventList[I]->ZeEvent));
15191532
if (!MultiDeviceEvent->CounterBasedEventsEnabled)
1520-
zeEventHostSignal(MultiDeviceZeEvent);
1533+
ZE2UR_CALL(zeEventHostSignal, (MultiDeviceZeEvent));
15211534

15221535
UR_CALL(Queue->executeCommandList(CommandList, /* IsBlocking */ false,
15231536
/* OkToBatchCommand */ true));

test/adapters/level_zero/CMakeLists.txt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,4 +56,15 @@ if(NOT WIN32)
5656
)
5757

5858
target_link_libraries(test-adapter-level_zero_ze_calls PRIVATE zeCallMap)
59+
60+
add_adapter_test(level_zero_multi_queue
61+
FIXTURE DEVICES
62+
SOURCES
63+
multi_device_event_cache_tests.cpp
64+
ENVIRONMENT
65+
"UR_ADAPTERS_FORCE_LOAD=\"$<TARGET_FILE:ur_adapter_level_zero>\""
66+
"UR_L0_LEAKS_DEBUG=1"
67+
)
68+
69+
target_link_libraries(test-adapter-level_zero_multi_queue PRIVATE zeCallMap)
5970
endif()
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
// Copyright (C) 2024 Intel Corporation
2+
// Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions.
3+
// See LICENSE.TXT
4+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
5+
6+
#include "ur_print.hpp"
7+
#include "uur/fixtures.h"
8+
#include "uur/raii.h"
9+
10+
#include <map>
11+
#include <string>
12+
13+
extern std::map<std::string, int> *ZeCallCount;
14+
15+
using urMultiQueueMultiDeviceEventCacheTest = uur::urAllDevicesTest;
16+
TEST_F(urMultiQueueMultiDeviceEventCacheTest,
17+
GivenMultiSubDeviceWithQueuePerSubDeviceThenEventIsSharedBetweenQueues) {
18+
uint32_t max_sub_devices = 0;
19+
ASSERT_SUCCESS(
20+
uur::GetDevicePartitionMaxSubDevices(devices[0], max_sub_devices));
21+
if (max_sub_devices < 2) {
22+
GTEST_SKIP();
23+
}
24+
ur_device_partition_property_t prop;
25+
prop.type = UR_DEVICE_PARTITION_BY_AFFINITY_DOMAIN;
26+
prop.value.affinity_domain =
27+
UR_DEVICE_AFFINITY_DOMAIN_FLAG_NEXT_PARTITIONABLE;
28+
29+
ur_device_partition_properties_t properties{
30+
UR_STRUCTURE_TYPE_DEVICE_PARTITION_PROPERTIES,
31+
nullptr,
32+
&prop,
33+
1,
34+
};
35+
uint32_t numSubDevices = 0;
36+
ASSERT_SUCCESS(
37+
urDevicePartition(devices[0], &properties, 0, nullptr, &numSubDevices));
38+
std::vector<ur_device_handle_t> sub_devices;
39+
sub_devices.reserve(numSubDevices);
40+
ASSERT_SUCCESS(urDevicePartition(devices[0], &properties, numSubDevices,
41+
sub_devices.data(), nullptr));
42+
uur::raii::Context context1 = nullptr;
43+
ASSERT_SUCCESS(
44+
urContextCreate(1, &sub_devices[0], nullptr, context1.ptr()));
45+
ASSERT_NE(nullptr, context1);
46+
uur::raii::Context context2 = nullptr;
47+
ASSERT_SUCCESS(
48+
urContextCreate(1, &sub_devices[1], nullptr, context2.ptr()));
49+
ASSERT_NE(nullptr, context2);
50+
ur_queue_handle_t queue1 = nullptr;
51+
ASSERT_SUCCESS(urQueueCreate(context1, sub_devices[0], 0, &queue1));
52+
ur_queue_handle_t queue2 = nullptr;
53+
ASSERT_SUCCESS(urQueueCreate(context2, sub_devices[1], 0, &queue2));
54+
uur::raii::Event event = nullptr;
55+
uur::raii::Event eventWait = nullptr;
56+
uur::raii::Event eventWaitDummy = nullptr;
57+
(*ZeCallCount)["zeCommandListAppendWaitOnEvents"] = 0;
58+
EXPECT_SUCCESS(urEventCreateWithNativeHandle(nullptr, context2, nullptr,
59+
eventWait.ptr()));
60+
EXPECT_SUCCESS(urEventCreateWithNativeHandle(nullptr, context1, nullptr,
61+
eventWaitDummy.ptr()));
62+
EXPECT_SUCCESS(
63+
urEnqueueEventsWait(queue1, 1, eventWaitDummy.ptr(), eventWait.ptr()));
64+
EXPECT_SUCCESS(
65+
urEnqueueEventsWait(queue2, 1, eventWait.ptr(), event.ptr()));
66+
EXPECT_EQ((*ZeCallCount)["zeCommandListAppendWaitOnEvents"], 2);
67+
ASSERT_SUCCESS(urEventRelease(eventWaitDummy.get()));
68+
ASSERT_SUCCESS(urEventRelease(eventWait.get()));
69+
ASSERT_SUCCESS(urEventRelease(event.get()));
70+
ASSERT_SUCCESS(urQueueRelease(queue2));
71+
ASSERT_SUCCESS(urQueueRelease(queue1));
72+
}
73+
74+
TEST_F(urMultiQueueMultiDeviceEventCacheTest,
75+
GivenMultiDeviceWithQueuePerDeviceThenMultiDeviceEventIsCreated) {
76+
if (devices.size() < 2) {
77+
GTEST_SKIP();
78+
}
79+
uur::raii::Context context1 = nullptr;
80+
ASSERT_SUCCESS(urContextCreate(1, &devices[0], nullptr, context1.ptr()));
81+
ASSERT_NE(nullptr, context1);
82+
uur::raii::Context context2 = nullptr;
83+
ASSERT_SUCCESS(urContextCreate(1, &devices[1], nullptr, context2.ptr()));
84+
ASSERT_NE(nullptr, context2);
85+
ur_queue_handle_t queue1 = nullptr;
86+
ASSERT_SUCCESS(urQueueCreate(context1, devices[0], 0, &queue1));
87+
ur_queue_handle_t queue2 = nullptr;
88+
ASSERT_SUCCESS(urQueueCreate(context2, devices[1], 0, &queue2));
89+
uur::raii::Event event = nullptr;
90+
uur::raii::Event eventWait = nullptr;
91+
uur::raii::Event eventWaitDummy = nullptr;
92+
(*ZeCallCount)["zeCommandListAppendWaitOnEvents"] = 0;
93+
EXPECT_SUCCESS(urEventCreateWithNativeHandle(nullptr, context2, nullptr,
94+
eventWait.ptr()));
95+
EXPECT_SUCCESS(urEventCreateWithNativeHandle(nullptr, context1, nullptr,
96+
eventWaitDummy.ptr()));
97+
EXPECT_SUCCESS(
98+
urEnqueueEventsWait(queue1, 1, eventWaitDummy.ptr(), eventWait.ptr()));
99+
EXPECT_SUCCESS(
100+
urEnqueueEventsWait(queue2, 1, eventWait.ptr(), event.ptr()));
101+
EXPECT_EQ((*ZeCallCount)["zeCommandListAppendWaitOnEvents"], 3);
102+
ASSERT_SUCCESS(urEventRelease(eventWaitDummy.get()));
103+
ASSERT_SUCCESS(urEventRelease(eventWait.get()));
104+
ASSERT_SUCCESS(urEventRelease(event.get()));
105+
ASSERT_SUCCESS(urQueueRelease(queue2));
106+
ASSERT_SUCCESS(urQueueRelease(queue1));
107+
}

0 commit comments

Comments
 (0)