Skip to content

Commit 9a9d2e8

Browse files
authored
Merge pull request #870 from veselypeta/petr/860/support-raw-c-arrays
[UR] Support C style arrays
2 parents e40a321 + e9aa0f4 commit 9a9d2e8

File tree

7 files changed

+86
-36
lines changed

7 files changed

+86
-36
lines changed

include/ur.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2228,9 +2228,7 @@ class ur_exp_sampler_addr_modes_t(Structure):
22282228
("stype", ur_structure_type_t), ## [in] type of this structure, must be
22292229
## ::UR_STRUCTURE_TYPE_EXP_SAMPLER_ADDR_MODES
22302230
("pNext", c_void_p), ## [in,out][optional] pointer to extension-specific structure
2231-
("addrModeX", ur_sampler_addressing_mode_t), ## [in] Specify the addressing mode of the x-dimension.
2232-
("addrModeY", ur_sampler_addressing_mode_t), ## [in] Specify the addressing mode of the y-dimension.
2233-
("addrModeZ", ur_sampler_addressing_mode_t) ## [in] Specify the addressing mode of the z-dimension.
2231+
("addrModes", ur_sampler_addressing_mode_t * 3) ## [in] Specify the address mode of the sampler per dimension
22342232
]
22352233

22362234
###############################################################################

