15
15
T = TypeVar ("T" , bound = NBitBase )
16
16
17
17
Tensor = Union ["NDArray[np.number[T]]" , "NDArray[np.bool_]" ]
18
-
18
+ FloatIntOrBool = Literal [ "float" , "int" , "bool" ]
19
19
log = logging .getLogger (__name__ )
20
20
21
21
@@ -53,7 +53,7 @@ def __init__(self, **kwargs: dict[str, str]):
53
53
self .name = 'numpy'
54
54
self .precision = kwargs .get ('precision' , '64b' )
55
55
self .dtypemap : Mapping [
56
- Literal [ 'float' , 'int' , 'bool' ] ,
56
+ FloatIntOrBool ,
57
57
DTypeLike , # Type[np.floating[T]] | Type[np.integer[T]] | Type[np.bool_],
58
58
] = {
59
59
'float' : np .float64 if self .precision == '64b' else np .float32 ,
@@ -206,7 +206,7 @@ def isfinite(self, tensor: Tensor[T]) -> NDArray[np.bool_]:
206
206
return np .isfinite (tensor )
207
207
208
208
def astensor (
209
- self , tensor_in : ArrayLike , dtype : Literal [ 'float' ] = 'float'
209
+ self , tensor_in : ArrayLike , dtype : FloatIntOrBool = 'float'
210
210
) -> ArrayLike :
211
211
"""
212
212
Convert to a NumPy array.
@@ -247,9 +247,7 @@ def product(self, tensor_in: Tensor[T], axis: Shape | None = None) -> ArrayLike:
247
247
def abs (self , tensor : Tensor [T ]) -> ArrayLike :
248
248
return np .abs (tensor )
249
249
250
- def ones (
251
- self , shape : Shape , dtype : Literal ["float" , "int" , "bool" ] = "float"
252
- ) -> ArrayLike :
250
+ def ones (self , shape : Shape , dtype : FloatIntOrBool = "float" ) -> ArrayLike :
253
251
try :
254
252
dtype_obj = self .dtypemap [dtype ]
255
253
except KeyError :
@@ -261,9 +259,7 @@ def ones(
261
259
262
260
return np .ones (shape , dtype = dtype_obj )
263
261
264
- def zeros (
265
- self , shape : Shape , dtype : Literal ["float" , "int" , "bool" ] = "float"
266
- ) -> ArrayLike :
262
+ def zeros (self , shape : Shape , dtype : FloatIntOrBool = "float" ) -> ArrayLike :
267
263
try :
268
264
dtype_obj = self .dtypemap [dtype ]
269
265
except KeyError :
0 commit comments