Skip to content

Commit 78b171d

Browse files
authored
Merge pull request #380 from BCG-X-Official/dev/3.0.2
2 parents 8f16c28 + 8cbf1e9 commit 78b171d

File tree

16 files changed

+146
-55
lines changed

16 files changed

+146
-55
lines changed

.idea/pytools.iml

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,6 @@ repos:
5050
language_version: python310
5151
pass_filenames: false
5252
additional_dependencies:
53-
- numpy~=1.24
53+
- numpy~=2.0
5454
- pytest
5555
- packaging

RELEASE_NOTES.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,14 @@ Release Notes
1111
*pytools* 3.0 adds support for language features introduced up to and including
1212
Python 3.10, and drops support for Python versions.
1313

14+
*pytools* 3.0.2
15+
~~~~~~~~~~~~~~~
16+
17+
- BUILD: :mod:`numpy` |nbsp| 2 is now supported
18+
- FIX: :func:`.issubclass_generic` now supports unions, tuples of types, and ``None``,
19+
and uses clearer error messages if called with invalid arguments
20+
21+
1422
*pytools* 3.0.1
1523
~~~~~~~~~~~~~~~
1624

azure-pipelines.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ stages:
8989
- script: |
9090
# package dependencies for mypy
9191
dependencies=(
92-
numpy~=1.24
92+
numpy~=2.0
9393
packaging
9494
pytest
9595
)

