From d1d42167915b74625d68f269effe018b58da039c Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 27 Feb 2025 17:08:48 +0000 Subject: [PATCH] MAINT: torch: tweak imports --- array_api_compat/torch/_aliases.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 5b20aabc..a6e833f9 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -3,13 +3,7 @@ from functools import wraps as _wraps from builtins import all as _builtin_all, any as _builtin_any -from ..common._aliases import (matrix_transpose as _aliases_matrix_transpose, - vecdot as _aliases_vecdot, - clip as _aliases_clip, - unstack as _aliases_unstack, - cumulative_sum as _aliases_cumulative_sum, - cumulative_prod as _aliases_cumulative_prod, - ) +from ..common import _aliases from .._internal import get_xp from ._info import __array_namespace_info__ @@ -215,10 +209,10 @@ def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep return torch.clone(x) return torch.amin(x, axis, keepdims=keepdims) -clip = get_xp(torch)(_aliases_clip) -unstack = get_xp(torch)(_aliases_unstack) -cumulative_sum = get_xp(torch)(_aliases_cumulative_sum) -cumulative_prod = get_xp(torch)(_aliases_cumulative_prod) +clip = get_xp(torch)(_aliases.clip) +unstack = get_xp(torch)(_aliases.unstack) +cumulative_sum = get_xp(torch)(_aliases.cumulative_sum) +cumulative_prod = get_xp(torch)(_aliases.cumulative_prod) # torch.sort also returns a tuple # https://github.com/pytorch/pytorch/issues/70921 @@ -710,8 +704,8 @@ def matmul(x1: array, x2: array, /, **kwargs) -> array: x1, x2 = _fix_promotion(x1, x2, only_scalar=False) return torch.matmul(x1, x2, **kwargs) -matrix_transpose = get_xp(torch)(_aliases_matrix_transpose) -_vecdot = get_xp(torch)(_aliases_vecdot) +matrix_transpose = get_xp(torch)(_aliases.matrix_transpose) +_vecdot = get_xp(torch)(_aliases.vecdot) def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array: x1, x2 = _fix_promotion(x1, x2, only_scalar=False)