Skip to content

Commit e912573

Browse files
committed
prototype plum-dispatch for fastcore.transform
1 parent 6db6aaa commit e912573

File tree

3 files changed

+146
-52
lines changed

3 files changed

+146
-52
lines changed

fastcore/transform.py

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,37 +9,50 @@
99
from .utils import *
1010
from .dispatch import *
1111
import inspect
12+
from plum import add_conversion_method, dispatch, Dispatcher
13+
14+
# Cell
15+
def _mk_plum_func(d, n, f=None, cls=None):
16+
f = (lambda x: x) if f is None else copy(f)
17+
f.__name__ = n
18+
# plum uses __qualname__ to infer f's owner
19+
f.__qualname__ = n if cls is None else '.'.join([cls.__name__,n])
20+
pf = d(f)
21+
if cls is not None:
22+
setattr(cls,n,pf)
23+
# plum uses __set_name__ to resolve a Function's owner.
24+
# since class variable is assigned after class is created, __set_name__ must be called directly
25+
# source: https://docs.python.org/3/reference/datamodel.html#object.__set_name__
26+
pf.__set_name__(cls,n)
27+
return pf
1228

1329
# Cell
1430
_tfm_methods = 'encodes','decodes','setups'
1531

32+
def _is_tfm_method(f,n):
33+
return n in _tfm_methods and callable(f)
34+
1635
class _TfmDict(dict):
1736
def __setitem__(self,k,v):
18-
if k not in _tfm_methods or not callable(v): return super().__setitem__(k,v)
19-
if k not in self: super().__setitem__(k,TypeDispatch())
20-
self[k].add(v)
37+
if _is_tfm_method(v,k): v = dispatch(v)
38+
super().__setitem__(k,v)
2139

2240
# Cell
2341
class _TfmMeta(type):
42+
@classmethod
43+
def __prepare__(cls, name, bases): return _TfmDict()
44+
2445
def __new__(cls, name, bases, dict):
2546
res = super().__new__(cls, name, bases, dict)
26-
for nm in _tfm_methods:
27-
base_td = [getattr(b,nm,None) for b in bases]
28-
if nm in res.__dict__: getattr(res,nm).bases = base_td
29-
else: setattr(res, nm, TypeDispatch(bases=base_td))
3047
res.__signature__ = inspect.signature(res.__init__)
3148
return res
3249

3350
def __call__(cls, *args, **kwargs):
34-
f = args[0] if args else None
51+
f = first(args)
3552
n = getattr(f,'__name__',None)
36-
if callable(f) and n in _tfm_methods:
37-
getattr(cls,n).add(f)
38-
return f
39-
return super().__call__(*args, **kwargs)
40-
41-
@classmethod
42-
def __prepare__(cls, name, bases): return _TfmDict()
53+
if not _is_tfm_method(f,n): return super().__call__(*args,**kwargs)
54+
if n in cls.__dict__: getattr(cls,n).dispatch(f)
55+
return _mk_plum_func(dispatch,n,f,cls=cls)
4356

4457
# Cell
4558
def _get_name(o):
@@ -60,13 +73,14 @@ def __init__(self, enc=None, dec=None, split_idx=None, order=None):
6073
self.init_enc = enc or dec
6174
if not self.init_enc: return
6275

63-
self.encodes,self.decodes,self.setups = TypeDispatch(),TypeDispatch(),TypeDispatch()
76+
self._d = Dispatcher() # TODO: do we need to hold this reference?
77+
for n in _tfm_methods: setattr(self,n,_mk_plum_func(self._d,n))
6478
if enc:
65-
self.encodes.add(enc)
79+
self.encodes = self.encodes.dispatch(enc)
6680
self.order = getattr(enc,'order',self.order)
6781
if len(type_hints(enc)) > 0: self.input_types = first(type_hints(enc).values())
6882
self._name = _get_name(enc)
69-
if dec: self.decodes.add(dec)
83+
if dec: self.decodes = self.decodes.dispatch(dec)
7084

7185
@property
7286
def name(self): return getattr(self, '_name', _get_name(self))
@@ -85,13 +99,24 @@ def _call(self, fn, x, split_idx=None, **kwargs):
8599
def _do_call(self, f, x, **kwargs):
86100
if not _is_tuple(x):
87101
if f is None: return x
88-
ret = f.returns(x) if hasattr(f,'returns') else None
89-
return retain_type(f(x, **kwargs), x, ret)
102+
ts = [type(self),type(x)] if hasattr(f,'instance') else [type(x)]
103+
_, ret = f.resolve_method(*ts)
104+
ret = ret._type
105+
# plum reads empty return annot as object, fastcore reads as None
106+
if ret is object: ret = None
107+
return retain_type(f(x,**kwargs), x, ret)
90108
res = tuple(self._do_call(f, x_, **kwargs) for x_ in x)
91109
return retain_type(res, x)
110+
def encodes(self, x): return x
111+
def decodes(self, x): return x
112+
def setups(self, dl): return dl
92113

93114
add_docs(Transform, decode="Delegate to <code>decodes</code> to undo transform", setup="Delegate to <code>setups</code> to set up transform")
94115

116+
# Cell
117+
#Transform interpret's None return type as no conversion
118+
add_conversion_method(object, NoneType, lambda x: x)
119+
95120
# Cell
96121
class InplaceTransform(Transform):
97122
"A `Transform` that modifies in-place and just returns whatever it's passed"

nbs/05_transform.ipynb

Lines changed: 100 additions & 31 deletions
Large diffs are not rendered by default.

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
min_python = cfg['min_python']
2727
lic = licenses[cfg['license']]
2828

29-
requirements = ['pip', 'packaging']
29+
requirements = ['pip', 'packaging', 'plum-dispatch']
3030
if cfg.get('requirements'): requirements += cfg.get('requirements','').split()
3131
if cfg.get('pip_requirements'): requirements += cfg.get('pip_requirements','').split()
3232
dev_requirements = (cfg.get('dev_requirements') or '').split()

0 commit comments

Comments
 (0)