From c8ae92b987f518a3b68361b972db06e7f0a91f3b Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Thu, 3 Apr 2025 16:03:39 +1100 Subject: [PATCH 1/4] iter --- array_api_strict/_array_object.py | 2 +- array_api_strict/tests/test_array_object.py | 25 +++++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index cb2dd11..483952e 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -722,7 +722,7 @@ def __getitem__( devices = {self.device} if isinstance(key, tuple): devices.update( - [subkey.device for subkey in key if hasattr(subkey, "device")] + [subkey.device for subkey in key if isinstance(subkey, Array)] ) if len(devices) > 1: raise ValueError( diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index 51f4f31..7a2618c 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -101,6 +101,31 @@ def test_validate_index(): assert_raises(IndexError, lambda: a[idx]) +@pytest.mark.parametrize("device", [None, "CPU_DEVICE", "device1", "device2"]) +@pytest.mark.parametrize( + "integer_index", + [ + 1, + np.bool(1), + np.int8(0), + np.uint8(0), + np.int16(0), + np.uint16(0), + np.int32(0), + np.uint32(0), + np.int64(0), + np.uint64(0), + 2, + ], +) +def test_indexing_ints(integer_index, device): + # Ensure indexing with different integer types works on all Devices. + device = None if device is None else Device(device) + + a = arange(5, device=device) + a[integer_index] + + @pytest.mark.parametrize("device", [None, "CPU_DEVICE", "device1", "device2"]) def test_indexing_arrays(device): # indexing with 1D integer arrays and mixes of integers and 1D integer are allowed From ae5b49f10d4b7d41929ba1127f253cf120d8690c Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Thu, 3 Apr 2025 16:04:10 +1100 Subject: [PATCH 2/4] typo --- array_api_strict/tests/test_array_object.py | 1 - 1 file changed, 1 deletion(-) diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index 7a2618c..633f804 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -115,7 +115,6 @@ def test_validate_index(): np.uint32(0), np.int64(0), np.uint64(0), - 2, ], ) def test_indexing_ints(integer_index, device): From 58db9c13d36a051dbd118028352f479563486c8a Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Thu, 3 Apr 2025 16:08:31 +1100 Subject: [PATCH 3/4] fix --- array_api_strict/tests/test_array_object.py | 1 - 1 file changed, 1 deletion(-) diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index 633f804..43dd157 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -106,7 +106,6 @@ def test_validate_index(): "integer_index", [ 1, - np.bool(1), np.int8(0), np.uint8(0), np.int16(0), From d5824c73a6760cbe32e384fcad73eb897028b9fe Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Thu, 3 Apr 2025 21:09:54 +1100 Subject: [PATCH 4/4] review --- array_api_strict/tests/test_array_object.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index 43dd157..c7330d8 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -100,12 +100,18 @@ def test_validate_index(): assert_raises(IndexError, lambda: a[:]) assert_raises(IndexError, lambda: a[idx]) +class DummyIndex: + def __init__(self, x): + self.x = x + def __index__(self): + return self.x + @pytest.mark.parametrize("device", [None, "CPU_DEVICE", "device1", "device2"]) @pytest.mark.parametrize( "integer_index", [ - 1, + 0, np.int8(0), np.uint8(0), np.int16(0), @@ -114,6 +120,7 @@ def test_validate_index(): np.uint32(0), np.int64(0), np.uint64(0), + DummyIndex(0), ], ) def test_indexing_ints(integer_index, device): @@ -121,7 +128,7 @@ def test_indexing_ints(integer_index, device): device = None if device is None else Device(device) a = arange(5, device=device) - a[integer_index] + assert a[(integer_index,)] == a[integer_index] == a[0] @pytest.mark.parametrize("device", [None, "CPU_DEVICE", "device1", "device2"])