Can copies of arrays be avoided by XLA optimization? #17605
Unanswered
pfackeldey
asked this question in
Q&A
Replies: 1 comment 2 replies
-
Hi - I don't believe that XLA will optimize-away copies in the case of concatenation. The problem is that input arrays will not generally be adjacent in memory, so the concatenated result cannot in general be represented using the original memory layout. The best way to proceed would be to do your entire computation in a vectorized manner, e.g. by computing your histogram bins within a Hope that helps! |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Dear JAX experts,
I am currently developing a fitting library in high energy physics for binned likelihood fits (
dilax
). These type of fits are typically used in physics analyses of particle physics experiments at CERN.In these fits I need to evaluate a likelihood function, which is based on a large set of histograms (O(100 - 10^5)), and parameters that are adjusted during minimisation (e.g. experimental uncertainties). Each histogram has typically O(1-100) bins and the number of bins differ between histograms.
Currently I am just looping over histograms and apply functions that are based on previous mentioned parameters on them. However, I could in principle also concatenate multiple histograms together and then apply these functions. As far as I understand, I could benefit (significantly more) from vectorisation especially on GPUs; however, (afaik)
jnp.concatenate
will create copies of the inputs internally potentially leading to memory spikes, which is not the case when I just loop over these histograms.My questions now boil down too:
jnp.concatenate
for vectorisation, but potentially have memory spikes due to these copies, or is it better for loop over these different sized histograms?jax.lax.scan
to avoid large compile times (which I would get if I just use python for-loops as they are unrolled). Are there any other benefits of usingjax.lax.scan
, i.e. any optimizations which XLA can do compared to using a regular python for-loop?Thank you very much in advance for you help!
Best, Peter
Beta Was this translation helpful? Give feedback.
All reactions