diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 5b20aabc..a6e833f9 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -3,13 +3,7 @@ from functools import wraps as _wraps from builtins import all as _builtin_all, any as _builtin_any -from ..common._aliases import (matrix_transpose as _aliases_matrix_transpose, - vecdot as _aliases_vecdot, - clip as _aliases_clip, - unstack as _aliases_unstack, - cumulative_sum as _aliases_cumulative_sum, - cumulative_prod as _aliases_cumulative_prod, - ) +from ..common import _aliases from .._internal import get_xp from ._info import __array_namespace_info__ @@ -215,10 +209,10 @@ def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep return torch.clone(x) return torch.amin(x, axis, keepdims=keepdims) -clip = get_xp(torch)(_aliases_clip) -unstack = get_xp(torch)(_aliases_unstack) -cumulative_sum = get_xp(torch)(_aliases_cumulative_sum) -cumulative_prod = get_xp(torch)(_aliases_cumulative_prod) +clip = get_xp(torch)(_aliases.clip) +unstack = get_xp(torch)(_aliases.unstack) +cumulative_sum = get_xp(torch)(_aliases.cumulative_sum) +cumulative_prod = get_xp(torch)(_aliases.cumulative_prod) # torch.sort also returns a tuple # https://github.com/pytorch/pytorch/issues/70921 @@ -710,8 +704,8 @@ def matmul(x1: array, x2: array, /, **kwargs) -> array: x1, x2 = _fix_promotion(x1, x2, only_scalar=False) return torch.matmul(x1, x2, **kwargs) -matrix_transpose = get_xp(torch)(_aliases_matrix_transpose) -_vecdot = get_xp(torch)(_aliases_vecdot) +matrix_transpose = get_xp(torch)(_aliases.matrix_transpose) +_vecdot = get_xp(torch)(_aliases.vecdot) def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array: x1, x2 = _fix_promotion(x1, x2, only_scalar=False)