Skip to content

Commit 0d992a0

Browse files
authored
Merge pull request #696 from veselypeta/petr/538/handle-unions
[UR] print unions correctly in ur_params
2 parents db40d34 + d66a850 commit 0d992a0

File tree

10 files changed

+217
-50
lines changed

10 files changed

+217
-50
lines changed

include/ur.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -848,15 +848,15 @@ class ur_device_partition_value_t(Structure):
848848
("count", c_ulong), ## [in] Number of compute units in a sub-device when partitioning with
849849
## ::UR_DEVICE_PARTITION_BY_COUNTS.
850850
("affinity_domain", ur_device_affinity_domain_flags_t) ## [in] The affinity domain to partition for when partitioning with
851-
## $UR_DEVICE_PARTITION_BY_AFFINITY_DOMAIN.
851+
## ::UR_DEVICE_PARTITION_BY_AFFINITY_DOMAIN.
852852
]
853853

854854
###############################################################################
855855
## @brief Device partition property
856856
class ur_device_partition_property_t(Structure):
857857
_fields_ = [
858858
("type", ur_device_partition_t), ## [in] The partitioning type to be used.
859-
("value", ur_device_partition_value_t) ## [in] The partitioning value.
859+
("value", ur_device_partition_value_t) ## [in][tagged_by(type)] The partitioning value.
860860
]
861861

862862
###############################################################################
@@ -1579,7 +1579,7 @@ class ur_program_metadata_t(Structure):
15791579
("type", ur_program_metadata_type_t), ## [in] the type of metadata value.
15801580
("size", c_size_t), ## [in] size in bytes of the data pointed to by value.pData, or 0 when
15811581
## value size is less than 64-bits and is stored directly in value.data.
1582-
("value", ur_program_metadata_value_t) ## [in] the metadata value storage.
1582+
("value", ur_program_metadata_value_t) ## [in][tagged_by(type)] the metadata value storage.
15831583
]
15841584

15851585
###############################################################################

include/ur_api.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1348,15 +1348,15 @@ typedef union ur_device_partition_value_t {
13481348
uint32_t count; ///< [in] Number of compute units in a sub-device when partitioning with
13491349
///< ::UR_DEVICE_PARTITION_BY_COUNTS.
13501350
ur_device_affinity_domain_flags_t affinity_domain; ///< [in] The affinity domain to partition for when partitioning with
1351-
///< $UR_DEVICE_PARTITION_BY_AFFINITY_DOMAIN.
1351+
///< ::UR_DEVICE_PARTITION_BY_AFFINITY_DOMAIN.
13521352

13531353
} ur_device_partition_value_t;
13541354

13551355
///////////////////////////////////////////////////////////////////////////////
13561356
/// @brief Device partition property
13571357
typedef struct ur_device_partition_property_t {
13581358
ur_device_partition_t type; ///< [in] The partitioning type to be used.
1359-
ur_device_partition_value_t value; ///< [in] The partitioning value.
1359+
ur_device_partition_value_t value; ///< [in][tagged_by(type)] The partitioning value.
13601360

13611361
} ur_device_partition_property_t;
13621362

