Skip to content

Commit b32721d

Browse files
Merge pull request #30447 from justinjfu:mgpu_docs_fixes
PiperOrigin-RevId: 786432145
2 parents 535829d + 98665a7 commit b32721d

File tree

1 file changed

+15
-12
lines changed

1 file changed

+15
-12
lines changed

docs/pallas/gpu/reference.md

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -689,7 +689,8 @@ helper. We recommend reviewing the [software pipelining guide](./pipelining.md).
689689

690690
```python
691691
@pl.run_state
692-
def run_kernel(x_ref, y_ref):
692+
def run_kernel(refs):
693+
x_ref, y_ref = refs
693694
# Here, we're not in the kernel yet! pl.run_state simply changes the JAX
694695
# immutable arrays into mutable GMEM (not SMEM!) references.
695696

@@ -700,11 +701,11 @@ def run_kernel(x_ref, y_ref):
700701
def kernel_body():
701702
# Once we enter the pl.core_map scope, we are in the body of the kernel.
702703
block_slice = pl.ds(lax.axis_index("x") * 128, 128)
703-
o_ref[block_slice] = x_ref[block_slice] + 1
704+
y_ref[block_slice] = x_ref[block_slice] + 1
704705

705-
x = jnp.arange(128, jnp.float32)
706+
x = jnp.arange(256, jnp.float32)
706707
y_init = jnp.zeros_like(x)
707-
y = run_kernel(x, y_init)
708+
_, y = run_kernel(x, y_init)
708709
np.testing.assert_array_equal(y, x + 1)
709710
```
710711

@@ -721,12 +722,12 @@ mesh = plgpu.Mesh(grid=(2,), grid_names=("x",))
721722
out_shape=jax.ShapeDtypeStruct((256,), jnp.float32),
722723
mesh=mesh
723724
)
724-
def increment_kernel_core_map(x_ref, y_ref):
725+
def run_kernel(x_ref, y_ref):
725726
# x_ref and y_ref are in GMEM!
726727
block_slice = pl.ds(lax.axis_index("x") * 128, 128)
727-
o_ref[block_slice] = x_ref[block_slice] + 1
728+
y_ref[block_slice] = x_ref[block_slice] + 1
728729

729-
x = jnp.arange(128, jnp.float32)
730+
x = jnp.arange(256, jnp.float32)
730731
y = run_kernel(x) # No need to preallocate outputs as in pl.core_map.
731732
np.testing.assert_array_equal(y, x + 1)
732733
```
@@ -752,23 +753,25 @@ synchronizing through a barrier and even exchanging data through SMEM.
752753

753754
```python
754755
mesh = plgpu.Mesh(num_threads=2, thread_name="pallas_thread")
756+
x = jnp.arange(128, jnp.float32)
757+
755758
@functools.partial(
756-
plgpu.kernel, out_shape=x, mesh=mesh, scratch_shapes=[plgpu.Barrier()]
759+
plgpu.kernel, out_shape=x, mesh=mesh,
760+
scratch_shapes=[plgpu.SMEM(x.shape, x.dtype), plgpu.Barrier()]
757761
)
758-
def run_kernel(x_ref, y_ref, barrier_ref):
762+
def run_kernel(x_ref, y_ref, smem_ref, barrier_ref):
759763
thread_id = jax.lax.axis_index("pallas_thread")
760764

761765
@pl.when(thread_id == 0)
762766
def producer_thread():
763-
smem_val = x_ref[...] + 1
767+
smem_ref[...] = x_ref[...] + 1
764768
plgpu.barrier_arrive(barrier_ref) # Signal the consumer thread
765769

766770
@pl.when(thread_id == 1)
767771
def consumer_thread():
768772
plgpu.barrier_wait(barrier_ref) # Wait for the producer thread
769-
out_ref[...] = x_ref[...] + 1
773+
out_ref[...] = smem_ref[...] + 1
770774

771-
x = jnp.arange(128, jnp.float32)
772775
y = run_kernel(x) # There's no need to preallocate the input anymore.
773776
np.testing.assert_array_equal(y, x + 2)
774777
```

0 commit comments

Comments
 (0)