1
1
import sys
2
- import logging
3
- from typing import Generic , TypeVar
4
- from typing_extensions import Annotated
2
+ from typing import Generic , TypeVar , Any
5
3
4
+ from typing_extensions import Annotated
5
+ from packaging import version
6
6
from pydantic import Field
7
7
from pydantic .fields import ModelField
8
8
import numpy as np
9
9
10
10
Cuid = Annotated [str , Field (min_length = 25 , max_length = 25 )]
11
11
12
12
DType = TypeVar ('DType' )
13
-
14
- logger = logging .getLogger (__name__ )
13
+ DShape = TypeVar ('DShape' )
15
14
16
15
17
- class TypedArray (np .ndarray , Generic [DType ]):
16
+ class _TypedArray (np .ndarray , Generic [DType , DShape ]):
18
17
19
18
@classmethod
20
19
def __get_validators__ (cls ):
@@ -26,12 +25,19 @@ def validate(cls, val, field: ModelField):
26
25
raise TypeError (f"Expected numpy array. Found { type (val )} " )
27
26
28
27
if sys .version_info .minor > 6 :
29
- actual_dtype = field .sub_fields [0 ].type_ .__args__ [0 ]
28
+ actual_dtype = field .sub_fields [- 1 ].type_ .__args__ [0 ]
30
29
else :
31
- actual_dtype = field .sub_fields [0 ].type_ .__values__ [0 ]
30
+ actual_dtype = field .sub_fields [- 1 ].type_ .__values__ [0 ]
32
31
33
32
if val .dtype != actual_dtype :
34
33
raise TypeError (
35
34
f"Expected numpy array have type { actual_dtype } . Found { val .dtype } "
36
35
)
37
36
return val
37
+
38
+
39
+ if version .parse (np .__version__ ) >= version .parse ('1.22.0' ):
40
+ from numpy .typing import _GenericAlias
41
+ TypedArray = _GenericAlias (_TypedArray , (Any , DType ))
42
+ else :
43
+ TypedArray = _TypedArray [Any , DType ]
0 commit comments