@@ -3553,11 +3553,11 @@ typedef union ur_program_metadata_value_t {
35533553
///////////////////////////////////////////////////////////////////////////////
35543554
/// @brief Program metadata property.
35553555
typedef struct ur_program_metadata_t {
3556-
char *pName; ///< [in] null-terminated metadata name.
3556+
const char *pName; ///< [in] null-terminated metadata name.
35573557
ur_program_metadata_type_t type; ///< [in] the type of metadata value.
35583558
size_t size; ///< [in] size in bytes of the data pointed to by value.pData, or 0 when
35593559
///< value size is less than 64-bits and is stored directly in value.data.
3560-
ur_program_metadata_value_t value; ///< [in] the metadata value storage.
3560+
ur_program_metadata_value_t value; ///< [in][tagged_by(type)] the metadata value storage.
35613561

35623562
} ur_program_metadata_t;
35633563

scripts/YaML.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,8 @@ class ur_name_flags_v(IntEnum):
454454
- `name` must be a unique ISO-C standard identifier, start with `$` tag, be snake_case and end with `_t`
455455
+ The special-case descriptor struct should always end with `_desc_t`
456456
+ The special-case property struct should always end with `_properties_t`
457+
* A union requires the following
458+
- `tag` is a reference to an enum type that will be used to describe which field of the union to access.
457459
* A struct|union may take the following optional scalar fields: {`class`, `base`, `condition`, `ordinal`, `version`}
458460
- `class` will be used to scope the struct|union declaration within the specified C++ class
459461
- `base` will be used as the base type of the structure
@@ -468,6 +470,8 @@ class ur_name_flags_v(IntEnum):
468470
- `out` is used for members that are write-only; if the member is a pointer, then the memory being pointed to is also write-only
469471
- `in,out` is used for members that are both read and write; typically this is used for pointers to other data structures that contain both read and write members
470472
- `nocheck` is used to specify that no additional validation checks will be generated.
473+
+ `desc` must also include the following annotation when describing a union: {`"tagged_by(param)"`}
474+
- `tagged_by` is used to specify which parameter will be used as the tag for accessing the union.
471475
+ `desc` may include one the following annotations: {`"[optional]"`, `"[typename(typeVarName, sizeVarName)]"`}
472476
- `optional` is used for members that are pointers where it is legal for the value to be `nullptr`
473477
- `typename` is used to denote the type enum for params that are opaque pointers to values of tagged data types.
@@ -477,6 +481,7 @@ class ur_name_flags_v(IntEnum):
477481
+ `init` will be used to initialize the C++ struct|union member's value
478482
+ `init` must be an ISO-C standard identifier or literal
479483
+ `version` will be used to define the minimum API version in which the member will appear; `default="1.0"` This will also affect the order in which the member appears within the struct|union.
484+
+ `tag` applies only to unions and refers to a value for when this member can be accessed.
480485
* A struct|union may take the following optional field which can be a scalar, a sequence of scalars or scalars to sequences: {`details`}
481486
- `details` will be used as the struct|union's detailed comment
482487

scripts/core/device.yml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -518,16 +518,20 @@ type: union
518518
desc: "Device partition value."
519519
name: $x_device_partition_value_t
520520
class: $xDevice
521+
tag: $x_device_partition_t
521522
members:
522523
- type: uint32_t
523524
name: equally
524525
desc: "[in] Number of compute units per sub-device when partitioning with $X_DEVICE_PARTITION_EQUALLY."
526+
tag: $X_DEVICE_PARTITION_EQUALLY
525527
- type: uint32_t
526528
name: count
527529
desc: "[in] Number of compute units in a sub-device when partitioning with $X_DEVICE_PARTITION_BY_COUNTS."
530+
tag: $X_DEVICE_PARTITION_BY_COUNTS
528531
- type: $x_device_affinity_domain_flags_t
529532
name: affinity_domain
530-
desc: "[in] The affinity domain to partition for when partitioning with $UR_DEVICE_PARTITION_BY_AFFINITY_DOMAIN."
533+
desc: "[in] The affinity domain to partition for when partitioning with $X_DEVICE_PARTITION_BY_AFFINITY_DOMAIN."
534+
tag: $X_DEVICE_PARTITION_BY_AFFINITY_DOMAIN
531535
--- #--------------------------------------------------------------------------
532536
type: struct
533537
desc: "Device partition property"
@@ -539,7 +543,7 @@ members:
539543
desc: "[in] The partitioning type to be used."
540544
- type: $x_device_partition_value_t
541545
name: value
542-
desc: "[in] The partitioning value."
546+
desc: "[in][tagged_by(type)] The partitioning value."
543547
--- #--------------------------------------------------------------------------
544548
type: struct
545549
desc: "Device Partition Properties"

scripts/core/program.yml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,26 +29,31 @@ type: union
2929
desc: "Program metadata value union."
3030
class: $xProgram
3131
name: $x_program_metadata_value_t
32+
tag: $x_program_metadata_type_t
3233
members:
3334
- type: uint32_t
3435
name: data32
3536
desc: "[in] inline storage for the 32-bit data, type $X_PROGRAM_METADATA_TYPE_UINT32."
37+
tag: $X_PROGRAM_METADATA_TYPE_UINT32
3638
- type: uint64_t
3739
name: data64
3840
desc: "[in] inline storage for the 64-bit data, type $X_PROGRAM_METADATA_TYPE_UINT64."
41+
tag: $X_PROGRAM_METADATA_TYPE_UINT64
3942
- type: char*
4043
name: pString
4144
desc: "[in] pointer to null-terminated string data, type $X_PROGRAM_METADATA_TYPE_STRING."
45+
tag: $X_PROGRAM_METADATA_TYPE_STRING
4246
- type: void*
4347
name: pData
4448
desc: "[in] pointer to binary data, type $X_PROGRAM_METADATA_TYPE_BYTE_ARRAY."
49+
tag: $X_PROGRAM_METADATA_TYPE_BYTE_ARRAY
4550
--- #--------------------------------------------------------------------------
4651
type: struct
4752
desc: "Program metadata property."
4853
class: $xProgram
4954
name: $x_program_metadata_t
5055
members:
51-
- type: char*
56+
- type: const char*
5257
name: pName
5358
desc: "[in] null-terminated metadata name."
5459
- type: $x_program_metadata_type_t
@@ -59,7 +64,7 @@ members:
5964
desc: "[in] size in bytes of the data pointed to by value.pData, or 0 when value size is less than 64-bits and is stored directly in value.data."
6065
- type: $x_program_metadata_value_t
6166
name: value
62-
desc: "[in] the metadata value storage."
67+
desc: "[in][tagged_by(type)] the metadata value storage."
6368
--- #--------------------------------------------------------------------------
6469
type: struct
6570
desc: "Program creation properties."

scripts/parse_specs.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,9 @@ def __validate_members(d, tags):
292292
if item['type'].endswith("flag_t"):
293293
raise Exception(prefix+"'type' must not be '*_flag_t': %s"%item['type'])
294294

295+
if d['type'] == 'union'and item.get('tag') is None:
296+
raise Exception(prefix + f"union member {item['name']} must include a 'tag' annotation")
297+
295298
ver = __validate_version(item, prefix=prefix, base_version=d_ver)
296299
if ver < max_ver:
297300
raise Exception(prefix+"'version' must be increasing: %s"%item['version'])
@@ -339,7 +342,11 @@ def __validate_params(d, tags):
339342
if ver < max_ver:
340343
raise Exception(prefix+"'version' must be increasing: %s"%item['version'])
341344
max_ver = ver
342-
345+
346+
def __validate_union_tag(d):
347+
if d.get('tag') is None:
348+
raise Exception(f"{d['name']} must include a 'tag' part of the union.")
349+
343350
try:
344351
if 'type' not in d:
345352
raise Exception("every document must have 'type'")
@@ -401,6 +408,8 @@ def __validate_params(d, tags):
401408
if ('desc' not in d) or ('name' not in d):
402409
raise Exception("'%s' requires the following scalar fields: {`desc`, `name`}"%d['type'])
403410

411+
if d['type'] == 'union':
412+
__validate_union_tag(d)
404413
__validate_type(d, 'name', tags)
405414
__validate_base(d)
406415
__validate_members(d, tags)

scripts/templates/helper.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,7 @@ class param_traits:
319319
RE_RANGE = r".*\[range\((.+),\s*(.+)\)\][\S\s]*"
320320
RE_RELEASE = r".*\[release\].*"
321321
RE_TYPENAME = r".*\[typename\((.+),\s(.+)\)\].*"
322+
RE_TAGGED = r".*\[tagged_by\((.+)\)].*"
322323

323324
@classmethod
324325
def is_mbz(cls, item):
@@ -369,6 +370,20 @@ def is_range(cls, item):
369370
except:
370371
return False
371372

373+
@classmethod
374+
def is_tagged(cls, item):
375+
try:
376+
return True if re.match(cls.RE_TAGGED, item['desc']) else False
377+
except:
378+
return False
379+
380+
@classmethod
381+
def tagged_member(cls, item):
382+
try:
383+
return re.sub(cls.RE_TAGGED, r"\1", item['desc'])
384+
except:
385+
return None
386+
372387
@classmethod
373388
def range_start(cls, item):
374389
try:

scripts/templates/params.hpp.mako

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,16 @@ from templates import helper as th
4141
%endif
4242
</%def>
4343

44+
<%
45+
def findUnionTag(_union):
46+
tag = [_obj for _s in specs for _obj in _s['objects'] if _obj['name'] == _union['tag']]
47+
return tag[0] if len(tag) > 0 else None
48+
49+
def findMemberType(_item):
50+
query = [_o for _s in specs for _o in _s['objects'] if _o['name'] == _item['type']]
51+
return query[0] if len(query) > 0 else None
52+
%>
53+
4454
<%def name="line(item, n, params, params_dict)">
4555
<%
4656
iname = th._get_param_name(n, tags, item)
@@ -57,7 +67,7 @@ from templates import helper as th
5767
%if n != 0:
5868
os << ", ";
5969
%endif
60-
## can't iterate over 'void *'...
70+
## can't iterate over 'void *'...
6171
%if th.param_traits.is_range(item) and "void*" not in itype:
6272
os << ".${iname} = {";
6373
for (size_t i = ${th.param_traits.range_start(item)}; ${deref}(params${access}${pname}) != NULL && i < ${deref}params${access}${prefix + th.param_traits.range_end(item)}; ++i) {
@@ -69,6 +79,9 @@ from templates import helper as th
6979
</%call>
7080
}
7181
os << "}";
82+
%elif findMemberType(item) is not None and findMemberType(item)['type'] == "union":
83+
os << ".${iname} = ";
84+
${x}_params::serializeUnion(os, ${deref}(params${access}${item['name']}), params${access}${th.param_traits.tagged_member(item)});
7285
%elif typename is not None:
7386
os << ".${iname} = ";
7487
${x}_params::serializeTagged(os, ${deref}(params${access}${pname}), ${deref}(params${access}${prefix}${typename}), ${deref}(params${access}${prefix}${typename_size}));
@@ -96,6 +109,15 @@ template <typename T> inline void serializeTagged(std::ostream &os, const void *
96109
%endif
97110
%endif
98111

112+
%if re.match(r"union", obj['type']) and obj['name']:
113+
<% tag = [_obj for _s in specs for _obj in _s['objects'] if _obj['name'] == obj['tag']][0] %>
114+
inline void serializeUnion(
115+
std::ostream &os,
116+
const ${obj['type']} ${th.make_type_name(n, tags, obj)} params,
117+
const ${tag['type']} ${th.make_type_name(n, tags, tag)} tag
118+
);
119+
%endif
120+
99121

100122
%if th.type_traits.is_flags(obj['name']):
101123
template<> inline void serializeFlag<${th.make_enum_name(n, tags, obj)}>(std::ostream &os, uint32_t flag);
@@ -109,7 +131,7 @@ template <typename T> inline void serializeTagged(std::ostream &os, const void *
109131
## ENUM #######################################################################
110132
%if re.match(r"enum", obj['type']):
111133
inline std::ostream &operator<<(std::ostream &os, enum ${th.make_enum_name(n, tags, obj)} value);
112-
%elif re.match(r"struct|union", obj['type']):
134+
%elif re.match(r"struct", obj['type']):
113135
inline std::ostream &operator<<(std::ostream &os, const ${obj['type']} ${th.make_type_name(n, tags, obj)} params);
114136
%endif
115137
%endfor # obj in spec['objects']
@@ -260,7 +282,7 @@ inline void serializeFlag<${th.make_enum_name(n, tags, obj)}>(std::ostream &os,
260282
} // namespace ${x}_params
261283
%endif
262284
## STRUCT/UNION ###############################################################
263-
%elif re.match(r"struct|union", obj['type']):
285+
%elif re.match(r"struct", obj['type']):
264286
inline std::ostream &operator<<(std::ostream &os, const ${obj['type']} ${th.make_type_name(n, tags, obj)} params) {
265287
os << "(${obj['type']} ${th.make_type_name(n, tags, obj)}){";
266288
<%
@@ -277,6 +299,33 @@ inline std::ostream &operator<<(std::ostream &os, const ${obj['type']} ${th.make
277299
os << "}";
278300
return os;
279301
}
302+
%elif re.match(r"union", obj['type']) and obj['name']:
303+
<% tag = findUnionTag(obj) %>
304+
inline void ${x}_params::serializeUnion(
305+
std::ostream &os,
306+
const ${obj['type']} ${th.make_type_name(n, tags, obj)} params,
307+
const ${tag['type']} ${th.make_type_name(n, tags, tag)} tag
308+
){
309+
os << "(${obj['type']} ${th.make_type_name(n, tags, obj)}){";
310+
<%
311+
params_dict = dict()
312+
for item in obj['members']:
313+
iname = th._get_param_name(n, tags, item)
314+
itype = th._get_type_name(n, tags, obj, item)
315+
params_dict[iname] = itype
316+
%>
317+
switch(tag){
318+
%for mem in obj['members']:
319+
case ${th.subt(n, tags, mem['tag'])}:
320+
${line(mem, 0, False, params_dict)}
321+
break;
322+
%endfor
323+
default:
324+
os << "<unknown>";
325+
break;
326+
}
327+
os << "}";
328+
}
280329
%endif
281330
%endfor # obj in spec['objects']
282331
%endfor

0 commit comments

Comments
 (0)