Skip to content

Upgrade to PyTorch 2.3. #546

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,11 @@ jobs:
# Note: We install in three steps in order to satisfy requirements
# from non default locations first. Installing the PyTorch CPU
# wheels saves multiple minutes and a lot of bandwidth on runner setup.
pip install --index-url https://download.pytorch.org/whl/cpu \
-r core/pytorch-cpu-requirements.txt \
-r core/torchvision-requirements.txt
pip install -r core/pytorch-cpu-requirements.txt
pip install --upgrade \
-r core/requirements.txt \
-r mypy-requirements.txt
-r mypy-requirements.txt \
-r serving/requirements.txt
pip install -e core[testing] -e serving[testing]

- name: Run core tests
Expand Down
10 changes: 4 additions & 6 deletions .github/workflows/test_models.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,10 @@ jobs:
# Note: We install in three steps in order to satisfy requirements
# from non default locations first. Installing the PyTorch CPU
# wheels saves multiple minutes and a lot of bandwidth on runner setup.
pip install --index-url https://download.pytorch.org/whl/cpu \
-r core/pytorch-cpu-requirements.txt \
-r core/torchvision-requirements.txt
pip install --upgrade -r core/requirements.txt
pip install -e core[testing]
pip install -e models
pip install -r core/pytorch-cpu-requirements.txt
pip install --pre --upgrade -r core/requirements.txt
pip install --pre -e core[testing]
pip install --pre -e models

- name: Show current free memory
run: |
Expand Down
3 changes: 1 addition & 2 deletions .github/workflows/test_sdxl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ jobs:
# from non default locations first. Installing the PyTorch CPU
# wheels saves multiple minutes and a lot of bandwidth on runner setup.
pip install --index-url https://download.pytorch.org/whl/cpu \
-r core/pytorch-cpu-requirements.txt \
-r core/torchvision-requirements.txt
-r core/pytorch-cpu-requirements.txt
pip install --upgrade -r core/requirements.txt
pip install -e core[testing,torch-cpu-nightly]
pip install --upgrade -r models/requirements.txt
Expand Down
1 change: 0 additions & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
include README.md
include requirements.txt
include pytorch-cpu-requirements.txt
include torchvision-requirements.txt
include version_info.json
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@ pip install shark-turbine
The above does install some unecessary cuda/cudnn packages for cpu use. To avoid this you
can specify pytorch-cpu and install via:
```
pip install --index-url https://download.pytorch.org/whl/cpu \
-r core/pytorch-cpu-requirements.txt \
-r core/torchvision-requirements.txt
pip install -r core/pytorch-cpu-requirements.txt
pip install shark-turbine
```

Expand Down
7 changes: 6 additions & 1 deletion core/examples/aot_mlp/mlp_export_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,12 @@ def main(self, x=aot.AbstractTensor(None, 97, 8, dtype=torch.float32)):
)


exported = aot.export(CompiledMLP)
batch = torch.export.Dim("batch")
exported = aot.export(
model,
args=(torch.empty([2, 97, 8], dtype=torch.float32),),
dynamic_shapes={"x": {0: batch}},
)
# Note that dynamic Torch IR is created below.
exported.print_readable()

Expand Down
4 changes: 2 additions & 2 deletions core/iree-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
iree-compiler==20240311.828
iree-runtime==20240311.828
iree-compiler==20240327.844
iree-runtime==20240327.844
4 changes: 2 additions & 2 deletions core/pytorch-cpu-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
--pre
torch==2.1.0
mpmath==1.3.0
--index-url https://download.pytorch.org/whl/test/cpu
-r pytorch-requirements.txt
3 changes: 3 additions & 0 deletions core/pytorch-requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
torch==2.3.0
torchaudio
torchvision
3 changes: 1 addition & 2 deletions core/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,5 @@
# versions, not specific).
-f https://openxla.github.io/iree/pip-release-links.html

-r pytorch-cpu-requirements.txt
-r torchvision-requirements.txt
-r pytorch-requirements.txt
-r iree-requirements.txt
27 changes: 17 additions & 10 deletions core/shark_turbine/aot/builtins/jittable.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,14 @@

from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union

import warnings

