You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/pallas/design.md
+8-4Lines changed: 8 additions & 4 deletions
Original file line number
Diff line number
Diff line change
@@ -88,6 +88,7 @@ Because JAX was designed with HLO in mind, the set of JAX primitives closely mir
88
88
Because Pallas was initially designed with Triton in mind, we offer a set of new primitives targeting the Triton programming model. As we’ll show later, we can lower these primitives to Mosaic as well.
89
89
90
90
#### `pallas.load` and `pallas.store`
91
+
91
92
`pallas.load` and `pallas.store` are primitives that allow loading from memory and storing into memory. Unlike `__getitem__` and `__setitem__` they are more flexible at the cost of being more verbose. Specifically, you can use the `pallas.dynamic_slice` (`pallas.ds` for short) construct (which should maybe be upstreamed into JAX to be used with Ref `__getitem__` and `__setitem__`).
92
93
93
94
```python
@@ -114,6 +115,7 @@ def f(x_ref, o_ref):
114
115
Masking is important when doing out-of-bounds loads/stores. The operational semantics of masking can be compiler-determined (if we understand the documentation properly, Triton avoids the read from/write to memory if it’s masked).
115
116
116
117
#### `pallas.program_id` and `pallas.num_programs`
118
+
117
119
As we’ll soon see, we’ll be executing the same Pallas kernels many times (either in parallel or in a pipeline depending on the backend). These new primitives tell us “where” we are in the execution of the kernel.
118
120
119
121
`pallas.program_id` takes in an axis argument, which tells us which index in an axis of a multidimensional grid this kernel is currently executing in (analogous to `threadId` from CUDA programming or `lax.axis_index` in `jax.pmap`). Note that we are currently borrowing the “program” terminology from Triton and in the future we might want to change it to something more familiar to JAX users.
@@ -251,7 +253,7 @@ In this example, we compute tiles of the output by doing an unrolled accumulatio
0 commit comments