Skip to content

Commit 52506c6

Browse files
committed
Disallow left of None, properly account for adding/multiplying cyclers with different types
py39 compat py38 compat flake8
1 parent 7412cb2 commit 52506c6

File tree

1 file changed

+52
-30
lines changed

1 file changed

+52
-30
lines changed

cycler/__init__.py

Lines changed: 52 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -43,21 +43,24 @@
4343

4444
from __future__ import annotations
4545

46-
from collections.abc import Hashable, Iterable
46+
from collections.abc import Hashable, Iterable, Generator
4747
import copy
4848
from functools import reduce
4949
from itertools import product, cycle
5050
from operator import mul, add
51-
from typing import TypeVar, Generic, Generator, Any, overload
51+
# Dict, List, Union required for runtime cast calls
52+
from typing import TypeVar, Generic, Callable, Union, Dict, List, Any, overload, cast
5253

5354
__version__ = "0.12.0.dev0"
5455

5556
K = TypeVar("K", bound=Hashable)
57+
L = TypeVar("L", bound=Hashable)
5658
V = TypeVar("V")
59+
U = TypeVar("U")
5760

5861

5962
def _process_keys(
60-
left: Cycler[K, V] | Iterable[dict[K, V]] | None,
63+
left: Cycler[K, V] | Iterable[dict[K, V]],
6164
right: Cycler[K, V] | Iterable[dict[K, V]] | None,
6265
) -> set[K]:
6366
"""
@@ -73,7 +76,7 @@ def _process_keys(
7376
keys : set
7477
The keys in the composition of the two cyclers.
7578
"""
76-
l_peek: dict[K, V] = next(iter(left)) if left is not None else {}
79+
l_peek: dict[K, V] = next(iter(left)) if left != [] else {}
7780
r_peek: dict[K, V] = next(iter(right)) if right is not None else {}
7881
l_key: set[K] = set(l_peek.keys())
7982
r_key: set[K] = set(r_peek.keys())
@@ -82,7 +85,7 @@ def _process_keys(
8285
return l_key | r_key
8386

8487

85-
def concat(left: Cycler[K, V], right: Cycler[K, V]) -> Cycler[K, V]:
88+
def concat(left: Cycler[K, V], right: Cycler[K, U]) -> Cycler[K, V | U]:
8689
r"""
8790
Concatenate `Cycler`\s, as if chained using `itertools.chain`.
8891
@@ -108,8 +111,8 @@ def concat(left: Cycler[K, V], right: Cycler[K, V]) -> Cycler[K, V]:
108111
both=left.keys & right.keys, just_one=left.keys ^ right.keys
109112
)
110113
)
111-
_l = left.by_key()
112-
_r = right.by_key()
114+
_l = cast(Dict[K, List[Union[V, U]]], left.by_key())
115+
_r = cast(Dict[K, List[Union[V, U]]], right.by_key())
113116
return reduce(add, (_cycler(k, _l[k] + _r[k]) for k in left.keys))
114117

115118

@@ -156,15 +159,15 @@ def __init__(
156159
Do not use this directly, use `cycler` function instead.
157160
"""
158161
if isinstance(left, Cycler):
159-
self._left: Cycler[K, V] | list[dict[K, V]] | None = Cycler(
162+
self._left: Cycler[K, V] | list[dict[K, V]] = Cycler(
160163
left._left, left._right, left._op
161164
)
162165
elif left is not None:
163166
# Need to copy the dictionary or else that will be a residual
164167
# mutable that could lead to strange errors
165168
self._left = [copy.copy(v) for v in left]
166169
else:
167-
self._left = None
170+
self._left = []
168171

169172
if isinstance(right, Cycler):
170173
self._right: Cycler[K, V] | list[dict[K, V]] | None = Cycler(
@@ -220,8 +223,6 @@ def change_key(self, old: K, new: K) -> None:
220223

221224
# self._left should always be non-None
222225
# if self._keys is non-empty.
223-
elif self._left is None:
224-
pass
225226
elif isinstance(self._left, Cycler):
226227
self._left.change_key(old, new)
227228
else:
@@ -264,10 +265,9 @@ def __getitem__(self, key: slice) -> Cycler[K, V]:
264265
raise ValueError("Can only use slices with Cycler.__getitem__")
265266

266267
def __iter__(self) -> Generator[dict[K, V], None, None]:
267-
if self._right is None or self._left is None:
268-
if self._left is not None:
269-
for left in self._left:
270-
yield dict(left)
268+
if self._right is None:
269+
for left in self._left:
270+
yield dict(left)
271271
else:
272272
if self._op is None:
273273
raise TypeError(
@@ -279,7 +279,7 @@ def __iter__(self) -> Generator[dict[K, V], None, None]:
279279
out.update(b)
280280
yield out
281281

282-
def __add__(self, other: Cycler[K, V]) -> Cycler[K, V]:
282+
def __add__(self, other: Cycler[L, U]) -> Cycler[K | L, V | U]:
283283
"""
284284
Pair-wise combine two equal length cyclers (zip).
285285
@@ -291,9 +291,21 @@ def __add__(self, other: Cycler[K, V]) -> Cycler[K, V]:
291291
raise ValueError(
292292
f"Can only add equal length cycles, not {len(self)} and {len(other)}"
293293
)
294-
return Cycler(self, other, zip)
294+
return Cycler(
295+
cast(Cycler[Union[K, L], Union[V, U]], self),
296+
cast(Cycler[Union[K, L], Union[V, U]], other),
297+
zip
298+
)
299+
300+
@overload
301+
def __mul__(self, other: Cycler[L, U]) -> Cycler[K | L, V | U]:
302+
...
303+
304+
@overload
305+
def __mul__(self, other: int) -> Cycler[K, V]:
306+
...
295307

296-
def __mul__(self, other: Cycler[K, V] | int) -> Cycler[K, V]:
308+
def __mul__(self, other):
297309
"""
298310
Outer product of two cyclers (`itertools.product`) or integer
299311
multiplication.
@@ -303,7 +315,11 @@ def __mul__(self, other: Cycler[K, V] | int) -> Cycler[K, V]:
303315
other : Cycler or int
304316
"""
305317
if isinstance(other, Cycler):
306-
return Cycler(self, other, product)
318+
return Cycler(
319+
cast(Cycler[Union[K, L], Union[V, U]], self),
320+
cast(Cycler[Union[K, L], Union[V, U]], other),
321+
product
322+
)
307323
elif isinstance(other, int):
308324
trans = self.by_key()
309325
return reduce(
@@ -312,22 +328,28 @@ def __mul__(self, other: Cycler[K, V] | int) -> Cycler[K, V]:
312328
else:
313329
return NotImplemented
314330

315-
def __rmul__(self, other: Cycler[K, V]) -> Cycler[K, V]:
331+
@overload
332+
def __rmul__(self, other: Cycler[L, U]) -> Cycler[K | L, V | U]:
333+
...
334+
335+
@overload
336+
def __rmul__(self, other: int) -> Cycler[K, V]:
337+
...
338+
339+
def __rmul__(self, other):
316340
return self * other
317341

318342
def __len__(self) -> int:
319-
op_dict = {zip: min, product: mul}
320-
if self._left is None:
321-
if self._left is None:
322-
return 0
323-
return 0
343+
op_dict: dict[Callable, Callable[[int, int], int]] = {zip: min, product: mul}
324344
if self._right is None:
325345
return len(self._left)
326346
l_len = len(self._left)
327347
r_len = len(self._right)
328-
return op_dict[self._op](l_len, r_len) # type: ignore
348+
return op_dict[self._op](l_len, r_len)
329349

330-
def __iadd__(self, other: Cycler[K, V]) -> Cycler[K, V]:
350+
# iadd and imul do not exapand the the type as the returns must be consistent with
351+
# self, thus they flag as inconsistent with add/mul
352+
def __iadd__(self, other: Cycler[K, V]) -> Cycler[K, V]: # type: ignore[misc]
331353
"""
332354
In-place pair-wise combine two equal length cyclers (zip).
333355
@@ -345,7 +367,7 @@ def __iadd__(self, other: Cycler[K, V]) -> Cycler[K, V]:
345367
self._right = Cycler(other._left, other._right, other._op)
346368
return self
347369

348-
def __imul__(self, other: Cycler[K, V] | int) -> Cycler[K, V]:
370+
def __imul__(self, other: Cycler[K, V] | int) -> Cycler[K, V]: # type: ignore[misc]
349371
"""
350372
In-place outer product of two cyclers (`itertools.product`).
351373
@@ -451,7 +473,7 @@ def simplify(self) -> Cycler[K, V]:
451473

452474

453475
@overload
454-
def cycler(args: Cycler[K, V]) -> Cycler[K, V]:
476+
def cycler(arg: Cycler[K, V]) -> Cycler[K, V]:
455477
...
456478

457479

@@ -505,7 +527,7 @@ def cycler(*args, **kwargs):
505527
"""
506528
if args and kwargs:
507529
raise TypeError(
508-
"cyl() can only accept positional OR keyword arguments -- not both."
530+
"cycler() can only accept positional OR keyword arguments -- not both."
509531
)
510532

511533
if len(args) == 1:

0 commit comments

Comments
 (0)