Skip to content

Missing .device attribute inside @jax.jit #26000

@crusaderky

Description

@crusaderky

Description

JAX arrays are missing the .device attribute when running inside jax.jit, and accessing the .devices() method raises a concretization error:

>>> import jax, jax.numpy as jnp
>>> x = jnp.asarray(0)
>>> def f(x): return jnp.zeros(0, device=x.device)
>>> f(x)
Array([], shape=(0,), dtype=float32)
>>> jax.jit(f)(x)
AttributeError: DynamicJaxprTracer has no attribute device
>>> def g(x): return jnp.zeros(0, device=next(iter(x.devices())))
>>> jax.jit(g)(x)
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[]
The devices() method was called on traced array with shape int32[].
The error occurred while tracing the function g at <ipython-input-6-5e654833a7fe>:1 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
The error occurred while tracing the function g at <ipython-input-6-5e654833a7fe>:1 for jit. This concrete value was not available in Python because it depends on the value of the argument x.

This breaks Array API compatibility and hinders array API compliant libraries that use the pattern

def f(x):
    xp = array_namespace(x)
    return xp.asarray(123, device=x.device)

You can see two such use cases in array-api-extra: https://github.com/search?q=repo%3Adata-apis%2Farray-api-extra+device%3D_compat.device&type=code

Workaround

I'm provisionally implementing a workaround in array-api-compat that causes device(x) to return None and to_device(x, device) to accept None to work around this issue. This will however produce outputs on the wrong device when x is not on the default device.

System info (python version, jaxlib version, accelerator, etc.)

JAX 0.4.35

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions