Skip to content

Commit e282bf5

Browse files
author
jax authors
committed
Merge pull request #20536 from jakevdp:broadcast-to
PiperOrigin-RevId: 621287464
2 parents 3c80812 + 6de6983 commit e282bf5

File tree

2 files changed

+9
-0
lines changed

2 files changed

+9
-0
lines changed

jax/_src/numpy/util.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,8 @@ def _broadcast_to(arr: ArrayLike, shape: DimSize | Shape) -> Array:
418418
arr_shape = np.shape(arr)
419419
if core.definitely_equal_shape(arr_shape, shape):
420420
return arr
421+
elif len(shape) < len(arr_shape):
422+
raise ValueError(f"Cannot broadcast to shape with fewer dimensions: {arr_shape=} {shape=}")
421423
else:
422424
nlead = len(shape) - len(arr_shape)
423425
shape_tail = shape[nlead:]

tests/lax_numpy_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5181,6 +5181,13 @@ def testBroadcastTo(self, from_shape, to_shape):
51815181
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
51825182
self._CompileAndCheck(jnp_op, args_maker)
51835183

5184+
def testBroadcastToInvalidShape(self):
5185+
# Regression test for https://github.com/google/jax/issues/20533
5186+
x = jnp.zeros((3, 4, 5))
5187+
with self.assertRaisesRegex(
5188+
ValueError, "Cannot broadcast to shape with fewer dimensions"):
5189+
jnp.broadcast_to(x, (4, 5))
5190+
51845191
@jtu.sample_product(
51855192
[dict(shapes=shapes, broadcasted_shape=broadcasted_shape)
51865193
for shapes, broadcasted_shape in [

0 commit comments

Comments
 (0)