Skip to content

Commit 51352fa

Browse files
author
jax authors
committed
fix matrix dimension and block shape.
PiperOrigin-RevId: 624988654
1 parent 90401d5 commit 51352fa

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

docs/pallas/design.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,15 +252,15 @@ In this example, we compute tiles of the output by doing an unrolled accumulatio
252252

253253
```python
254254
def matmul_kernel(x_ref, y_ref, o_ref, *, activation, block_k):
255-
acc = jnp.zeros((x_ref.shape[0], x_ref.shape[1]), jnp.float32)
255+
acc = jnp.zeros((x_ref.shape[0], y_ref.shape[1]), jnp.float32)
256256
for k in range(x_ref.shape[1] // block_k):
257257
x = x_ref[:, k*block_k:(k+1)*block_k]
258258
y = y_ref[k*block_k:(k+1)*block_k, :]
259259
acc += x @ y
260260
o_ref[:, :] = activation(acc).astype(o_ref.dtype)
261261

262262
x, y = jnp.ones((512, 256)), jnp.ones((256, 1024))
263-
block_shape = 256, 256, 128
263+
block_shape = 128, 256, 128
264264

265265
@partial(jax.jit, static_argnames=["block_shape", "activation"])
266266
def matmul(x, y, *, block_shape, activation):

0 commit comments

Comments
 (0)