1
1
import sys
2
- from typing import Generic , TypeVar
3
- from typing_extensions import Annotated
2
+ from typing import Generic , TypeVar , Any
4
3
4
+ from typing_extensions import Annotated
5
+ from packaging import version
5
6
from pydantic import Field
6
7
from pydantic .fields import ModelField
7
8
import numpy as np
8
9
9
10
Cuid = Annotated [str , Field (min_length = 25 , max_length = 25 )]
10
- DType = TypeVar ('DType' )
11
11
12
+ DType = TypeVar ('DType' )
13
+ DShape = TypeVar ('DShape' )
12
14
13
- class TypedArray (Generic [DType ]):
14
15
15
- def __new__ (cls , * args , ** kwargs ):
16
- return np .ndarray (* args , ** kwargs )
16
+ class _TypedArray (np .ndarray , Generic [DType , DShape ]):
17
17
18
18
@classmethod
19
19
def __get_validators__ (cls ):
@@ -25,12 +25,19 @@ def validate(cls, val, field: ModelField):
25
25
raise TypeError (f"Expected numpy array. Found { type (val )} " )
26
26
27
27
if sys .version_info .minor > 6 :
28
- actual_dtype = field .sub_fields [0 ].type_ .__args__ [0 ]
28
+ actual_dtype = field .sub_fields [1 ].type_ .__args__ [0 ]
29
29
else :
30
- actual_dtype = field .sub_fields [0 ].type_ .__values__ [0 ]
30
+ actual_dtype = field .sub_fields [1 ].type_ .__values__ [0 ]
31
31
32
32
if val .dtype != actual_dtype :
33
33
raise TypeError (
34
34
f"Expected numpy array have type { actual_dtype } . Found { val .dtype } "
35
35
)
36
36
return val
37
+
38
+
39
+ if version .parse (np .__version__ ) >= version .parse ('1.22.0' ):
40
+ from numpy .typing import _GenericAlias
41
+ TypedArray = _GenericAlias (_TypedArray , (Any , DType ))
42
+ else :
43
+ TypedArray = _TypedArray [Any , DType ]
0 commit comments