Skip to content

Commit fce1adb

Browse files
committed
fix python 3.7 compatibility; address review comments
1 parent 5d1a117 commit fce1adb

File tree

3 files changed

+41
-107
lines changed

3 files changed

+41
-107
lines changed

fastcore/transform.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import inspect
1717
from copy import copy
1818
from plum import add_conversion_method, dispatch, Function
19-
from typing import get_args, get_origin
2019

2120
# Cell
2221
# Convert tuple annotations to unions to work with plum
@@ -27,9 +26,13 @@ def _annot_tuple_to_union(f):
2726

2827
def _dispatch(f): return dispatch(_annot_tuple_to_union(f))
2928

29+
def _pf_dispatch(pf, f): return pf.dispatch(_annot_tuple_to_union(f))
30+
3031
def _dispatch_method(f, cls):
31-
f = copy(f)
3232
n = f.__name__
33+
# Use __dict__ to avoid searching base classes
34+
if n in cls.__dict__: return _pf_dispatch(getattr(cls, n), f)
35+
f = copy(f)
3336
# plum uses __qualname__ to infer f's owner
3437
f.__qualname__ = f'{cls.__name__}.{n}'
3538
pf = _dispatch(f)
@@ -40,8 +43,6 @@ def _dispatch_method(f, cls):
4043
pf.__set_name__(cls, n)
4144
return pf
4245

43-
def _pf_dispatch(pf, f): return pf.dispatch(_annot_tuple_to_union(f))
44-
4546
# Cell
4647
_tfm_methods = 'encodes','decodes','setups'
4748

@@ -98,6 +99,10 @@ def _pt_repr(o):
9899
def _pf_repr(pf): return '\n'.join(f"{f.__name__}: ({','.join(_pt_repr(t) for t in s.types)}) -> {_pt_repr(r)}"
99100
for s, (f, r) in pf.methods.items())
100101

102+
# Cell
103+
def _union_to_tuple(t):
104+
return t.__args__ if getattr(t,'__origin__',None) is Union else t
105+
101106
# Cell
102107
class Transform(metaclass=_TfmMeta):
103108
"Delegates (`__call__`,`decode`,`setup`) to (<code>encodes</code>,<code>decodes</code>,<code>setups</code>) if `split_idx` matches"
@@ -114,9 +119,8 @@ def identity(x): return x
114119
_pf_dispatch(self.encodes, enc)
115120
self.order = getattr(enc,'order',self.order)
116121
if len(type_hints(enc)) > 0:
117-
self.input_types = first(type_hints(enc).values())
118122
# Convert Union to tuple, remove once the rest of fastai supports Union
119-
if get_origin(self.input_types) is Union: self.input_types=get_args(self.input_types)
123+
self.input_types = _union_to_tuple(first(type_hints(enc).values()))
120124
self._name = _get_name(enc)
121125
if dec: _pf_dispatch(self.decodes, dec)
122126

nbs/05_transform.ipynb

