Skip to content

Commit 10976bb

Browse files
committed
Swift: use IPA layer in generated classes
1 parent ea07255 commit 10976bb

File tree

326 files changed

+5526
-2405
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

326 files changed

+5526
-2405
lines changed

swift/codegen/generators/cppgen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def get_classes(self):
8383

8484
def generate(opts, renderer):
8585
assert opts.cpp_output
86-
processor = Processor({cls.name: cls for cls in schema.load(opts.schema).classes})
86+
processor = Processor(schema.load(opts.schema).classes)
8787
out = opts.cpp_output
8888
for dir, classes in processor.get_classes().items():
8989
include_parent = (dir != pathlib.Path())

swift/codegen/generators/dbschemegen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def cls_to_dbscheme(cls: schema.Class):
8888

8989

9090
def get_declarations(data: schema.Schema):
91-
return [d for cls in data.classes for d in cls_to_dbscheme(cls)]
91+
return [d for cls in data.classes.values() for d in cls_to_dbscheme(cls)]
9292

9393

9494
def get_includes(data: schema.Schema, include_dir: pathlib.Path, swift_dir: pathlib.Path):

swift/codegen/generators/qlgen.py

Lines changed: 60 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#!/usr/bin/env python3
2+
# TODO this should probably be split in different generators now: ql, qltest, maybe qlipa
23

34
import logging
45
import pathlib
@@ -8,6 +9,7 @@
89
import itertools
910

1011
import inflection
12+
from toposort import toposort_flatten
1113

1214
from swift.codegen.lib import schema, ql
1315

@@ -27,55 +29,55 @@ class ModifiedStubMarkedAsGeneratedError(Error):
2729
pass
2830

2931

