Skip to content

Commit 6df0281

Browse files
committed
enable more robust multiple dispatch with plum
1 parent 06922b7 commit 6df0281

10 files changed

+385
-1038
lines changed

fastcore/_nbdev.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -227,10 +227,8 @@
227227
"do_request": "03b_net.ipynb",
228228
"start_server": "03b_net.ipynb",
229229
"start_client": "03b_net.ipynb",
230-
"lenient_issubclass": "04_dispatch.ipynb",
231-
"sorted_topologically": "04_dispatch.ipynb",
232-
"TypeDispatch": "04_dispatch.ipynb",
233-
"DispatchReg": "04_dispatch.ipynb",
230+
"FastFunction": "04_dispatch.ipynb",
231+
"FastDispatcher": "04_dispatch.ipynb",
234232
"typedispatch": "04_dispatch.ipynb",
235233
"retain_meta": "04_dispatch.ipynb",
236234
"default_set_meta": "04_dispatch.ipynb",

fastcore/basics.py

+2
Original file line numberDiff line numberDiff line change
@@ -888,6 +888,8 @@ def copy_func(f):
888888
fn = FunctionType(f.__code__, f.__globals__, f.__name__, f.__defaults__, f.__closure__)
889889
fn.__kwdefaults__ = f.__kwdefaults__
890890
fn.__dict__.update(f.__dict__)
891+
fn.__annotations__.update(f.__annotations__)
892+
fn.__qualname__ = f.__qualname__
891893
return fn
892894

893895
# Cell

fastcore/dispatch.py

+72-130
Original file line numberDiff line numberDiff line change
@@ -4,154 +4,96 @@
44
from __future__ import annotations
55

66

7-
__all__ = ['lenient_issubclass', 'sorted_topologically', 'TypeDispatch', 'DispatchReg', 'typedispatch', 'cast',
8-
'retain_meta', 'default_set_meta', 'retain_type', 'retain_types', 'explode_types']
7+
__all__ = ['FastFunction', 'FastDispatcher', 'typedispatch', 'cast', 'retain_meta', 'default_set_meta', 'retain_type',
8+
'retain_types', 'explode_types']
99

1010
# Cell
1111
#nbdev_comment from __future__ import annotations
1212
from .imports import *
1313
from .foundation import *
1414
from .utils import *
15+
from .meta import delegates
1516

1617
from collections import defaultdict
18+
from plum import Function, Dispatcher
1719

1820
# Cell
19-
def lenient_issubclass(cls, types):
20-
"If possible return whether `cls` is a subclass of `types`, otherwise return False."
21-
if cls is object and types is not object: return False # treat `object` as highest level
22-
try: return isinstance(cls, types) or issubclass(cls, types)
23-
except: return False
21+
def _eval_annotations(f):
22+
"Evaluate future annotations before passing to plum to support backported union operator `|`"
23+
f = copy_func(f)
24+
for k, v in type_hints(f).items(): f.__annotations__[k] = Union[v] if isinstance(v, tuple) else v
25+
return f
2426

2527
# Cell
26-
def sorted_topologically(iterable, *, cmp=operator.lt, reverse=False):
27-
"Return a new list containing all items from the iterable sorted topologically"
28-
l,res = L(list(iterable)),[]
29-
for _ in range(len(l)):
30-
t = l.reduce(lambda x,y: y if cmp(y,x) else x)
31-
res.append(t), l.remove(t)
32-
return res[::-1] if reverse else res
28+
def _pt_repr(o):
29+
"Concise repr of plum types"
30+
n = type(o).__name__
31+
if n == 'Tuple': return f"{n.lower()}[{','.join(_pt_repr(t) for t in o._el_types)}]"
32+
if n == 'List': return f'{n.lower()}[{_pt_repr(o._el_type)}]'
33+
if n == 'Dict': return f'{n.lower()}[{_pt_repr(o._key_type)},{_pt_repr(o._value_type)}]'
34+
if n in ('Sequence','Iterable'): return f'{n}[{_pt_repr(o._el_type)}]'
35+
if n == 'VarArgs': return f'{n}[{_pt_repr(o.type)}]'
36+
if n == 'Union': return '|'.join(sorted(t.__name__ for t in (o.get_types())))
37+
assert len(o.get_types()) == 1
38+
return o.get_types()[0].__name__
3339

