Skip to content

Commit cf49e2e

Browse files
authored
[UR] Cleanup Python linter issues add type hints (#18694)
Continue cleanup of the Python generator scripts, resolving Pyright issues and adding typing hints. This should enable easier refactoring of the scripts in future commits.
1 parent ba7e8d5 commit cf49e2e

File tree

6 files changed

+117
-96
lines changed

6 files changed

+117
-96
lines changed

unified-runtime/scripts/generate_code.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

77
import os
8-
import re
98
import util
109

1110

@@ -428,7 +427,7 @@ def _mako_interface_loader_api(
428427
template = f"ur_interface_loader.{ext}.mako"
429428
fin = os.path.join(templates_dir, template)
430429

431-
name = f"ur_interface_loader"
430+
name = "ur_interface_loader"
432431

433432
filename = f"{name}.{ext}"
434433
fout = os.path.join(dstpath, filename)

unified-runtime/scripts/generate_ids.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66

77
"""Generates a unique id for each spec function that doesn't have it."""
88

9-
from fileinput import FileInput
9+
from typing import Callable, List
10+
from yaml.dumper import Dumper
11+
from yaml.representer import Node
1012
import util
1113
import yaml
1214
import re
@@ -15,46 +17,48 @@
1517
ENUM_NAME = "$x_function_t"
1618

1719

18-
class quoted(str):
20+
class Quoted(str):
1921
pass
2022

2123

22-
def quoted_presenter(dumper, data):
24+
def quoted_presenter(dumper: Dumper, data: Quoted) -> Node:
2325
return dumper.represent_scalar("tag:yaml.org,2002:str", data, style='"')
2426

2527

26-
def get_registry_header():
28+
def get_registry_header() -> dict:
2729
return {
2830
"type": "header",
29-
"desc": quoted("Intel $OneApi Unified Runtime function registry"),
30-
"ordinal": quoted(-1),
31+
"desc": Quoted("Intel $OneApi Unified Runtime function registry"),
32+
"ordinal": Quoted(-1),
3133
}
3234

3335

34-
def write_registry(data, path):
36+
def write_registry(data: List[dict], path: str) -> None:
3537
with open(path, "w") as fout:
36-
yaml.add_representer(quoted, quoted_presenter)
38+
yaml.add_representer(Quoted, quoted_presenter)
3739
yaml.dump_all(
3840
data, fout, default_flow_style=False, sort_keys=False, explicit_start=True
3941
)
4042

4143

42-
def find_type_in_specs(specs, type):
44+
def find_type_in_specs(specs: List[dict], type: str) -> dict:
4345
return [obj for s in specs for obj in s["objects"] if obj["name"] == type][0]
4446

4547

46-
def get_max_enum(enum):
48+
def get_max_enum(enum: dict) -> int:
4749
return int(max(enum["etors"], key=lambda x: int(x["value"]))["value"])
4850

4951

50-
def copy_and_strip_prefix_from_enums(enum, prefix):
52+
def copy_and_strip_prefix_from_enums(enum: dict, prefix: str) -> dict:
5153
cpy = copy.deepcopy(enum)
5254
for etor in cpy["etors"]:
5355
etor["name"] = etor["name"][len(prefix) :]
5456
return cpy
5557

5658

57-
def generate_function_type(specs, meta, update_fn) -> dict:
59+
def generate_function_type(
60+
specs: List[dict], meta: dict, update_fn: Callable[[dict, dict], None]
61+
) -> dict:
5862
existing_function_type = find_type_in_specs(specs, "$x_function_t")
5963
existing_etors = {
6064
etor["name"]: etor["value"] for etor in existing_function_type["etors"]
@@ -84,7 +88,9 @@ def generate_function_type(specs, meta, update_fn) -> dict:
8488
return copy_and_strip_prefix_from_enums(existing_function_type, "$X_FUNCTION_")
8589

8690

87-
def generate_structure_type(specs, meta, refresh_fn) -> dict:
91+
def generate_structure_type(
92+
specs: List[dict], meta: dict, refresh_fn: Callable[[dict, dict], None]
93+
) -> dict:
8894
structure_type = find_type_in_specs(specs, "$x_structure_type_t")
8995
extended_structs = [
9096
obj
@@ -124,7 +130,9 @@ def generate_structure_type(specs, meta, refresh_fn) -> dict:
124130
return copy_and_strip_prefix_from_enums(structure_type, "$X_STRUCTURE_TYPE_")
125131

126132

127-
def generate_registry(path, specs, meta, update_fn):
133+
def generate_registry(
134+
path: str, specs: List[dict], meta: dict, update_fn: Callable[[dict, dict], None]
135+
) -> None:
128136
try:
129137
write_registry(
130138
[

unified-runtime/scripts/parse_specs.py

Lines changed: 49 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from templates.helper import param_traits, type_traits, value_traits
1515
import ctypes
1616
import itertools
17-
from typing import Dict, List, Optional
17+
from typing import Dict, List, Optional, Union
1818
from version import Version
1919

2020

@@ -86,7 +86,9 @@ def _validate_doc(f, d, tags, line_num, meta):
8686
validate documents meet some basic (easily detectable) requirements of code
8787
generation
8888
"""
89-
is_iso = lambda x: re.match(r"[_a-zA-Z][_a-zA-Z0-9]{0,30}", x)
89+
90+
def is_iso(x):
91+
return re.match(r"[_a-zA-Z][_a-zA-Z0-9]{0,30}", x)
9092

9193
def __validate_ordinal(d):
9294
if "ordinal" in d:
@@ -95,7 +97,7 @@ def __validate_ordinal(d):
9597

9698
try:
9799
ordinal = str(int(d["ordinal"]))
98-
except:
100+
except BaseException:
99101
ordinal = None
100102

101103
if ordinal != d["ordinal"]:
@@ -110,7 +112,7 @@ def __validate_version(d, prefix="", base_version=default_version):
110112

111113
try:
112114
version = str(d["version"])
113-
except:
115+
except BaseException:
114116
version = None
115117

116118
if version != d["version"]:
@@ -124,7 +126,7 @@ def __validate_tag(d, key, tags, case):
124126
return x
125127
return None
126128

127-
def __validate_desc(desc):
129+
def __validate_desc(desc: Union[dict, str], prefix: str):
128130
if isinstance(desc, dict):
129131
for k, v in desc.items():
130132
if not isinstance(k, str):
@@ -134,7 +136,7 @@ def __validate_desc(desc):
134136

135137
try:
136138
version = str(k)
137-
except:
139+
except BaseException:
138140
version = None
139141

140142
if version != k:
@@ -244,34 +246,36 @@ def __validate_etors(d, tags):
244246
for i, item in enumerate(d["etors"]):
245247
prefix = "'etors'[%s] " % i
246248
if not isinstance(item, dict):
247-
raise Exception(prefix + "must be a mapping: '%s'" % (i, type(item)))
249+
raise Exception(
250+
prefix + "must be a mapping: %d, '%s'" % (i, type(item))
251+
)
248252

249253
if ("desc" not in item) or ("name" not in item):
250254
raise Exception(
251255
prefix + "requires the following scalar fields: {`desc`, `name`}"
252256
)
253257

254-
if "extend" in d and d.get("extend") == True and "value" not in item:
258+
if "extend" in d and d.get("extend") and "value" not in item:
255259
raise Exception(
256260
prefix
257261
+ "must include a value for experimental features: {`value`: `0xabcd`}"
258262
)
259263

260264
if typed:
261-
type = extract_type(item["desc"])
262-
if type is None:
265+
ty = extract_type(item["desc"])
266+
if ty is None:
263267
raise Exception(
264268
prefix
265269
+ "typed etor "
266270
+ item["name"]
267271
+ " must begin with a type identifier: [type]"
268272
)
269-
type_name = _subt(type, tags)
273+
type_name = _subt(ty, tags)
270274
if not is_iso(type_name):
271275
raise Exception(
272276
prefix
273277
+ "type "
274-
+ str(type)
278+
+ str(ty)
275279
+ " in a typed etor "
276280
+ item["name"]
277281
+ " must be a valid ISO C identifier"
@@ -331,7 +335,9 @@ def has_handle(members, meta):
331335
if type_traits.is_handle(m):
332336
return True
333337
if type_traits.is_struct(m, meta):
334-
return has_handle(type_traits.get_struct_members(m["type"]), meta)
338+
return has_handle(
339+
type_traits.get_struct_members(m["type"], meta), meta
340+
)
335341
return False
336342

337343
for m in members:
@@ -350,7 +356,7 @@ def has_handle(members, meta):
350356
# exception messages.
351357
__validate_struct_range_members(name, member_members, meta)
352358

353-
def __validate_members(d, tags, meta):
359+
def __validate_members(d, meta):
354360
if "members" not in d:
355361
raise Exception(
356362
"'%s' requires the following sequence of mappings: {`members`}"
@@ -365,15 +371,17 @@ def __validate_members(d, tags, meta):
365371
for i, item in enumerate(d["members"]):
366372
prefix = "'members'[%s] " % i
367373
if not isinstance(item, dict):
368-
raise Exception(prefix + "must be a mapping: '%s'" % (i, type(item)))
374+
raise Exception(
375+
prefix + "must be a mapping: %d, '%s'" % (i, type(item))
376+
)
369377

370378
if ("desc" not in item) or ("type" not in item) or ("name" not in item):
371379
raise Exception(
372380
prefix
373381
+ "requires the following scalar fields: {`desc`, 'type', `name`}"
374382
)
375383

376-
annotation = __validate_desc(item["desc"])
384+
annotation = __validate_desc(item["desc"], prefix)
377385
if not annotation:
378386
raise Exception(
379387
prefix + "'desc' must start with {'[in]', '[out]', '[in,out]'}"
@@ -409,7 +417,7 @@ def __validate_members(d, tags, meta):
409417
)
410418
max_ver = ver
411419

412-
def __validate_params(d, tags, meta):
420+
def __validate_params(d, meta):
413421
if "params" not in d:
414422
raise Exception(
415423
"'function' requires the following sequence of mappings: {`params`}"
@@ -420,19 +428,25 @@ def __validate_params(d, tags, meta):
420428

421429
d_ver = Version(d.get("version", default_version))
422430
max_ver = d_ver
423-
min = {"[in]": None, "[out]": None, "[in,out]": None}
431+
min: Dict[str, Union[int, None]] = {
432+
"[in]": None,
433+
"[out]": None,
434+
"[in,out]": None,
435+
}
424436
for i, item in enumerate(d["params"]):
425437
prefix = "'params'[%s] " % i
426438
if not isinstance(item, dict):
427-
raise Exception(prefix + "must be a mapping: '%s'" % (i, type(item)))
439+
raise Exception(
440+
prefix + "must be a mapping: %d, '%s'" % (i, type(item))
441+
)
428442

429443
if ("desc" not in item) or ("type" not in item) or ("name" not in item):
430444
raise Exception(
431445
prefix
432446
+ "requires the following scalar fields: {`desc`, 'type', `name`}"
433447
)
434448

435-
annotation = __validate_desc(item["desc"])
449+
annotation = __validate_desc(item["desc"], prefix)
436450
if not annotation:
437451
raise Exception(
438452
prefix + "'desc' must start with {'[in]', '[out]', '[in,out]'}"
@@ -576,7 +590,7 @@ def __validate_union_tag(d):
576590
__validate_union_tag(d)
577591
__validate_type(d, "name", tags)
578592
__validate_base(d)
579-
__validate_members(d, tags, meta)
593+
__validate_members(d, meta)
580594
__validate_details(d)
581595
__validate_ordinal(d)
582596
__validate_version(d)
@@ -594,7 +608,7 @@ def __validate_union_tag(d):
594608
else:
595609
__validate_name(d, "name", tags, case="camel")
596610

597-
__validate_params(d, tags, meta)
611+
__validate_params(d, meta)
598612
__validate_details(d)
599613
__validate_ordinal(d)
600614
__validate_version(d)
@@ -734,6 +748,7 @@ def _generate_meta(d, ordinal, meta):
734748
if "enum" == type:
735749
value = -1
736750
max_value = -1
751+
max_index = -1
737752
bit_mask = 0
738753
meta[type][name]["etors"] = []
739754
for idx, etor in enumerate(d["etors"]):
@@ -902,7 +917,8 @@ def append_nullchecks(param, accessor: str):
902917
)
903918

904919
def append_enum_checks(param, accessor: str):
905-
ptypename = type_traits.base(param["type"])
920+
typename = type_traits.base(param["type"])
921+
assert typename
906922

907923
prefix = "`"
908924
if param_traits.is_optional(item):
@@ -924,20 +940,19 @@ def append_enum_checks(param, accessor: str):
924940
else:
925941
if (
926942
type_traits.is_flags(param["type"])
927-
and "bit_mask" in meta["enum"][ptypename].keys()
943+
and "bit_mask" in meta["enum"][typename].keys()
928944
):
929945
_append(
930946
rets,
931947
"$X_RESULT_ERROR_INVALID_ENUMERATION",
932948
prefix
933-
+ "%s & %s`" % (ptypename.upper()[:-2] + "_MASK", accessor),
949+
+ "%s & %s`" % (typename.upper()[:-2] + "_MASK", accessor),
934950
)
935951
else:
936952
_append(
937953
rets,
938954
"$X_RESULT_ERROR_INVALID_ENUMERATION",
939-
prefix
940-
+ "%s < %s`" % (meta["enum"][ptypename]["max"], accessor),
955+
prefix + "%s < %s`" % (meta["enum"][typename]["max"], accessor),
941956
)
942957

943958
# generate results based on parameters
@@ -958,7 +973,7 @@ def append_enum_checks(param, accessor: str):
958973
):
959974
typename = type_traits.base(item["type"])
960975
# walk each entry in the desc for pointers and enums
961-
for i, m in enumerate(meta["struct"][typename]["members"]):
976+
for m in meta["struct"][typename]["members"]:
962977
if param_traits.is_nocheck(m):
963978
continue
964979

@@ -1050,8 +1065,9 @@ def _validate_ext_enum_range(extension, enum) -> bool:
10501065
if value in existing_values:
10511066
return False
10521067
return True
1053-
except:
1054-
return False
1068+
except BaseException:
1069+
pass
1070+
return False
10551071

10561072

10571073
def _extend_enums(enum_extensions, specs, meta):
@@ -1068,9 +1084,9 @@ def _extend_enums(enum_extensions, specs, meta):
10681084
for obj in s["objects"]
10691085
if obj["type"] == "enum" and k == obj["name"] and not obj.get("extend")
10701086
][0]
1071-
for i, extension in enumerate(group):
1087+
for extension in group:
10721088
if not _validate_ext_enum_range(extension, matching_enum):
1073-
raise Exception(f"Invalid enum values.")
1089+
raise Exception("Invalid enum values.")
10741090
matching_enum["etors"].extend(extension["etors"])
10751091

10761092
_refresh_enum_meta(matching_enum, meta)
@@ -1118,7 +1134,7 @@ def parse(section, version, tags, meta, ref):
11181134
if not d:
11191135
continue
11201136

1121-
if d["type"] == "enum" and d.get("extend") == True:
1137+
if d["type"] == "enum" and d.get("extend"):
11221138
# enum extensions are resolved later
11231139
enum_extensions.append(d)
11241140
continue

0 commit comments

Comments
 (0)