@@ -689,7 +689,8 @@ helper. We recommend reviewing the [software pipelining guide](./pipelining.md).
689
689
690
690
``` python
691
691
@pl.run_state
692
- def run_kernel (x_ref , y_ref ):
692
+ def run_kernel (refs ):
693
+ x_ref, y_ref = refs
693
694
# Here, we're not in the kernel yet! pl.run_state simply changes the JAX
694
695
# immutable arrays into mutable GMEM (not SMEM!) references.
695
696
@@ -700,11 +701,11 @@ def run_kernel(x_ref, y_ref):
700
701
def kernel_body ():
701
702
# Once we enter the pl.core_map scope, we are in the body of the kernel.
702
703
block_slice = pl.ds(lax.axis_index(" x" ) * 128 , 128 )
703
- o_ref [block_slice] = x_ref[block_slice] + 1
704
+ y_ref [block_slice] = x_ref[block_slice] + 1
704
705
705
- x = jnp.arange(128 , jnp.float32)
706
+ x = jnp.arange(256 , jnp.float32)
706
707
y_init = jnp.zeros_like(x)
707
- y = run_kernel(x, y_init)
708
+ _, y = run_kernel(x, y_init)
708
709
np.testing.assert_array_equal(y, x + 1 )
709
710
```
710
711
@@ -721,12 +722,12 @@ mesh = plgpu.Mesh(grid=(2,), grid_names=("x",))
721
722
out_shape = jax.ShapeDtypeStruct((256 ,), jnp.float32),
722
723
mesh = mesh
723
724
)
724
- def increment_kernel_core_map (x_ref , y_ref ):
725
+ def run_kernel (x_ref , y_ref ):
725
726
# x_ref and y_ref are in GMEM!
726
727
block_slice = pl.ds(lax.axis_index(" x" ) * 128 , 128 )
727
- o_ref [block_slice] = x_ref[block_slice] + 1
728
+ y_ref [block_slice] = x_ref[block_slice] + 1
728
729
729
- x = jnp.arange(128 , jnp.float32)
730
+ x = jnp.arange(256 , jnp.float32)
730
731
y = run_kernel(x) # No need to preallocate outputs as in pl.core_map.
731
732
np.testing.assert_array_equal(y, x + 1 )
732
733
```
@@ -752,23 +753,25 @@ synchronizing through a barrier and even exchanging data through SMEM.
752
753
753
754
``` python
754
755
mesh = plgpu.Mesh(num_threads = 2 , thread_name = " pallas_thread" )
756
+ x = jnp.arange(128 , jnp.float32)
757
+
755
758
@functools.partial (
756
- plgpu.kernel, out_shape = x, mesh = mesh, scratch_shapes = [plgpu.Barrier()]
759
+ plgpu.kernel, out_shape = x, mesh = mesh,
760
+ scratch_shapes = [plgpu.SMEM(x.shape, x.dtype), plgpu.Barrier()]
757
761
)
758
- def run_kernel (x_ref , y_ref , barrier_ref ):
762
+ def run_kernel (x_ref , y_ref , smem_ref , barrier_ref ):
759
763
thread_id = jax.lax.axis_index(" pallas_thread" )
760
764
761
765
@pl.when (thread_id == 0 )
762
766
def producer_thread ():
763
- smem_val = x_ref[... ] + 1
767
+ smem_ref[ ... ] = x_ref[... ] + 1
764
768
plgpu.barrier_arrive(barrier_ref) # Signal the consumer thread
765
769
766
770
@pl.when (thread_id == 1 )
767
771
def consumer_thread ():
768
772
plgpu.barrier_wait(barrier_ref) # Wait for the producer thread
769
- out_ref[... ] = x_ref [... ] + 1
773
+ out_ref[... ] = smem_ref [... ] + 1
770
774
771
- x = jnp.arange(128 , jnp.float32)
772
775
y = run_kernel(x) # There's no need to preallocate the input anymore.
773
776
np.testing.assert_array_equal(y, x + 2 )
774
777
```
0 commit comments