7
7
from typing import TYPE_CHECKING
8
8
if TYPE_CHECKING :
9
9
from typing import Optional , Tuple , Union
10
- from numpy import ndarray , dtype
10
+ from . _typing import ndarray , Device , Dtype , NestedSequence , SupportsBufferProtocol
11
11
12
12
from typing import NamedTuple
13
13
@@ -107,7 +107,7 @@ def unique_values(x: ndarray, /) -> ndarray:
107
107
equal_nan = False ,
108
108
)
109
109
110
- def astype (x : ndarray , dtype : dtype , / , * , copy : bool = True ) -> ndarray :
110
+ def astype (x : ndarray , dtype : Dtype , / , * , copy : bool = True ) -> ndarray :
111
111
if not copy and dtype == x .dtype :
112
112
return x
113
113
return x .astype (dtype = dtype , copy = copy )
@@ -138,6 +138,136 @@ def var(
138
138
def permute_dims (x : ndarray , / , axes : Tuple [int , ...]) -> ndarray :
139
139
return np .transpose (x , axes )
140
140
141
+ # Creation functions add the device keyword (which does nothing for NumPy)
142
+
143
+ def _check_device (device ):
144
+ if device not in ["cpu" , None ]:
145
+ raise ValueError (f"Unsupported device { device !r} " )
146
+
147
+ def asarray (
148
+ obj : Union [
149
+ ndarray ,
150
+ bool ,
151
+ int ,
152
+ float ,
153
+ NestedSequence [bool | int | float ],
154
+ SupportsBufferProtocol ,
155
+ ],
156
+ / ,
157
+ * ,
158
+ dtype : Optional [Dtype ] = None ,
159
+ device : Optional [Device ] = None ,
160
+ copy : Optional [Union [bool , np ._CopyMode ]] = None ,
161
+ ) -> ndarray :
162
+ _check_device (device )
163
+ if copy in (False , np ._CopyMode .IF_NEEDED ):
164
+ # copy=False is not yet implemented in np.asarray
165
+ raise NotImplementedError ("copy=False is not yet implemented" )
166
+ return np .asarray (obj , dtype = dtype )
167
+
168
+ def arange (
169
+ start : Union [int , float ],
170
+ / ,
171
+ stop : Optional [Union [int , float ]] = None ,
172
+ step : Union [int , float ] = 1 ,
173
+ * ,
174
+ dtype : Optional [Dtype ] = None ,
175
+ device : Optional [Device ] = None ,
176
+ ) -> ndarray :
177
+ _check_device (device )
178
+ return np .arange (start , stop = stop , step = step , dtype = dtype )
179
+
180
+ def empty (
181
+ shape : Union [int , Tuple [int , ...]],
182
+ * ,
183
+ dtype : Optional [Dtype ] = None ,
184
+ device : Optional [Device ] = None ,
185
+ ) -> ndarray :
186
+ _check_device (device )
187
+ return np .empty (shape , dtype = dtype )
188
+
189
+ def empty_like (
190
+ x : ndarray , / , * , dtype : Optional [Dtype ] = None , device : Optional [Device ] = None
191
+ ) -> ndarray :
192
+ _check_device (device )
193
+ return np .empty_like (x , dtype = dtype )
194
+
195
+ def eye (
196
+ n_rows : int ,
197
+ n_cols : Optional [int ] = None ,
198
+ / ,
199
+ * ,
200
+ k : int = 0 ,
201
+ dtype : Optional [Dtype ] = None ,
202
+ device : Optional [Device ] = None ,
203
+ ) -> ndarray :
204
+ _check_device (device )
205
+ return np .eye (n_rows , M = n_cols , k = k , dtype = dtype )
206
+
207
+ def full (
208
+ shape : Union [int , Tuple [int , ...]],
209
+ fill_value : Union [int , float ],
210
+ * ,
211
+ dtype : Optional [Dtype ] = None ,
212
+ device : Optional [Device ] = None ,
213
+ ) -> ndarray :
214
+ _check_device (device )
215
+ return np .full (shape , fill_value , dtype = dtype )
216
+
217
+ def full_like (
218
+ x : ndarray ,
219
+ / ,
220
+ fill_value : Union [int , float ],
221
+ * ,
222
+ dtype : Optional [Dtype ] = None ,
223
+ device : Optional [Device ] = None ,
224
+ ) -> ndarray :
225
+ _check_device (device )
226
+ return np .full_like (x , fill_value , dtype = dtype )
227
+
228
+ def linspace (
229
+ start : Union [int , float ],
230
+ stop : Union [int , float ],
231
+ / ,
232
+ num : int ,
233
+ * ,
234
+ dtype : Optional [Dtype ] = None ,
235
+ device : Optional [Device ] = None ,
236
+ endpoint : bool = True ,
237
+ ) -> ndarray :
238
+ _check_device (device )
239
+ return np .linspace (start , stop , num , dtype = dtype , endpoint = endpoint )
240
+
241
+ def ones (
242
+ shape : Union [int , Tuple [int , ...]],
243
+ * ,
244
+ dtype : Optional [Dtype ] = None ,
245
+ device : Optional [Device ] = None ,
246
+ ) -> ndarray :
247
+ _check_device (device )
248
+ return np .ones (shape , dtype = dtype )
249
+
250
+ def ones_like (
251
+ x : ndarray , / , * , dtype : Optional [Dtype ] = None , device : Optional [Device ] = None
252
+ ) -> ndarray :
253
+ _check_device (device )
254
+ return np .ones_like (x , dtype = dtype )
255
+
256
+ def zeros (
257
+ shape : Union [int , Tuple [int , ...]],
258
+ * ,
259
+ dtype : Optional [Dtype ] = None ,
260
+ device : Optional [Device ] = None ,
261
+ ) -> ndarray :
262
+ _check_device (device )
263
+ return np .zeros (shape , dtype = dtype )
264
+
265
+ def zeros_like (
266
+ x : ndarray , / , * , dtype : Optional [Dtype ] = None , device : Optional [Device ] = None
267
+ ) -> ndarray :
268
+ _check_device (device )
269
+ return np .zeros_like (x , dtype = dtype )
270
+
141
271
# from numpy import * doesn't overwrite these builtin names
142
272
from numpy import abs , max , min , round
143
273
@@ -146,4 +276,6 @@ def permute_dims(x: ndarray, /, axes: Tuple[int, ...]) -> ndarray:
146
276
'bool' , 'concat' , 'pow' , 'UniqueAllResult' , 'UniqueCountsResult' ,
147
277
'UniqueInverseResult' , 'unique_all' , 'unique_counts' ,
148
278
'unique_inverse' , 'unique_values' , 'astype' , 'abs' , 'max' , 'min' ,
149
- 'round' , 'std' , 'var' , 'permute_dims' ]
279
+ 'round' , 'std' , 'var' , 'permute_dims' , 'asarray' , 'arange' ,
280
+ 'empty' , 'empty_like' , 'eye' , 'full' , 'full_like' , 'linspace' ,
281
+ 'ones' , 'ones_like' , 'zeros' , 'zeros_like' ]
0 commit comments