Skip to content

Commit f65473d

Browse files
committed
[OpenCL] Add bounds checking to the Enqueue memory operations.
This allows us to return UR_ERROR_INVALID_SIZE when we should. Extra checks are only performed on a non-success error code. Also adds a missing bounds check to urMemBufferPartition
1 parent 39eec0c commit f65473d

File tree

2 files changed

+152
-39
lines changed

2 files changed

+152
-39
lines changed

source/adapters/opencl/enqueue.cpp

Lines changed: 143 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,77 @@ cl_map_flags convertURMapFlagsToCL(ur_map_flags_t URFlags) {
2525
return CLFlags;
2626
}
2727

28+
ur_result_t ValidateBufferSize(ur_mem_handle_t Buffer, size_t Size,
29+
size_t Origin) {
30+
size_t BufferSize = 0;
31+
CL_RETURN_ON_FAILURE(clGetMemObjectInfo(cl_adapter::cast<cl_mem>(Buffer),
32+
CL_MEM_SIZE, sizeof(BufferSize),
33+
&BufferSize, nullptr));
34+
if (Size + Origin > BufferSize)
35+
return UR_RESULT_ERROR_INVALID_SIZE;
36+
return UR_RESULT_SUCCESS;
37+
}
38+
39+
ur_result_t ValidateBufferRectSize(ur_mem_handle_t Buffer,
40+
ur_rect_region_t Region,
41+
ur_rect_offset_t Offset) {
42+
size_t BufferSize = 0;
43+
CL_RETURN_ON_FAILURE(clGetMemObjectInfo(cl_adapter::cast<cl_mem>(Buffer),
44+
CL_MEM_SIZE, sizeof(BufferSize),
45+
&BufferSize, nullptr));
46+
if (Offset.x >= BufferSize || Offset.y >= BufferSize ||
47+
Offset.z >= BufferSize) {
48+
return UR_RESULT_ERROR_INVALID_SIZE;
49+
}
50+
51+
if ((Region.width + Offset.x) * (Region.height + Offset.y) *
52+
(Region.depth + Offset.z) >
53+
BufferSize) {
54+
return UR_RESULT_ERROR_INVALID_SIZE;
55+
}
56+
57+
return UR_RESULT_SUCCESS;
58+
}
59+
60+
ur_result_t ValidateImageSize(ur_mem_handle_t Image, ur_rect_region_t Region,
61+
ur_rect_offset_t Origin) {
62+
size_t Width = 0;
63+
CL_RETURN_ON_FAILURE(clGetImageInfo(cl_adapter::cast<cl_mem>(Image),
64+
CL_IMAGE_WIDTH, sizeof(Width), &Width,
65+
nullptr));
66+
if (Region.width + Origin.x > Width) {
67+
return UR_RESULT_ERROR_INVALID_SIZE;
68+
}
69+
70+
size_t Height = 0;
71+
CL_RETURN_ON_FAILURE(clGetImageInfo(cl_adapter::cast<cl_mem>(Image),
72+
CL_IMAGE_HEIGHT, sizeof(Height), &Height,
73+
nullptr));
74+
75+
// CL returns a height and depth of 0 for images that don't have those
76+
// dimensions, but regions for enqueue operations must set these to 1, so we
77+
// need to make this adjustment to validate.
78+
if (Height == 0)
79+
Height = 1;
80+
81+
if (Region.height + Origin.y > Height) {
82+
return UR_RESULT_ERROR_INVALID_SIZE;
83+
}
84+
85+
size_t Depth = 0;
86+
CL_RETURN_ON_FAILURE(clGetImageInfo(cl_adapter::cast<cl_mem>(Image),
87+
CL_IMAGE_DEPTH, sizeof(Depth), &Depth,
88+
nullptr));
89+
if (Depth == 0)
90+
Depth = 1;
91+
92+
if (Region.depth + Origin.z > Depth) {
93+
return UR_RESULT_ERROR_INVALID_SIZE;
94+
}
95+
96+
return UR_RESULT_SUCCESS;
97+
}
98+
2899
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
29100
ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
30101
const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
@@ -70,27 +141,33 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead(
70141
size_t offset, size_t size, void *pDst, uint32_t numEventsInWaitList,
71142
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
72143

73-
CL_RETURN_ON_FAILURE(clEnqueueReadBuffer(
144+
auto ClErr = clEnqueueReadBuffer(
74145
cl_adapter::cast<cl_command_queue>(hQueue),
75146
cl_adapter::cast<cl_mem>(hBuffer), blockingRead, offset, size, pDst,
76147
numEventsInWaitList, cl_adapter::cast<const cl_event *>(phEventWaitList),
77-
cl_adapter::cast<cl_event *>(phEvent)));
148+
cl_adapter::cast<cl_event *>(phEvent));
78149

79-
return UR_RESULT_SUCCESS;
150+
if (ClErr == CL_INVALID_VALUE) {
151+
UR_RETURN_ON_FAILURE(ValidateBufferSize(hBuffer, size, offset));
152+
}
153+
return mapCLErrorToUR(ClErr);
80154
}
81155

