Skip to content

Commit 19dc410

Browse files
committed
Add torch wrappers for ones(), zeros(), and empty()
This allows shape to be passed as a keyword argument. Fixes #22.
1 parent e2203df commit 19dc410

File tree

1 file changed

+25
-3
lines changed

1 file changed

+25
-3
lines changed

array_api_compat/torch/_aliases.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,28 @@ def full(shape: Union[int, Tuple[int, ...]],
519519

520520
return torch.full(shape, fill_value, dtype=dtype, device=device, **kwargs)
521521

522+
# ones, zeros, and empty do not accept shape as a keyword argument
523+
def ones(shape: Union[int, Tuple[int, ...]],
524+
*,
525+
dtype: Optional[Dtype] = None,
526+
device: Optional[Device] = None,
527+
**kwargs) -> array:
528+
return torch.ones(shape, dtype=dtype, device=device, **kwargs)
529+
530+
def zeros(shape: Union[int, Tuple[int, ...]],
531+
*,
532+
dtype: Optional[Dtype] = None,
533+
device: Optional[Device] = None,
534+
**kwargs) -> array:
535+
return torch.zeros(shape, dtype=dtype, device=device, **kwargs)
536+
537+
def empty(shape: Union[int, Tuple[int, ...]],
538+
*,
539+
dtype: Optional[Dtype] = None,
540+
device: Optional[Device] = None,
541+
**kwargs) -> array:
542+
return torch.empty(shape, dtype=dtype, device=device, **kwargs)
543+
522544
# Functions that aren't in torch https://github.com/pytorch/pytorch/issues/58742
523545
def expand_dims(x: array, /, *, axis: int = 0) -> array:
524546
return torch.unsqueeze(x, axis)
@@ -585,7 +607,7 @@ def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int],
585607
'logaddexp', 'multiply', 'not_equal', 'pow', 'remainder',
586608
'subtract', 'max', 'min', 'sort', 'prod', 'sum', 'any', 'all',
587609
'mean', 'std', 'var', 'concat', 'squeeze', 'flip', 'roll',
588-
'nonzero', 'where', 'arange', 'eye', 'linspace', 'full',
589-
'expand_dims', 'astype', 'broadcast_arrays', 'unique_all',
590-
'unique_counts', 'unique_inverse', 'unique_values',
610+
'nonzero', 'where', 'arange', 'eye', 'linspace', 'full', 'ones',
611+
'zeros', 'empty', 'expand_dims', 'astype', 'broadcast_arrays',
612+
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
591613
'matmul', 'matrix_transpose', 'vecdot', 'tensordot']

0 commit comments

Comments
 (0)