|
23 | 23 | "from fastcore.dispatch import *\n",
|
24 | 24 | "import inspect\n",
|
25 | 25 | "from copy import copy\n",
|
26 |
| - "from plum import add_conversion_method, dispatch, Function\n", |
27 |
| - "from typing import get_args, get_origin" |
| 26 | + "from plum import add_conversion_method, dispatch, Function" |
28 | 27 | ]
|
29 | 28 | },
|
30 | 29 | {
|
|
78 | 77 | "\n",
|
79 | 78 | "def _dispatch(f): return dispatch(_annot_tuple_to_union(f))\n",
|
80 | 79 | "\n",
|
| 80 | + "def _pf_dispatch(pf, f): return pf.dispatch(_annot_tuple_to_union(f))\n", |
| 81 | + "\n", |
81 | 82 | "def _dispatch_method(f, cls):\n",
|
82 |
| - " f = copy(f)\n", |
83 | 83 | " n = f.__name__\n",
|
| 84 | + " # Use __dict__ to avoid searching base classes\n", |
| 85 | + " if n in cls.__dict__: return _pf_dispatch(getattr(cls, n), f)\n", |
| 86 | + " f = copy(f)\n", |
84 | 87 | " # plum uses __qualname__ to infer f's owner\n",
|
85 | 88 | " f.__qualname__ = f'{cls.__name__}.{n}'\n",
|
86 | 89 | " pf = _dispatch(f)\n",
|
|
89 | 92 | " # since we assign after class creation, __set_name__ must be called directly\n",
|
90 | 93 | " # source: https://docs.python.org/3/reference/datamodel.html#object.__set_name__\n",
|
91 | 94 | " pf.__set_name__(cls, n)\n",
|
92 |
| - " return pf\n", |
93 |
| - "\n", |
94 |
| - "def _pf_dispatch(pf, f): return pf.dispatch(_annot_tuple_to_union(f))" |
| 95 | + " return pf" |
95 | 96 | ]
|
96 | 97 | },
|
97 | 98 | {
|
|
225 | 226 | "test_eq(_pf_repr(_f), '_f1: (int,dict[str,float]) -> float\\n_f2: (int,tuple[str,float]) -> float')"
|
226 | 227 | ]
|
227 | 228 | },
|
| 229 | + { |
| 230 | + "cell_type": "code", |
| 231 | + "execution_count": null, |
| 232 | + "metadata": {}, |
| 233 | + "outputs": [], |
| 234 | + "source": [ |
| 235 | + "#export\n", |
| 236 | + "def _union_to_tuple(t):\n", |
| 237 | + " return t.__args__ if getattr(t,'__origin__',None) is Union else t" |
| 238 | + ] |
| 239 | + }, |
| 240 | + { |
| 241 | + "cell_type": "code", |
| 242 | + "execution_count": null, |
| 243 | + "metadata": {}, |
| 244 | + "outputs": [], |
| 245 | + "source": [ |
| 246 | + "test_eq(_union_to_tuple(Union[int,Union[str,None]]), (int,str,NoneType))\n", |
| 247 | + "test_eq(_union_to_tuple(Tuple[int,str]), Tuple[int,str])\n", |
| 248 | + "test_eq(_union_to_tuple(int), int)" |
| 249 | + ] |
| 250 | + }, |
228 | 251 | {
|
229 | 252 | "cell_type": "code",
|
230 | 253 | "execution_count": null,
|
|
247 | 270 | " _pf_dispatch(self.encodes, enc)\n",
|
248 | 271 | " self.order = getattr(enc,'order',self.order)\n",
|
249 | 272 | " if len(type_hints(enc)) > 0:\n",
|
250 |
| - " self.input_types = first(type_hints(enc).values())\n", |
251 | 273 | " # Convert Union to tuple, remove once the rest of fastai supports Union\n",
|
252 |
| - " if get_origin(self.input_types) is Union: self.input_types=get_args(self.input_types)\n", |
| 274 | + " self.input_types = _union_to_tuple(first(type_hints(enc).values()))\n", |
253 | 275 | " self._name = _get_name(enc)\n",
|
254 | 276 | " if dec: _pf_dispatch(self.decodes, dec)\n",
|
255 | 277 | "\n",
|
|
520 | 542 | "`Transform` can be used as a decorator to turn a function into a `Transform`."
|
521 | 543 | ]
|
522 | 544 | },
|
523 |
| - { |
524 |
| - "cell_type": "code", |
525 |
| - "execution_count": null, |
526 |
| - "metadata": {}, |
527 |
| - "outputs": [], |
528 |
| - "source": [ |
529 |
| - "from nbdev.showdoc import _format_cls_doc, _format_func_doc" |
530 |
| - ] |
531 |
| - }, |
532 | 545 | {
|
533 | 546 | "cell_type": "code",
|
534 | 547 | "execution_count": null,
|
|
680 | 693 | "test_eq(f(['a','b','c']), \"['a', 'b', 'c']_1\") # input is of type list"
|
681 | 694 | ]
|
682 | 695 | },
|
683 |
| - { |
684 |
| - "cell_type": "code", |
685 |
| - "execution_count": null, |
686 |
| - "metadata": {}, |
687 |
| - "outputs": [], |
688 |
| - "source": [ |
689 |
| - "@Transform\n", |
690 |
| - "def f(x:(int,float)): return x+1\n", |
691 |
| - "test_eq(f(0), 1)\n", |
692 |
| - "test_eq(f('a'), 'a')" |
693 |
| - ] |
694 |
| - }, |
695 | 696 | {
|
696 | 697 | "cell_type": "markdown",
|
697 | 698 | "metadata": {},
|
|
929 | 930 | "test_eq(f.decode(t), [1,2])"
|
930 | 931 | ]
|
931 | 932 | },
|
932 |
| - { |
933 |
| - "cell_type": "code", |
934 |
| - "execution_count": null, |
935 |
| - "metadata": {}, |
936 |
| - "outputs": [], |
937 |
| - "source": [ |
938 |
| - "def encodes(self, x): pass" |
939 |
| - ] |
940 |
| - }, |
941 |
| - { |
942 |
| - "cell_type": "code", |
943 |
| - "execution_count": null, |
944 |
| - "metadata": {}, |
945 |
| - "outputs": [ |
946 |
| - { |
947 |
| - "data": { |
948 |
| - "text/plain": [ |
949 |
| - "Promise(obj=<function <function AL.encodes at 0x11e3fb670> with 2 method(s)>)" |
950 |
| - ] |
951 |
| - }, |
952 |
| - "execution_count": null, |
953 |
| - "metadata": {}, |
954 |
| - "output_type": "execute_result" |
955 |
| - } |
956 |
| - ], |
957 |
| - "source": [ |
958 |
| - "AL(encodes)" |
959 |
| - ] |
960 |
| - }, |
961 |
| - { |
962 |
| - "cell_type": "code", |
963 |
| - "execution_count": null, |
964 |
| - "metadata": {}, |
965 |
| - "outputs": [ |
966 |
| - { |
967 |
| - "data": { |
968 |
| - "text/plain": [ |
969 |
| - "<lambda>:\n", |
970 |
| - "encodes: <lambda>: (object) -> object\n", |
971 |
| - "decodes: identity: (object) -> object" |
972 |
| - ] |
973 |
| - }, |
974 |
| - "execution_count": null, |
975 |
| - "metadata": {}, |
976 |
| - "output_type": "execute_result" |
977 |
| - } |
978 |
| - ], |
979 |
| - "source": [ |
980 |
| - "AL(lambda x: x)" |
981 |
| - ] |
982 |
| - }, |
983 |
| - { |
984 |
| - "cell_type": "code", |
985 |
| - "execution_count": null, |
986 |
| - "metadata": {}, |
987 |
| - "outputs": [ |
988 |
| - { |
989 |
| - "data": { |
990 |
| - "text/plain": [ |
991 |
| - "__main__.AL" |
992 |
| - ] |
993 |
| - }, |
994 |
| - "execution_count": null, |
995 |
| - "metadata": {}, |
996 |
| - "output_type": "execute_result" |
997 |
| - } |
998 |
| - ], |
999 |
| - "source": [ |
1000 |
| - "type(AL(lambda x: x))" |
1001 |
| - ] |
1002 |
| - }, |
1003 | 933 | {
|
1004 | 934 | "cell_type": "markdown",
|
1005 | 935 | "metadata": {},
|
|
0 commit comments