Skip to content

Commit a4f241a

Browse files
authored
Merge pull request #19 from klauer/enh_deps_and_serialization
ENH: dependencies, first pass at serialization, and more
2 parents 21e35de + f486129 commit a4f241a

File tree

11 files changed

+1735
-1396
lines changed

11 files changed

+1735
-1396
lines changed

blark/apischema_compat.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
"""
2+
Serialization helpers for apischema, an optional dependency.
3+
"""
4+
# Largely based on issue discussions regarding tagged unions.
5+
from __future__ import annotations
6+
7+
from collections import defaultdict
8+
from collections.abc import Callable, Iterator
9+
from types import new_class
10+
from typing import Any, Dict, Generic, List, Tuple, TypeVar, get_type_hints
11+
12+
import lark
13+
from apischema import deserializer, serializer, type_name
14+
from apischema.conversions import Conversion
15+
from apischema.metadata import conversion
16+
from apischema.objects import object_deserialization
17+
from apischema.tagged_unions import Tagged, TaggedUnion, get_tagged
18+
from apischema.typing import get_origin
19+
from apischema.utils import to_pascal_case
20+
21+
_alternative_constructors: Dict[type, List[Callable]] = defaultdict(list)
22+
Func = TypeVar("Func", bound=Callable)
23+
24+
25+
def alternative_constructor(func: Func) -> Func:
26+
"""Alternative constructor for a given type."""
27+
return_type = get_type_hints(func)["return"]
28+
_alternative_constructors[get_origin(return_type) or return_type].append(func)
29+
return func
30+
31+
32+
def get_all_subclasses(cls: type) -> Iterator[type]:
33+
"""Recursive implementation of type.__subclasses__"""
34+
for sub_cls in cls.__subclasses__():
35+
yield sub_cls
36+
yield from get_all_subclasses(sub_cls)
37+
38+
39+
Cls = TypeVar("Cls", bound=type)
40+
41+
42+
def _get_generic_name_factory(cls: type, *args: type):
43+
def _capitalized(name: str) -> str:
44+
return name[0].upper() + name[1:]
45+
46+
return "".join((cls.__name__, *(_capitalized(arg.__name__) for arg in args)))
47+
48+
49+
generic_name = type_name(_get_generic_name_factory)
50+
51+
52+
def as_tagged_union(cls: Cls) -> Cls:
53+
"""
54+
Tagged union decorator, to be used on base class.
55+
56+
Supports generics as well, with names generated by way of
57+
`_get_generic_name_factory`.
58+
"""
59+
params = tuple(getattr(cls, "__parameters__", ()))
60+
tagged_union_bases: Tuple[type, ...] = (TaggedUnion,)
61+
62+
# Generic handling is here:
63+
if params:
64+
tagged_union_bases = (TaggedUnion, Generic[params])
65+
generic_name(cls)
66+
prev_init_subclass = getattr(cls, "__init_subclass__", None)
67+
68+
def __init_subclass__(cls, **kwargs):
69+
if prev_init_subclass is not None:
70+
prev_init_subclass(**kwargs)
71+
generic_name(cls)
72+
73+
cls.__init_subclass__ = classmethod(__init_subclass__)
74+
75+
def with_params(cls: type) -> Any:
76+
"""Specify type of Generic if set."""
77+
return cls[params] if params else cls
78+
79+
def serialization() -> Conversion:
80+
"""
81+
Define the serializer Conversion for the tagged union.
82+
83+
source is the base ``cls`` (or ``cls[T]``).
84+
target is the new tagged union class ``TaggedUnion`` which gets the
85+
dictionary {cls.__name__: obj} as its arguments.
86+
"""
87+
annotations = {
88+
# Assume that subclasses have same generic parameters than cls
89+
sub.__name__: Tagged[with_params(sub)]
90+
for sub in get_all_subclasses(cls)
91+
}
92+
namespace = {"__annotations__": annotations}
93+
tagged_union = new_class(
94+
cls.__name__, tagged_union_bases, exec_body=lambda ns: ns.update(namespace)
95+
)
96+
return Conversion(
97+
lambda obj: tagged_union(**{obj.__class__.__name__: obj}),
98+
source=with_params(cls),
99+
target=with_params(tagged_union),
100+
# Conversion must not be inherited because it would lead to
101+
# infinite recursion otherwise
102+
inherited=False,
103+
)
104+
105+
def deserialization() -> Conversion:
106+
"""
107+
Define the deserializer Conversion for the tagged union.
108+
109+
Allows for alternative standalone constructors as per the apischema
110+
example.
111+
"""
112+
annotations: dict[str, Any] = {}
113+
namespace: dict[str, Any] = {"__annotations__": annotations}
114+
for sub in get_all_subclasses(cls):
115+
annotations[sub.__name__] = Tagged[with_params(sub)]
116+
for constructor in _alternative_constructors.get(sub, ()):
117+
# Build the alias of the field
118+
alias = to_pascal_case(constructor.__name__)
119+
# object_deserialization uses get_type_hints, but the constructor
120+
# return type is stringified and the class not defined yet,
121+
# so it must be assigned manually
122+
constructor.__annotations__["return"] = with_params(sub)
123+
# Use object_deserialization to wrap constructor as deserializer
124+
deserialization = object_deserialization(constructor, generic_name)
125+
# Add constructor tagged field with its conversion
126+
annotations[alias] = Tagged[with_params(sub)]
127+
namespace[alias] = Tagged(conversion(deserialization=deserialization))
128+
# Create the deserialization tagged union class
129+
tagged_union = new_class(
130+
cls.__name__, tagged_union_bases, exec_body=lambda ns: ns.update(namespace)
131+
)
132+
return Conversion(
133+
lambda obj: get_tagged(obj)[1],
134+
source=with_params(tagged_union),
135+
target=with_params(cls),
136+
)
137+
138+
deserializer(lazy=deserialization, target=cls)
139+
serializer(lazy=serialization, source=cls)
140+
return cls
141+
142+
143+
@serializer
144+
def token_serializer(token: lark.Token) -> List[str]:
145+
return [token.type, token.value]
146+
147+
148+
@deserializer
149+
def token_deserializer(parts: List[str]) -> lark.Token:
150+
return lark.Token(*parts)

blark/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
import os
2+
3+
BLARK_TWINCAT_ROOT = os.environ.get("BLARK_TWINCAT_ROOT", ".")

0 commit comments

Comments
 (0)