include/ur_api.h

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7047,12 +7047,10 @@ typedef struct ur_exp_sampler_mip_properties_t {
70477047
/// - Specify these properties in ::urSamplerCreate via ::ur_sampler_desc_t
70487048
/// as part of a `pNext` chain.
70497049
typedef struct ur_exp_sampler_addr_modes_t {
7050-
ur_structure_type_t stype; ///< [in] type of this structure, must be
7051-
///< ::UR_STRUCTURE_TYPE_EXP_SAMPLER_ADDR_MODES
7052-
void *pNext; ///< [in,out][optional] pointer to extension-specific structure
7053-
ur_sampler_addressing_mode_t addrModeX; ///< [in] Specify the addressing mode of the x-dimension.
7054-
ur_sampler_addressing_mode_t addrModeY; ///< [in] Specify the addressing mode of the y-dimension.
7055-
ur_sampler_addressing_mode_t addrModeZ; ///< [in] Specify the addressing mode of the z-dimension.
7050+
ur_structure_type_t stype; ///< [in] type of this structure, must be
7051+
///< ::UR_STRUCTURE_TYPE_EXP_SAMPLER_ADDR_MODES
7052+
void *pNext; ///< [in,out][optional] pointer to extension-specific structure
7053+
ur_sampler_addressing_mode_t addrModes[3]; ///< [in] Specify the address mode of the sampler per dimension
70567054

70577055
} ur_exp_sampler_addr_modes_t;
70587056

scripts/core/exp-bindless-images.yml

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -187,15 +187,9 @@ class: $xBindlessImages
187187
name: $x_exp_sampler_addr_modes_t
188188
base: $x_base_properties_t
189189
members:
190-
- type: $x_sampler_addressing_mode_t
191-
name: addrModeX
192-
desc: "[in] Specify the addressing mode of the x-dimension."
193-
- type: $x_sampler_addressing_mode_t
194-
name: addrModeY
195-
desc: "[in] Specify the addressing mode of the y-dimension."
196-
- type: $x_sampler_addressing_mode_t
197-
name: addrModeZ
198-
desc: "[in] Specify the addressing mode of the z-dimension."
190+
- type: $x_sampler_addressing_mode_t[3]
191+
name: addrModes
192+
desc: "[in] Specify the address mode of the sampler per dimension"
199193
--- #--------------------------------------------------------------------------
200194
type: struct
201195
desc: "Describes an interop memory resource descriptor"

scripts/templates/helper.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ class type_traits:
105105
RE_DESC = r"(.*)desc_t.*"
106106
RE_PROPS = r"(.*)properties_t.*"
107107
RE_FLAGS = r"(.*)flags_t"
108+
RE_ARRAY = r"(.*)\[([1-9][0-9]*)\]"
108109

109110
@staticmethod
110111
def base(name):
@@ -217,6 +218,29 @@ def find_class_name(name, meta):
217218
except:
218219
return None
219220

221+
@classmethod
222+
def is_array(cls, name):
223+
try:
224+
return True if re.match(cls.RE_ARRAY, name) else False
225+
except:
226+
return False
227+
228+
@classmethod
229+
def get_array_length(cls, name):
230+
if not cls.is_array(name):
231+
raise Exception("Cannot find array length of non-array type.")
232+
233+
match = re.match(cls.RE_ARRAY, name)
234+
return match.groups()[1]
235+
236+
@classmethod
237+
def get_array_element_type(cls, name):
238+
if not cls.is_array(name):
239+
raise Exception("Cannot find array type of non-array type.")
240+
241+
match = re.match(cls.RE_ARRAY, name)
242+
return match.groups()[0]
243+
220244
"""
221245
Extracts traits from a value name
222246
"""
@@ -729,7 +753,10 @@ def make_etor_lines(namespace, tags, obj, py=False, meta=None):
729753
returns c/c++ name of any type
730754
"""
731755
def _get_type_name(namespace, tags, obj, item):
732-
name = subt(namespace, tags, item['type'],)
756+
type = item['type']
757+
if type_traits.is_array(type):
758+
type = type_traits.get_array_element_type(type)
759+
name = subt(namespace, tags, type,)
733760
return name
734761

735762
"""
@@ -763,9 +790,9 @@ def get_ctype_name(namespace, tags, item):
763790
while type_traits.is_pointer(name):
764791
name = "POINTER(%s)"%_remove_ptr(name)
765792

766-
if 'name' in item and value_traits.is_array(item['name']):
767-
length = subt(namespace, tags, value_traits.get_array_length(item['name']))
768-
name = "%s * %s"%(name, length)
793+
if 'name' in item and type_traits.is_array(item['type']):
794+
length = subt(namespace, tags, type_traits.get_array_length(item['type']))
795+
name = "%s * %s"%(type_traits.get_array_element_type(name), length)
769796

770797
return name
771798

@@ -804,7 +831,8 @@ def make_member_lines(namespace, tags, obj, prefix="", py=False, meta=None):
804831
delim = "," if i < (len(obj['members'])-1) else ""
805832
prologue = "(\"%s\", %s)%s"%(name, tname, delim)
806833
else:
807-
prologue = "%s %s;"%(tname, name)
834+
array_suffix = f"[{type_traits.get_array_length(item['type'])}]" if type_traits.is_array(item['type']) else ""
835+
prologue = "%s %s %s;"%(tname, name, array_suffix)
808836

809837
comment_style = "##" if py else "///<"
810838
ws_count = 64 if py else 48

scripts/templates/params.hpp.mako

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,17 @@ def findMemberType(_item):
8484
%elif findMemberType(item) is not None and findMemberType(item)['type'] == "union":
8585
os << ".${iname} = ";
8686
${x}_params::serializeUnion(os, ${deref}(params${access}${item['name']}), params${access}${th.param_traits.tagged_member(item)});
87+
%elif th.type_traits.is_array(item['type']):
88+
os << ".${iname} = {";
89+
for(auto i = 0; i < ${th.type_traits.get_array_length(item['type'])}; i++){
90+
if(i != 0){
91+
os << ", ";
92+
}
93+
<%call expr="member(iname, itype, True)">
94+
${deref}(params${access}${item['name']}[i])
95+
</%call>
96+
}
97+
os << "}";
8798
%elif typename is not None:
8899
os << ".${iname} = ";
89100
${x}_params::serializeTagged(os, ${deref}(params${access}${pname}), ${deref}(params${access}${prefix}${typename}), ${deref}(params${access}${prefix}${typename_size}));

source/common/ur_params.hpp

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9937,19 +9937,15 @@ operator<<(std::ostream &os, const struct ur_exp_sampler_addr_modes_t params) {
99379937
ur_params::serializeStruct(os, (params.pNext));
99389938

99399939
os << ", ";
9940-
os << ".addrModeX = ";
9941-
9942-
os << (params.addrModeX);
9943-
9944-
os << ", ";
9945-
os << ".addrModeY = ";
9946-
9947-
os << (params.addrModeY);
9948-
9949-
os << ", ";
9950-
os << ".addrModeZ = ";
9940+
os << ".addrModes = {";
9941+
for (auto i = 0; i < 3; i++) {
9942+
if (i != 0) {
9943+
os << ", ";
9944+
}
99519945

9952-
os << (params.addrModeZ);
9946+
os << (params.addrModes[i]);
9947+
}
9948+
os << "}";
99539949

99549950
os << "}";
99559951
return os;

test/unit/utils/params.cpp

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,31 @@ struct UrDevicePartitionPropertyTest {
367367
ur_device_partition_property_t prop;
368368
};
369369

370+
struct UrSamplerAddressModesTest {
371+
UrSamplerAddressModesTest() {
372+
prop.addrModes[0] = UR_SAMPLER_ADDRESSING_MODE_CLAMP;
373+
prop.addrModes[1] = UR_SAMPLER_ADDRESSING_MODE_MIRRORED_REPEAT;
374+
prop.addrModes[2] = UR_SAMPLER_ADDRESSING_MODE_REPEAT;
375+
prop.pNext = nullptr;
376+
prop.stype = UR_STRUCTURE_TYPE_EXP_SAMPLER_ADDR_MODES;
377+
}
378+
ur_exp_sampler_addr_modes_t &get_struct() { return prop; }
379+
const char *get_expected() {
380+
return "\\(struct ur_exp_sampler_addr_modes_t\\)"
381+
"\\{"
382+
".stype = UR_STRUCTURE_TYPE_EXP_SAMPLER_ADDR_MODES, "
383+
".pNext = nullptr, "
384+
".addrModes = \\{"
385+
"UR_SAMPLER_ADDRESSING_MODE_CLAMP, "
386+
"UR_SAMPLER_ADDRESSING_MODE_MIRRORED_REPEAT, "
387+
"UR_SAMPLER_ADDRESSING_MODE_REPEAT"
388+
"\\}"
389+
"\\}";
390+
}
391+
392+
ur_exp_sampler_addr_modes_t prop;
393+
};
394+
370395
using testing::Types;
371396
typedef Types<UrLoaderInitParamsNoFlags, UrLoaderInitParamsInvalidFlags,
372397
UrUsmHostAllocParamsEmpty, UrPlatformGetEmptyArray,
@@ -376,7 +401,7 @@ typedef Types<UrLoaderInitParamsNoFlags, UrLoaderInitParamsInvalidFlags,
376401
UrDeviceGetInfoParamsPartitionArray,
377402
UrContextGetInfoParamsDevicesArray,
378403
UrDeviceGetInfoParamsInvalidSize, UrProgramMetadataTest,
379-
UrDevicePartitionPropertyTest>
404+
UrDevicePartitionPropertyTest, UrSamplerAddressModesTest>
380405
Implementations;
381406

382407
using ::testing::MatchesRegex;

0 commit comments

Comments
 (0)