Skip to content

Commit a6c2db0

Browse files
lucyleeowogrisel
andauthored
TST Fix array API test_fill_or_add_to_diagonal (scikit-learn#31439)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent 4493f86 commit a6c2db0

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

sklearn/utils/tests/test_array_api.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -581,10 +581,14 @@ def test_count_nonzero(
581581
@pytest.mark.parametrize("wrap", [True, False])
582582
def test_fill_or_add_to_diagonal(array_namespace, device_, dtype_name, wrap):
583583
xp = _array_api_for_tests(array_namespace, device_)
584-
array_np = numpy.zeros((5, 4), dtype=numpy.int64)
585-
array_xp = xp.asarray(array_np)
586-
_fill_or_add_to_diagonal(array_xp, value=1, xp=xp, add_value=False, wrap=wrap)
584+
585+
array_np = numpy.zeros((5, 4), dtype=dtype_name)
586+
array_xp = xp.asarray(array_np.copy(), device=device_)
587+
587588
numpy.fill_diagonal(array_np, val=1, wrap=wrap)
589+
with config_context(array_api_dispatch=True):
590+
_fill_or_add_to_diagonal(array_xp, value=1, xp=xp, add_value=False, wrap=wrap)
591+
588592
assert_array_equal(_convert_to_numpy(array_xp, xp=xp), array_np)
589593

590594

0 commit comments

Comments
 (0)