@@ -1251,7 +1251,6 @@ def full(shape: Shape, fill_value: ArrayLike, dtype: DTypeLike | None = None, *,
1251
1251
# if needed?
1252
1252
if (sharding is not None and not isinstance (sharding , PmapSharding ) and
1253
1253
isinstance (fill_value , array .ArrayImpl )):
1254
-
1255
1254
broadcast_shape = sharding .shard_shape (shape )
1256
1255
shard = broadcast (fill_value , broadcast_shape )
1257
1256
return array .make_array_from_callback (shape , sharding , lambda _ : shard )
@@ -1415,17 +1414,15 @@ def full_like(x: ArrayLike | DuckTypedArray,
1415
1414
if dtypes .issubdtype (dtype , dtypes .extended ):
1416
1415
return dtype ._rules .full (fill_shape , fill_value , dtype ) # type: ignore[union-attr]
1417
1416
1417
+ # If `x` has a sharding but no `_committed` attribute
1418
+ # (in case of ShapeDtypeStruct), default it to True.
1418
1419
use_x_sharding = (
1419
1420
sharding is None and
1420
- isinstance (x , array .ArrayImpl ) and
1421
- not weak_type and x ._committed and
1422
- # NB: consider reusng x.sharding for mismatched shapes
1423
- # if x is replicated or single device.
1424
- fill_shape == x .shape )
1421
+ hasattr (x , 'sharding' ) and getattr (x , '_committed' , True ) and
1422
+ not weak_type and fill_shape == x .shape ) # type: ignore
1425
1423
if use_x_sharding :
1426
- assert isinstance (x , array .ArrayImpl ) # makes pytype happy.
1427
1424
# TODO(yashkatariya): Use shard_alike in tracing_mode once it is supported.
1428
- sharding = x .sharding
1425
+ sharding = x .sharding # type: ignore
1429
1426
val = full (fill_shape , _convert_element_type (fill_value , dtype , weak_type ),
1430
1427
sharding = sharding )
1431
1428
return val
0 commit comments