Skip to content

Commit 1184a54

Browse files
committed
minor cleanup
1 parent e912573 commit 1184a54

File tree

2 files changed

+20
-46
lines changed

2 files changed

+20
-46
lines changed

fastcore/transform.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,6 @@ def __setitem__(self,k,v):
3939

4040
# Cell
4141
class _TfmMeta(type):
42-
@classmethod
43-
def __prepare__(cls, name, bases): return _TfmDict()
44-
4542
def __new__(cls, name, bases, dict):
4643
res = super().__new__(cls, name, bases, dict)
4744
res.__signature__ = inspect.signature(res.__init__)
@@ -50,10 +47,16 @@ def __new__(cls, name, bases, dict):
5047
def __call__(cls, *args, **kwargs):
5148
f = first(args)
5249
n = getattr(f,'__name__',None)
53-
if not _is_tfm_method(f,n): return super().__call__(*args,**kwargs)
50+
if _is_tfm_method(f,n):
51+
if n in cls.__dict__: return getattr(cls,n).dispatch(f)
52+
53+
return super().__call__(*args,**kwargs)
5454
if n in cls.__dict__: getattr(cls,n).dispatch(f)
5555
return _mk_plum_func(dispatch,n,f,cls=cls)
5656

57+
@classmethod
58+
def __prepare__(cls, name, bases): return _TfmDict()
59+
5760
# Cell
5861
def _get_name(o):
5962
if hasattr(o,'__qualname__'): return o.__qualname__
@@ -76,11 +79,11 @@ def __init__(self, enc=None, dec=None, split_idx=None, order=None):
7679
self._d = Dispatcher() # TODO: do we need to hold this reference?
7780
for n in _tfm_methods: setattr(self,n,_mk_plum_func(self._d,n))
7881
if enc:
79-
self.encodes = self.encodes.dispatch(enc)
82+
self.encodes.dispatch(enc)
8083
self.order = getattr(enc,'order',self.order)
8184
if len(type_hints(enc)) > 0: self.input_types = first(type_hints(enc).values())
8285
self._name = _get_name(enc)
83-
if dec: self.decodes = self.decodes.dispatch(dec)
86+
if dec: self.decodes.dispatch(dec)
8487

8588
@property
8689
def name(self): return getattr(self, '_name', _get_name(self))

nbs/05_transform.ipynb

Lines changed: 11 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,6 @@
108108
"source": [
109109
"#export\n",
110110
"class _TfmMeta(type):\n",
111-
" @classmethod\n",
112-
" def __prepare__(cls, name, bases): return _TfmDict()\n",
113-
"\n",
114111
" def __new__(cls, name, bases, dict):\n",
115112
" res = super().__new__(cls, name, bases, dict)\n",
116113
" res.__signature__ = inspect.signature(res.__init__)\n",
@@ -119,9 +116,13 @@
119116
" def __call__(cls, *args, **kwargs):\n",
120117
" f = first(args)\n",
121118
" n = getattr(f,'__name__',None)\n",
122-
" if not _is_tfm_method(f,n): return super().__call__(*args,**kwargs)\n",
123-
" if n in cls.__dict__: getattr(cls,n).dispatch(f)\n",
124-
" return _mk_plum_func(dispatch,n,f,cls=cls)"
119+
" if _is_tfm_method(f,n):\n",
120+
" if n in cls.__dict__: return getattr(cls,n).dispatch(f)\n",
121+
" return _mk_plum_func(dispatch,n,f,cls=cls)\n",
122+
" return super().__call__(*args,**kwargs)\n",
123+
"\n",
124+
" @classmethod\n",
125+
" def __prepare__(cls, name, bases): return _TfmDict()"
125126
]
126127
},
127128
{
@@ -166,11 +167,11 @@
166167
" self._d = Dispatcher() # TODO: do we need to hold this reference?\n",
167168
" for n in _tfm_methods: setattr(self,n,_mk_plum_func(self._d,n))\n",
168169
" if enc:\n",
169-
" self.encodes = self.encodes.dispatch(enc)\n",
170+
" self.encodes.dispatch(enc)\n",
170171
" self.order = getattr(enc,'order',self.order)\n",
171172
" if len(type_hints(enc)) > 0: self.input_types = first(type_hints(enc).values())\n",
172173
" self._name = _get_name(enc)\n",
173-
" if dec: self.decodes = self.decodes.dispatch(dec)\n",
174+
" if dec: self.decodes.dispatch(dec)\n",
174175
"\n",
175176
" @property\n",
176177
" def name(self): return getattr(self, '_name', _get_name(self))\n",
@@ -355,16 +356,6 @@
355356
"test_eq(f1(1), 2) # f1(1) is the same as f1.encode(1)"
356357
]
357358
},
358-
{
359-
"cell_type": "code",
360-
"execution_count": null,
361-
"metadata": {},
362-
"outputs": [],
363-
"source": [
364-
"class A(Transform):\n",
365-
" def encodes(self,x:int): return x+1"
366-
]
367-
},
368359
{
369360
"cell_type": "markdown",
370361
"metadata": {},
@@ -624,26 +615,6 @@
624615
"test_eq_type(f(FloatSubclass(3.0)), FloatSubclass(6.0))"
625616
]
626617
},
627-
{
628-
"cell_type": "code",
629-
"execution_count": null,
630-
"metadata": {},
631-
"outputs": [
632-
{
633-
"data": {
634-
"text/plain": [
635-
"__main__.FloatSubclass"
636-
]
637-
},
638-
"execution_count": null,
639-
"metadata": {},
640-
"output_type": "execute_result"
641-
}
642-
],
643-
"source": [
644-
"type(f(FloatSubclass(3.0)))"
645-
]
646-
},
647618
{
648619
"cell_type": "markdown",
649620
"metadata": {},
@@ -1083,7 +1054,7 @@
10831054
"data": {
10841055
"text/plain": [
10851056
"A:\n",
1086-
"encodes: <function <function noop at 0x111e2be50> with 2 method(s)>decodes: <function <function noop at 0x111e2be50> with 2 method(s)>"
1057+
"encodes: <function <function noop at 0x11fa2aa60> with 2 method(s)>decodes: <function <function noop at 0x11fa2aa60> with 2 method(s)>"
10871058
]
10881059
},
10891060
"execution_count": null,
@@ -1113,7 +1084,7 @@
11131084
"data": {
11141085
"text/plain": [
11151086
"A -- {'a': 1, 'b': 2}:\n",
1116-
"encodes: <function <function noop at 0x111e2be50> with 3 method(s)>decodes: <function <function Transform.decodes at 0x115652f70> with 1 method(s)>"
1087+
"encodes: <function <function noop at 0x11fa2aa60> with 3 method(s)>decodes: <function <function Transform.decodes at 0x123253f70> with 1 method(s)>"
11171088
]
11181089
},
11191090
"execution_count": null,

0 commit comments

Comments
 (0)