Skip to content

Commit 87aee90

Browse files
sharadmvjax authors
authored andcommitted
Fix typo in Pallas design
PiperOrigin-RevId: 621275025
1 parent 87b869d commit 87aee90

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

docs/pallas/design.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -236,10 +236,11 @@ add = pl.pallas_call(
236236
add_kernel,
237237
out_shape=jax.ShapeDtypeStruct((8,), jnp.int32),
238238
in_specs=[
239-
pl.BlockSpec(lambda i:, i, (2,)),
240-
pl.BlockSpec(lambda i:, i, (2,))
239+
pl.BlockSpec(lambda i: i, (2,)),
240+
pl.BlockSpec(lambda i: i, (2,))
241241
],
242-
out_specs=pl.BlockSpec(lambda i: i, (2,))
242+
out_specs=pl.BlockSpec(lambda i: i, (2,)),
243+
grid=(4,))
243244
add(x, y)
244245
```
245246

0 commit comments

Comments
 (0)