Data dependency heuristics for in-place operations are too conservative #19165
-
Consider the following MWE. import time
from functools import partial
import jax.numpy as jnp
from jax import Array, jit, random
@partial(jit, donate_argnums=0)
def f(x: Array) -> tuple[Array, Array]:
x = x.at[0, 0].add(1)
y = x[0, 0]
return x, y
if __name__ == "__main__":
n = 10**4
rng = random.key(0)
rng, subkey = random.split(rng)
x = random.uniform(subkey, shape=(n, n))
f(jnp.copy(x))[0].block_until_ready()
start = time.time()
x = f(x)[0].block_until_ready()
print(f"{time.time() - start:.3e}") As expected, the update happens in-place and runs quickly ( However, if the assignment of @partial(jit, donate_argnums=0)
def f(x: Array) -> tuple[Array, Array]:
y = x[0, 0]
x = x.at[0, 0].add(1)
return x, y the update happens out of place and is very slow ( Note that this happens even if My intuition for what's happening is that in the first example, Is there any way to declare that the requisite data from the original (Based on what I've read so far I suspect the answer is "no", but I'd be happy to be proven wrong!) If the MWE seems contrived, the exact code I'm trying to optimize is as follows. @partial(jit, donate_argnums=(0, 1))
def f(x: Array, y: Array, i: int, k: int) -> tuple[Array, Array]:
v = x[k, i]
x = x.at[k, i].set(0.0)
x = x.at[:, i].add(-(x @ x[k]))
x = x.at[:, i].divide(jnp.sqrt(v + x[k, i]))
y = y.at[:].add(-jnp.square(x[:, i]))
return x, y where Ideally all of these updates to Possibly related to #17845, #17640, and #10197 but much simpler (scan and autograd are not involved). |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
This MWE, at least, is fixed by the XLA flag |
Beta Was this translation helpful? Give feedback.
This MWE, at least, is fixed by the XLA flag
--xla_cpu_copy_insertion_use_region_analysis=true
. See #25399.