Skip to content

Commit e3018db

Browse files
bythew3ijax authors
authored andcommitted
[Pallas][Mosaic] Expose semaphore read.
PiperOrigin-RevId: 623593440
1 parent ff12b2a commit e3018db

File tree

6 files changed

+85
-22
lines changed

6 files changed

+85
-22
lines changed

jax/_src/pallas/mosaic/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from jax._src.pallas.mosaic.primitives import repeat
4040
from jax._src.pallas.mosaic.primitives import roll
4141
from jax._src.pallas.mosaic.primitives import run_scoped
42+
from jax._src.pallas.mosaic.primitives import semaphore_read
4243
from jax._src.pallas.mosaic.primitives import semaphore_signal
4344
from jax._src.pallas.mosaic.primitives import semaphore_wait
4445
from jax._src.pallas.mosaic.primitives import trace

jax/_src/pallas/mosaic/lowering.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2095,6 +2095,20 @@ def _linearize_mesh_indices(*indices):
20952095
return device_id
20962096
raise NotImplementedError(f"Unsupported device id type: {device_id_type}")
20972097

2098+
2099+
def _semaphore_read_lowering_rule(
2100+
ctx: LoweringRuleContext,
2101+
*args,
2102+
args_tree,
2103+
):
2104+
sem_aval, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in)
2105+
sem, indexers = tree_util.tree_unflatten(args_tree, args)
2106+
sem, _ = _index_ref(sem, sem_aval, sem_aval.shape, indexers)
2107+
return tpu.SemaphoreReadOp(sem).result
2108+
2109+
2110+
lowering_rules[tpu_primitives.semaphore_read_p] = _semaphore_read_lowering_rule
2111+
20982112
def _semaphore_signal_lowering_rule(
20992113
ctx: LoweringRuleContext,
21002114
*args,

jax/_src/pallas/mosaic/primitives.py

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,42 @@ class DeviceIdType(enum.Enum):
222222
LOGICAL = "logical"
223223

224224

225+
def check_sem_avals(sem_aval, sem_indexers_avals, name):
226+
if not isinstance(sem_aval, state.AbstractRef):
227+
raise ValueError(f"Cannot {name} on a non-semaphore Ref: {sem_aval}")
228+
sem_shape = sem_aval.shape
229+
if sem_indexers_avals:
230+
sem_shape = sem_indexers_avals[-1].get_indexer_shape()
231+
if sem_shape:
232+
raise ValueError(f"Cannot {name} on a non-()-shaped semaphore: {sem_shape}")
233+
sem_dtype = sem_aval.dtype
234+
if not (
235+
jnp.issubdtype(sem_dtype, tpu_core.semaphore)
236+
or jnp.issubdtype(sem_dtype, tpu_core.barrier_semaphore)
237+
):
238+
raise ValueError(f"Must {name} a REGULAR or BARRIER semaphore: {sem_dtype}")
239+
240+
241+
semaphore_read_p = jax_core.Primitive("semaphore_read")
242+
semaphore_read_p.multiple_results = False
243+
244+
245+
def semaphore_read(sem_or_view):
246+
ref, indexers = _get_ref_and_indexers(sem_or_view)
247+
args = [ref, indexers]
248+
flat_args, args_tree = tree_util.tree_flatten(args)
249+
return semaphore_read_p.bind(*flat_args, args_tree=args_tree)
250+
251+
@semaphore_read_p.def_abstract_eval
252+
def _semaphore_read_abstract_eval(
253+
*avals,
254+
args_tree,
255+
):
256+
sem_aval, sem_indexers_avals = tree_util.tree_unflatten(args_tree, avals)
257+
check_sem_avals(sem_aval, sem_indexers_avals, "read")
258+
return jax_core.ShapedArray((), jnp.dtype("int32"))
259+
260+
225261
semaphore_signal_p = jax_core.Primitive('semaphore_signal')
226262
semaphore_signal_p.multiple_results = True
227263

@@ -254,17 +290,7 @@ def _semaphore_signal_abstract_eval(
254290
sem_aval, sem_indexers_avals, value_aval, device_id_avals = (
255291
tree_util.tree_unflatten(args_tree, avals)
256292
)
257-
if not isinstance(sem_aval, state.AbstractRef):
258-
raise ValueError(f"Cannot signal on a non-Ref: {sem_aval}")
259-
sem_shape = sem_aval.shape
260-
if sem_indexers_avals:
261-
sem_shape = sem_indexers_avals[-1].get_indexer_shape()
262-
if sem_shape:
263-
raise ValueError(f"Cannot signal on a non-()-shaped semaphore: {sem_shape}")
264-
sem_dtype = sem_aval.dtype
265-
if not (jnp.issubdtype(sem_dtype, tpu_core.semaphore) or jnp.issubdtype(
266-
sem_dtype, tpu_core.barrier_semaphore)):
267-
raise ValueError(f"Must signal a REGULAR or BARRIER semaphore: {sem_dtype}")
293+
check_sem_avals(sem_aval, sem_indexers_avals, "signal")
268294
if value_aval.dtype != jnp.dtype("int32"):
269295
raise ValueError("Must signal an int32 value.")
270296
if device_id_avals is not None:
@@ -319,17 +345,7 @@ def semaphore_wait(sem_or_view, dec: int | jax.Array = 1):
319345
@semaphore_wait_p.def_abstract_eval
320346
def _semaphore_wait_abstract_eval(*avals, args_tree):
321347
sem_aval, sem_indexers_avals, value_aval = tree_util.tree_unflatten(args_tree, avals)
322-
if not isinstance(sem_aval, state.AbstractRef):
323-
raise ValueError(f"Cannot wait on a non-semaphore Ref: {sem_aval}")
324-
sem_shape = sem_aval.shape
325-
if sem_indexers_avals:
326-
sem_shape = sem_indexers_avals[-1].get_indexer_shape()
327-
if sem_shape:
328-
raise ValueError(f"Cannot wait on a non-()-shaped semaphore: {sem_shape}")
329-
sem_dtype = sem_aval.dtype
330-
if not (jnp.issubdtype(sem_dtype, tpu_core.semaphore) or jnp.issubdtype(
331-
sem_dtype, tpu_core.barrier_semaphore)):
332-
raise ValueError(f"Must wait a REGULAR or BARRIER semaphore: {sem_dtype}")
348+
check_sem_avals(sem_aval, sem_indexers_avals, "wait")
333349
if value_aval.dtype != jnp.dtype("int32"):
334350
raise ValueError("Must wait an int32 value.")
335351
return []

jax/experimental/pallas/tpu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from jax._src.pallas.mosaic import roll
4141
from jax._src.pallas.mosaic import run_scoped
4242
from jax._src.pallas.mosaic import semaphore
43+
from jax._src.pallas.mosaic import semaphore_read
4344
from jax._src.pallas.mosaic import semaphore_signal
4445
from jax._src.pallas.mosaic import semaphore_wait
4546
from jax._src.pallas.mosaic import trace

jaxlib/mosaic/dialect/tpu/tpu.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,12 @@ def TPU_DeviceIdOp : TPU_Op<"device_id", [Pure]> {
452452
let assemblyFormat = [{ attr-dict `:` type($result) }];
453453
}
454454

455+
def TPU_SemaphoreReadOp : TPU_Op<"sem_read"> {
456+
let arguments = (ins MemRefOf<[TPU_SemaphoreType]>:$semaphore);
457+
let results = (outs I32:$result);
458+
let assemblyFormat = [{ $semaphore attr-dict `:` type($semaphore) `->` type($result)}];
459+
}
460+
455461
def TPU_SemaphoreWaitOp : TPU_Op<"sem_wait"> {
456462
let arguments = (ins
457463
MemRefOf<[TPU_SemaphoreType]>:$semaphore,

tests/pallas/pallas_call_tpu_test.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,31 @@ def body(sems):
730730
debug=True,
731731
)())
732732

733+
def test_can_read_semaphore(self):
734+
m, n = 2, 3
735+
736+
def kernel(y_ref):
737+
def body(sems):
738+
for r in range(m):
739+
for c in range(n):
740+
v = r * n + c
741+
pltpu.semaphore_signal(sems.at[r, c],v)
742+
y_ref[r, c] = pltpu.semaphore_read(sems.at[r, c])
743+
pltpu.semaphore_wait(sems.at[r, c], v)
744+
745+
pltpu.run_scoped(body, pltpu.SemaphoreType.REGULAR((m, n)))
746+
747+
y = jax.block_until_ready(
748+
pl.pallas_call(
749+
kernel,
750+
out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),
751+
out_shape=jax.ShapeDtypeStruct((m, n), jnp.int32),
752+
)()
753+
)
754+
np.testing.assert_array_equal(
755+
y, jnp.arange(m * n).astype(jnp.int32).reshape((m, n))
756+
)
757+
733758
def test_hbm_hbm_dma(self):
734759
def kernel(x_hbm_ref, y_hbm_ref):
735760
def body(sem):

0 commit comments

Comments
 (0)