@@ -126,13 +126,13 @@ def one_hot(
126
126
"""
127
127
One-hot encode the given indices.
128
128
129
- Each index in the input ``x`` is encoded as a vector of zeros of length
130
- ``num_classes`` with the element at the given index set to one.
129
+ Each index in the input `x` is encoded as a vector of zeros of length `num_classes`
130
+ with the element at the given index set to one.
131
131
132
132
Parameters
133
133
----------
134
134
x : array
135
- An array with integral dtype having shape ``batch_dims`` .
135
+ An array with integral dtype and concrete size (``x.size`` cannot be `None`) .
136
136
num_classes : int
137
137
Number of classes in the one-hot dimension.
138
138
dtype : DType, optional
@@ -147,17 +147,20 @@ def one_hot(
147
147
-------
148
148
array
149
149
An array having the same shape as `x` except for a new axis at the position
150
- given by `axis` having size `num_classes`.
150
+ given by `axis` having size `num_classes`. If `axis` is unspecified, it
151
+ defaults to -1, which appends a new axis.
151
152
152
153
If ``x < 0`` or ``x >= num_classes``, then the result is undefined, may raise
153
154
an exception, or may even cause a bad state. `x` is not checked.
154
155
155
156
Examples
156
157
--------
157
- >>> xp.one_hot(jnp.array([1, 2, 0]), 3)
158
+ >>> import array_api_extra as xpx
159
+ >>> import array-api-strict as xp
160
+ >>> xpx.one_hot(xp.asarray([1, 2, 0]), 3)
158
161
Array([[0., 1., 0.],
159
- [0., 0., 1.],
160
- [1., 0., 0.]], dtype=float64)
162
+ [0., 0., 1.],
163
+ [1., 0., 0.]], dtype=array_api_strict. float64)
161
164
"""
162
165
# Validate inputs.
163
166
if xp is None :
0 commit comments