Skip to content

Commit 0624775

Browse files
author
jax authors
committed
Merge pull request #20561 from superbobry:docs
PiperOrigin-RevId: 621608577
2 parents 2df89b2 + ea8e393 commit 0624775

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

docs/pallas/design.md

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ Because JAX was designed with HLO in mind, the set of JAX primitives closely mir
8888
Because Pallas was initially designed with Triton in mind, we offer a set of new primitives targeting the Triton programming model. As we’ll show later, we can lower these primitives to Mosaic as well.
8989

9090
#### `pallas.load` and `pallas.store`
91+
9192
`pallas.load` and `pallas.store` are primitives that allow loading from memory and storing into memory. Unlike `__getitem__` and `__setitem__` they are more flexible at the cost of being more verbose. Specifically, you can use the `pallas.dynamic_slice` (`pallas.ds` for short) construct (which should maybe be upstreamed into JAX to be used with Ref `__getitem__` and `__setitem__`).
9293

9394
```python
@@ -114,6 +115,7 @@ def f(x_ref, o_ref):
114115
Masking is important when doing out-of-bounds loads/stores. The operational semantics of masking can be compiler-determined (if we understand the documentation properly, Triton avoids the read from/write to memory if it’s masked).
115116

116117
#### `pallas.program_id` and `pallas.num_programs`
118+
117119
As we’ll soon see, we’ll be executing the same Pallas kernels many times (either in parallel or in a pipeline depending on the backend). These new primitives tell us “where” we are in the execution of the kernel.
118120

119121
`pallas.program_id` takes in an axis argument, which tells us which index in an axis of a multidimensional grid this kernel is currently executing in (analogous to `threadId` from CUDA programming or `lax.axis_index` in `jax.pmap`). Note that we are currently borrowing the “program” terminology from Triton and in the future we might want to change it to something more familiar to JAX users.
@@ -251,7 +253,7 @@ In this example, we compute tiles of the output by doing an unrolled accumulatio
251253
```python
252254
def matmul_kernel(x_ref, y_ref, o_ref, *, activation, block_k):
253255
acc = jnp.zeros((x_ref.shape[0], x_ref.shape[1]), jnp.float32)
254-
for k in range(x_ref.shape[1] // block_k)
256+
for k in range(x_ref.shape[1] // block_k):
255257
x = x_ref[:, k*block_k:(k+1)*block_k]
256258
y = y_ref[k*block_k:(k+1)*block_k, :]
257259
acc += x @ y
@@ -267,10 +269,12 @@ def matmul(x, y, *, block_shape, activation):
267269
partial(matmul_kernel, block_k=block_k, activation=activation),
268270
out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1],), jnp.float32),
269271
in_specs=[
270-
pl.BlockSpec(lambda i, j:, (i, 0), (block_m, x.shape[1])),
271-
pl.BlockSpec(lambda i, j:, (0, j), (y.shape[0], block_n))
272+
pl.BlockSpec(lambda i, j: (i, 0), (block_m, x.shape[1])),
273+
pl.BlockSpec(lambda i, j: (0, j), (y.shape[0], block_n))
272274
],
273-
out_specs=pl.BlockSpec(lambda i, j: (i, j), (block_m, block_n))
275+
out_specs=pl.BlockSpec(lambda i, j: (i, j), (block_m, block_n)),
276+
grid=(4, 4),
277+
)
274278
return fused_matmul(x, y)
275279

276280
z = matmul(x, y, block_shape=block_shape, activation=jax.nn.gelu)

0 commit comments

Comments
 (0)