Skip to content

Unable to import jax in TPU #7216

Answered by jakevdp
paramjeet2021 asked this question in General
Discussion options

You must be logged in to vote

It might be that your system has the wrong version of libtpu. Try installing the version pinned here: https://github.com/google/jax/blob/38884b02f72950a8d187f81420a270e46afe8889/setup.py#L27

I believe the easiest way to install the correct version is to run

$ pip install -U jax
$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/jax_releases.html

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@paramjeet2021
Comment options

Answer selected by paramjeet2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants