6
6
#include " context.hpp"
7
7
#include " fixtures.h"
8
8
#include " queue.hpp"
9
+ #include " uur/raii.h"
9
10
#include < thread>
10
11
11
12
using cudaUrContextCreateTest = uur::urDeviceTest;
@@ -14,14 +15,13 @@ UUR_INSTANTIATE_DEVICE_TEST_SUITE_P(cudaUrContextCreateTest);
14
15
constexpr unsigned int known_cuda_api_version = 3020 ;
15
16
16
17
TEST_P (cudaUrContextCreateTest, CreateWithChildThread) {
17
-
18
- ur_context_handle_t context = nullptr ;
19
- ASSERT_SUCCESS (urContextCreate (1 , &device, nullptr , &context));
18
+ uur::raii::Context context = nullptr ;
19
+ ASSERT_SUCCESS (urContextCreate (1 , &device, nullptr , context.ptr ()));
20
20
ASSERT_NE (context, nullptr );
21
21
22
22
// Retrieve the CUDA context to check information is correct
23
23
auto checkValue = [=] {
24
- CUcontext cudaContext = context->get ();
24
+ CUcontext cudaContext = context. handle ->get ();
25
25
unsigned int version = 0 ;
26
26
EXPECT_SUCCESS_CUDA (cuCtxGetApiVersion (cudaContext, &version));
27
27
EXPECT_EQ (version, known_cuda_api_version);
@@ -39,27 +39,26 @@ TEST_P(cudaUrContextCreateTest, CreateWithChildThread) {
39
39
40
40
auto callContextFromOtherThread = std::thread (checkValue);
41
41
callContextFromOtherThread.join ();
42
- ASSERT_SUCCESS (urContextRelease (context));
43
42
}
44
43
45
44
TEST_P (cudaUrContextCreateTest, ActiveContext) {
46
- ur_context_handle_t context = nullptr ;
47
- ASSERT_SUCCESS (urContextCreate (1 , &device, nullptr , & context));
45
+ uur::raii::Context context = nullptr ;
46
+ ASSERT_SUCCESS (urContextCreate (1 , &device, nullptr , context. ptr () ));
48
47
ASSERT_NE (context, nullptr );
49
48
50
- ur_queue_handle_t queue = nullptr ;
49
+ uur::raii::Queue queue = nullptr ;
51
50
ur_queue_properties_t queue_props{UR_STRUCTURE_TYPE_QUEUE_PROPERTIES,
52
51
nullptr , 0 };
53
- ASSERT_SUCCESS (urQueueCreate (context, device, &queue_props, & queue));
52
+ ASSERT_SUCCESS (urQueueCreate (context, device, &queue_props, queue. ptr () ));
54
53
ASSERT_NE (queue, nullptr );
55
54
56
55
// check that the queue has the correct context
57
56
ASSERT_EQ (context, queue->getContext ());
58
57
59
58
// create a buffer
60
- ur_mem_handle_t buffer = nullptr ;
59
+ uur::raii::Mem buffer = nullptr ;
61
60
ASSERT_SUCCESS (urMemBufferCreate (context, UR_MEM_FLAG_READ_WRITE, 1024 ,
62
- nullptr , & buffer));
61
+ nullptr , buffer. ptr () ));
63
62
ASSERT_NE (buffer, nullptr );
64
63
65
64
// check that the context is now the active CUDA context
@@ -71,11 +70,6 @@ TEST_P(cudaUrContextCreateTest, ActiveContext) {
71
70
ASSERT_SUCCESS (urContextGetNativeHandle (context, &native_context));
72
71
ASSERT_NE (native_context, nullptr );
73
72
ASSERT_EQ (cudaCtx, reinterpret_cast <CUcontext>(native_context));
74
-
75
- // release resources
76
- ASSERT_SUCCESS (urMemRelease (buffer));
77
- ASSERT_SUCCESS (urQueueRelease (queue));
78
- ASSERT_SUCCESS (urContextRelease (context));
79
73
}
80
74
81
75
TEST_P (cudaUrContextCreateTest, ContextLifetimeExisting) {
@@ -89,13 +83,13 @@ TEST_P(cudaUrContextCreateTest, ContextLifetimeExisting) {
89
83
ASSERT_EQ (original, current);
90
84
91
85
// create a UR context
92
- ur_context_handle_t context;
93
- ASSERT_SUCCESS (urContextCreate (1 , &device, nullptr , & context));
86
+ uur::raii::Context context;
87
+ ASSERT_SUCCESS (urContextCreate (1 , &device, nullptr , context. ptr () ));
94
88
ASSERT_NE (context, nullptr );
95
89
96
90
// create a queue with the context
97
- ur_queue_handle_t queue;
98
- ASSERT_SUCCESS (urQueueCreate (context, device, nullptr , & queue));
91
+ uur::raii::Queue queue;
92
+ ASSERT_SUCCESS (urQueueCreate (context, device, nullptr , queue. ptr () ));
99
93
ASSERT_NE (queue, nullptr );
100
94
101
95
// ensure the queue has the correct context
@@ -109,19 +103,16 @@ TEST_P(cudaUrContextCreateTest, ContextLifetimeExisting) {
109
103
// check that context is now the active cuda context
110
104
ASSERT_SUCCESS_CUDA (cuCtxGetCurrent (¤t));
111
105
ASSERT_EQ (current, context->get ());
112
-
113
- ASSERT_SUCCESS (urQueueRelease (queue));
114
- ASSERT_SUCCESS (urContextRelease (context));
115
106
}
116
107
117
108
TEST_P (cudaUrContextCreateTest, ThreadedContext) {
118
109
// create two new UR contexts
119
- ur_context_handle_t context1;
120
- ASSERT_SUCCESS (urContextCreate (1 , &device, nullptr , & context1));
110
+ uur::raii::Context context1;
111
+ ASSERT_SUCCESS (urContextCreate (1 , &device, nullptr , context1. ptr () ));
121
112
ASSERT_NE (context1, nullptr );
122
113
123
- ur_context_handle_t context2;
124
- ASSERT_SUCCESS (urContextCreate (1 , &device, nullptr , & context2));
114
+ uur::raii::Context context2;
115
+ ASSERT_SUCCESS (urContextCreate (1 , &device, nullptr , context2. ptr () ));
125
116
ASSERT_NE (context2, nullptr );
126
117
127
118
// setup synchronization variables between the main thread and
@@ -138,23 +129,22 @@ TEST_P(cudaUrContextCreateTest, ThreadedContext) {
138
129
auto test_thread = std::thread ([&] {
139
130
CUcontext current = nullptr ;
140
131
141
- // create a queue with the first context
142
- ur_queue_handle_t queue;
143
- ASSERT_SUCCESS (urQueueCreate (context1, device, nullptr , &queue));
144
- ASSERT_NE (queue, nullptr );
145
-
146
- // ensure that the queue has the correct context
147
- ASSERT_EQ (context1, queue->getContext ());
132
+ {
133
+ // create a queue with the first context
134
+ uur::raii::Queue queue;
135
+ ASSERT_SUCCESS (
136
+ urQueueCreate (context1, device, nullptr , queue.ptr ()));
137
+ ASSERT_NE (queue, nullptr );
148
138
149
- // create a buffer to set context1 as the active context
150
- ur_mem_handle_t buffer;
151
- ASSERT_SUCCESS (urMemBufferCreate (context1, UR_MEM_FLAG_READ_WRITE, 1024 ,
152
- nullptr , &buffer));
153
- ASSERT_NE (buffer, nullptr );
139
+ // ensure that the queue has the correct context
140
+ ASSERT_EQ (context1, queue->getContext ());
154
141
155
- // release the mem and queue
156
- ASSERT_SUCCESS (urMemRelease (buffer));
157
- ASSERT_SUCCESS (urQueueRelease (queue));
142
+ // create a buffer to set context1 as the active context
143
+ uur::raii::Mem buffer;
144
+ ASSERT_SUCCESS (urMemBufferCreate (context1, UR_MEM_FLAG_READ_WRITE,
145
+ 1024 , nullptr , buffer.ptr ()));
146
+ ASSERT_NE (buffer, nullptr );
147
+ }
158
148
159
149
// mark the first set of processing as done and notify the main thread
160
150
std::unique_lock<std::mutex> lock (m);
@@ -166,31 +156,31 @@ TEST_P(cudaUrContextCreateTest, ThreadedContext) {
166
156
lock.lock ();
167
157
cv.wait (lock, [&] { return released; });
168
158
169
- // create a queue with the 2nd context
170
- ASSERT_SUCCESS (urQueueCreate (context2, device, nullptr , &queue));
171
- ASSERT_NE (queue, nullptr );
172
-
173
- // ensure queue has correct context
174
- ASSERT_EQ (context2, queue->getContext ());
175
-
176
- // create a buffer to set the active context
177
- ASSERT_SUCCESS (urMemBufferCreate (context2, UR_MEM_FLAG_READ_WRITE, 1024 ,
178
- nullptr , &buffer));
179
-
180
- // check that the 2nd context is now tha active cuda context
181
- ASSERT_SUCCESS_CUDA (cuCtxGetCurrent (¤t));
182
- ASSERT_EQ (current, context2->get ());
183
-
184
- // release
185
- ASSERT_SUCCESS (urMemRelease (buffer));
186
- ASSERT_SUCCESS (urQueueRelease (queue));
159
+ {
160
+ // create a queue with the 2nd context
161
+ uur::raii::Queue queue = nullptr ;
162
+ ASSERT_SUCCESS (
163
+ urQueueCreate (context2, device, nullptr , queue.ptr ()));
164
+ ASSERT_NE (queue, nullptr );
165
+
166
+ // ensure queue has correct context
167
+ ASSERT_EQ (context2, queue->getContext ());
168
+
169
+ // create a buffer to set the active context
170
+ uur::raii::Mem buffer = nullptr ;
171
+ ASSERT_SUCCESS (urMemBufferCreate (context2, UR_MEM_FLAG_READ_WRITE,
172
+ 1024 , nullptr , buffer.ptr ()));
173
+
174
+ // check that the 2nd context is now tha active cuda context
175
+ ASSERT_SUCCESS_CUDA (cuCtxGetCurrent (¤t));
176
+ ASSERT_EQ (current, context2->get ());
177
+ }
187
178
});
188
179
189
180
// wait for the thread to be done with the first queue to release the first
190
181
// context
191
182
std::unique_lock<std::mutex> lock (m);
192
183
cv.wait (lock, [&] { return thread_done; });
193
- ASSERT_SUCCESS (urContextRelease (context1));
194
184
195
185
// notify the other thread that the context was released
196
186
released = true ;
@@ -199,6 +189,4 @@ TEST_P(cudaUrContextCreateTest, ThreadedContext) {
199
189
200
190
// wait for the thread to finish
201
191
test_thread.join ();
202
-
203
- ASSERT_SUCCESS (urContextRelease (context2));
204
192
}
0 commit comments