Skip to content

Commit db69f90

Browse files
authored
Add extra checks for optional parameters and invalid flags (#857)
- Refactors the error code generation script - Adds checks for optional parameters - Adds checks for invalid combinations of flags in urQueueCreate Closes #856
1 parent a346a30 commit db69f90

File tree

7 files changed

+159
-58
lines changed

7 files changed

+159
-58
lines changed

include/ur_api.h

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1717,6 +1717,7 @@ typedef struct ur_device_partition_properties_t {
17171717
/// + `NULL == hDevice`
17181718
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
17191719
/// + `NULL == pProperties`
1720+
/// + `NULL == pProperties->pProperties`
17201721
/// - ::UR_RESULT_ERROR_DEVICE_PARTITION_FAILED
17211722
/// - ::UR_RESULT_ERROR_INVALID_DEVICE_PARTITION_COUNT
17221723
UR_APIEXPORT ur_result_t UR_APICALL
@@ -2029,6 +2030,8 @@ typedef struct ur_context_properties_t {
20292030
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
20302031
/// + `NULL == phDevices`
20312032
/// + `NULL == phContext`
2033+
/// - ::UR_RESULT_ERROR_INVALID_ENUMERATION
2034+
/// + `NULL != pProperties && ::UR_CONTEXT_FLAGS_MASK & pProperties->flags`
20322035
/// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY
20332036
/// - ::UR_RESULT_ERROR_OUT_OF_DEVICE_MEMORY
20342037
UR_APIEXPORT ur_result_t UR_APICALL
@@ -3273,6 +3276,8 @@ typedef struct ur_usm_pool_limits_desc_t {
32733276
/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC
32743277
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
32753278
/// + `NULL == hContext`
3279+
/// - ::UR_RESULT_ERROR_INVALID_ENUMERATION
3280+
/// + `NULL != pUSMDesc && ::UR_USM_ADVICE_FLAGS_MASK & pUSMDesc->hints`
32763281
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
32773282
/// + `NULL == ppMem`
32783283
/// - ::UR_RESULT_ERROR_INVALID_CONTEXT
@@ -3317,6 +3322,8 @@ urUSMHostAlloc(
33173322
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
33183323
/// + `NULL == hContext`
33193324
/// + `NULL == hDevice`
3325+
/// - ::UR_RESULT_ERROR_INVALID_ENUMERATION
3326+
/// + `NULL != pUSMDesc && ::UR_USM_ADVICE_FLAGS_MASK & pUSMDesc->hints`
33203327
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
33213328
/// + `NULL == ppMem`
33223329
/// - ::UR_RESULT_ERROR_INVALID_CONTEXT
@@ -3363,6 +3370,8 @@ urUSMDeviceAlloc(
33633370
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
33643371
/// + `NULL == hContext`
33653372
/// + `NULL == hDevice`
3373+
/// - ::UR_RESULT_ERROR_INVALID_ENUMERATION
3374+
/// + `NULL != pUSMDesc && ::UR_USM_ADVICE_FLAGS_MASK & pUSMDesc->hints`
33663375
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
33673376
/// + `NULL == ppMem`
33683377
/// - ::UR_RESULT_ERROR_INVALID_CONTEXT
@@ -3805,6 +3814,8 @@ typedef struct ur_physical_mem_properties_t {
38053814
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
38063815
/// + `NULL == hContext`
38073816
/// + `NULL == hDevice`
3817+
/// - ::UR_RESULT_ERROR_INVALID_ENUMERATION
3818+
/// + `NULL != pProperties && ::UR_PHYSICAL_MEM_FLAGS_MASK & pProperties->flags`
38083819
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
38093820
/// + `NULL == phPhysicalMem`
38103821
/// - ::UR_RESULT_ERROR_INVALID_SIZE
@@ -4882,6 +4893,8 @@ typedef struct ur_kernel_arg_mem_obj_properties_t {
48824893
/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC
48834894
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
48844895
/// + `NULL == hKernel`
4896+
/// - ::UR_RESULT_ERROR_INVALID_ENUMERATION
4897+
/// + `NULL != pProperties && ::UR_MEM_FLAGS_MASK & pProperties->memoryAccess`
48854898
/// - ::UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX
48864899
UR_APIEXPORT ur_result_t UR_APICALL
48874900
urKernelSetArgMemObj(
@@ -5135,12 +5148,15 @@ typedef struct ur_queue_index_properties_t {
51355148
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
51365149
/// + `NULL == hContext`
51375150
/// + `NULL == hDevice`
5151+
/// - ::UR_RESULT_ERROR_INVALID_ENUMERATION
5152+
/// + `NULL != pProperties && ::UR_QUEUE_FLAGS_MASK & pProperties->flags`
51385153
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
51395154
/// + `NULL == phQueue`
51405155
/// - ::UR_RESULT_ERROR_INVALID_CONTEXT
51415156
/// - ::UR_RESULT_ERROR_INVALID_DEVICE
5142-
/// - ::UR_RESULT_ERROR_INVALID_VALUE
51435157
/// - ::UR_RESULT_ERROR_INVALID_QUEUE_PROPERTIES
5158+
/// + `pProperties != NULL && pProperties->flags & UR_QUEUE_FLAG_PRIORITY_HIGH && pProperties->flags & UR_QUEUE_FLAG_PRIORITY_LOW`
5159+
/// + `pProperties != NULL && pProperties->flags & UR_QUEUE_FLAG_SUBMISSION_BATCHED && pProperties->flags & UR_QUEUE_FLAG_SUBMISSION_IMMEDIATE`
51445160
/// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY
51455161
/// - ::UR_RESULT_ERROR_OUT_OF_RESOURCES
51465162
UR_APIEXPORT ur_result_t UR_APICALL
@@ -7069,6 +7085,8 @@ typedef struct ur_exp_layered_image_properties_t {
70697085
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
70707086
/// + `NULL == hContext`
70717087
/// + `NULL == hDevice`
7088+
/// - ::UR_RESULT_ERROR_INVALID_ENUMERATION
7089+
/// + `NULL != pUSMDesc && ::UR_USM_ADVICE_FLAGS_MASK & pUSMDesc->hints`
70727090
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
70737091
/// + `NULL == ppMem`
70747092
/// + `NULL == pResultPitch`

scripts/core/queue.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,9 @@ params:
161161
returns:
162162
- $X_RESULT_ERROR_INVALID_CONTEXT
163163
- $X_RESULT_ERROR_INVALID_DEVICE
164-
- $X_RESULT_ERROR_INVALID_VALUE
165-
- $X_RESULT_ERROR_INVALID_QUEUE_PROPERTIES
164+
- $X_RESULT_ERROR_INVALID_QUEUE_PROPERTIES:
165+
- "`pProperties != NULL && pProperties->flags & UR_QUEUE_FLAG_PRIORITY_HIGH && pProperties->flags & UR_QUEUE_FLAG_PRIORITY_LOW`"
166+
- "`pProperties != NULL && pProperties->flags & UR_QUEUE_FLAG_SUBMISSION_BATCHED && pProperties->flags & UR_QUEUE_FLAG_SUBMISSION_IMMEDIATE`"
166167
- $X_RESULT_ERROR_OUT_OF_HOST_MEMORY
167168
- $X_RESULT_ERROR_OUT_OF_RESOURCES
168169
--- #--------------------------------------------------------------------------

scripts/parse_specs.py

Lines changed: 48 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def __validate_ordinal(d):
9797
ordinal = None
9898

9999
if ordinal != d['ordinal']:
100-
raise Exception("'ordinal' invalid value: '%s'"%d['ordinal'])
100+
raise Exception("'ordinal' invalid value: '%s'"%d['ordinal'])
101101

102102
def __validate_version(d, prefix="", base_version=default_version):
103103
if 'version' in d:
@@ -333,7 +333,7 @@ def __validate_params(d, tags):
333333

334334
if item['type'].endswith("flag_t"):
335335
raise Exception(prefix+"'type' must not be '*_flag_t': %s"%item['type'])
336-
336+
337337
if type_traits.is_pointer(item['type']) and "_handle_t" in item['type'] and "[in]" in item['desc']:
338338
if not param_traits.is_range(item):
339339
raise Exception(prefix+"handle type must include a range(start, end) as part of 'desc'")
@@ -342,11 +342,11 @@ def __validate_params(d, tags):
342342
if ver < max_ver:
343343
raise Exception(prefix+"'version' must be increasing: %s"%item['version'])
344344
max_ver = ver
345-
345+
346346
def __validate_union_tag(d):
347347
if d.get('tag') is None:
348348
raise Exception(f"{d['name']} must include a 'tag' part of the union.")
349-
349+
350350
try:
351351
if 'type' not in d:
352352
raise Exception("every document must have 'type'")
@@ -466,7 +466,7 @@ def __filter_desc(d):
466466
return d
467467

468468
flt = []
469-
type = d['type']
469+
type = d['type']
470470
if 'enum' == type:
471471
for e in d['etors']:
472472
ver = float(e.get('version', default_version))
@@ -706,58 +706,54 @@ def _append(lst, key, val):
706706
if val and val not in rets[idx][key]:
707707
rets[idx][key].append(val)
708708

709+
def append_nullchecks(param, accessor: str):
710+
if type_traits.is_pointer(param['type']):
711+
_append(rets, "$X_RESULT_ERROR_INVALID_NULL_POINTER", "`NULL == %s`" % accessor)
712+
713+
elif type_traits.is_funcptr(param['type'], meta):
714+
_append(rets, "$X_RESULT_ERROR_INVALID_NULL_POINTER", "`NULL == %s`" % accessor)
715+
716+
elif type_traits.is_handle(param['type']) and not type_traits.is_ipc_handle(item['type']):
717+
_append(rets, "$X_RESULT_ERROR_INVALID_NULL_HANDLE", "`NULL == %s`" % accessor)
718+
719+
def append_enum_checks(param, accessor: str):
720+
ptypename = type_traits.base(param['type'])
721+
722+
prefix = "`"
723+
if param_traits.is_optional(item):
724+
prefix = "`NULL != %s && " % item['name']
725+
726+
if re.match(r"stype", param['name']):
727+
_append(rets, "$X_RESULT_ERROR_UNSUPPORTED_VERSION", prefix + "%s != %s`"%(re.sub(r"(\$\w)_(.*)_t.*", r"\1_STRUCTURE_TYPE_\2", typename).upper(), accessor))
728+
else:
729+
if type_traits.is_flags(param['type']) and 'bit_mask' in meta['enum'][ptypename].keys():
730+
_append(rets, "$X_RESULT_ERROR_INVALID_ENUMERATION", prefix + "%s & %s`"%(ptypename.upper()[:-2]+ "_MASK", accessor))
731+
else:
732+
_append(rets, "$X_RESULT_ERROR_INVALID_ENUMERATION", prefix + "%s < %s`"%(meta['enum'][ptypename]['max'], accessor))
733+
709734
# generate results based on parameters
710735
for item in obj['params']:
711736
if param_traits.is_nocheck(item):
712737
continue
713738

714739
if not param_traits.is_optional(item):
740+
append_nullchecks(item, item['name'])
741+
742+
if type_traits.is_enum(item['type'], meta) and not type_traits.is_pointer(item['type']):
743+
append_enum_checks(item, item['name'])
744+
745+
if type_traits.is_descriptor(item['type']) or type_traits.is_properties(item['type']):
715746
typename = type_traits.base(item['type'])
747+
# walk each entry in the desc for pointers and enums
748+
for i, m in enumerate(meta['struct'][typename]['members']):
749+
if param_traits.is_nocheck(m):
750+
continue
751+
752+
if not param_traits.is_optional(m):
753+
append_nullchecks(m, "%s->%s" % (item['name'], m['name']))
716754

717-
if type_traits.is_pointer(item['type']):
718-
_append(rets, "$X_RESULT_ERROR_INVALID_NULL_POINTER", "`NULL == %s`"%item['name'])
719-
720-
elif type_traits.is_funcptr(item['type'], meta):
721-
_append(rets, "$X_RESULT_ERROR_INVALID_NULL_POINTER", "`NULL == %s`"%item['name'])
722-
723-
elif type_traits.is_handle(item['type']) and not type_traits.is_ipc_handle(item['type']):
724-
_append(rets, "$X_RESULT_ERROR_INVALID_NULL_HANDLE", "`NULL == %s`"%item['name'])
725-
726-
elif type_traits.is_enum(item['type'], meta):
727-
if type_traits.is_flags(item['type']) and 'bit_mask' in meta['enum'][typename].keys():
728-
_append(rets, "$X_RESULT_ERROR_INVALID_ENUMERATION", "`%s & %s`"%(typename.upper()[:-2]+ "_MASK", item['name']))
729-
else:
730-
_append(rets, "$X_RESULT_ERROR_INVALID_ENUMERATION", "`%s < %s`"%(meta['enum'][typename]['max'], item['name']))
731-
732-
if type_traits.is_descriptor(item['type']):
733-
# walk each entry in the desc for pointers and enums
734-
for i, m in enumerate(meta['struct'][typename]['members']):
735-
if param_traits.is_nocheck(m):
736-
continue
737-
mtypename = type_traits.base(m['type'])
738-
739-
if type_traits.is_pointer(m['type']) and not param_traits.is_optional({'desc': m['desc']}):
740-
_append(rets,
741-
"$X_RESULT_ERROR_INVALID_NULL_POINTER",
742-
"`NULL == %s->%s`"%(item['name'], m['name']))
743-
744-
elif type_traits.is_enum(m['type'], meta):
745-
if re.match(r"stype", m['name']):
746-
_append(rets, "$X_RESULT_ERROR_UNSUPPORTED_VERSION", "`%s != %s->stype`"%(re.sub(r"(\$\w)_(.*)_t.*", r"\1_STRUCTURE_TYPE_\2", typename).upper(), item['name']))
747-
else:
748-
if type_traits.is_flags(m['type']) and 'bit_mask' in meta['enum'][mtypename].keys():
749-
_append(rets, "$X_RESULT_ERROR_INVALID_ENUMERATION", "`%s & %s->%s`"%(mtypename.upper()[:-2]+ "_MASK", item['name'], m['name']))
750-
else:
751-
_append(rets, "$X_RESULT_ERROR_INVALID_ENUMERATION", "`%s < %s->%s`"%(meta['enum'][mtypename]['max'], item['name'], m['name']))
752-
753-
elif type_traits.is_properties(item['type']):
754-
# walk each entry in the properties
755-
for i, m in enumerate(meta['struct'][typename]['members']):
756-
if param_traits.is_nocheck(m):
757-
continue
758-
if type_traits.is_enum(m['type'], meta):
759-
if re.match(r"stype", m['name']):
760-
_append(rets, "$X_RESULT_ERROR_UNSUPPORTED_VERSION", "`%s != %s->stype`"%(re.sub(r"(\$\w)_(.*)_t.*", r"\1_STRUCTURE_TYPE_\2", typename).upper(), item['name']))
755+
if type_traits.is_enum(m['type'], meta) and not type_traits.is_pointer(m['type']):
756+
append_enum_checks(m, "%s->%s" % (item['name'], m['name']))
761757

762758
# finally, append all user entries
763759
for item in obj.get('returns', []):
@@ -823,7 +819,7 @@ def _refresh_enum_meta(obj, meta):
823819
## remove the existing meta records
824820
if obj.get('class'):
825821
meta['class'][obj['class']]['enum'].remove(obj['name'])
826-
822+
827823
if meta['enum'].get(obj['name']):
828824
del meta['enum'][obj['name']]
829825
## re-generate meta
@@ -851,13 +847,13 @@ def _extend_enums(enum_extensions, specs, meta):
851847
if not _validate_ext_enum_range(extension, matching_enum):
852848
raise Exception(f"Invalid enum values.")
853849
matching_enum['etors'].extend(extension['etors'])
854-
850+
855851
_refresh_enum_meta(matching_enum, meta)
856852

857853
## Sort the etors
858854
value = -1
859855
def sort_etors(x):
860-
nonlocal value
856+
nonlocal value
861857
value = _get_etor_value(x.get('value'), value)
862858
return value
863859
matching_enum['etors'] = sorted(matching_enum['etors'], key=sort_etors)

source/loader/layers/validation/ur_valddi.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,10 @@ __urdlllocal ur_result_t UR_APICALL urDevicePartition(
566566
if (NULL == pProperties) {
567567
return UR_RESULT_ERROR_INVALID_NULL_POINTER;
568568
}
569+
570+
if (NULL == pProperties->pProperties) {
571+
return UR_RESULT_ERROR_INVALID_NULL_POINTER;
572+
}
569573
}
570574

571575
ur_result_t result = pfnPartition(hDevice, pProperties, NumDevices,
@@ -739,6 +743,10 @@ __urdlllocal ur_result_t UR_APICALL urContextCreate(
739743
if (NULL == phContext) {
740744
return UR_RESULT_ERROR_INVALID_NULL_POINTER;
741745
}
746+
747+
if (NULL != pProperties && UR_CONTEXT_FLAGS_MASK & pProperties->flags) {
748+
return UR_RESULT_ERROR_INVALID_ENUMERATION;
749+
}
742750
}
743751

744752
ur_result_t result =
@@ -1616,6 +1624,10 @@ __urdlllocal ur_result_t UR_APICALL urUSMHostAlloc(
16161624
return UR_RESULT_ERROR_INVALID_NULL_POINTER;
16171625
}
16181626

1627+
if (NULL != pUSMDesc && UR_USM_ADVICE_FLAGS_MASK & pUSMDesc->hints) {
1628+
return UR_RESULT_ERROR_INVALID_ENUMERATION;
1629+
}
1630+
16191631
if (pUSMDesc && pUSMDesc->align != 0 &&
16201632
((pUSMDesc->align & (pUSMDesc->align - 1)) != 0)) {
16211633
return UR_RESULT_ERROR_INVALID_VALUE;
@@ -1663,6 +1675,10 @@ __urdlllocal ur_result_t UR_APICALL urUSMDeviceAlloc(
16631675
return UR_RESULT_ERROR_INVALID_NULL_POINTER;
16641676
}
16651677

1678+
if (NULL != pUSMDesc && UR_USM_ADVICE_FLAGS_MASK & pUSMDesc->hints) {
1679+
return UR_RESULT_ERROR_INVALID_ENUMERATION;
1680+
}
1681+
16661682
if (pUSMDesc && pUSMDesc->align != 0 &&
16671683
((pUSMDesc->align & (pUSMDesc->align - 1)) != 0)) {
16681684
return UR_RESULT_ERROR_INVALID_VALUE;
@@ -1711,6 +1727,10 @@ __urdlllocal ur_result_t UR_APICALL urUSMSharedAlloc(
17111727
return UR_RESULT_ERROR_INVALID_NULL_POINTER;
17121728
}
17131729

1730+
if (NULL != pUSMDesc && UR_USM_ADVICE_FLAGS_MASK & pUSMDesc->hints) {
1731+
return UR_RESULT_ERROR_INVALID_ENUMERATION;
1732+
}
1733+
17141734
if (pUSMDesc && pUSMDesc->align != 0 &&
17151735
((pUSMDesc->align & (pUSMDesc->align - 1)) != 0)) {
17161736
return UR_RESULT_ERROR_INVALID_VALUE;
@@ -2236,6 +2256,11 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemCreate(
22362256
if (NULL == phPhysicalMem) {
22372257
return UR_RESULT_ERROR_INVALID_NULL_POINTER;
22382258
}
2259+
2260+
if (NULL != pProperties &&
2261+
UR_PHYSICAL_MEM_FLAGS_MASK & pProperties->flags) {
2262+
return UR_RESULT_ERROR_INVALID_ENUMERATION;
2263+
}
22392264
}
22402265

22412266
ur_result_t result =
@@ -3208,6 +3233,11 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgMemObj(
32083233
if (NULL == hKernel) {
32093234
return UR_RESULT_ERROR_INVALID_NULL_HANDLE;
32103235
}
3236+
3237+
if (NULL != pProperties &&
3238+
UR_MEM_FLAGS_MASK & pProperties->memoryAccess) {
3239+
return UR_RESULT_ERROR_INVALID_ENUMERATION;
3240+
}
32113241
}
32123242

32133243
ur_result_t result =
@@ -3398,6 +3428,22 @@ __urdlllocal ur_result_t UR_APICALL urQueueCreate(
33983428
if (NULL == phQueue) {
33993429
return UR_RESULT_ERROR_INVALID_NULL_POINTER;
34003430
}
3431+
3432+
if (NULL != pProperties && UR_QUEUE_FLAGS_MASK & pProperties->flags) {
3433+
return UR_RESULT_ERROR_INVALID_ENUMERATION;
3434+
}
3435+
3436+
if (pProperties != NULL &&
3437+
pProperties->flags & UR_QUEUE_FLAG_PRIORITY_HIGH &&
3438+
pProperties->flags & UR_QUEUE_FLAG_PRIORITY_LOW) {
3439+
return UR_RESULT_ERROR_INVALID_QUEUE_PROPERTIES;
3440+
}
3441+
3442+
if (pProperties != NULL &&
3443+
pProperties->flags & UR_QUEUE_FLAG_SUBMISSION_BATCHED &&
3444+
pProperties->flags & UR_QUEUE_FLAG_SUBMISSION_IMMEDIATE) {
3445+
return UR_RESULT_ERROR_INVALID_QUEUE_PROPERTIES;
3446+
}
34013447
}
34023448

34033449
ur_result_t result = pfnCreate(hContext, hDevice, pProperties, phQueue);
@@ -5556,6 +5602,10 @@ __urdlllocal ur_result_t UR_APICALL urUSMPitchedAllocExp(
55565602
return UR_RESULT_ERROR_INVALID_NULL_POINTER;
55575603
}
55585604

5605+
if (NULL != pUSMDesc && UR_USM_ADVICE_FLAGS_MASK & pUSMDesc->hints) {
5606+
return UR_RESULT_ERROR_INVALID_ENUMERATION;
5607+
}
5608+
55595609
if (pUSMDesc && pUSMDesc->align != 0 &&
55605610
((pUSMDesc->align & (pUSMDesc->align - 1)) != 0)) {
55615611
return UR_RESULT_ERROR_INVALID_VALUE;

0 commit comments

Comments
 (0)