Skip to content

Commit 763c6ff

Browse files
bythew3ijax authors
authored andcommitted
[Pallas] Fix typo in semaphore_wait error messages.
PiperOrigin-RevId: 623321130
1 parent d967c33 commit 763c6ff

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

jax/_src/pallas/mosaic/primitives.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -320,18 +320,18 @@ def semaphore_wait(sem_or_view, dec: int | jax.Array = 1):
320320
def _semaphore_wait_abstract_eval(*avals, args_tree):
321321
sem_aval, sem_indexers_avals, value_aval = tree_util.tree_unflatten(args_tree, avals)
322322
if not isinstance(sem_aval, state.AbstractRef):
323-
raise ValueError(f"Cannot signal on a non-semaphore Ref: {sem_aval}")
323+
raise ValueError(f"Cannot wait on a non-semaphore Ref: {sem_aval}")
324324
sem_shape = sem_aval.shape
325325
if sem_indexers_avals:
326326
sem_shape = sem_indexers_avals[-1].get_indexer_shape()
327327
if sem_shape:
328-
raise ValueError(f"Cannot signal on a non-()-shaped semaphore: {sem_shape}")
328+
raise ValueError(f"Cannot wait on a non-()-shaped semaphore: {sem_shape}")
329329
sem_dtype = sem_aval.dtype
330330
if not (jnp.issubdtype(sem_dtype, tpu_core.semaphore) or jnp.issubdtype(
331331
sem_dtype, tpu_core.barrier_semaphore)):
332-
raise ValueError(f"Must signal a REGULAR or BARRIER semaphore: {sem_dtype}")
332+
raise ValueError(f"Must wait a REGULAR or BARRIER semaphore: {sem_dtype}")
333333
if value_aval.dtype != jnp.dtype("int32"):
334-
raise ValueError("Must signal an int32 value.")
334+
raise ValueError("Must wait an int32 value.")
335335
return []
336336

337337
def _semaphore_wait_pp_eqn(eqn: jax_core.JaxprEqn,

0 commit comments

Comments
 (0)