Skip to content

Commit cb4a134

Browse files
authored
Upgrade JAX to latest (#1113)
- 0.2.19 is too old for the latest version of flax. - Updated command to install the GPU version of JAX since it has changed: https://github.com/google/jax#pip-installation-gpu-cuda http://b/210130222
1 parent a75fd9c commit cb4a134

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

Dockerfile.tmpl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,11 @@ RUN pip install lightgbm==$LIGHTGBM_VERSION && \
104104
{{ end }}
105105

106106
# Install JAX
107-
ENV JAX_VERSION=0.2.19
108107
{{ if eq .Accelerator "gpu" }}
109-
RUN pip install jax[cuda$CUDA_MAJOR_VERSION$CUDA_MINOR_VERSION]==$JAX_VERSION -f https://storage.googleapis.com/jax-releases/jax_releases.html && \
108+
RUN pip install jax[cuda] -f https://storage.googleapis.com/jax-releases/jax_releases.html && \
110109
/tmp/clean-layer.sh
111110
{{ else }}
112-
RUN pip install jax[cpu]==$JAX_VERSION && \
111+
RUN pip install jax[cpu] && \
113112
/tmp/clean-layer.sh
114113
{{ end }}
115114

0 commit comments

Comments
 (0)