diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 91161d24..40960f45 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -826,7 +826,7 @@ def sign(x: Array, /) -> Array: def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> list[Array]: # enforce the default of 'xy' # TODO: is the return type a list or a tuple - return list(torch.meshgrid(*arrays, indexing='xy')) + return list(torch.meshgrid(*arrays, indexing=indexing)) __all__ = ['asarray', 'result_type', 'can_cast', diff --git a/tests/test_dask.py b/tests/test_dask.py index fb0a84d4..4200e5b7 100644 --- a/tests/test_dask.py +++ b/tests/test_dask.py @@ -1,3 +1,4 @@ +import sys from contextlib import contextmanager import numpy as np @@ -167,6 +168,10 @@ def test_sort_argsort_chunk_size(xp, func, shape, chunks): ) +@pytest.mark.skipif( + sys.version_info.major*100 + sys.version_info.minor < 312, + reason="dask interop requires numpy >= 3.12" +) @pytest.mark.parametrize("func", ["sort", "argsort"]) def test_sort_argsort_meta(xp, func): """Test meta-namespace other than numpy""" diff --git a/tests/test_torch.py b/tests/test_torch.py index 7adb4ab3..f661a272 100644 --- a/tests/test_torch.py +++ b/tests/test_torch.py @@ -117,3 +117,16 @@ def test_meshgrid(): assert Y.shape == Y_xy.shape assert xp.all(Y == Y_xy) + + # repeat with an explicit indexing + X, Y = xp.meshgrid(x, y, indexing='ij') + + # output of torch.meshgrid(x, y, indexing='ij') + X_ij, Y_ij = xp.asarray([[1], [2]]), xp.asarray([[4], [4]]) + + assert X.shape == X_ij.shape + assert xp.all(X == X_ij) + + assert Y.shape == Y_ij.shape + assert xp.all(Y == Y_ij) +