Skip to content

Commit 3c6a60f

Browse files
committed
[Mosaic GPU] Fix some typos in docs
1 parent fd8a79f commit 3c6a60f

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

docs/pallas/gpu/reference.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -700,9 +700,9 @@ def run_kernel(x_ref, y_ref):
700700
def kernel_body():
701701
# Once we enter the pl.core_map scope, we are in the body of the kernel.
702702
block_slice = pl.ds(lax.axis_index("x") * 128, 128)
703-
o_ref[block_slice] = x_ref[block_slice] + 1
703+
y_ref[block_slice] = x_ref[block_slice] + 1
704704

705-
x = jnp.arange(128, jnp.float32)
705+
x = jnp.arange(256, jnp.float32)
706706
y_init = jnp.zeros_like(x)
707707
y = run_kernel(x, y_init)
708708
np.testing.assert_array_equal(y, x + 1)
@@ -721,12 +721,12 @@ mesh = plgpu.Mesh(grid=(2,), grid_names=("x",))
721721
out_shape=jax.ShapeDtypeStruct((256,), jnp.float32),
722722
mesh=mesh
723723
)
724-
def increment_kernel_core_map(x_ref, y_ref):
724+
def run_kernel(x_ref, y_ref):
725725
# x_ref and y_ref are in GMEM!
726726
block_slice = pl.ds(lax.axis_index("x") * 128, 128)
727-
o_ref[block_slice] = x_ref[block_slice] + 1
727+
y_ref[block_slice] = x_ref[block_slice] + 1
728728

729-
x = jnp.arange(128, jnp.float32)
729+
x = jnp.arange(256, jnp.float32)
730730
y = run_kernel(x) # No need to preallocate outputs as in pl.core_map.
731731
np.testing.assert_array_equal(y, x + 1)
732732
```
@@ -760,7 +760,7 @@ def run_kernel(x_ref, y_ref, barrier_ref):
760760

761761
@pl.when(thread_id == 0)
762762
def producer_thread():
763-
smem_val = x_ref[...] + 1
763+
x_ref[...] = x_ref[...] + 1
764764
plgpu.barrier_arrive(barrier_ref) # Signal the consumer thread
765765

766766
@pl.when(thread_id == 1)

0 commit comments

Comments
 (0)