environment.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ dependencies:
55
# run
66
- joblib ~= 1.2
77
- matplotlib ~= 3.6
8-
- numpy ~= 1.24
8+
- numpy ~= 2.0
99
- pandas ~= 2.0
1010
- python ~= 3.10.14
1111
- scipy ~= 1.10

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ license = "Apache Software License v2.0"
1616
requires = [
1717
"joblib ~=1.0",
1818
"matplotlib ~=3.6",
19-
"numpy >=1.23,<2a", # cannot use ~= due to conda bug
19+
"numpy >=1.23,<3a", # cannot use ~= due to conda bug
2020
"pandas >=1.5",
2121
"scipy ~=1.9",
2222
"typing_inspect ~=0.7",
@@ -80,7 +80,7 @@ typing_extensions = "~=4.0.0"
8080
# maximum requirements of gamma-pytools
8181
joblib = "~=1.3"
8282
matplotlib = "~=3.8"
83-
numpy = ">=1.26,<2a" # cannot use ~= due to conda bug
83+
numpy = ">=2,<3a" # cannot use ~= due to conda bug
8484
pandas = "~=2.0"
8585
python = ">=3.12,<4a" # cannot use ~= due to conda bug
8686
scipy = "~=1.12"

src/pytools/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
A collection of Python extensions and tools used in BCG GAMMA's open-source libraries.
33
"""
44

5-
__version__ = "3.0.1"
5+
__version__ = "3.0.2"

src/pytools/data/_linkage.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
# Type variables
3333
#
3434

35-
LinkageMatrix = npt.NDArray[np.float_]
35+
LinkageMatrix = npt.NDArray[np.float64]
3636

3737

3838
#

src/pytools/data/_matrix.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
#
3333

3434
T_Matrix = TypeVar("T_Matrix", bound="Matrix[Any]")
35-
T_Number = TypeVar("T_Number", bound="np.number[npt.NBitBase]")
35+
T_Number = TypeVar("T_Number", bound="np.number[Any]")
3636

3737
#
3838
# Ensure all symbols introduced below are included in __all__
@@ -59,7 +59,7 @@ class Matrix(HasExpressionRepr, Generic[T_Number]):
5959
names: tuple[npt.NDArray[Any] | None, npt.NDArray[Any] | None]
6060

6161
#: the weights of the rows and columns
62-
weights: tuple[npt.NDArray[np.float_] | None, npt.NDArray[np.float_] | None]
62+
weights: tuple[npt.NDArray[np.float64] | None, npt.NDArray[np.float64] | None]
6363

6464
#: the labels for the row and column axes
6565
name_labels: tuple[str | None, str | None]
@@ -155,8 +155,8 @@ def _arg_to_array(
155155
else:
156156

157157
def _ensure_positive(
158-
w: npt.NDArray[np.float_] | None, axis: int
159-
) -> npt.NDArray[np.float_] | None:
158+
w: npt.NDArray[np.float64] | None, axis: int
159+
) -> npt.NDArray[np.float64] | None:
160160
if w is not None and (w < 0).any():
161161
raise ValueError(
162162
f"arg weights[{axis}] should be all positive, "
@@ -352,7 +352,7 @@ def _message(error: str) -> str:
352352

353353

354354
def _top_items_mask(
355-
weights: npt.NDArray[np.float_] | None,
355+
weights: npt.NDArray[np.float64] | None,
356356
current_size: int,
357357
target_size: tuple[int | None, float | None],
358358
) -> npt.NDArray[np.bool_]:
@@ -385,7 +385,7 @@ def _top_items_mask(
385385
# THe target weight is expressed as a ratio of total weight
386386
# (0 < target_ratio <= 1).
387387

388-
weights_sorted_cumsum: npt.NDArray[np.float_] = weights[
388+
weights_sorted_cumsum: npt.NDArray[np.float64] = weights[
389389
ix_weights_descending_stable
390390
].cumsum()
391391
mask[
@@ -401,12 +401,12 @@ def _top_items_mask(
401401

402402
def _resize_rows(
403403
values: npt.NDArray[T_Number],
404-
weights: npt.NDArray[np.float_] | None,
404+
weights: npt.NDArray[np.float64] | None,
405405
names: npt.NDArray[Any] | None,
406406
current_size: int,
407407
target_size: tuple[int | None, float | None],
408408
) -> tuple[
409-
npt.NDArray[T_Number], npt.NDArray[np.float_] | None, npt.NDArray[Any] | None
409+
npt.NDArray[T_Number], npt.NDArray[np.float64] | None, npt.NDArray[Any] | None
410410
]:
411411
mask = _top_items_mask(
412412
weights=weights, current_size=current_size, target_size=target_size

src/pytools/typing/_typing.py

Lines changed: 64 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
Sequence,
2727
ValuesView,
2828
)
29-
from types import GenericAlias
29+
from types import GenericAlias, NoneType, UnionType
3030
from typing import (
3131
AbstractSet,
3232
Any,
@@ -373,13 +373,13 @@ def get_type_arguments(obj: Any, base: type) -> list[tuple[type, ...]]:
373373
return list(map(get_args, get_generic_instance(ti.get_generic_type(obj), base)))
374374

375375

376-
def issubclass_generic(subclass: type | Never, base: type | Never) -> bool:
376+
def issubclass_generic(subclass: Any, base: Any) -> bool:
377377
"""
378378
Check if a class is a subclass of a generic instance, i.e., it is a subclass of the
379379
generic class, and has compatible type arguments.
380380
381-
:param subclass: the class to check
382-
:param base: the generic class to check against
381+
:param subclass: the (potentially generic) subclass to check
382+
:param base: the (potentially generic) base class to check against
383383
:return: ``True`` if the class is a subclass of the generic instance, ``False``
384384
otherwise
385385
"""
@@ -396,16 +396,56 @@ def issubclass_generic(subclass: type | Never, base: type | Never) -> bool:
396396
elif base is Never:
397397
return False
398398

399+
# Special case: if the subclass is a union type, check if all types in the union are
400+
# subclasses of the base class
401+
if get_origin(subclass) in (typing.Union, UnionType):
402+
return all(issubclass_generic(arg, base) for arg in get_args(subclass))
403+
404+
# Special case: if the base class is a union type, check if the subclass is a
405+
# subclass of at least one of the types in the union
406+
if get_origin(base) in (typing.Union, UnionType):
407+
return any(issubclass_generic(subclass, arg) for arg in get_args(base))
408+
409+
# Special case: if the base class is a tuple, check if the subclass is a subclass of
410+
# at least one type in the tuple
411+
if isinstance(base, tuple):
412+
try:
413+
return any(issubclass_generic(subclass, arg) for arg in base)
414+
except TypeError as e:
415+
raise TypeError(
416+
f"isinstance_generic() arg 2 must be a type, type-like, or tuple of "
417+
f"types or type-likes, but got {base!r}"
418+
) from e
419+
420+
# Typehints can contain `None` as a shorthand for `NoneType`; replace it with the
421+
# actual type
422+
if subclass is None:
423+
subclass = NoneType
424+
if base is None:
425+
base = NoneType
426+
427+
# Replace deprecated types in typing with their canonical replacements in
428+
# collections.abc
399429
subclass = _replace_deprecated_type(subclass)
400430
base = _replace_deprecated_type(base)
401431

402432
# Get the non-generic origin of the base class
403433
base_origin = get_origin(base) or base
434+
if not isinstance(base_origin, type):
435+
raise TypeError(
436+
f"isinstance_generic() arg 2 must be a type, type-like, or tuple of types "
437+
f"or type-likes, but got {base!r}"
438+
)
404439

405440
# If the non-generic origin of the subclass is not a subclass of the non-generic
406441
# origin of the base class, the subclass cannot be a subclass of the base class
407442
subclass_origin = get_origin(subclass) or subclass
408-
if not issubclass(subclass_origin, base_origin):
443+
if not isinstance(subclass_origin, type):
444+
raise TypeError(
445+
f"isinstance_generic() arg 1 must be a type or type-like, but got "
446+
f"{subclass!r}"
447+
)
448+
elif not issubclass(subclass_origin, base_origin):
409449
return False
410450

411451
# If the base class is not a generic class, there are no type arguments to check
@@ -567,7 +607,7 @@ def _get_origin_parameters(
567607
)
568608

569609

570-
def _replace_deprecated_type(tp: type) -> type:
610+
def _replace_deprecated_type(tp: T) -> T:
571611
"""
572612
Replace deprecated types in :mod:`typing` with their canonical replacements in
573613
:mod:`collections.abc`.
@@ -577,18 +617,23 @@ def _replace_deprecated_type(tp: type) -> type:
577617
deprecated
578618
"""
579619

580-
if tp.__module__ == "typing":
620+
origin: type | None = get_origin(tp)
621+
if (
581622
# Check if the same type is defined in collections.abc
582-
origin: type | None = get_origin(tp)
583-
if origin is not None and origin.__module__ == "collections.abc":
584-
log.warning(
585-
f"Type typing.{tp.__name__} is deprecated; "
586-
f"please use {origin.__module__}.{origin.__name__} instead"
587-
)
588-
args: tuple[type, ...] = get_args(tp)
589-
if args:
590-
# If the type has arguments, apply the same arguments to the replacement
591-
return cast(type, origin[args]) # type: ignore[index]
592-
else:
593-
return origin
623+
origin is not None
624+
and tp.__module__ == "typing"
625+
and origin.__module__ == "collections.abc"
626+
):
627+
log.warning(
628+
"Type typing.%s is deprecated; please use %s.%s instead",
629+
tp.__name__, # type: ignore[attr-defined]
630+
origin.__module__,
631+
origin.__name__,
632+
)
633+
args: tuple[type, ...] = get_args(tp)
634+
if args:
635+
# If the type has arguments, apply the same arguments to the replacement
636+
return cast(T, origin[args]) # type: ignore[index]
637+
else:
638+
return cast(T, origin)
594639
return tp

src/pytools/viz/_matplot.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -304,13 +304,13 @@ def color_for_value(self, z: int | float) -> RgbaColor:
304304
pass
305305

306306
@overload
307-
def color_for_value(self, z: npt.NDArray[np.float_]) -> npt.NDArray[np.float_]:
307+
def color_for_value(self, z: npt.NDArray[np.float64]) -> npt.NDArray[np.float64]:
308308
"""[overload]"""
309309
pass
310310

311311
def color_for_value(
312-
self, z: int | float | npt.NDArray[np.float_]
313-
) -> RgbaColor | npt.NDArray[np.float_]:
312+
self, z: int | float | npt.NDArray[np.float64]
313+
) -> RgbaColor | npt.NDArray[np.float64]:
314314
"""
315315
Get the color(s) associated with the given value(s), based on the color map and
316316
normalization defined for this style.

src/pytools/viz/dendrogram/_style.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def draw_leaf_labels(
222222

223223
def _get_ytick_locations(
224224
self, *, weights: Sequence[float]
225-
) -> npt.NDArray[np.float_]:
225+
) -> npt.NDArray[np.float64]:
226226
"""
227227
Get the tick locations for the y axis.
228228
@@ -231,7 +231,7 @@ def _get_ytick_locations(
231231
"""
232232
weights_array = np.array(weights)
233233
# noinspection PyTypeChecker
234-
ytick_locations: npt.NDArray[np.float_] = -(
234+
ytick_locations: npt.NDArray[np.float64] = -(
235235
np.arange(len(weights)) * self.padding
236236
+ weights_array.cumsum()
237237
- weights_array / 2

src/pytools/viz/matrix/_matrix.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -155,25 +155,25 @@ def draw_matrix(
155155
npt.NDArray[Any] | None,
156156
],
157157
weights: tuple[
158-
npt.NDArray[np.float_] | None,
159-
npt.NDArray[np.float_] | None,
158+
npt.NDArray[np.float64] | None,
159+
npt.NDArray[np.float64] | None,
160160
],
161161
) -> None:
162162
"""[see superclass]"""
163163
ax: Axes = self.ax
164164
colors = self.colors
165165

166-
weights_rows: npt.NDArray[np.float_]
167-
weights_columns: npt.NDArray[np.float_]
166+
weights_rows: npt.NDArray[np.float64]
167+
weights_columns: npt.NDArray[np.float64]
168168
# replace undefined weights with all ones
169169
weights_rows, weights_columns = tuple(
170170
np.ones(n) if w is None else w for w, n in zip(weights, data.shape)
171171
)
172172

173173
# calculate the horizontal and vertical matrix cell bounds based on the
174174
# cumulative sums of the axis weights; default all weights to 1 if not defined
175-
column_bounds: npt.NDArray[np.float_]
176-
row_bounds: npt.NDArray[np.float_]
175+
column_bounds: npt.NDArray[np.float64]
176+
row_bounds: npt.NDArray[np.float64]
177177

178178
row_bounds = -np.array([0, *weights_rows]).cumsum()
179179
column_bounds = np.array([0, *weights_columns]).cumsum()
@@ -200,7 +200,7 @@ def draw_matrix(
200200
# draw the matrix cells
201201
for c, (x0, x1) in enumerate(zip(column_bounds, column_bounds[1:])):
202202
for r, (y1, y0) in enumerate(zip(row_bounds, row_bounds[1:])):
203-
color: npt.NDArray[np.float_] = cell_colors[r, c]
203+
color: npt.NDArray[np.float64] = cell_colors[r, c]
204204
ax.add_patch(
205205
Rectangle(
206206
(
@@ -224,7 +224,7 @@ def draw_matrix(
224224
y_tick_locations = (row_bounds[:-1] + row_bounds[1:]) / 2
225225

226226
def _set_ticks(
227-
tick_locations: npt.NDArray[np.float_],
227+
tick_locations: npt.NDArray[np.float64],
228228
tick_labels: npt.NDArray[Any],
229229
axis: Axis,
230230
tick_params: dict[str, Any],
@@ -461,15 +461,15 @@ def draw_matrix(
461461
npt.NDArray[Any] | None,
462462
],
463463
weights: tuple[
464-
npt.NDArray[np.float_] | None,
465-
npt.NDArray[np.float_] | None,
464+
npt.NDArray[np.float64] | None,
465+
npt.NDArray[np.float64] | None,
466466
],
467467
) -> None:
468468
"""[see superclass]"""
469469

470470
def _axis_marks(
471471
axis_names: npt.NDArray[Any] | None,
472-
axis_weights: npt.NDArray[np.float_] | None,
472+
axis_weights: npt.NDArray[np.float64] | None,
473473
) -> Iterable[str] | None:
474474
axis_names_iter: Iterable[Any]
475475

src/pytools/viz/matrix/base/_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ def draw_matrix(
6767
npt.NDArray[Any] | None,
6868
],
6969
weights: tuple[
70-
npt.NDArray[np.float_] | None,
71-
npt.NDArray[np.float_] | None,
70+
npt.NDArray[np.float64] | None,
71+
npt.NDArray[np.float64] | None,
7272
],
7373
) -> None:
7474
"""

0 commit comments

Comments
 (0)