30-
def get_ql_property(cls: schema.Class, prop: schema.Property):
31-
common_args = dict(
32+
def get_ql_property(cls: schema.Class, source: schema.Class, prop: schema.Property) -> ql.Property:
33+
args = dict(
3234
type=prop.type if not prop.is_predicate else "predicate",
3335
qltest_skip="qltest_skip" in prop.pragmas,
3436
is_child=prop.is_child,
3537
is_optional=prop.is_optional,
3638
is_predicate=prop.is_predicate,
3739
)
3840
if prop.is_single:
39-
return ql.Property(
40-
**common_args,
41+
args.update(
4142
singular=inflection.camelize(prop.name),
42-
tablename=inflection.tableize(cls.name),
43-
tableparams=[
44-
"this"] + ["result" if p is prop else "_" for p in cls.properties if p.is_single],
43+
tablename=inflection.tableize(source.name),
44+
tableparams=["this"] + ["result" if p is prop else "_" for p in source.properties if p.is_single],
4545
)
4646
elif prop.is_repeated:
47-
return ql.Property(
48-
**common_args,
47+
args.update(
4948
singular=inflection.singularize(inflection.camelize(prop.name)),
5049
plural=inflection.pluralize(inflection.camelize(prop.name)),
51-
tablename=inflection.tableize(f"{cls.name}_{prop.name}"),
50+
tablename=inflection.tableize(f"{source.name}_{prop.name}"),
5251
tableparams=["this", "index", "result"],
5352
)
5453
elif prop.is_optional:
55-
return ql.Property(
56-
**common_args,
54+
args.update(
5755
singular=inflection.camelize(prop.name),
58-
tablename=inflection.tableize(f"{cls.name}_{prop.name}"),
56+
tablename=inflection.tableize(f"{source.name}_{prop.name}"),
5957
tableparams=["this", "result"],
6058
)
6159
elif prop.is_predicate:
62-
return ql.Property(
63-
**common_args,
64-
singular=inflection.camelize(
65-
prop.name, uppercase_first_letter=False),
66-
tablename=inflection.underscore(f"{cls.name}_{prop.name}"),
60+
args.update(
61+
singular=inflection.camelize(prop.name, uppercase_first_letter=False),
62+
tablename=inflection.underscore(f"{source.name}_{prop.name}"),
6763
tableparams=["this"],
6864
)
65+
else:
66+
raise ValueError(f"unknown property kind for {prop.name} from {source.name}")
67+
return ql.Property(**args)
6968

7069

71-
def get_ql_class(cls: schema.Class):
70+
def get_ql_class(cls: schema.Class, lookup: typing.Dict[str, schema.Class]):
7271
pragmas = {k: True for k in cls.pragmas if k.startswith("ql")}
7372
return ql.Class(
7473
name=cls.name,
7574
bases=cls.bases,
7675
final=not cls.derived,
77-
properties=[get_ql_property(cls, p) for p in cls.properties],
76+
77+
properties=[get_ql_property(cls, cls, p) for p in cls.properties],
7878
dir=cls.dir,
79+
has_db_id=not cls.ipa or cls.ipa.from_class,
80+
ipa=bool(cls.ipa),
7981
**pragmas,
8082
)
8183

@@ -92,11 +94,11 @@ def _to_db_type(x: str) -> str:
9294
def get_ql_ipa_class(cls: schema.Class):
9395
if cls.derived:
9496
return ql.Ipa.NonFinalClass(name=cls.name, derived=sorted(cls.derived))
95-
if cls.ipa and cls.ipa.from_class:
97+
if cls.ipa and cls.ipa.from_class is not None:
9698
source = cls.ipa.from_class
9799
_final_db_class_lookup.setdefault(source, ql.Ipa.FinalClassDb(source)).subtract_type(cls.name)
98100
return ql.Ipa.FinalClassIpaFrom(name=cls.name, type=_to_db_type(source))
99-
if cls.ipa and cls.ipa.on_arguments:
101+
if cls.ipa and cls.ipa.on_arguments is not None:
100102
return ql.Ipa.FinalClassIpaOn(name=cls.name,
101103
params=[ql.Ipa.Param(k, _to_db_type(v)) for k, v in cls.ipa.on_arguments.items()])
102104
return _final_db_class_lookup.setdefault(cls.name, ql.Ipa.FinalClassDb(cls.name))
@@ -136,7 +138,6 @@ def _is_generated_stub(file):
136138
line_threshold = 5
137139
first_lines = list(itertools.islice(contents, line_threshold))
138140
if len(first_lines) == line_threshold or not _generated_stub_re.match("".join(first_lines)):
139-
print("".join(first_lines))
140141
raise ModifiedStubMarkedAsGeneratedError(
141142
f"{file.name} stub was modified but is still marked as generated")
142143
return True
@@ -154,23 +155,28 @@ def format(codeql, files):
154155
log.debug(line.strip())
155156

156157

157-
def _get_all_properties(cls: ql.Class, lookup: typing.Dict[str, ql.Class]) -> typing.Iterable[
158-
typing.Tuple[ql.Class, ql.Property]]:
159-
for b in cls.bases:
158+
def _get_all_properties(cls: schema.Class, lookup: typing.Dict[str, schema.Class],
159+
already_seen: typing.Optional[typing.Set[int]] = None) -> \
160+
typing.Iterable[typing.Tuple[schema.Class, schema.Property]]:
161+
# deduplicate using ids
162+
if already_seen is None:
163+
already_seen = set()
164+
for b in sorted(cls.bases):
160165
base = lookup[b]
161-
for item in _get_all_properties(base, lookup):
166+
for item in _get_all_properties(base, lookup, already_seen):
162167
yield item
163168
for p in cls.properties:
164-
yield cls, p
169+
if id(p) not in already_seen:
170+
already_seen.add(id(p))
171+
yield cls, p
165172

166173

167-
def _get_all_properties_to_be_tested(cls: ql.Class, lookup: typing.Dict[str, ql.Class]) -> typing.Iterable[
168-
ql.PropertyForTest]:
169-
# deduplicate using id
170-
already_seen = set()
174+
def _get_all_properties_to_be_tested(cls: schema.Class, lookup: typing.Dict[str, schema.Class]) -> \
175+
typing.Iterable[ql.PropertyForTest]:
171176
for c, p in _get_all_properties(cls, lookup):
172-
if not (c.qltest_skip or p.qltest_skip or id(p) in already_seen):
173-
already_seen.add(id(p))
177+
if not ("qltest_skip" in c.pragmas or "qltest_skip" in p.pragmas):
178+
# TODO here operations are duplicated, but should be better if we split ql and qltest generation
179+
p = get_ql_property(cls, c, p)
174180
yield ql.PropertyForTest(p.getter, p.type, p.is_single, p.is_predicate, p.is_repeated)
175181

176182

@@ -184,17 +190,18 @@ def _partition(l, pred):
184190
return map(list, _partition_iter(l, pred))
185191

186192

187-
def _is_in_qltest_collapsed_hierachy(cls: ql.Class, lookup: typing.Dict[str, ql.Class]):
188-
return cls.qltest_collapse_hierarchy or _is_under_qltest_collapsed_hierachy(cls, lookup)
193+
def _is_in_qltest_collapsed_hierachy(cls: schema.Class, lookup: typing.Dict[str, schema.Class]):
194+
return "qltest_collapse_hierarchy" in cls.pragmas or _is_under_qltest_collapsed_hierachy(cls, lookup)
189195

190196

191-
def _is_under_qltest_collapsed_hierachy(cls: ql.Class, lookup: typing.Dict[str, ql.Class]):
192-
return not cls.qltest_uncollapse_hierarchy and any(
197+
def _is_under_qltest_collapsed_hierachy(cls: schema.Class, lookup: typing.Dict[str, schema.Class]):
198+
return "qltest_uncollapse_hierarchy" not in cls.pragmas and any(
193199
_is_in_qltest_collapsed_hierachy(lookup[b], lookup) for b in cls.bases)
194200

195201

196-
def _should_skip_qltest(cls: ql.Class, lookup: typing.Dict[str, ql.Class]):
197-
return cls.qltest_skip or not (cls.final or cls.qltest_collapse_hierarchy) or _is_under_qltest_collapsed_hierachy(
202+
def _should_skip_qltest(cls: schema.Class, lookup: typing.Dict[str, schema.Class]):
203+
return "qltest_skip" in cls.pragmas or not (
204+
cls.final or "qltest_collapse_hierarchy" in cls.pragmas) or _is_under_qltest_collapsed_hierachy(
198205
cls, lookup)
199206

200207

@@ -211,12 +218,14 @@ def generate(opts, renderer):
211218
existing |= {q for q in test_out.rglob(missing_test_source_filename)}
212219

213220
data = schema.load(input)
214-
data.classes.sort(key=lambda cls: (cls.dir, cls.name))
221+
inheritance_graph = {name: cls.bases for name, cls in data.classes.items()}
222+
input_classes = [data.classes[name] for name in toposort_flatten(inheritance_graph)]
215223

216-
classes = [get_ql_class(cls) for cls in data.classes]
217-
lookup = {cls.name: cls for cls in classes}
224+
classes = [get_ql_class(cls, data.classes) for cls in input_classes]
218225
imports = {}
219226

227+
renderer.render(ql.DbClasses(cls for cls in classes if not cls.ipa), out / "Db.qll")
228+
220229
for c in classes:
221230
imports[c.name] = get_import(stub_out / c.path, opts.swift_dir)
222231

@@ -237,17 +246,17 @@ def generate(opts, renderer):
237246
renderer.render(ql.GetParentImplementation(
238247
classes), out / 'GetImmediateParent.qll')
239248

240-
for c in classes:
241-
if _should_skip_qltest(c, lookup):
249+
for c in input_classes:
250+
if _should_skip_qltest(c, data.classes):
242251
continue
243-
test_dir = test_out / c.path
252+
test_dir = test_out / c.dir / c.name
244253
test_dir.mkdir(parents=True, exist_ok=True)
245254
if not any(test_dir.glob("*.swift")):
246-
log.warning(f"no test source in {c.path}")
255+
log.warning(f"no test source in {c.dir / c.name}")
247256
renderer.render(ql.MissingTestInstructions(),
248257
test_dir / missing_test_source_filename)
249258
continue
250-
total_props, partial_props = _partition(_get_all_properties_to_be_tested(c, lookup),
259+
total_props, partial_props = _partition(_get_all_properties_to_be_tested(c, data.classes),
251260
lambda p: p.is_single or p.is_predicate)
252261
renderer.render(ql.ClassTester(class_name=c.name,
253262
properties=total_props), test_dir / f"{c.name}.ql")
@@ -258,16 +267,16 @@ def generate(opts, renderer):
258267
final_ipa_types = []
259268
non_final_ipa_types = []
260269
constructor_imports = []
261-
for cls in data.classes:
270+
for cls in input_classes:
262271
ipa_type = get_ql_ipa_class(cls)
263272
if ipa_type.is_final:
264273
final_ipa_types.append(ipa_type)
265-
if ipa_type.is_ipa:
274+
if ipa_type.is_ipa_from or (ipa_type.is_ipa_on and ipa_type.has_params):
266275
stub_file = stub_out / cls.dir / f"{cls.name}Constructor.qll"
267276
if not stub_file.is_file() or _is_generated_stub(stub_file):
268277
renderer.render(ql.Ipa.ConstructorStub(ipa_type), stub_file)
269278
constructor_imports.append(get_import(stub_file, opts.swift_dir))
270-
else:
279+
elif cls.name != schema.root_class_name:
271280
non_final_ipa_types.append(ipa_type)
272281

273282
renderer.render(ql.Ipa.Types(schema.root_class_name, final_ipa_types, non_final_ipa_types), out / "IpaTypes.qll")

swift/codegen/lib/ql.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import pathlib
1616
from dataclasses import dataclass, field
17-
from typing import List, ClassVar, Union
17+
from typing import List, ClassVar, Union, Optional
1818

1919
import inflection
2020

@@ -28,21 +28,18 @@ class Param:
2828
@dataclass
2929
class Property:
3030
singular: str
31-
type: str = None
32-
tablename: str = None
31+
type: Optional[str] = None
32+
tablename: Optional[str] = None
3333
tableparams: List[Param] = field(default_factory=list)
34-
plural: str = None
34+
plural: Optional[str] = None
3535
first: bool = False
36-
local_var: str = "x"
3736
is_optional: bool = False
3837
is_predicate: bool = False
3938
is_child: bool = False
4039
qltest_skip: bool = False
4140

4241
def __post_init__(self):
4342
if self.tableparams:
44-
if self.type_is_class:
45-
self.tableparams = [x if x != "result" else self.local_var for x in self.tableparams]
4643
self.tableparams = [Param(x) for x in self.tableparams]
4744
self.tableparams[0].first = True
4845

@@ -82,24 +79,27 @@ class Class:
8279
qltest_skip: bool = False
8380
qltest_collapse_hierarchy: bool = False
8481
qltest_uncollapse_hierarchy: bool = False
82+
has_db_id: bool = False
83+
ipa: bool = False
8584

8685
def __post_init__(self):
8786
self.bases = sorted(self.bases)
8887
if self.properties:
8988
self.properties[0].first = True
9089

9190
@property
92-
def db_id(self):
93-
return "@" + inflection.underscore(self.name)
94-
95-
@property
96-
def root(self):
91+
def root(self) -> bool:
9792
return not self.bases
9893

9994
@property
100-
def path(self):
95+
def path(self) -> pathlib.Path:
10196
return self.dir / self.name
10297

98+
@property
99+
def db_id(self) -> Optional[str]:
100+
if self.has_db_id:
101+
return "@" + inflection.underscore(self.name)
102+
103103

104104
@dataclass
105105
class Stub:
@@ -109,6 +109,13 @@ class Stub:
109109
base_import: str
110110

111111

112+
@dataclass
113+
class DbClasses:
114+
template: ClassVar = 'ql_db'
115+
116+
classes: List[Class] = field(default_factory=list)
117+
118+
112119
@dataclass
113120
class ImportList:
114121
template: ClassVar = 'ql_imports'
@@ -126,7 +133,7 @@ class GetParentImplementation:
126133
@dataclass
127134
class PropertyForTest:
128135
getter: str
129-
type: str = None
136+
type: Optional[str] = None
130137
is_single: bool = False
131138
is_predicate: bool = False
132139
is_repeated: bool = False
@@ -198,6 +205,10 @@ def __post_init__(self):
198205
if self.params:
199206
self.params[0].first = True
200207

208+
@property
209+
def has_params(self) -> bool:
210+
return bool(self.params)
211+
201212
@dataclass
202213
class FinalClassDb(FinalClass):
203214
is_db: ClassVar = True

0 commit comments

Comments
 (0)