Skip to content

Commit 5cbbace

Browse files
committed
Review response
1 parent aa4eb5c commit 5cbbace

File tree

2 files changed

+15
-12
lines changed

2 files changed

+15
-12
lines changed

src/array_api_extra/_delegation.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,13 @@ def one_hot(
126126
"""
127127
One-hot encode the given indices.
128128
129-
Each index in the input ``x`` is encoded as a vector of zeros of length
130-
``num_classes`` with the element at the given index set to one.
129+
Each index in the input `x` is encoded as a vector of zeros of length `num_classes`
130+
with the element at the given index set to one.
131131
132132
Parameters
133133
----------
134134
x : array
135-
An array with integral dtype having shape ``batch_dims``.
135+
An array with integral dtype and concrete size (``x.size`` cannot be `None`).
136136
num_classes : int
137137
Number of classes in the one-hot dimension.
138138
dtype : DType, optional
@@ -147,17 +147,20 @@ def one_hot(
147147
-------
148148
array
149149
An array having the same shape as `x` except for a new axis at the position
150-
given by `axis` having size `num_classes`.
150+
given by `axis` having size `num_classes`. If `axis` is unspecified, it
151+
defaults to -1, which appends a new axis.
151152
152153
If ``x < 0`` or ``x >= num_classes``, then the result is undefined, may raise
153154
an exception, or may even cause a bad state. `x` is not checked.
154155
155156
Examples
156157
--------
157-
>>> xp.one_hot(jnp.array([1, 2, 0]), 3)
158+
>>> import array_api_extra as xpx
159+
>>> import array-api-strict as xp
160+
>>> xpx.one_hot(xp.asarray([1, 2, 0]), 3)
158161
Array([[0., 1., 0.],
159-
[0., 0., 1.],
160-
[1., 0., 0.]], dtype=float64)
162+
[0., 0., 1.],
163+
[1., 0., 0.]], dtype=array_api_strict.float64)
161164
"""
162165
# Validate inputs.
163166
if xp is None:

src/array_api_extra/_lib/_funcs.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,7 @@
88

99
from ._at import at
1010
from ._utils import _compat, _helpers
11-
from ._utils._compat import (
12-
array_namespace,
13-
is_dask_namespace,
14-
is_jax_array,
15-
)
11+
from ._utils._compat import array_namespace, is_dask_namespace, is_jax_array
1612
from ._utils._helpers import (
1713
asarrays,
1814
capabilities,
@@ -392,6 +388,10 @@ def one_hot(
392388
"""See docstring in `array_api_extra._delegation.py`."""
393389
x_size = x.size
394390
if x_size is None: # pragma: no cover
391+
# This cannot be tested because there is no way to create an array with abstract
392+
# size today. However, it must be blocked for the sake of type-checking and
393+
# future-proofing since x.size is allowed to None according to the
394+
# specification.
395395
msg = "x must have a concrete size."
396396
raise TypeError(msg)
397397
out = xp.zeros((x.size, num_classes), dtype=dtype)

0 commit comments

Comments
 (0)