34
34
assert_array_equal ,
35
35
skip_if_array_api_compat_not_configured ,
36
36
)
37
- from sklearn .utils .fixes import _IS_32BIT , CSR_CONTAINERS
37
+ from sklearn .utils .fixes import _IS_32BIT , CSR_CONTAINERS , np_version , parse_version
38
38
39
39
40
40
@pytest .mark .parametrize ("X" , [numpy .asarray ([1 , 2 , 3 ]), [1 , 2 , 3 ]])
@@ -67,7 +67,12 @@ def test_get_namespace_ndarray_with_dispatch():
67
67
with config_context (array_api_dispatch = True ):
68
68
xp_out , is_array_api_compliant = get_namespace (X_np )
69
69
assert is_array_api_compliant
70
- assert xp_out is array_api_compat .numpy
70
+ if np_version >= parse_version ("2.0.0" ):
71
+ # NumPy 2.0+ is an array API compliant library.
72
+ assert xp_out is numpy
73
+ else :
74
+ # Older NumPy versions require the compatibility layer.
75
+ assert xp_out is array_api_compat .numpy
71
76
72
77
73
78
@skip_if_array_api_compat_not_configured
@@ -135,7 +140,7 @@ def test_asarray_with_order_ignored():
135
140
136
141
137
142
@pytest .mark .parametrize (
138
- "array_namespace, device , dtype_name" , yield_namespace_device_dtype_combinations ()
143
+ "array_namespace, device_ , dtype_name" , yield_namespace_device_dtype_combinations ()
139
144
)
140
145
@pytest .mark .parametrize (
141
146
"weights, axis, normalize, expected" ,
@@ -167,19 +172,22 @@ def test_asarray_with_order_ignored():
167
172
],
168
173
)
169
174
def test_average (
170
- array_namespace , device , dtype_name , weights , axis , normalize , expected
175
+ array_namespace , device_ , dtype_name , weights , axis , normalize , expected
171
176
):
172
- xp = _array_api_for_tests (array_namespace , device )
177
+ xp = _array_api_for_tests (array_namespace , device_ )
173
178
array_in = numpy .asarray ([[1 , 2 , 3 ], [4 , 5 , 6 ]], dtype = dtype_name )
174
- array_in = xp .asarray (array_in , device = device )
179
+ array_in = xp .asarray (array_in , device = device_ )
175
180
if weights is not None :
176
181
weights = numpy .asarray (weights , dtype = dtype_name )
177
- weights = xp .asarray (weights , device = device )
182
+ weights = xp .asarray (weights , device = device_ )
178
183
179
184
with config_context (array_api_dispatch = True ):
180
185
result = _average (array_in , axis = axis , weights = weights , normalize = normalize )
181
186
182
- assert getattr (array_in , "device" , None ) == getattr (result , "device" , None )
187
+ if np_version < parse_version ("2.0.0" ) or np_version >= parse_version ("2.1.0" ):
188
+ # NumPy 2.0 has a problem with the device attribute of scalar arrays:
189
+ # https://github.com/numpy/numpy/issues/26850
190
+ assert device (array_in ) == device (result )
183
191
184
192
result = _convert_to_numpy (result , xp )
185
193
assert_allclose (result , expected , atol = _atol_for_type (dtype_name ))
@@ -226,14 +234,15 @@ def test_average_raises_with_wrong_dtype(array_namespace, device, dtype_name):
226
234
(
227
235
0 ,
228
236
[[1 , 2 ]],
229
- TypeError ,
230
- "1D weights expected" ,
237
+ # NumPy 2 raises ValueError, NumPy 1 raises TypeError
238
+ (ValueError , TypeError ),
239
+ "weights" , # the message is different for NumPy 1 and 2...
231
240
),
232
241
(
233
242
0 ,
234
243
[1 , 2 , 3 , 4 ],
235
244
ValueError ,
236
- "Length of weights" ,
245
+ "weights" ,
237
246
),
238
247
(0 , [- 1 , 1 ], ZeroDivisionError , "Weights sum to zero, can't be normalized" ),
239
248
),
@@ -580,18 +589,18 @@ def test_get_namespace_and_device():
580
589
581
590
582
591
@pytest .mark .parametrize (
583
- "array_namespace, device , dtype_name" , yield_namespace_device_dtype_combinations ()
592
+ "array_namespace, device_ , dtype_name" , yield_namespace_device_dtype_combinations ()
584
593
)
585
594
@pytest .mark .parametrize ("csr_container" , CSR_CONTAINERS )
586
595
@pytest .mark .parametrize ("axis" , [0 , 1 , None , - 1 , - 2 ])
587
596
@pytest .mark .parametrize ("sample_weight_type" , [None , "int" , "float" ])
588
597
def test_count_nonzero (
589
- array_namespace , device , dtype_name , csr_container , axis , sample_weight_type
598
+ array_namespace , device_ , dtype_name , csr_container , axis , sample_weight_type
590
599
):
591
600
592
601
from sklearn .utils .sparsefuncs import count_nonzero as sparse_count_nonzero
593
602
594
- xp = _array_api_for_tests (array_namespace , device )
603
+ xp = _array_api_for_tests (array_namespace , device_ )
595
604
array = numpy .array ([[0 , 3 , 0 ], [2 , - 1 , 0 ], [0 , 0 , 0 ], [9 , 8 , 7 ], [4 , 0 , 5 ]])
596
605
if sample_weight_type == "int" :
597
606
sample_weight = numpy .asarray ([1 , 2 , 2 , 3 , 1 ])
@@ -602,12 +611,16 @@ def test_count_nonzero(
602
611
expected = sparse_count_nonzero (
603
612
csr_container (array ), axis = axis , sample_weight = sample_weight
604
613
)
605
- array_xp = xp .asarray (array , device = device )
614
+ array_xp = xp .asarray (array , device = device_ )
606
615
607
616
with config_context (array_api_dispatch = True ):
608
617
result = _count_nonzero (
609
- array_xp , xp = xp , device = device , axis = axis , sample_weight = sample_weight
618
+ array_xp , xp = xp , device = device_ , axis = axis , sample_weight = sample_weight
610
619
)
611
620
612
621
assert_allclose (_convert_to_numpy (result , xp = xp ), expected )
613
- assert getattr (array_xp , "device" , None ) == getattr (result , "device" , None )
622
+
623
+ if np_version < parse_version ("2.0.0" ) or np_version >= parse_version ("2.1.0" ):
624
+ # NumPy 2.0 has a problem with the device attribute of scalar arrays:
625
+ # https://github.com/numpy/numpy/issues/26850
626
+ assert device (array_xp ) == device (result )
0 commit comments