@@ -700,9 +700,9 @@ def run_kernel(x_ref, y_ref):
700
700
def kernel_body ():
701
701
# Once we enter the pl.core_map scope, we are in the body of the kernel.
702
702
block_slice = pl.ds(lax.axis_index(" x" ) * 128 , 128 )
703
- o_ref [block_slice] = x_ref[block_slice] + 1
703
+ y_ref [block_slice] = x_ref[block_slice] + 1
704
704
705
- x = jnp.arange(128 , jnp.float32)
705
+ x = jnp.arange(256 , jnp.float32)
706
706
y_init = jnp.zeros_like(x)
707
707
y = run_kernel(x, y_init)
708
708
np.testing.assert_array_equal(y, x + 1 )
@@ -721,12 +721,12 @@ mesh = plgpu.Mesh(grid=(2,), grid_names=("x",))
721
721
out_shape = jax.ShapeDtypeStruct((256 ,), jnp.float32),
722
722
mesh = mesh
723
723
)
724
- def increment_kernel_core_map (x_ref , y_ref ):
724
+ def run_kernel (x_ref , y_ref ):
725
725
# x_ref and y_ref are in GMEM!
726
726
block_slice = pl.ds(lax.axis_index(" x" ) * 128 , 128 )
727
- o_ref [block_slice] = x_ref[block_slice] + 1
727
+ y_ref [block_slice] = x_ref[block_slice] + 1
728
728
729
- x = jnp.arange(128 , jnp.float32)
729
+ x = jnp.arange(256 , jnp.float32)
730
730
y = run_kernel(x) # No need to preallocate outputs as in pl.core_map.
731
731
np.testing.assert_array_equal(y, x + 1 )
732
732
```
@@ -760,7 +760,7 @@ def run_kernel(x_ref, y_ref, barrier_ref):
760
760
761
761
@pl.when (thread_id == 0 )
762
762
def producer_thread ():
763
- smem_val = x_ref[... ] + 1
763
+ x_ref[ ... ] = x_ref[... ] + 1
764
764
plgpu.barrier_arrive(barrier_ref) # Signal the consumer thread
765
765
766
766
@pl.when (thread_id == 1 )
0 commit comments