|
4 | 4 | from __future__ import annotations
|
5 | 5 |
|
6 | 6 |
|
7 |
| -__all__ = ['lenient_issubclass', 'sorted_topologically', 'TypeDispatch', 'DispatchReg', 'typedispatch', 'cast', |
8 |
| - 'retain_meta', 'default_set_meta', 'retain_type', 'retain_types', 'explode_types'] |
| 7 | +__all__ = ['FastFunction', 'FastDispatcher', 'typedispatch', 'cast', 'retain_meta', 'default_set_meta', 'retain_type', |
| 8 | + 'retain_types', 'explode_types'] |
9 | 9 |
|
10 | 10 | # Cell
|
11 | 11 | #nbdev_comment from __future__ import annotations
|
12 | 12 | from .imports import *
|
13 | 13 | from .foundation import *
|
14 | 14 | from .utils import *
|
| 15 | +from .meta import delegates |
15 | 16 |
|
16 | 17 | from collections import defaultdict
|
| 18 | +from plum import Function, Dispatcher |
17 | 19 |
|
18 | 20 | # Cell
|
19 |
| -def lenient_issubclass(cls, types): |
20 |
| - "If possible return whether `cls` is a subclass of `types`, otherwise return False." |
21 |
| - if cls is object and types is not object: return False # treat `object` as highest level |
22 |
| - try: return isinstance(cls, types) or issubclass(cls, types) |
23 |
| - except: return False |
| 21 | +def _eval_annotations(f): |
| 22 | + "Evaluate future annotations before passing to plum to support backported union operator `|`" |
| 23 | + f = copy_func(f) |
| 24 | + for k, v in type_hints(f).items(): f.__annotations__[k] = Union[v] if isinstance(v, tuple) else v |
| 25 | + return f |
24 | 26 |
|
25 | 27 | # Cell
|
26 |
| -def sorted_topologically(iterable, *, cmp=operator.lt, reverse=False): |
27 |
| - "Return a new list containing all items from the iterable sorted topologically" |
28 |
| - l,res = L(list(iterable)),[] |
29 |
| - for _ in range(len(l)): |
30 |
| - t = l.reduce(lambda x,y: y if cmp(y,x) else x) |
31 |
| - res.append(t), l.remove(t) |
32 |
| - return res[::-1] if reverse else res |
| 28 | +def _pt_repr(o): |
| 29 | + "Concise repr of plum types" |
| 30 | + n = type(o).__name__ |
| 31 | + if n == 'Tuple': return f"{n.lower()}[{','.join(_pt_repr(t) for t in o._el_types)}]" |
| 32 | + if n == 'List': return f'{n.lower()}[{_pt_repr(o._el_type)}]' |
| 33 | + if n == 'Dict': return f'{n.lower()}[{_pt_repr(o._key_type)},{_pt_repr(o._value_type)}]' |
| 34 | + if n in ('Sequence','Iterable'): return f'{n}[{_pt_repr(o._el_type)}]' |
| 35 | + if n == 'VarArgs': return f'{n}[{_pt_repr(o.type)}]' |
| 36 | + if n == 'Union': return '|'.join(sorted(t.__name__ for t in (o.get_types()))) |
| 37 | + assert len(o.get_types()) == 1 |
| 38 | + return o.get_types()[0].__name__ |
33 | 39 |
|
34 | 40 | # Cell
|
35 |
| -def _chk_defaults(f, ann): |
36 |
| - pass |
37 |
| -# Implementation removed until we can figure out how to do this without `inspect` module |
38 |
| -# try: # Some callables don't have signatures, so ignore those errors |
39 |
| -# params = list(inspect.signature(f).parameters.values())[:min(len(ann),2)] |
40 |
| -# if any(p.default!=inspect.Parameter.empty for p in params): |
41 |
| -# warn(f"{f.__name__} has default params. These will be ignored.") |
42 |
| -# except ValueError: pass |
43 |
| - |
44 |
| -# Cell |
45 |
| -def _p2_anno(f): |
46 |
| - "Get the 1st 2 annotations of `f`, defaulting to `object`" |
47 |
| - hints = type_hints(f) |
48 |
| - ann = [o for n,o in hints.items() if n!='return'] |
49 |
| - if callable(f): _chk_defaults(f, ann) |
50 |
| - while len(ann)<2: ann.append(object) |
51 |
| - return ann[:2] |
| 41 | +class FastFunction(Function): |
| 42 | + def __repr__(self): |
| 43 | + return '\n'.join(f"{f.__name__}: ({','.join(_pt_repr(t) for t in s.types)}) -> {_pt_repr(r)}" |
| 44 | + for s, (f, r) in self.methods.items()) |
52 | 45 |
|
53 |
| -# Cell |
54 |
| -class _TypeDict: |
55 |
| - def __init__(self): self.d,self.cache = {},{} |
56 |
| - |
57 |
| - def _reset(self): |
58 |
| - self.d = {k:self.d[k] for k in sorted_topologically(self.d, cmp=lenient_issubclass)} |
59 |
| - self.cache = {} |
60 |
| - |
61 |
| - def add(self, t, f): |
62 |
| - "Add type `t` and function `f`" |
63 |
| - if not isinstance(t, tuple): t = tuple(L(union2tuple(t))) |
64 |
| - for t_ in t: self.d[t_] = f |
65 |
| - self._reset() |
66 |
| - |
67 |
| - def all_matches(self, k): |
68 |
| - "Find first matching type that is a super-class of `k`" |
69 |
| - if k not in self.cache: |
70 |
| - types = [f for f in self.d if lenient_issubclass(k,f)] |
71 |
| - self.cache[k] = [self.d[o] for o in types] |
72 |
| - return self.cache[k] |
73 |
| - |
74 |
| - def __getitem__(self, k): |
75 |
| - "Find first matching type that is a super-class of `k`" |
76 |
| - res = self.all_matches(k) |
77 |
| - return res[0] if len(res) else None |
78 |
| - |
79 |
| - def __repr__(self): return self.d.__repr__() |
80 |
| - def first(self): return first(self.d.values()) |
| 46 | + @delegates(Function.dispatch) |
| 47 | + def dispatch(self, f=None, **kwargs): return super().dispatch(_eval_annotations(f), **kwargs) |
81 | 48 |
|
82 |
| -# Cell |
83 |
| -class TypeDispatch: |
84 |
| - "Dictionary-like object; `__getitem__` matches keys of types using `issubclass`" |
85 |
| - def __init__(self, funcs=(), bases=()): |
86 |
| - self.funcs,self.bases = _TypeDict(),L(bases).filter(is_not(None)) |
87 |
| - for o in L(funcs): self.add(o) |
88 |
| - self.inst = None |
89 |
| - self.owner = None |
90 |
| - |
91 |
| - def add(self, f): |
92 |
| - "Add type `t` and function `f`" |
93 |
| - if isinstance(f, staticmethod): a0,a1 = _p2_anno(f.__func__) |
94 |
| - else: a0,a1 = _p2_anno(f) |
95 |
| - t = self.funcs.d.get(a0) |
96 |
| - if t is None: |
97 |
| - t = _TypeDict() |
98 |
| - self.funcs.add(a0, t) |
99 |
| - t.add(a1, f) |
100 |
| - |
101 |
| - def first(self): |
102 |
| - "Get first function in ordered dict of type:func." |
103 |
| - return self.funcs.first().first() |
104 |
| - |
105 |
| - def returns(self, x): |
106 |
| - "Get the return type of annotation of `x`." |
107 |
| - return anno_ret(self[type(x)]) |
108 |
| - |
109 |
| - def _attname(self,k): return getattr(k,'__name__',str(k)) |
110 |
| - def __repr__(self): |
111 |
| - r = [f'({self._attname(k)},{self._attname(l)}) -> {getattr(v, "__name__", type(v).__name__)}' |
112 |
| - for k in self.funcs.d for l,v in self.funcs[k].d.items()] |
113 |
| - r = r + [o.__repr__() for o in self.bases] |
114 |
| - return '\n'.join(r) |
115 |
| - |
116 |
| - def __call__(self, *args, **kwargs): |
117 |
| - ts = L(args).map(type)[:2] |
118 |
| - f = self[tuple(ts)] |
119 |
| - if not f: return args[0] |
120 |
| - if isinstance(f, staticmethod): f = f.__func__ |
121 |
| - elif self.inst is not None: f = MethodType(f, self.inst) |
122 |
| - elif self.owner is not None: f = MethodType(f, self.owner) |
123 |
| - return f(*args, **kwargs) |
124 |
| - |
125 |
| - def __get__(self, inst, owner): |
126 |
| - self.inst = inst |
127 |
| - self.owner = owner |
128 |
| - return self |
129 |
| - |
130 |
| - def __getitem__(self, k): |
131 |
| - "Find first matching type that is a super-class of `k`" |
132 |
| - k = L(k) |
133 |
| - while len(k)<2: k.append(object) |
134 |
| - r = self.funcs.all_matches(k[0]) |
135 |
| - for t in r: |
136 |
| - o = t[k[1]] |
137 |
| - if o is not None: return o |
138 |
| - for base in self.bases: |
139 |
| - res = base[k] |
140 |
| - if res is not None: return res |
141 |
| - return None |
| 49 | + def __getitem__(self, ts): |
| 50 | + "Return the most-specific matching method with fewest parameters" |
| 51 | + ts = L(ts) |
| 52 | + nargs = min(len(o) for o in self.methods.keys()) |
| 53 | + while len(ts) < nargs: ts.append(object) |
| 54 | + return self.invoke(*ts) |
142 | 55 |
|
143 | 56 | # Cell
|
144 |
| -class DispatchReg: |
145 |
| - "A global registry for `TypeDispatch` objects keyed by function name" |
146 |
| - def __init__(self): self.d = defaultdict(TypeDispatch) |
147 |
| - def __call__(self, f): |
148 |
| - if isinstance(f, (classmethod, staticmethod)): nm = f'{f.__func__.__qualname__}' |
149 |
| - else: nm = f'{f.__qualname__}' |
150 |
| - if isinstance(f, classmethod): f=f.__func__ |
151 |
| - self.d[nm].add(f) |
152 |
| - return self.d[nm] |
153 |
| - |
154 |
| -typedispatch = DispatchReg() |
| 57 | +class FastDispatcher(Dispatcher): |
| 58 | + def _get_function(self, method, owner): |
| 59 | + "Adapted from `Dispatcher._get_function` to use `FastFunction`" |
| 60 | + name = method.__name__ |
| 61 | + if owner: |
| 62 | + if owner not in self._classes: self._classes[owner] = {} |
| 63 | + namespace = self._classes[owner] |
| 64 | + else: namespace = self._functions |
| 65 | + if name not in namespace: namespace[name] = FastFunction(method, owner=owner) |
| 66 | + return namespace[name] |
| 67 | + |
| 68 | + @delegates(Dispatcher.__call__, but='method') |
| 69 | + def __call__(self, f, **kwargs): return super().__call__(_eval_annotations(f), **kwargs) |
| 70 | + |
| 71 | + def _to(self, cls, nm, f, **kwargs): |
| 72 | + nf = copy_func(f) |
| 73 | + nf.__qualname__ = f'{cls.__name__}.{nm}' # plum uses __qualname__ to infer f's owner |
| 74 | + pf = self(nf, **kwargs) |
| 75 | + # plum uses __set_name__ to resolve a plum.Function's owner |
| 76 | + # since we assign after class creation, __set_name__ must be called directly |
| 77 | + # source: https://docs.python.org/3/reference/datamodel.html#object.__set_name__ |
| 78 | + pf.__set_name__(cls, nm) |
| 79 | + pf = pf.resolve() |
| 80 | + setattr(cls, nm, pf) |
| 81 | + return pf |
| 82 | + |
| 83 | + def to(self, cls): |
| 84 | + "Decorator: dispatch `f` to `cls.f`" |
| 85 | + def _inner(f, **kwargs): |
| 86 | + nm = f.__name__ |
| 87 | + # check __dict__ to avoid inherited methods but use getattr so pf.__get__ is called, which plum relies on |
| 88 | + if nm in cls.__dict__: |
| 89 | + pf = getattr(cls, nm) |
| 90 | + if not hasattr(pf, 'dispatch'): pf = self._to(cls, nm, pf, **kwargs) |
| 91 | + pf.dispatch(f) |
| 92 | + else: pf = self._to(cls, nm, f, **kwargs) |
| 93 | + return pf |
| 94 | + return _inner |
| 95 | + |
| 96 | +typedispatch = FastDispatcher() |
155 | 97 |
|
156 | 98 | # Cell
|
157 | 99 | #nbdev_comment _all_=['cast']
|
|
0 commit comments