82156
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite(
83157
ur_queue_handle_t hQueue, ur_mem_handle_t hBuffer, bool blockingWrite,
84158
size_t offset, size_t size, const void *pSrc, uint32_t numEventsInWaitList,
85159
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
86160

87-
CL_RETURN_ON_FAILURE(clEnqueueWriteBuffer(
161+
auto ClErr = clEnqueueWriteBuffer(
88162
cl_adapter::cast<cl_command_queue>(hQueue),
89163
cl_adapter::cast<cl_mem>(hBuffer), blockingWrite, offset, size, pSrc,
90164
numEventsInWaitList, cl_adapter::cast<const cl_event *>(phEventWaitList),
91-
cl_adapter::cast<cl_event *>(phEvent)));
165+
cl_adapter::cast<cl_event *>(phEvent));
92166

93-
return UR_RESULT_SUCCESS;
167+
if (ClErr == CL_INVALID_VALUE) {
168+
UR_RETURN_ON_FAILURE(ValidateBufferSize(hBuffer, size, offset));
169+
}
170+
return mapCLErrorToUR(ClErr);
94171
}
95172

96173
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferReadRect(
@@ -101,17 +178,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferReadRect(
101178
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
102179
ur_event_handle_t *phEvent) {
103180

104-
CL_RETURN_ON_FAILURE(clEnqueueReadBufferRect(
181+
auto ClErr = clEnqueueReadBufferRect(
105182
cl_adapter::cast<cl_command_queue>(hQueue),
106183
cl_adapter::cast<cl_mem>(hBuffer), blockingRead,
107184
cl_adapter::cast<const size_t *>(&bufferOrigin),
108185
cl_adapter::cast<const size_t *>(&hostOrigin),
109186
cl_adapter::cast<const size_t *>(&region), bufferRowPitch,
110187
bufferSlicePitch, hostRowPitch, hostSlicePitch, pDst, numEventsInWaitList,
111188
cl_adapter::cast<const cl_event *>(phEventWaitList),
112-
cl_adapter::cast<cl_event *>(phEvent)));
189+
cl_adapter::cast<cl_event *>(phEvent));
113190

114-
return UR_RESULT_SUCCESS;
191+
if (ClErr == CL_INVALID_VALUE) {
192+
UR_RETURN_ON_FAILURE(ValidateBufferRectSize(hBuffer, region, bufferOrigin));
193+
}
194+
return mapCLErrorToUR(ClErr);
115195
}
116196

117197
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWriteRect(
@@ -122,17 +202,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWriteRect(
122202
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
123203
ur_event_handle_t *phEvent) {
124204

125-
CL_RETURN_ON_FAILURE(clEnqueueWriteBufferRect(
205+
auto ClErr = clEnqueueWriteBufferRect(
126206
cl_adapter::cast<cl_command_queue>(hQueue),
127207
cl_adapter::cast<cl_mem>(hBuffer), blockingWrite,
128208
cl_adapter::cast<const size_t *>(&bufferOrigin),
129209
cl_adapter::cast<const size_t *>(&hostOrigin),
130210
cl_adapter::cast<const size_t *>(&region), bufferRowPitch,
131211
bufferSlicePitch, hostRowPitch, hostSlicePitch, pSrc, numEventsInWaitList,
132212
cl_adapter::cast<const cl_event *>(phEventWaitList),
133-
cl_adapter::cast<cl_event *>(phEvent)));
213+
cl_adapter::cast<cl_event *>(phEvent));
134214

135-
return UR_RESULT_SUCCESS;
215+
if (ClErr == CL_INVALID_VALUE) {
216+
UR_RETURN_ON_FAILURE(ValidateBufferRectSize(hBuffer, region, bufferOrigin));
217+
}
218+
return mapCLErrorToUR(ClErr);
136219
}
137220

138221
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopy(
@@ -141,14 +224,18 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopy(
141224
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
142225
ur_event_handle_t *phEvent) {
143226

144-
CL_RETURN_ON_FAILURE(clEnqueueCopyBuffer(
227+
auto ClErr = clEnqueueCopyBuffer(
145228
cl_adapter::cast<cl_command_queue>(hQueue),
146229
cl_adapter::cast<cl_mem>(hBufferSrc),
147230
cl_adapter::cast<cl_mem>(hBufferDst), srcOffset, dstOffset, size,
148231
numEventsInWaitList, cl_adapter::cast<const cl_event *>(phEventWaitList),
149-
cl_adapter::cast<cl_event *>(phEvent)));
232+
cl_adapter::cast<cl_event *>(phEvent));
150233

151-
return UR_RESULT_SUCCESS;
234+
if (ClErr == CL_INVALID_VALUE) {
235+
UR_RETURN_ON_FAILURE(ValidateBufferSize(hBufferSrc, size, srcOffset));
236+
UR_RETURN_ON_FAILURE(ValidateBufferSize(hBufferDst, size, dstOffset));
237+
}
238+
return mapCLErrorToUR(ClErr);
152239
}
153240

154241
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect(
@@ -159,7 +246,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect(
159246
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
160247
ur_event_handle_t *phEvent) {
161248

162-
CL_RETURN_ON_FAILURE(clEnqueueCopyBufferRect(
249+
auto ClErr = clEnqueueCopyBufferRect(
163250
cl_adapter::cast<cl_command_queue>(hQueue),
164251
cl_adapter::cast<cl_mem>(hBufferSrc),
165252
cl_adapter::cast<cl_mem>(hBufferDst),
@@ -168,9 +255,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect(
168255
cl_adapter::cast<const size_t *>(&region), srcRowPitch, srcSlicePitch,
169256
dstRowPitch, dstSlicePitch, numEventsInWaitList,
170257
cl_adapter::cast<const cl_event *>(phEventWaitList),
171-
cl_adapter::cast<cl_event *>(phEvent)));
258+
cl_adapter::cast<cl_event *>(phEvent));
172259

173-
return UR_RESULT_SUCCESS;
260+
if (ClErr == CL_INVALID_VALUE) {
261+
UR_RETURN_ON_FAILURE(ValidateBufferRectSize(hBufferSrc, region, srcOrigin));
262+
UR_RETURN_ON_FAILURE(ValidateBufferRectSize(hBufferDst, region, dstOrigin));
263+
}
264+
return mapCLErrorToUR(ClErr);
174265
}
175266

176267
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
@@ -181,13 +272,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
181272
// CL FillBuffer only allows pattern sizes up to the largest CL type:
182273
// long16/double16
183274
if (patternSize <= 128) {
184-
CL_RETURN_ON_FAILURE(
185-
clEnqueueFillBuffer(cl_adapter::cast<cl_command_queue>(hQueue),
186-
cl_adapter::cast<cl_mem>(hBuffer), pPattern,
187-
patternSize, offset, size, numEventsInWaitList,
188-
cl_adapter::cast<const cl_event *>(phEventWaitList),
189-
cl_adapter::cast<cl_event *>(phEvent)));
190-
return UR_RESULT_SUCCESS;
275+
auto ClErr = (clEnqueueFillBuffer(
276+
cl_adapter::cast<cl_command_queue>(hQueue),
277+
cl_adapter::cast<cl_mem>(hBuffer), pPattern, patternSize, offset, size,
278+
numEventsInWaitList,
279+
cl_adapter::cast<const cl_event *>(phEventWaitList),
280+
cl_adapter::cast<cl_event *>(phEvent)));
281+
if (ClErr != CL_SUCCESS) {
282+
UR_RETURN_ON_FAILURE(ValidateBufferSize(hBuffer, size, offset));
283+
}
284+
return mapCLErrorToUR(ClErr);
191285
}
192286

193287
auto NumValues = size / sizeof(uint64_t);
@@ -205,6 +299,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
205299
&WriteEvent);
206300
if (ClErr != CL_SUCCESS) {
207301
delete[] HostBuffer;
302+
UR_RETURN_ON_FAILURE(ValidateBufferSize(hBuffer, offset, size));
208303
CL_RETURN_ON_FAILURE(ClErr);
209304
}
210305

@@ -237,15 +332,18 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageRead(
237332
size_t slicePitch, void *pDst, uint32_t numEventsInWaitList,
238333
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
239334

240-
CL_RETURN_ON_FAILURE(clEnqueueReadImage(
335+
auto ClErr = clEnqueueReadImage(
241336
cl_adapter::cast<cl_command_queue>(hQueue),
242337
cl_adapter::cast<cl_mem>(hImage), blockingRead,
243338
cl_adapter::cast<const size_t *>(&origin),
244339
cl_adapter::cast<const size_t *>(&region), rowPitch, slicePitch, pDst,
245340
numEventsInWaitList, cl_adapter::cast<const cl_event *>(phEventWaitList),
246-
cl_adapter::cast<cl_event *>(phEvent)));
341+
cl_adapter::cast<cl_event *>(phEvent));
247342

248-
return UR_RESULT_SUCCESS;
343+
if (ClErr == CL_INVALID_VALUE) {
344+
UR_RETURN_ON_FAILURE(ValidateImageSize(hImage, region, origin));
345+
}
346+
return mapCLErrorToUR(ClErr);
249347
}
250348

251349
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageWrite(
@@ -254,15 +352,18 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageWrite(
254352
size_t slicePitch, void *pSrc, uint32_t numEventsInWaitList,
255353
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
256354

257-
CL_RETURN_ON_FAILURE(clEnqueueWriteImage(
355+
auto ClErr = clEnqueueWriteImage(
258356
cl_adapter::cast<cl_command_queue>(hQueue),
259357
cl_adapter::cast<cl_mem>(hImage), blockingWrite,
260358
cl_adapter::cast<const size_t *>(&origin),
261359
cl_adapter::cast<const size_t *>(&region), rowPitch, slicePitch, pSrc,
262360
numEventsInWaitList, cl_adapter::cast<const cl_event *>(phEventWaitList),
263-
cl_adapter::cast<cl_event *>(phEvent)));
361+
cl_adapter::cast<cl_event *>(phEvent));
264362

265-
return UR_RESULT_SUCCESS;
363+
if (ClErr == CL_INVALID_VALUE) {
364+
UR_RETURN_ON_FAILURE(ValidateImageSize(hImage, region, origin));
365+
}
366+
return mapCLErrorToUR(ClErr);
266367
}
267368

268369
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageCopy(
@@ -272,16 +373,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageCopy(
272373
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
273374
ur_event_handle_t *phEvent) {
274375

275-
CL_RETURN_ON_FAILURE(clEnqueueCopyImage(
376+
auto ClErr = clEnqueueCopyImage(
276377
cl_adapter::cast<cl_command_queue>(hQueue),
277378
cl_adapter::cast<cl_mem>(hImageSrc), cl_adapter::cast<cl_mem>(hImageDst),
278379
cl_adapter::cast<const size_t *>(&srcOrigin),
279380
cl_adapter::cast<const size_t *>(&dstOrigin),
280381
cl_adapter::cast<const size_t *>(&region), numEventsInWaitList,
281382
cl_adapter::cast<const cl_event *>(phEventWaitList),
282-
cl_adapter::cast<cl_event *>(phEvent)));
383+
cl_adapter::cast<cl_event *>(phEvent));
283384

284-
return UR_RESULT_SUCCESS;
385+
if (ClErr == CL_INVALID_VALUE) {
386+
UR_RETURN_ON_FAILURE(ValidateImageSize(hImageSrc, region, srcOrigin));
387+
UR_RETURN_ON_FAILURE(ValidateImageSize(hImageDst, region, dstOrigin));
388+
}
389+
return mapCLErrorToUR(ClErr);
285390
}
286391

287392
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferMap(
@@ -298,9 +403,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferMap(
298403
cl_adapter::cast<const cl_event *>(phEventWaitList),
299404
cl_adapter::cast<cl_event *>(phEvent), &Err);
300405

301-
CL_RETURN_ON_FAILURE(Err);
302-
303-
return UR_RESULT_SUCCESS;
406+
if (Err == CL_INVALID_VALUE) {
407+
UR_RETURN_ON_FAILURE(ValidateBufferSize(hBuffer, size, offset));
408+
}
409+
return mapCLErrorToUR(Err);
304410
}
305411

306412
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemUnmap(

source/adapters/opencl/memory.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,9 +319,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemBufferPartition(
319319
*phMem = reinterpret_cast<ur_mem_handle_t>(clCreateSubBuffer(
320320
cl_adapter::cast<cl_mem>(hBuffer), static_cast<cl_mem_flags>(flags),
321321
BufferCreateType, &BufferRegion, cl_adapter::cast<cl_int *>(&RetErr)));
322-
CL_RETURN_ON_FAILURE(RetErr);
323322

324-
return UR_RESULT_SUCCESS;
323+
if (RetErr == CL_INVALID_VALUE) {
324+
size_t BufferSize = 0;
325+
CL_RETURN_ON_FAILURE(clGetMemObjectInfo(cl_adapter::cast<cl_mem>(hBuffer),
326+
CL_MEM_SIZE, sizeof(BufferSize),
327+
&BufferSize, nullptr));
328+
if (BufferRegion.size + BufferRegion.origin > BufferSize)
329+
return UR_RESULT_ERROR_INVALID_BUFFER_SIZE;
330+
}
331+
return mapCLErrorToUR(RetErr);
325332
}
326333

327334
UR_APIEXPORT ur_result_t UR_APICALL

0 commit comments

Comments
 (0)