Should jax.jit optimize away diagonal matrix multiplications? #15343
Unanswered
ryan112358
asked this question in
General
Replies: 1 comment 6 replies
-
No, I don't believe the compiler has any optimization that will fuse diag plus matmul. Your best bet is probably to write the efficient version of this operation manually, as you did in Adding such an optimization would be tricky because there is no XLA primitive for constructing a 2D diagonal matrix from a 1D vector. Consider the jaxpr for >>> jax.make_jaxpr(jnp.diag)(jnp.arange(4))
{ lambda ; a:i32[4]. let
b:i32[4,4] = pjit[
jaxpr={ lambda ; c:i32[4]. let
d:i32[4] = pad[padding_config=((0, 0, 0),)] c 0
e:i32[4,4] = iota[dimension=0 dtype=int32 shape=(4, 4)]
f:i32[4,4] = add e 0
g:i32[4,4] = iota[dimension=1 dtype=int32 shape=(4, 4)]
h:bool[4,4] = eq f g
i:i32[4] = broadcast_in_dim[broadcast_dimensions=() shape=(4,)] 0
j:i32[4,4] = pjit[
jaxpr={ lambda ; k:bool[4,4] l:i32[4] m:i32[4]. let
n:i32[4,4] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(4, 4)
] l
o:i32[4,4] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(4, 4)
] m
p:i32[4,4] = select_n k o n
in (p,) }
name=_where
] h d i
in (j,) }
name=_diag
] a
in (b,) } The compiler would have to pattern-match this sequence of ops in order to recognize it as "create diagonal matrix" and optimize it away, and that's probably not an optimization that XLA will support. |
Beta Was this translation helpful? Give feedback.
6 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
From a readability standpoint it is sometimes preferable to write A @ jnp.diag(v), where A is size n x n and v is size n, even though A * v is equivalent due to broadcasting. With numpy it's clear the latter is more efficient, but one might hope that the former implementation would get compiled to the latter one with jax+jit. That doesn't appear to be the case:
Beta Was this translation helpful? Give feedback.
All reactions