From deb508b68193f85d587c29d01cc78abd39773cad Mon Sep 17 00:00:00 2001 From: seem Date: Mon, 30 May 2022 13:45:37 +0200 Subject: [PATCH] refactor `_TypeDict` Make it more dict-like with fewer lines where appropriate: - Rename `add` to `__setitem__` - Add `setdefault` which cleans up `TypeDispatch.add` Also renamed args for consistency: vars referring to keys are named `t` since they will always be types, and vars referring to values are named `v` (or `vs`) because they may be functions or nested `_TypeDict`s. --- fastcore/dispatch.py | 45 +++++++++++++++------------------- nbs/04_dispatch.ipynb | 56 ++++++++++++++++--------------------------- 2 files changed, 40 insertions(+), 61 deletions(-) diff --git a/fastcore/dispatch.py b/fastcore/dispatch.py index d1fe07a1..e4bfac22 100644 --- a/fastcore/dispatch.py +++ b/fastcore/dispatch.py @@ -24,7 +24,7 @@ def lenient_issubclass(cls, types): # Cell def sorted_topologically(iterable, *, cmp=operator.lt, reverse=False): - "Return a new list containing all items from the iterable sorted topologically" + "Return a new list containing all items from `iterable` sorted topologically." l,res = L(list(iterable)),[] for _ in range(len(l)): t = l.reduce(lambda x,y: y if cmp(y,x) else x) @@ -43,7 +43,7 @@ def _chk_defaults(f, ann): # Cell def _p2_anno(f): - "Get the 1st 2 annotations of `f`, defaulting to `object`" + "Get 1st 2 annotations of `f`, defaulting to `object`" hints = type_hints(f) ann = [o for n,o in hints.items() if n!='return'] if callable(f): _chk_defaults(f, ann) @@ -52,31 +52,28 @@ def _p2_anno(f): # Cell class _TypeDict: + "Dict-like keyed by types and matched by `lenient_issubclass`" def __init__(self): self.d,self.cache = {},{} - def _reset(self): - self.d = {k:self.d[k] for k in sorted_topologically(self.d, cmp=lenient_issubclass)} + def __setitem__(self, t, v): + "Map type `t` (tuple interpreted as union) to `v`" + for t_ in L(t): self.d[t_] = v + self.d = {k:self.d[k] for k in sorted_topologically(self.d,cmp=lenient_issubclass)} self.cache = {} - def add(self, t, f): - "Add type `t` and function `f`" - if not isinstance(t,tuple): t=tuple(L(t)) - for t_ in t: self.d[t_] = f - self._reset() - - def all_matches(self, k): - "Find first matching type that is a super-class of `k`" - if k not in self.cache: - types = [f for f in self.d if lenient_issubclass(k,f)] - self.cache[k] = [self.d[o] for o in types] - return self.cache[k] + def setdefault(self, t, default): + v = self.d.get(t) + if v is None: v = self[t] = default + return v - def __getitem__(self, k): - "Find first matching type that is a super-class of `k`" - res = self.all_matches(k) - return res[0] if len(res) else None + def all_matches(self, t): + "Find all values matching types that are a super-class of `t`" + vs = self.cache.get(t) + if vs is None: vs = self.cache[t] = [v for k, v in self.d.items() if lenient_issubclass(t,k)] + return vs - def __repr__(self): return self.d.__repr__() + def __getitem__(self, t): return first(self.all_matches(t)) + def __repr__(self): return repr(self.d) def first(self): return first(self.d.values()) # Cell @@ -92,11 +89,7 @@ def add(self, f): "Add type `t` and function `f`" if isinstance(f, staticmethod): a0,a1 = _p2_anno(f.__func__) else: a0,a1 = _p2_anno(f) - t = self.funcs.d.get(a0) - if t is None: - t = _TypeDict() - self.funcs.add(a0, t) - t.add(a1, f) + self.funcs.setdefault(a0,_TypeDict())[a1] = f def first(self): "Get first function in ordered dict of type:func." diff --git a/nbs/04_dispatch.ipynb b/nbs/04_dispatch.ipynb index 2c38407d..2ee4e2bb 100644 --- a/nbs/04_dispatch.ipynb +++ b/nbs/04_dispatch.ipynb @@ -87,7 +87,7 @@ "source": [ "#export\n", "def sorted_topologically(iterable, *, cmp=operator.lt, reverse=False):\n", - " \"Return a new list containing all items from the iterable sorted topologically\"\n", + " \"Return a new list containing all items from `iterable` sorted topologically.\"\n", " l,res = L(list(iterable)),[]\n", " for _ in range(len(l)):\n", " t = l.reduce(lambda x,y: y if cmp(y,x) else x)\n", @@ -152,7 +152,7 @@ "source": [ "#export\n", "def _p2_anno(f):\n", - " \"Get the 1st 2 annotations of `f`, defaulting to `object`\"\n", + " \"Get 1st 2 annotations of `f`, defaulting to `object`\"\n", " hints = type_hints(f)\n", " ann = [o for n,o in hints.items() if n!='return']\n", " if callable(f): _chk_defaults(f, ann)\n", @@ -251,31 +251,28 @@ "source": [ "#export\n", "class _TypeDict:\n", + " \"Dict-like keyed by types and matched by `lenient_issubclass`\"\n", " def __init__(self): self.d,self.cache = {},{}\n", "\n", - " def _reset(self):\n", - " self.d = {k:self.d[k] for k in sorted_topologically(self.d, cmp=lenient_issubclass)}\n", + " def __setitem__(self, t, v):\n", + " \"Map type `t` (tuple interpreted as union) to `v`\"\n", + " for t_ in L(t): self.d[t_] = v\n", + " self.d = {k:self.d[k] for k in sorted_topologically(self.d,cmp=lenient_issubclass)}\n", " self.cache = {}\n", "\n", - " def add(self, t, f):\n", - " \"Add type `t` and function `f`\"\n", - " if not isinstance(t,tuple): t=tuple(L(t))\n", - " for t_ in t: self.d[t_] = f\n", - " self._reset()\n", - "\n", - " def all_matches(self, k):\n", - " \"Find first matching type that is a super-class of `k`\"\n", - " if k not in self.cache:\n", - " types = [f for f in self.d if lenient_issubclass(k,f)]\n", - " self.cache[k] = [self.d[o] for o in types]\n", - " return self.cache[k]\n", + " def setdefault(self, t, default):\n", + " v = self.d.get(t)\n", + " if v is None: v = self[t] = default\n", + " return v\n", "\n", - " def __getitem__(self, k):\n", - " \"Find first matching type that is a super-class of `k`\"\n", - " res = self.all_matches(k)\n", - " return res[0] if len(res) else None\n", + " def all_matches(self, t):\n", + " \"Find all values matching types that are a super-class of `t`\"\n", + " vs = self.cache.get(t)\n", + " if vs is None: vs = self.cache[t] = [v for k, v in self.d.items() if lenient_issubclass(t,k)]\n", + " return vs\n", "\n", - " def __repr__(self): return self.d.__repr__()\n", + " def __getitem__(self, t): return first(self.all_matches(t))\n", + " def __repr__(self): return repr(self.d)\n", " def first(self): return first(self.d.values())" ] }, @@ -298,11 +295,7 @@ " \"Add type `t` and function `f`\"\n", " if isinstance(f, staticmethod): a0,a1 = _p2_anno(f.__func__)\n", " else: a0,a1 = _p2_anno(f)\n", - " t = self.funcs.d.get(a0)\n", - " if t is None:\n", - " t = _TypeDict()\n", - " self.funcs.add(a0, t)\n", - " t.add(a1, f)\n", + " self.funcs.setdefault(a0,_TypeDict())[a1] = f\n", "\n", " def first(self):\n", " \"Get first function in ordered dict of type:func.\"\n", @@ -795,7 +788,7 @@ { "data": { "text/markdown": [ - "

TypeDispatch.__call__[source]

\n", + "

TypeDispatch.__call__[source]

\n", "\n", "> TypeDispatch.__call__(**\\*`args`**, **\\*\\*`kwargs`**)\n", "\n", @@ -883,7 +876,7 @@ { "data": { "text/markdown": [ - "

TypeDispatch.returns[source]

\n", + "

TypeDispatch.returns[source]

\n", "\n", "> TypeDispatch.returns(**`x`**)\n", "\n", @@ -1409,13 +1402,6 @@ "from nbdev.export import notebook2script\n", "notebook2script()" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": {