Skip to content

Commit d7e5dde

Browse files
yashk2810jax authors
authored andcommitted
Remove _maybe_device_put because jax.device_put accepts None on the device parameter
PiperOrigin-RevId: 618223250
1 parent 5f467b9 commit d7e5dde

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

jax/_src/numpy/lax_numpy.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2287,9 +2287,6 @@ def empty_like(prototype: ArrayLike | DuckTypedArray,
22872287
return zeros_like(prototype, dtype=dtype, shape=shape, device=device)
22882288

22892289

2290-
def _maybe_device_put(arr: Array, device: xc.Device | Sharding | None) -> Array:
2291-
return arr if device is None else jax.device_put(arr, device)
2292-
22932290
def _normalize_to_sharding(device: xc.Device | Sharding | None) -> Sharding | None:
22942291
if isinstance(device, xc.Device):
22952292
return SingleDeviceSharding(device)
@@ -2308,7 +2305,8 @@ def full(shape: Any, fill_value: ArrayLike,
23082305
shape = canonicalize_shape(shape)
23092306
return lax.full(shape, fill_value, dtype, sharding=_normalize_to_sharding(device))
23102307
else:
2311-
return _maybe_device_put(broadcast_to(asarray(fill_value, dtype=dtype), shape), device)
2308+
return jax.device_put(
2309+
broadcast_to(asarray(fill_value, dtype=dtype), shape), device)
23122310

23132311

23142312
@util.implements(np.full_like)
@@ -2328,7 +2326,8 @@ def full_like(a: ArrayLike | DuckTypedArray,
23282326
else:
23292327
shape = np.shape(a) if shape is None else shape # type: ignore[arg-type]
23302328
dtype = result_type(a) if dtype is None else dtype # type: ignore[arg-type]
2331-
return _maybe_device_put(broadcast_to(asarray(fill_value, dtype=dtype), shape), device)
2329+
return jax.device_put(
2330+
broadcast_to(asarray(fill_value, dtype=dtype), shape), device)
23322331

23332332

23342333
@util.implements(np.zeros)

0 commit comments

Comments
 (0)