+30-100
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@
2323
"from fastcore.dispatch import *\n",
2424
"import inspect\n",
2525
"from copy import copy\n",
26-
"from plum import add_conversion_method, dispatch, Function\n",
27-
"from typing import get_args, get_origin"
26+
"from plum import add_conversion_method, dispatch, Function"
2827
]
2928
},
3029
{
@@ -78,9 +77,13 @@
7877
"\n",
7978
"def _dispatch(f): return dispatch(_annot_tuple_to_union(f))\n",
8079
"\n",
80+
"def _pf_dispatch(pf, f): return pf.dispatch(_annot_tuple_to_union(f))\n",
81+
"\n",
8182
"def _dispatch_method(f, cls):\n",
82-
" f = copy(f)\n",
8383
" n = f.__name__\n",
84+
" # Use __dict__ to avoid searching base classes\n",
85+
" if n in cls.__dict__: return _pf_dispatch(getattr(cls, n), f)\n",
86+
" f = copy(f)\n",
8487
" # plum uses __qualname__ to infer f's owner\n",
8588
" f.__qualname__ = f'{cls.__name__}.{n}'\n",
8689
" pf = _dispatch(f)\n",
@@ -89,9 +92,7 @@
8992
" # since we assign after class creation, __set_name__ must be called directly\n",
9093
" # source: https://docs.python.org/3/reference/datamodel.html#object.__set_name__\n",
9194
" pf.__set_name__(cls, n)\n",
92-
" return pf\n",
93-
"\n",
94-
"def _pf_dispatch(pf, f): return pf.dispatch(_annot_tuple_to_union(f))"
95+
" return pf"
9596
]
9697
},
9798
{
@@ -225,6 +226,28 @@
225226
"test_eq(_pf_repr(_f), '_f1: (int,dict[str,float]) -> float\\n_f2: (int,tuple[str,float]) -> float')"
226227
]
227228
},
229+
{
230+
"cell_type": "code",
231+
"execution_count": null,
232+
"metadata": {},
233+
"outputs": [],
234+
"source": [
235+
"#export\n",
236+
"def _union_to_tuple(t):\n",
237+
" return t.__args__ if getattr(t,'__origin__',None) is Union else t"
238+
]
239+
},
240+
{
241+
"cell_type": "code",
242+
"execution_count": null,
243+
"metadata": {},
244+
"outputs": [],
245+
"source": [
246+
"test_eq(_union_to_tuple(Union[int,Union[str,None]]), (int,str,NoneType))\n",
247+
"test_eq(_union_to_tuple(Tuple[int,str]), Tuple[int,str])\n",
248+
"test_eq(_union_to_tuple(int), int)"
249+
]
250+
},
228251
{
229252
"cell_type": "code",
230253
"execution_count": null,
@@ -247,9 +270,8 @@
247270
" _pf_dispatch(self.encodes, enc)\n",
248271
" self.order = getattr(enc,'order',self.order)\n",
249272
" if len(type_hints(enc)) > 0:\n",
250-
" self.input_types = first(type_hints(enc).values())\n",
251273
" # Convert Union to tuple, remove once the rest of fastai supports Union\n",
252-
" if get_origin(self.input_types) is Union: self.input_types=get_args(self.input_types)\n",
274+
" self.input_types = _union_to_tuple(first(type_hints(enc).values()))\n",
253275
" self._name = _get_name(enc)\n",
254276
" if dec: _pf_dispatch(self.decodes, dec)\n",
255277
"\n",
@@ -520,15 +542,6 @@
520542
"`Transform` can be used as a decorator to turn a function into a `Transform`."
521543
]
522544
},
523-
{
524-
"cell_type": "code",
525-
"execution_count": null,
526-
"metadata": {},
527-
"outputs": [],
528-
"source": [
529-
"from nbdev.showdoc import _format_cls_doc, _format_func_doc"
530-
]
531-
},
532545
{
533546
"cell_type": "code",
534547
"execution_count": null,
@@ -680,18 +693,6 @@
680693
"test_eq(f(['a','b','c']), \"['a', 'b', 'c']_1\") # input is of type list"
681694
]
682695
},
683-
{
684-
"cell_type": "code",
685-
"execution_count": null,
686-
"metadata": {},
687-
"outputs": [],
688-
"source": [
689-
"@Transform\n",
690-
"def f(x:(int,float)): return x+1\n",
691-
"test_eq(f(0), 1)\n",
692-
"test_eq(f('a'), 'a')"
693-
]
694-
},
695696
{
696697
"cell_type": "markdown",
697698
"metadata": {},
@@ -929,77 +930,6 @@
929930
"test_eq(f.decode(t), [1,2])"
930931
]
931932
},
932-
{
933-
"cell_type": "code",
934-
"execution_count": null,
935-
"metadata": {},
936-
"outputs": [],
937-
"source": [
938-
"def encodes(self, x): pass"
939-
]
940-
},
941-
{
942-
"cell_type": "code",
943-
"execution_count": null,
944-
"metadata": {},
945-
"outputs": [
946-
{
947-
"data": {
948-
"text/plain": [
949-
"Promise(obj=<function <function AL.encodes at 0x11e3fb670> with 2 method(s)>)"
950-
]
951-
},
952-
"execution_count": null,
953-
"metadata": {},
954-
"output_type": "execute_result"
955-
}
956-
],
957-
"source": [
958-
"AL(encodes)"
959-
]
960-
},
961-
{
962-
"cell_type": "code",
963-
"execution_count": null,
964-
"metadata": {},
965-
"outputs": [
966-
{
967-
"data": {
968-
"text/plain": [
969-
"<lambda>:\n",
970-
"encodes: <lambda>: (object) -> object\n",
971-
"decodes: identity: (object) -> object"
972-
]
973-
},
974-
"execution_count": null,
975-
"metadata": {},
976-
"output_type": "execute_result"
977-
}
978-
],
979-
"source": [
980-
"AL(lambda x: x)"
981-
]
982-
},
983-
{
984-
"cell_type": "code",
985-
"execution_count": null,
986-
"metadata": {},
987-
"outputs": [
988-
{
989-
"data": {
990-
"text/plain": [
991-
"__main__.AL"
992-
]
993-
},
994-
"execution_count": null,
995-
"metadata": {},
996-
"output_type": "execute_result"
997-
}
998-
],
999-
"source": [
1000-
"type(AL(lambda x: x))"
1001-
]
1002-
},
1003933
{
1004934
"cell_type": "markdown",
1005935
"metadata": {},

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
min_python = cfg['min_python']
2727
lic = licenses[cfg['license']]
2828

29-
requirements = ['pip', 'packaging', 'plum-dispatch>=1.5.16']
29+
requirements = ['pip', 'packaging', 'plum-dispatch>=1.6']
3030
if cfg.get('requirements'): requirements += cfg.get('requirements','').split()
3131
if cfg.get('pip_requirements'): requirements += cfg.get('pip_requirements','').split()
3232
dev_requirements = (cfg.get('dev_requirements') or '').split()

0 commit comments

Comments
 (0)