[pallas]Explanation for scratch_shapes for tpu #22460
-
Flash attention code for tpu uses something called scratch_shapes, but it's not documented anywhere. Does anyone know what scratch_shapes does and how to use it? |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
We are working on adding documentation for Pallas, but we don't have much yet. The best we can point to at the moment are the tests. |
Beta Was this translation helpful? Give feedback.
-
https://github.com/google/jax/blob/main/docs/pallas/tpu/matmul.md now have some doc/hint to what scratch_shapes is. |
Beta Was this translation helpful? Give feedback.
We are working on adding documentation for Pallas, but we don't have much yet. The best we can point to at the moment are the tests.