Skip to content

Commit 5ff7366

Browse files
committed
Add some more tests for get_namespace
1 parent 30f4fac commit 5ff7366

File tree

1 file changed

+20
-1
lines changed

1 file changed

+20
-1
lines changed

tests/test_get_namespace.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import array_api_compat
2+
from array_api_compat import get_namespace
23
import pytest
34

45

@@ -17,5 +18,23 @@ def test_get_namespace(library):
1718
def test_get_namespace_returns_actual_namespace(array_namespace):
1819
xp = pytest.importorskip(array_namespace)
1920
X = xp.asarray([1, 2, 3])
20-
xp_ = array_api_compat.get_namespace(X)
21+
xp_ = get_namespace(X)
2122
assert xp_ is xp
23+
24+
def test_get_namespace_multiple():
25+
import numpy as np
26+
27+
x = np.asarray([1, 2])
28+
assert get_namespace(x, x) == get_namespace((x, x)) == \
29+
get_namespace((x, x), x) == array_api_compat.numpy
30+
31+
def test_get_namespace_errors():
32+
pytest.raises(TypeError, lambda: get_namespace([1]))
33+
pytest.raises(TypeError, lambda: get_namespace())
34+
35+
import numpy as np
36+
import torch
37+
x = np.asarray([1, 2])
38+
y = torch.asarray([1, 2])
39+
40+
pytest.raises(TypeError, lambda: get_namespace(x, y))

0 commit comments

Comments
 (0)