9
9
from .utils import *
10
10
from .dispatch import *
11
11
import inspect
12
+ from plum import add_conversion_method , dispatch , Dispatcher
13
+
14
+ # Cell
15
+ def _mk_plum_func (d , n , f = None , cls = None ):
16
+ f = (lambda x : x ) if f is None else copy (f )
17
+ f .__name__ = n
18
+ # plum uses __qualname__ to infer f's owner
19
+ f .__qualname__ = n if cls is None else '.' .join ([cls .__name__ ,n ])
20
+ pf = d (f )
21
+ if cls is not None :
22
+ setattr (cls ,n ,pf )
23
+ # plum uses __set_name__ to resolve a Function's owner.
24
+ # since class variable is assigned after class is created, __set_name__ must be called directly
25
+ # source: https://docs.python.org/3/reference/datamodel.html#object.__set_name__
26
+ pf .__set_name__ (cls ,n )
27
+ return pf
12
28
13
29
# Cell
14
30
_tfm_methods = 'encodes' ,'decodes' ,'setups'
15
31
32
+ def _is_tfm_method (f ,n ):
33
+ return n in _tfm_methods and callable (f )
34
+
16
35
class _TfmDict (dict ):
17
36
def __setitem__ (self ,k ,v ):
18
- if k not in _tfm_methods or not callable (v ): return super ().__setitem__ (k ,v )
19
- if k not in self : super ().__setitem__ (k ,TypeDispatch ())
20
- self [k ].add (v )
37
+ if _is_tfm_method (v ,k ): v = dispatch (v )
38
+ super ().__setitem__ (k ,v )
21
39
22
40
# Cell
23
41
class _TfmMeta (type ):
42
+ @classmethod
43
+ def __prepare__ (cls , name , bases ): return _TfmDict ()
44
+
24
45
def __new__ (cls , name , bases , dict ):
25
46
res = super ().__new__ (cls , name , bases , dict )
26
- for nm in _tfm_methods :
27
- base_td = [getattr (b ,nm ,None ) for b in bases ]
28
- if nm in res .__dict__ : getattr (res ,nm ).bases = base_td
29
- else : setattr (res , nm , TypeDispatch (bases = base_td ))
30
47
res .__signature__ = inspect .signature (res .__init__ )
31
48
return res
32
49
33
50
def __call__ (cls , * args , ** kwargs ):
34
- f = args [ 0 ] if args else None
51
+ f = first ( args )
35
52
n = getattr (f ,'__name__' ,None )
36
- if callable (f ) and n in _tfm_methods :
37
- getattr (cls ,n ).add (f )
38
- return f
39
- return super ().__call__ (* args , ** kwargs )
40
-
41
- @classmethod
42
- def __prepare__ (cls , name , bases ): return _TfmDict ()
53
+ if not _is_tfm_method (f ,n ): return super ().__call__ (* args ,** kwargs )
54
+ if n in cls .__dict__ : getattr (cls ,n ).dispatch (f )
55
+ return _mk_plum_func (dispatch ,n ,f ,cls = cls )
43
56
44
57
# Cell
45
58
def _get_name (o ):
@@ -60,13 +73,14 @@ def __init__(self, enc=None, dec=None, split_idx=None, order=None):
60
73
self .init_enc = enc or dec
61
74
if not self .init_enc : return
62
75
63
- self .encodes ,self .decodes ,self .setups = TypeDispatch (),TypeDispatch (),TypeDispatch ()
76
+ self ._d = Dispatcher () # TODO: do we need to hold this reference?
77
+ for n in _tfm_methods : setattr (self ,n ,_mk_plum_func (self ._d ,n ))
64
78
if enc :
65
- self .encodes . add (enc )
79
+ self .encodes = self . encodes . dispatch (enc )
66
80
self .order = getattr (enc ,'order' ,self .order )
67
81
if len (type_hints (enc )) > 0 : self .input_types = first (type_hints (enc ).values ())
68
82
self ._name = _get_name (enc )
69
- if dec : self .decodes . add (dec )
83
+ if dec : self .decodes = self . decodes . dispatch (dec )
70
84
71
85
@property
72
86
def name (self ): return getattr (self , '_name' , _get_name (self ))
@@ -85,13 +99,24 @@ def _call(self, fn, x, split_idx=None, **kwargs):
85
99
def _do_call (self , f , x , ** kwargs ):
86
100
if not _is_tuple (x ):
87
101
if f is None : return x
88
- ret = f .returns (x ) if hasattr (f ,'returns' ) else None
89
- return retain_type (f (x , ** kwargs ), x , ret )
102
+ ts = [type (self ),type (x )] if hasattr (f ,'instance' ) else [type (x )]
103
+ _ , ret = f .resolve_method (* ts )
104
+ ret = ret ._type
105
+ # plum reads empty return annot as object, fastcore reads as None
106
+ if ret is object : ret = None
107
+ return retain_type (f (x ,** kwargs ), x , ret )
90
108
res = tuple (self ._do_call (f , x_ , ** kwargs ) for x_ in x )
91
109
return retain_type (res , x )
110
+ def encodes (self , x ): return x
111
+ def decodes (self , x ): return x
112
+ def setups (self , dl ): return dl
92
113
93
114
add_docs (Transform , decode = "Delegate to <code>decodes</code> to undo transform" , setup = "Delegate to <code>setups</code> to set up transform" )
94
115
116
+ # Cell
117
+ #Transform interpret's None return type as no conversion
118
+ add_conversion_method (object , NoneType , lambda x : x )
119
+
95
120
# Cell
96
121
class InplaceTransform (Transform ):
97
122
"A `Transform` that modifies in-place and just returns whatever it's passed"
0 commit comments