Replies: 2 comments
-
I'm also having the same issue, waiting for a reply. |
Beta Was this translation helpful? Give feedback.
0 replies
-
I have just started to work on the tuning of the JAX model, and I have same questions about this. I hope this question can be answered. Thank you very much. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Question
When I was trying to optimize the inference time of Alphafold3 on GPU under the jax framework,
I found that the sum of the timings of each module function in the inference function was not
consistent with the time testing only on this inference function. Is this due to the jax framework?
We set time for all modules. Each module time consumption is shown in table below. Obviously, the total inference
time is not consistent with time testing on each module. So we suspect that it may be caused by the jax framework.
Thanks for your answering!
Testing the total inference time.
result = self._model(rng_key, featurised_example)
will call the_model()
function. Itmainly executes the forward fuction
model.Model(self._model_config)(batch)
of class ModelInference Process
Call function of Class Model
Beta Was this translation helpful? Give feedback.
All reactions