Skip to content

Commit f6f8f08

Browse files
author
Matt Sokoloff
committed
fix numpy types
1 parent 9fa271a commit f6f8f08

File tree

1 file changed

+15
-8
lines changed

1 file changed

+15
-8
lines changed
Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
import sys
2-
from typing import Generic, TypeVar
3-
from typing_extensions import Annotated
2+
from typing import Generic, TypeVar, Any
43

4+
from typing_extensions import Annotated
5+
from packaging import version
56
from pydantic import Field
67
from pydantic.fields import ModelField
78
import numpy as np
89

910
Cuid = Annotated[str, Field(min_length=25, max_length=25)]
10-
DType = TypeVar('DType')
1111

12+
DType = TypeVar('DType')
13+
DShape = TypeVar('DShape')
1214

13-
class TypedArray(Generic[DType]):
1415

15-
def __new__(cls, *args, **kwargs):
16-
return np.ndarray(*args, **kwargs)
16+
class _TypedArray(np.ndarray, Generic[DType, DShape]):
1717

1818
@classmethod
1919
def __get_validators__(cls):
@@ -25,12 +25,19 @@ def validate(cls, val, field: ModelField):
2525
raise TypeError(f"Expected numpy array. Found {type(val)}")
2626

2727
if sys.version_info.minor > 6:
28-
actual_dtype = field.sub_fields[0].type_.__args__[0]
28+
actual_dtype = field.sub_fields[1].type_.__args__[0]
2929
else:
30-
actual_dtype = field.sub_fields[0].type_.__values__[0]
30+
actual_dtype = field.sub_fields[1].type_.__values__[0]
3131

3232
if val.dtype != actual_dtype:
3333
raise TypeError(
3434
f"Expected numpy array have type {actual_dtype}. Found {val.dtype}"
3535
)
3636
return val
37+
38+
39+
if version.parse(np.__version__) >= version.parse('1.22.0'):
40+
from numpy.typing import _GenericAlias
41+
TypedArray = _GenericAlias(_TypedArray, (Any, DType))
42+
else:
43+
TypedArray = _TypedArray[Any, DType]

0 commit comments

Comments
 (0)