-
I think this pattern is fairly typical for code which is meant to work (possibly in limited fashion) without Jax: try:
import jax.numpy as np
except ImportError:
import numpy as np But is there a prototypical way to allow the user to disable even trying Jax, in case that leads to some other problems down the line? Something like this? if os.environ.get("DISABLE_JAX", False):
import numpy as np
else:
try:
import jax.numpy as np
except ImportError:
import numpy as np That'll definitely work for our purposes, but just wondering if there's an even more established pattern. Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
I'm not sure that this is such a common (or recommended!) pattern, but I could imagine it working in some cases. We typically recommend importing As an aside: JAX, numpy, and some other array libraries now have support for the Array API standard, so it might be worth looking into using the array API if you want to support multiple backends within your library in a standards compliant way. |
Beta Was this translation helpful? Give feedback.
I'm not sure that this is such a common (or recommended!) pattern, but I could imagine it working in some cases. We typically recommend importing
jax.numpy
asjnp
becausejnp
andnp
are often both used in JAX code, and the programming models are different enough that it seems somewhat unlikely that you'd be able to write performant code the way you suggest. But, like I said, I certainly believe you that there are cases where this works and is useful! I'd say that the best approach for your specific question will depend on the details of your project.As an aside: JAX, numpy, and some other array libraries now have support for the Array API standard, so it might be worth looking into using…