Skip to content

Commit b73c5c3

Browse files
Upgrade to PyTorch 2.3. (#546)
As discussed on Discord, this is a significant upgrade because it is the first stable release that has a fully functional `torch.export.export` with the preferred dynamic shapes support. It is also just prior to nightlies that completely remove support for the old constraints based API, so is therefore a good point to stop for a moment and support both styles. This patch makes a number of API changes: * Issues deprecation warnings if the `constraints=` keyword for jittable is used, otherwise not passing it to PyTorch. This should make jittable not immediately incompatible with later nightlies unless if that feature is used. * Adds the ability for a `CompiledModule` to directly have an attribute of a `torch.export.ExportedProgram`, allowing the user to pre-export with Torch and then construct a compiled module from that (vs the `jittable` approach where the `CompiledModule` API was directly invoking Torch internals to do so). This defaults to exporting as `public` if given a name not starting with an underscore and private otherwise. Private ExportedPrograms can be called from procedures just as with `jittable`. * `shark_turbine.aot.export()` now accepts either an `CompiledModule`, `nn.Module`, a or a `torch.export.ExportedProgram`. For the last two, a new `external_params=` bool is available to control whether parameters are inlined or externalized. For an `nn.Module` arguments corresponding to `torch.export.export` are added. Internally, for an `nn.Module`, it simply calls `torch.export.export`. `jittable` is no longer used internally. Some attempt has been made to be backwards compatible with Torch 2.1.0. New features will not work, but we should be able to support a short buffer window where older pinned systems are not completely broken. The repository prior to this patch will be branched to `torch_2.1`. Breaking changes: * ops.iree.trace_tensors (plural) had to be removed because the PyTorch auto functionalization thing has a TODO around lists of tensors. We can add a wrapper that takes a list and invokves trace_tensors multiple times and/or ass a `functional_trace_tensors` which works a bit better with the infra. * stateless_llama_test.py::test_rerotated_torch_comparison marked as expectedFailure. Filed #560
1 parent b785714 commit b73c5c3

28 files changed

+669
-201
lines changed

.github/workflows/test.yml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,11 @@ jobs:
3939
# Note: We install in three steps in order to satisfy requirements
4040
# from non default locations first. Installing the PyTorch CPU
4141
# wheels saves multiple minutes and a lot of bandwidth on runner setup.
42-
pip install --index-url https://download.pytorch.org/whl/cpu \
43-
-r core/pytorch-cpu-requirements.txt \
44-
-r core/torchvision-requirements.txt
42+
pip install -r core/pytorch-cpu-requirements.txt
4543
pip install --upgrade \
4644
-r core/requirements.txt \
47-
-r mypy-requirements.txt
45+
-r mypy-requirements.txt \
46+
-r serving/requirements.txt
4847
pip install -e core[testing] -e serving[testing]
4948
5049
- name: Run core tests

.github/workflows/test_models.yml

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,10 @@ jobs:
3838
# Note: We install in three steps in order to satisfy requirements
3939
# from non default locations first. Installing the PyTorch CPU
4040
# wheels saves multiple minutes and a lot of bandwidth on runner setup.
41-
pip install --index-url https://download.pytorch.org/whl/cpu \
42-
-r core/pytorch-cpu-requirements.txt \
43-
-r core/torchvision-requirements.txt
44-
pip install --upgrade -r core/requirements.txt
45-
pip install -e core[testing]
46-
pip install -e models
41+
pip install -r core/pytorch-cpu-requirements.txt
42+
pip install --pre --upgrade -r core/requirements.txt
43+
pip install --pre -e core[testing]
44+
pip install --pre -e models
4745
4846
- name: Show current free memory
4947
run: |

.github/workflows/test_sdxl.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@ jobs:
3131
# from non default locations first. Installing the PyTorch CPU
3232
# wheels saves multiple minutes and a lot of bandwidth on runner setup.
3333
pip install --index-url https://download.pytorch.org/whl/cpu \
34-
-r core/pytorch-cpu-requirements.txt \
35-
-r core/torchvision-requirements.txt
34+
-r core/pytorch-cpu-requirements.txt
3635
pip install --upgrade -r core/requirements.txt
3736
pip install -e core[testing,torch-cpu-nightly]
3837
pip install --upgrade -r models/requirements.txt

MANIFEST.in

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
include README.md
22
include requirements.txt
33
include pytorch-cpu-requirements.txt
4-
include torchvision-requirements.txt
54
include version_info.json

README.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,7 @@ pip install shark-turbine
4545
The above does install some unecessary cuda/cudnn packages for cpu use. To avoid this you
4646
can specify pytorch-cpu and install via:
4747
```
48-
pip install --index-url https://download.pytorch.org/whl/cpu \
49-
-r core/pytorch-cpu-requirements.txt \
50-
-r core/torchvision-requirements.txt
48+
pip install -r core/pytorch-cpu-requirements.txt
5149
pip install shark-turbine
5250
```
5351

core/examples/aot_mlp/mlp_export_dynamic.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,12 @@ def main(self, x=aot.AbstractTensor(None, 97, 8, dtype=torch.float32)):
4949
)
5050

5151

52-
exported = aot.export(CompiledMLP)
52+
batch = torch.export.Dim("batch")
53+
exported = aot.export(
54+
model,
55+
args=(torch.empty([2, 97, 8], dtype=torch.float32),),
56+
dynamic_shapes={"x": {0: batch}},
57+
)
5358
# Note that dynamic Torch IR is created below.
5459
exported.print_readable()
5560

core/iree-requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
iree-compiler==20240311.828
2-
iree-runtime==20240311.828
1+
iree-compiler==20240327.844
2+
iree-runtime==20240327.844

core/pytorch-cpu-requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
--pre
2-
torch==2.1.0
3-
mpmath==1.3.0
2+
--index-url https://download.pytorch.org/whl/test/cpu
3+
-r pytorch-requirements.txt

core/pytorch-requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
torch==2.3.0
2+
torchaudio
3+
torchvision

core/requirements.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,5 @@
44
# versions, not specific).
55
-f https://openxla.github.io/iree/pip-release-links.html
66

7-
-r pytorch-cpu-requirements.txt
8-
-r torchvision-requirements.txt
7+
-r pytorch-requirements.txt
98
-r iree-requirements.txt

0 commit comments

Comments
 (0)