Skip to content

Commit 4f6a2ad

Browse files
committed
Merge branch 'main' into apply
2 parents 8bb4ba0 + ec890f1 commit 4f6a2ad

File tree

6 files changed

+62
-1
lines changed

6 files changed

+62
-1
lines changed

docs/api-reference.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
create_diagonal
1414
expand_dims
1515
kron
16+
nunique
1617
setdiff1d
1718
sinc
1819
```

pixi.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,7 @@ messages_control.disable = [
300300
"line-too-long",
301301
"missing-module-docstring",
302302
"missing-function-docstring",
303+
"too-many-lines",
303304
"wrong-import-position",
304305
]
305306

src/array_api_extra/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
create_diagonal,
99
expand_dims,
1010
kron,
11+
nunique,
1112
pad,
1213
setdiff1d,
1314
sinc,
@@ -25,6 +26,7 @@
2526
"create_diagonal",
2627
"expand_dims",
2728
"kron",
29+
"nunique",
2830
"pad",
2931
"setdiff1d",
3032
"sinc",

src/array_api_extra/_funcs.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972
44
from __future__ import annotations
55

6+
import math
67
import operator
78
import warnings
89
from collections.abc import Callable
@@ -25,6 +26,7 @@
2526
"create_diagonal",
2627
"expand_dims",
2728
"kron",
29+
"nunique",
2830
"pad",
2931
"setdiff1d",
3032
"sinc",
@@ -638,6 +640,42 @@ def pad(
638640
return at(padded, tuple(slices)).set(x)
639641

640642

643+
def nunique(x: Array, /, *, xp: ModuleType | None = None) -> Array:
644+
"""
645+
Count the number of unique elements in an array.
646+
647+
Compatible with JAX and Dask, whose laziness would be otherwise
648+
problematic.
649+
650+
Parameters
651+
----------
652+
x : Array
653+
Input array.
654+
xp : array_namespace, optional
655+
The standard-compatible namespace for `x`. Default: infer.
656+
657+
Returns
658+
-------
659+
array: 0-dimensional integer array
660+
The number of unique elements in `x`. It can be lazy.
661+
"""
662+
if xp is None:
663+
xp = array_namespace(x)
664+
665+
if is_jax_array(x):
666+
# size= is JAX-specific
667+
# https://github.com/data-apis/array-api/issues/883
668+
_, counts = xp.unique_counts(x, size=_compat.size(x))
669+
return xp.astype(counts, xp.bool).sum()
670+
671+
_, counts = xp.unique_counts(x)
672+
n = _compat.size(counts)
673+
# FIXME https://github.com/data-apis/array-api-compat/pull/231
674+
if n is None or math.isnan(n): # e.g. Dask, ndonnx
675+
return xp.astype(counts, xp.bool).sum()
676+
return xp.asarray(n, device=_compat.device(x))
677+
678+
641679
class _AtOp(Enum):
642680
"""Operations for use in `xpx.at`."""
643681

tests/test_funcs.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
create_diagonal,
1212
expand_dims,
1313
kron,
14+
nunique,
1415
pad,
1516
setdiff1d,
1617
sinc,
@@ -448,3 +449,21 @@ def test_list_of_tuples_width(self, xp: ModuleType):
448449

449450
padded = pad(a, [(1, 0), (0, 0)])
450451
assert padded.shape == (4, 4)
452+
453+
454+
class TestNUnique:
455+
def test_simple(self, xp: ModuleType):
456+
a = xp.asarray([[1, 1], [0, 2], [2, 2]])
457+
xp_assert_equal(nunique(a), xp.asarray(3))
458+
459+
def test_empty(self, xp: ModuleType):
460+
a = xp.asarray([])
461+
xp_assert_equal(nunique(a), xp.asarray(0))
462+
463+
def test_device(self, xp: ModuleType, device: Device):
464+
a = xp.asarray(0.0, device=device)
465+
assert get_device(nunique(a)) == device
466+
467+
def test_xp(self, xp: ModuleType):
468+
a = xp.asarray([[1, 1], [0, 2], [2, 2]])
469+
xp_assert_equal(nunique(a, xp=xp), xp.asarray(3))

0 commit comments

Comments
 (0)