Skip to content

Commit 07f58d7

Browse files
authored
Make grain part of the default installation (#202)
1 parent 161732a commit 07f58d7

File tree

5 files changed

+17
-27
lines changed

5 files changed

+17
-27
lines changed

.github/workflows/nightly.yaml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ jobs:
6262
- name: Install dependencies with jax nightly
6363
run: |
6464
python -m pip install --upgrade pip
65-
python -m pip install .[dev,tfds,grain]
65+
python -m pip install .[dev,tfds]
6666
python -m pip install --upgrade --pre jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
6767
- name: Run tests
6868
run: |
@@ -87,7 +87,7 @@ jobs:
8787
- name: Install dependencies with flax nightly
8888
run: |
8989
python -m pip install --upgrade pip
90-
python -m pip install .[dev,tfds,grain]
90+
python -m pip install .[dev,tfds]
9191
python -m pip install --upgrade git+https://github.com/google/flax.git
9292
- name: Run tests
9393
run: |
@@ -112,7 +112,7 @@ jobs:
112112
- name: Install dependencies with optax nightly
113113
run: |
114114
python -m pip install --upgrade pip
115-
python -m pip install .[dev,tfds,grain]
115+
python -m pip install .[dev,tfds]
116116
python -m pip install --upgrade git+https://github.com/google-deepmind/optax.git
117117
- name: Run tests
118118
run: |
@@ -137,7 +137,7 @@ jobs:
137137
- name: Install dependencies with orbax-checkpoint and orbax-export nightly
138138
run: |
139139
python -m pip install --upgrade pip
140-
python -m pip install .[dev,tfds,grain]
140+
python -m pip install .[dev,tfds]
141141
python -m pip install --upgrade 'git+https://github.com/google/orbax/#subdirectory=checkpoint' 'git+https://github.com/google/orbax/#subdirectory=export'
142142
- name: Run tests
143143
run: |
@@ -162,7 +162,7 @@ jobs:
162162
- name: Install dependencies with chex nightly
163163
run: |
164164
python -m pip install --upgrade pip
165-
python -m pip install .[dev,tfds,grain]
165+
python -m pip install .[dev,tfds]
166166
python -m pip install --upgrade 'git+https://github.com/google-deepmind/chex/'
167167
- name: Run tests
168168
run: |

.github/workflows/test.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ jobs:
3737
- name: Install dependencies
3838
run: |
3939
python -m pip install --upgrade pip
40-
pip install .[dev,tfds,grain]
40+
pip install .[dev,tfds]
4141
- name: Run tests
4242
run: |
4343
pytest -n auto jax_ai_stack

README.md

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,20 +44,16 @@ together via the integration tests in this repository. Packages include:
4444
- [optax](https://github.com/google-deepmind/optax): gradient processing and optimization in JAX.
4545
- [orbax](https://github.com/google/orbax): checkpointing and persistence utilities for JAX.
4646
- [chex](https://github.com/google-deepmind/chex): utilities for writing reliable JAX code.
47+
- [grain](https://github.com/google/grain): data loading.
4748

4849
### Optional packages
4950

5051
Additionally, there are optional packages you can install with `pip` extras.
51-
The following command:
52-
```
53-
pip install jax-ai-stack[grain]
54-
```
55-
will install a compatible version of the [grain](https://github.com/google/grain) data
56-
loader (currently mac and linux-only).
5752

58-
Similarly, the following command:
53+
The following command:
5954
```
6055
pip install jax-ai-stack[tfds]
6156
```
62-
will install a compatible version of [tensorflow](https://github.com/tensorflow/tensorflow)
57+
will install a compatible version of
58+
[tensorflow](https://github.com/tensorflow/tensorflow)
6359
and [tensorflow-datasets](https://github.com/tensorflow/datasets).

docs/source/install.md

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,12 @@ together via the integration tests in this repository. Packages include:
1414
- [ml_dtypes](http://github.com/jax-ml/ml_dtypes): NumPy dtype extensions for machine learning.
1515
- [optax](https://github.com/google-deepmind/optax): gradient processing and optimization in JAX.
1616
- [orbax](https://github.com/google/orbax): checkpointing and persistence utilities for JAX.
17+
- [chex](https://github.com/google-deepmind/chex): utilities for writing reliable JAX code.
18+
- [grain](https://github.com/google/grain): data loading.
1719

1820
# Optional packages
1921

20-
Additionally, there are optional packages you can install with `pip` extras.
21-
The following command:
22-
```
23-
pip install jax-ai-stack[grain]
24-
```
25-
will install a compatible version of the [grain](https://github.com/google/grain) data
26-
loader (currently linux-only).
27-
28-
Similarly, the following command:
22+
Additionally, there are optional packages you can install with `pip` extras. The following command:
2923
```
3024
pip install jax-ai-stack[tfds]
3125
```

pyproject.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ keywords = []
1616
# pip dependencies of the project
1717
dependencies = [
1818
"chex==0.1.89",
19+
"grain==0.2.8",
1920
"jax==0.6.0",
2021
"flax==0.10.6",
2122
"ml_dtypes==0.5.1",
@@ -43,10 +44,9 @@ tfds = [
4344
"tensorflow_datasets==4.9.8",
4445
]
4546

46-
# Grain is an extra because as of v0.2.0 it has no OSX wheels.
47-
grain = [
48-
"grain==0.2.8",
49-
]
47+
# Grain is now part of the default installation; keep this for
48+
# backward compatibility.
49+
grain = []
5050

5151
[tool.pyink]
5252
# Formatting configuration to follow Google style-guide

0 commit comments

Comments
 (0)