@@ -519,6 +519,28 @@ def full(shape: Union[int, Tuple[int, ...]],
519
519
520
520
return torch .full (shape , fill_value , dtype = dtype , device = device , ** kwargs )
521
521
522
+ # ones, zeros, and empty do not accept shape as a keyword argument
523
+ def ones (shape : Union [int , Tuple [int , ...]],
524
+ * ,
525
+ dtype : Optional [Dtype ] = None ,
526
+ device : Optional [Device ] = None ,
527
+ ** kwargs ) -> array :
528
+ return torch .ones (shape , dtype = dtype , device = device , ** kwargs )
529
+
530
+ def zeros (shape : Union [int , Tuple [int , ...]],
531
+ * ,
532
+ dtype : Optional [Dtype ] = None ,
533
+ device : Optional [Device ] = None ,
534
+ ** kwargs ) -> array :
535
+ return torch .zeros (shape , dtype = dtype , device = device , ** kwargs )
536
+
537
+ def empty (shape : Union [int , Tuple [int , ...]],
538
+ * ,
539
+ dtype : Optional [Dtype ] = None ,
540
+ device : Optional [Device ] = None ,
541
+ ** kwargs ) -> array :
542
+ return torch .empty (shape , dtype = dtype , device = device , ** kwargs )
543
+
522
544
# Functions that aren't in torch https://github.com/pytorch/pytorch/issues/58742
523
545
def expand_dims (x : array , / , * , axis : int = 0 ) -> array :
524
546
return torch .unsqueeze (x , axis )
@@ -585,7 +607,7 @@ def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int],
585
607
'logaddexp' , 'multiply' , 'not_equal' , 'pow' , 'remainder' ,
586
608
'subtract' , 'max' , 'min' , 'sort' , 'prod' , 'sum' , 'any' , 'all' ,
587
609
'mean' , 'std' , 'var' , 'concat' , 'squeeze' , 'flip' , 'roll' ,
588
- 'nonzero' , 'where' , 'arange' , 'eye' , 'linspace' , 'full' ,
589
- 'expand_dims ' , 'astype ' , 'broadcast_arrays ' , 'unique_all ' ,
590
- 'unique_counts' , 'unique_inverse' , 'unique_values' ,
610
+ 'nonzero' , 'where' , 'arange' , 'eye' , 'linspace' , 'full' , 'ones' ,
611
+ 'zeros ' , 'empty ' , 'expand_dims ' , 'astype' , 'broadcast_arrays ' ,
612
+ 'unique_all' , ' unique_counts' , 'unique_inverse' , 'unique_values' ,
591
613
'matmul' , 'matrix_transpose' , 'vecdot' , 'tensordot' ]
0 commit comments