3440
# Cell
35-
def _chk_defaults(f, ann):
36-
pass
37-
# Implementation removed until we can figure out how to do this without `inspect` module
38-
# try: # Some callables don't have signatures, so ignore those errors
39-
# params = list(inspect.signature(f).parameters.values())[:min(len(ann),2)]
40-
# if any(p.default!=inspect.Parameter.empty for p in params):
41-
# warn(f"{f.__name__} has default params. These will be ignored.")
42-
# except ValueError: pass
43-
44-
# Cell
45-
def _p2_anno(f):
46-
"Get the 1st 2 annotations of `f`, defaulting to `object`"
47-
hints = type_hints(f)
48-
ann = [o for n,o in hints.items() if n!='return']
49-
if callable(f): _chk_defaults(f, ann)
50-
while len(ann)<2: ann.append(object)
51-
return ann[:2]
41+
class FastFunction(Function):
42+
def __repr__(self):
43+
return '\n'.join(f"{f.__name__}: ({','.join(_pt_repr(t) for t in s.types)}) -> {_pt_repr(r)}"
44+
for s, (f, r) in self.methods.items())
5245

53-
# Cell
54-
class _TypeDict:
55-
def __init__(self): self.d,self.cache = {},{}
56-
57-
def _reset(self):
58-
self.d = {k:self.d[k] for k in sorted_topologically(self.d, cmp=lenient_issubclass)}
59-
self.cache = {}
60-
61-
def add(self, t, f):
62-
"Add type `t` and function `f`"
63-
if not isinstance(t, tuple): t = tuple(L(union2tuple(t)))
64-
for t_ in t: self.d[t_] = f
65-
self._reset()
66-
67-
def all_matches(self, k):
68-
"Find first matching type that is a super-class of `k`"
69-
if k not in self.cache:
70-
types = [f for f in self.d if lenient_issubclass(k,f)]
71-
self.cache[k] = [self.d[o] for o in types]
72-
return self.cache[k]
73-
74-
def __getitem__(self, k):
75-
"Find first matching type that is a super-class of `k`"
76-
res = self.all_matches(k)
77-
return res[0] if len(res) else None
78-
79-
def __repr__(self): return self.d.__repr__()
80-
def first(self): return first(self.d.values())
46+
@delegates(Function.dispatch)
47+
def dispatch(self, f=None, **kwargs): return super().dispatch(_eval_annotations(f), **kwargs)
8148

82-
# Cell
83-
class TypeDispatch:
84-
"Dictionary-like object; `__getitem__` matches keys of types using `issubclass`"
85-
def __init__(self, funcs=(), bases=()):
86-
self.funcs,self.bases = _TypeDict(),L(bases).filter(is_not(None))
87-
for o in L(funcs): self.add(o)
88-
self.inst = None
89-
self.owner = None
90-
91-
def add(self, f):
92-
"Add type `t` and function `f`"
93-
if isinstance(f, staticmethod): a0,a1 = _p2_anno(f.__func__)
94-
else: a0,a1 = _p2_anno(f)
95-
t = self.funcs.d.get(a0)
96-
if t is None:
97-
t = _TypeDict()
98-
self.funcs.add(a0, t)
99-
t.add(a1, f)
100-
101-
def first(self):
102-
"Get first function in ordered dict of type:func."
103-
return self.funcs.first().first()
104-
105-
def returns(self, x):
106-
"Get the return type of annotation of `x`."
107-
return anno_ret(self[type(x)])
108-
109-
def _attname(self,k): return getattr(k,'__name__',str(k))
110-
def __repr__(self):
111-
r = [f'({self._attname(k)},{self._attname(l)}) -> {getattr(v, "__name__", type(v).__name__)}'
112-
for k in self.funcs.d for l,v in self.funcs[k].d.items()]
113-
r = r + [o.__repr__() for o in self.bases]
114-
return '\n'.join(r)
115-
116-
def __call__(self, *args, **kwargs):
117-
ts = L(args).map(type)[:2]
118-
f = self[tuple(ts)]
119-
if not f: return args[0]
120-
if isinstance(f, staticmethod): f = f.__func__
121-
elif self.inst is not None: f = MethodType(f, self.inst)
122-
elif self.owner is not None: f = MethodType(f, self.owner)
123-
return f(*args, **kwargs)
124-
125-
def __get__(self, inst, owner):
126-
self.inst = inst
127-
self.owner = owner
128-
return self
129-
130-
def __getitem__(self, k):
131-
"Find first matching type that is a super-class of `k`"
132-
k = L(k)
133-
while len(k)<2: k.append(object)
134-
r = self.funcs.all_matches(k[0])
135-
for t in r:
136-
o = t[k[1]]
137-
if o is not None: return o
138-
for base in self.bases:
139-
res = base[k]
140-
if res is not None: return res
141-
return None
49+
def __getitem__(self, ts):
50+
"Return the most-specific matching method with fewest parameters"
51+
ts = L(ts)
52+
nargs = min(len(o) for o in self.methods.keys())
53+
while len(ts) < nargs: ts.append(object)
54+
return self.invoke(*ts)
14255

