Skip to content

Commit 1f9a2dd

Browse files
committed
ufunc: fix implements wrapper for at
1 parent a5b8ce1 commit 1f9a2dd

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

jax/_src/numpy/ufunc_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ def scan_fun(carry, _):
258258
_, result = jax.lax.scan(scan_fun, (0, arr[0].astype(dtype)), None, length=arr.shape[0])
259259
return _moveaxis(result, 0, axis)
260260

261-
@implements(np.ufunc.accumulate, module="numpy.ufunc")
261+
@implements(np.ufunc.at, module="numpy.ufunc")
262262
@partial(jax.jit, static_argnums=[0], static_argnames=['inplace'])
263263
def at(self, a: ArrayLike, indices: Any, b: ArrayLike | None = None, /, *,
264264
inplace: bool = True) -> Array:

0 commit comments

Comments
 (0)