import torch
from torch._decomp import get_decompositions
import torch._dynamo as dynamo
from torch.export import (
Constraint,
dynamic_dim,
)
from torch.fx import (
Graph,
GraphModule,
)
from torch.fx.passes.shape_prop import TensorMetadata
from torch.utils._pytree import (
tree_flatten,
tree_unflatten,
Expand Down Expand Up @@ -148,7 +144,7 @@ def __init__(
*,
decompose_ops: Optional[List[Any]] = None,
decomposition_table: Optional[Dict[Any, Callable[..., Any]]] = None,
constraints: Optional[List[Constraint]] = None,
constraints: Optional[List[Any]] = None,
function_name: Optional[str] = None,
passes: Sequence[str] = DEFAULT_PASSES,
):
Expand Down Expand Up @@ -176,7 +172,7 @@ def resolve_call(
self,
proc_trace: IrTrace,
*py_args,
constraints: Optional[List[Constraint]] = None,
constraints: Optional[List[Any]] = None,
**py_kwargs,
):
type_converter = proc_trace.module_builder.native_type_converter
Expand All @@ -188,6 +184,17 @@ def resolve_call(
if self.constraints is not None:
constraints.extend(self.constraints)

export_kwargs = {}
if len(constraints) > 0:
warnings.warn(
"Compiling program with the old PyTorch constraints system "
"for dynamic shapes is deprecated and will break on PyTorch "
"nightlies after the 2.3 release cut (expect either a PyTorch "
"warning or exception to follow)",
DeprecationWarning,
)
export_kwargs["constraints"] = constraints

# Convert procedural trace values to things that Dynamo can handle.
flat_py_args, args_tree = tree_flatten((py_args, py_kwargs))
flat_pytorch_args = []
Expand Down Expand Up @@ -220,8 +227,8 @@ def flat_wrapped_f(*args):
transformed_f,
aten_graph=True,
decomposition_table=self.decomposition_table,
constraints=constraints,
assume_static_by_default=True,
**export_kwargs, # type: ignore
)
logger.debug("Invoking dynamo trace")
gm, guards = exported_f(*flat_pytorch_args)
Expand Down Expand Up @@ -315,7 +322,7 @@ def flat_wrapped_f(*args):
tree_py_results = tree_unflatten(flat_py_results, out_spec)
return tree_py_results

def _split_py_arg(self, arg, constraints: List[Constraint]) -> Tuple[Value, Any]:
def _split_py_arg(self, arg, constraints: List[Any]) -> Tuple[Value, Any]:
if isinstance(arg, IrTensor):
meta_tensor, meta_constraints = arg._to_meta_tensor()
constraints.extend(meta_constraints)
Expand Down
61 changes: 59 additions & 2 deletions core/shark_turbine/aot/compiled_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import weakref
import sys

from torch.export import ExportedProgram

from . import builtins

from ..support.ir_imports import (
Expand All @@ -35,6 +37,8 @@
current_ir_trace,
)

from .support.procedural.exported_program import import_exported_program

from .support.ir_utils import (
ModuleBuilder,
)
Expand Down Expand Up @@ -130,7 +134,28 @@ def __repr__(self):
return f"<def {self.export_name}({self.signature})>"


Exportable = Union[ExportProcDef, PyOnlyDef, GlobalsDef]
class ExportedProgramDef:
def __init__(
self,
ep: ExportedProgram,
*,
export_name: Optional[str] = None,
public: bool = False,
):
self.export_name = export_name
self.exported_program = ep
self.public = public

def copy(self) -> "ExportedProgramDef":
return ExportedProgramDef(
self.exported_program, export_name=self.export_name, public=self.public
)

def __repr__(self):
return f"<exported_program {self.exported_program}>"


Exportable = Union[ExportProcDef, ExportedProgramDef, PyOnlyDef, GlobalsDef]


class CompiledModuleClassInfo:
Expand All @@ -155,6 +180,15 @@ def export_procs(self) -> Generator[Tuple[str, ExportProcDef], None, None]:
self.all_exports.items(),
) # type: ignore

@property
def exported_programs(
self,
) -> Generator[Tuple[str, ExportedProgramDef], None, None]:
return filter(
lambda kv_tuple: isinstance(kv_tuple[1], ExportedProgramDef),
self.all_exports.items(),
) # type: ignore

@property
def py_only_defs(self) -> Generator[Tuple[str, PyOnlyDef], None, None]:
return filter(
Expand All @@ -175,6 +209,12 @@ def def_attribute(self, key, value):
if isinstance(value, builtins.jittable):
value = PyOnlyDef(value)

# Promote a torch ExportedProgram to an ExportedProgramDef.
if isinstance(value, ExportedProgram):
value = ExportedProgramDef(
value, export_name=key, public=not key.startswith("_")
)

# Detect our own descriptors.
if isinstance(value, GlobalsDef):
logging.debug("DEFINE GLOBALS: %s = %r", key, value)
Expand All @@ -186,11 +226,17 @@ def def_attribute(self, key, value):
value.export_name = key
self.add_export(key, value)
return value

if isinstance(value, PyOnlyDef):
logging.debug("DEFINE PY_ONLY: %s = %r", key, value)
self.add_export(key, value)
return value
if isinstance(value, ExportedProgramDef):
if value.export_name is None:
value = value.copy()
value.export_name = key
logging.debug("DEFINE EXPORTED_PROGRAM: %r", value.export_name)
self.add_export(key, value)
return value

# Infer if it is an exported function.
if callable(value) and inspect.isfunction(value):
Expand Down Expand Up @@ -542,6 +588,17 @@ def __new__(
for key, py_def in info.class_info.py_only_defs:
info.shadow_dict[key] = py_def.py_value

# Instantiate exported programs.
# TODO: This should be done in two phases along with export_procs
# in order to enable dependence.
for key, ep_def in info.class_info.exported_programs:
info.shadow_dict[key] = import_exported_program(
module_builder,
ep_def.exported_program,
symbol_name=ep_def.export_name or "main",
symbol_visibility=None if ep_def.public else "private",
)

# Instantiate procs.
# TODO: This should be done in two phases, first binding the symbols
# and then defining them, enabling dependence.
Expand Down
Loading