diff --git a/.github/workflows/nightly.yaml b/.github/workflows/nightly.yaml index a3c5389..64758c0 100644 --- a/.github/workflows/nightly.yaml +++ b/.github/workflows/nightly.yaml @@ -38,7 +38,7 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install pytest pytest-xdist - python -m pip install -U chex jax flax grain ml_dtypes optax orbax-checkpoint orbax-export tensorflow tensorflow_datasets + python -m pip install -U chex jax flax grain ml_dtypes optax orbax-checkpoint orbax-export tf-nightly tensorflow_datasets - name: Run tests run: | pytest -n auto jax_ai_stack diff --git a/pyproject.toml b/pyproject.toml index 26d13f4..cf1eaf6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,11 +16,11 @@ keywords = [] # pip dependencies of the project dependencies = [ "chex==0.1.89", - "grain==0.2.9", - "jax==0.6.0", + "grain==0.2.10", + "jax==0.6.2", "flax==0.10.6", "ml_dtypes==0.5.1", - "optax==0.2.4", + "optax==0.2.5", "orbax-checkpoint==0.11.13", "orbax-export==0.0.6", ] @@ -40,8 +40,8 @@ dev = [ # TensorFlow datasets is an extra because it has a large install footprint. tfds = [ - "tensorflow==2.19.0", - "tensorflow_datasets==4.9.8", + "tf-nightly", + "tensorflow_datasets==4.9.9", ] # Grain is now part of the default installation; keep this for