Skip to content

Commit ab330c8

Browse files
committed
Make the torch nonzero() raise an exception on zero-dimensional arrays
1 parent 7661820 commit ab330c8

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

array_api_compat/torch/_aliases.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,8 @@ def roll(x: array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Unio
475475
return torch.roll(x, shift, axis, **kwargs)
476476

477477
def nonzero(x: array, /, **kwargs) -> Tuple[array, ...]:
478+
if x.ndim == 0:
479+
raise ValueError("nonzero() does not support zero-dimensional arrays")
478480
return torch.nonzero(x, as_tuple=True, **kwargs)
479481

480482
def where(condition: array, x1: array, x2: array, /) -> array:

0 commit comments

Comments
 (0)