Skip to content

Is it worth using jnp.linalg.multi_dot rather than reduce jnp.dot? #18299

Answered by jakevdp
nasyxx asked this question in Q&A
Discussion options

You must be logged in to vote

multi_dot is not much help here – its main purpose is to find an optimal ordering of a sequence of dot products given the size of the inputs. In your case, all inputs are the same size, so the order of operations doesn't matter.

The slow compilation times here come from the fact that in either case your operation is relying on Python control flow to dispatch a large number of JAX operations: Python loops are flattened leading to large programs, and large programs have slow compilation (compilation time grows roughly as the number of operations squared). In situations like this, you can often do better by re-expressing your loop in terms of a control flow primitive like lax.scan. For example:

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@nasyxx
Comment options

Answer selected by jakevdp
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants