What's the reason behind significant latency between CPU and TPU data transmission (way higher than "data size" / "available bandwidth"? #29480
Unanswered
JianmingTONG
asked this question in
General
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
We are currently working on a project involving both a custom JAX C++ kernel (running on CPU) and a JAX kernel on TPU. In our setup, the CPU kernel performs data preprocessing, and the processed results are transferred to the TPU for further computation using jax.device_put(). The corresponding code snippet is shown in Figure 1 below.
To better understand the performance characteristics of this pipeline, we collected TPU and CPU execution traces. Please find attached two screenshots illustrating four iterations of the CPU–TPU processing flow:
Figure 2 shows a high-level trace:
Figure 3 provides a zoomed-in view of the region marked by the black oval in Figure 2:
Our main question is regarding the overhead observed immediately after jax.device_put():
For context, we’ve attached the relevant codebase and a README with reproduction instructions: https://drive.google.com/file/d/1yTJVLM-PPaKcpCjDXlCbof3DUErH9c2C/view?usp=sharing
We would greatly appreciate your insights on optimizing this data transfer and minimizing preprocessing bottlenecks!
Best
Jianming and Jingtian
Figure 1

Figure 2

Figure 3

Beta Was this translation helpful? Give feedback.
All reactions