14356
# Cell
144-
class DispatchReg:
145-
"A global registry for `TypeDispatch` objects keyed by function name"
146-
def __init__(self): self.d = defaultdict(TypeDispatch)
147-
def __call__(self, f):
148-
if isinstance(f, (classmethod, staticmethod)): nm = f'{f.__func__.__qualname__}'
149-
else: nm = f'{f.__qualname__}'
150-
if isinstance(f, classmethod): f=f.__func__
151-
self.d[nm].add(f)
152-
return self.d[nm]
153-
154-
typedispatch = DispatchReg()
57+
class FastDispatcher(Dispatcher):
58+
def _get_function(self, method, owner):
59+
"Adapted from `Dispatcher._get_function` to use `FastFunction`"
60+
name = method.__name__
61+
if owner:
62+
if owner not in self._classes: self._classes[owner] = {}
63+
namespace = self._classes[owner]
64+
else: namespace = self._functions
65+
if name not in namespace: namespace[name] = FastFunction(method, owner=owner)
66+
return namespace[name]
67+
68+
@delegates(Dispatcher.__call__, but='method')
69+
def __call__(self, f, **kwargs): return super().__call__(_eval_annotations(f), **kwargs)
70+
71+
def _to(self, cls, nm, f, **kwargs):
72+
nf = copy_func(f)
73+
nf.__qualname__ = f'{cls.__name__}.{nm}' # plum uses __qualname__ to infer f's owner
74+
pf = self(nf, **kwargs)
75+
# plum uses __set_name__ to resolve a plum.Function's owner
76+
# since we assign after class creation, __set_name__ must be called directly
77+
# source: https://docs.python.org/3/reference/datamodel.html#object.__set_name__
78+
pf.__set_name__(cls, nm)
79+
pf = pf.resolve()
80+
setattr(cls, nm, pf)
81+
return pf
82+
83+
def to(self, cls):
84+
"Decorator: dispatch `f` to `cls.f`"
85+
def _inner(f, **kwargs):
86+
nm = f.__name__
87+
# check __dict__ to avoid inherited methods but use getattr so pf.__get__ is called, which plum relies on
88+
if nm in cls.__dict__:
89+
pf = getattr(cls, nm)
90+
if not hasattr(pf, 'dispatch'): pf = self._to(cls, nm, pf, **kwargs)
91+
pf.dispatch(f)
92+
else: pf = self._to(cls, nm, f, **kwargs)
93+
return pf
94+
return _inner
95+
96+
typedispatch = FastDispatcher()
15597

15698
# Cell
15799
#nbdev_comment _all_=['cast']

fastcore/imports.py

+10
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import sys,os,re,typing,itertools,operator,functools,math,warnings,functools,io,enum
22

3+
from copy import copy
34
from operator import itemgetter,attrgetter
45
from warnings import warn
56
from typing import Iterable,Generator,Sequence,Iterator,List,Set,Dict,Union,Optional,Tuple
@@ -14,6 +15,15 @@
1415
MethodDescriptorType = type(str.join)
1516
from types import BuiltinFunctionType,BuiltinMethodType,MethodType,FunctionType,SimpleNamespace
1617

18+
#Patch autoreload (if its loaded) to work with plum
19+
try: from IPython import get_ipython
20+
except ImportError: pass
21+
else:
22+
ip = get_ipython()
23+
if ip is not None and 'IPython.extensions.storemagic' in ip.extension_manager.loaded:
24+
from plum.autoreload import activate
25+
activate()
26+
1727
NoneType = type(None)
1828
string_classes = (str,bytes)
1929

fastcore/transform.py

+30-19
Original file line numberDiff line numberDiff line change
@@ -9,34 +9,33 @@
99
from .utils import *
1010
from .dispatch import *
1111
import inspect
12+
from plum import add_conversion_method
1213

