File tree Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Original file line number Diff line number Diff line change @@ -320,18 +320,18 @@ def semaphore_wait(sem_or_view, dec: int | jax.Array = 1):
320
320
def _semaphore_wait_abstract_eval (* avals , args_tree ):
321
321
sem_aval , sem_indexers_avals , value_aval = tree_util .tree_unflatten (args_tree , avals )
322
322
if not isinstance (sem_aval , state .AbstractRef ):
323
- raise ValueError (f"Cannot signal on a non-semaphore Ref: { sem_aval } " )
323
+ raise ValueError (f"Cannot wait on a non-semaphore Ref: { sem_aval } " )
324
324
sem_shape = sem_aval .shape
325
325
if sem_indexers_avals :
326
326
sem_shape = sem_indexers_avals [- 1 ].get_indexer_shape ()
327
327
if sem_shape :
328
- raise ValueError (f"Cannot signal on a non-()-shaped semaphore: { sem_shape } " )
328
+ raise ValueError (f"Cannot wait on a non-()-shaped semaphore: { sem_shape } " )
329
329
sem_dtype = sem_aval .dtype
330
330
if not (jnp .issubdtype (sem_dtype , tpu_core .semaphore ) or jnp .issubdtype (
331
331
sem_dtype , tpu_core .barrier_semaphore )):
332
- raise ValueError (f"Must signal a REGULAR or BARRIER semaphore: { sem_dtype } " )
332
+ raise ValueError (f"Must wait a REGULAR or BARRIER semaphore: { sem_dtype } " )
333
333
if value_aval .dtype != jnp .dtype ("int32" ):
334
- raise ValueError ("Must signal an int32 value." )
334
+ raise ValueError ("Must wait an int32 value." )
335
335
return []
336
336
337
337
def _semaphore_wait_pp_eqn (eqn : jax_core .JaxprEqn ,
You can’t perform that action at this time.
0 commit comments