jax-ai-stack
packages:jax==0.5.3
↗️ chex==0.1.89
↗️ flax==0.10.5
↗️ ml_dtypes==0.5.1
↗️ optax==0.2.4
orbax-checkpoint==0.11.12
↗️ orbax-export==0.0.6
jax-ai-stack[tfds]
packages:tensorflow==2.19.0
↗️ tensorflow_datasets==4.9.8
↗️
jax-ai-stack[grain]
packages:grain==0.2.7
↗️