@@ -2287,9 +2287,6 @@ def empty_like(prototype: ArrayLike | DuckTypedArray,
2287
2287
return zeros_like (prototype , dtype = dtype , shape = shape , device = device )
2288
2288
2289
2289
2290
- def _maybe_device_put (arr : Array , device : xc .Device | Sharding | None ) -> Array :
2291
- return arr if device is None else jax .device_put (arr , device )
2292
-
2293
2290
def _normalize_to_sharding (device : xc .Device | Sharding | None ) -> Sharding | None :
2294
2291
if isinstance (device , xc .Device ):
2295
2292
return SingleDeviceSharding (device )
@@ -2308,7 +2305,8 @@ def full(shape: Any, fill_value: ArrayLike,
2308
2305
shape = canonicalize_shape (shape )
2309
2306
return lax .full (shape , fill_value , dtype , sharding = _normalize_to_sharding (device ))
2310
2307
else :
2311
- return _maybe_device_put (broadcast_to (asarray (fill_value , dtype = dtype ), shape ), device )
2308
+ return jax .device_put (
2309
+ broadcast_to (asarray (fill_value , dtype = dtype ), shape ), device )
2312
2310
2313
2311
2314
2312
@util .implements (np .full_like )
@@ -2328,7 +2326,8 @@ def full_like(a: ArrayLike | DuckTypedArray,
2328
2326
else :
2329
2327
shape = np .shape (a ) if shape is None else shape # type: ignore[arg-type]
2330
2328
dtype = result_type (a ) if dtype is None else dtype # type: ignore[arg-type]
2331
- return _maybe_device_put (broadcast_to (asarray (fill_value , dtype = dtype ), shape ), device )
2329
+ return jax .device_put (
2330
+ broadcast_to (asarray (fill_value , dtype = dtype ), shape ), device )
2332
2331
2333
2332
2334
2333
@util .implements (np .zeros )
0 commit comments