-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Description
JAX arrays are missing the .device
attribute when running inside jax.jit
, and accessing the .devices()
method raises a concretization error:
>>> import jax, jax.numpy as jnp
>>> x = jnp.asarray(0)
>>> def f(x): return jnp.zeros(0, device=x.device)
>>> f(x)
Array([], shape=(0,), dtype=float32)
>>> jax.jit(f)(x)
AttributeError: DynamicJaxprTracer has no attribute device
>>> def g(x): return jnp.zeros(0, device=next(iter(x.devices())))
>>> jax.jit(g)(x)
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[]
The devices() method was called on traced array with shape int32[].
The error occurred while tracing the function g at <ipython-input-6-5e654833a7fe>:1 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
The error occurred while tracing the function g at <ipython-input-6-5e654833a7fe>:1 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
This breaks Array API compatibility and hinders array API compliant libraries that use the pattern
def f(x):
xp = array_namespace(x)
return xp.asarray(123, device=x.device)
You can see two such use cases in array-api-extra: https://github.com/search?q=repo%3Adata-apis%2Farray-api-extra+device%3D_compat.device&type=code
Workaround
I'm provisionally implementing a workaround in array-api-compat
that causes device(x)
to return None and to_device(x, device)
to accept None to work around this issue. This will however produce outputs on the wrong device when x is not on the default device.
System info (python version, jaxlib version, accelerator, etc.)
JAX 0.4.35
34j
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working