1314
# Cell
1415
_tfm_methods = 'encodes','decodes','setups'
1516

17+
def _is_tfm_method(n, f): return n in _tfm_methods and callable(f)
18+
1619
class _TfmDict(dict):
17-
def __setitem__(self,k,v):
18-
if k not in _tfm_methods or not callable(v): return super().__setitem__(k,v)
19-
if k not in self: super().__setitem__(k,TypeDispatch())
20-
self[k].add(v)
20+
def __setitem__(self, k, v): super().__setitem__(k, typedispatch(v) if _is_tfm_method(k, v) else v)
2121

2222
# Cell
2323
class _TfmMeta(type):
2424
def __new__(cls, name, bases, dict):
25+
# _TfmMeta.__call__ shadows the signature of inheriting classes, set it back
2526
res = super().__new__(cls, name, bases, dict)
26-
for nm in _tfm_methods:
27-
base_td = [getattr(b,nm,None) for b in bases]
28-
if nm in res.__dict__: getattr(res,nm).bases = base_td
29-
else: setattr(res, nm, TypeDispatch(bases=base_td))
3027
res.__signature__ = inspect.signature(res.__init__)
3128
return res
3229

3330
def __call__(cls, *args, **kwargs):
34-
f = args[0] if args else None
35-
n = getattr(f,'__name__',None)
36-
if callable(f) and n in _tfm_methods:
37-
getattr(cls,n).add(f)
38-
return f
39-
return super().__call__(*args, **kwargs)
31+
f = first(args)
32+
n = getattr(f, '__name__', None)
33+
if _is_tfm_method(n, f): return typedispatch.to(cls)(f)
34+
obj = super().__call__(*args, **kwargs)
35+
# _TfmMeta.__new__ replaces cls.__signature__ which breaks the signature of a callable
36+
# instances of cls, fix it
37+
if hasattr(obj, '__call__'): obj.__signature__ = inspect.signature(obj.__call__)
38+
return obj
4039

4140
@classmethod
4241
def __prepare__(cls, name, bases): return _TfmDict()
@@ -60,13 +59,14 @@ def __init__(self, enc=None, dec=None, split_idx=None, order=None):
6059
self.init_enc = enc or dec
6160
if not self.init_enc: return
6261

63-
self.encodes,self.decodes,self.setups = TypeDispatch(),TypeDispatch(),TypeDispatch()
62+
def identity(x): return x
63+
for n in _tfm_methods: setattr(self,n,FastFunction(identity).dispatch(identity))
6464
if enc:
65-
self.encodes.add(enc)
65+
self.encodes.dispatch(enc)
6666
self.order = getattr(enc,'order',self.order)
6767
if len(type_hints(enc)) > 0: self.input_types = union2tuple(first(type_hints(enc).values()))
6868
self._name = _get_name(enc)
69-
if dec: self.decodes.add(dec)
69+
if dec: self.decodes.dispatch(dec)
7070

7171
@property
7272
def name(self): return getattr(self, '_name', _get_name(self))
@@ -85,13 +85,24 @@ def _call(self, fn, x, split_idx=None, **kwargs):
8585
def _do_call(self, f, x, **kwargs):
8686
if not _is_tuple(x):
8787
if f is None: return x
88-
ret = f.returns(x) if hasattr(f,'returns') else None
89-
return retain_type(f(x, **kwargs), x, ret)
88+
ts = [type(self),type(x)] if hasattr(f,'instance') else [type(x)]
89+
_, ret = f.resolve_method(*ts)
90+
ret = ret._type
91+
# plum reads empty return annotation as object, retain_type expects it as None
92+
if ret is object: ret = None
93+
return retain_type(f(x,**kwargs), x, ret)
9094
res = tuple(self._do_call(f, x_, **kwargs) for x_ in x)
9195
return retain_type(res, x)
96+
def encodes(self, x): return x
97+
def decodes(self, x): return x
98+
def setups(self, dl): return dl
9299

93100
add_docs(Transform, decode="Delegate to <code>decodes</code> to undo transform", setup="Delegate to <code>setups</code> to set up transform")
94101

102+
# Cell
103+
#Implement the Transform convention that a None return annotation disables conversion
104+
add_conversion_method(object, NoneType, lambda x: x)
105+
95106
# Cell
96107
class InplaceTransform(Transform):
97108
"A `Transform` that modifies in-place and just returns whatever it's passed"

0 commit comments

Comments
 (0)