File tree Expand file tree Collapse file tree 1 file changed +20
-1
lines changed Expand file tree Collapse file tree 1 file changed +20
-1
lines changed Original file line number Diff line number Diff line change 1
1
import array_api_compat
2
+ from array_api_compat import get_namespace
2
3
import pytest
3
4
4
5
@@ -17,5 +18,23 @@ def test_get_namespace(library):
17
18
def test_get_namespace_returns_actual_namespace (array_namespace ):
18
19
xp = pytest .importorskip (array_namespace )
19
20
X = xp .asarray ([1 , 2 , 3 ])
20
- xp_ = array_api_compat . get_namespace (X )
21
+ xp_ = get_namespace (X )
21
22
assert xp_ is xp
23
+
24
+ def test_get_namespace_multiple ():
25
+ import numpy as np
26
+
27
+ x = np .asarray ([1 , 2 ])
28
+ assert get_namespace (x , x ) == get_namespace ((x , x )) == \
29
+ get_namespace ((x , x ), x ) == array_api_compat .numpy
30
+
31
+ def test_get_namespace_errors ():
32
+ pytest .raises (TypeError , lambda : get_namespace ([1 ]))
33
+ pytest .raises (TypeError , lambda : get_namespace ())
34
+
35
+ import numpy as np
36
+ import torch
37
+ x = np .asarray ([1 , 2 ])
38
+ y = torch .asarray ([1 , 2 ])
39
+
40
+ pytest .raises (TypeError , lambda : get_namespace (x , y ))
You can’t perform that action at this time.
0 commit comments