From c12057f981bec474e0d19a995499b1c43d67929e Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Thu, 21 Mar 2024 13:55:33 -0700 Subject: [PATCH 01/18] Update requirements to torch 2.3.0 --- README.md | 4 +--- core/pytorch-cpu-requirements.txt | 6 ++++-- core/requirements.txt | 1 - core/torchvision-requirements.txt | 2 -- 4 files changed, 5 insertions(+), 8 deletions(-) delete mode 100644 core/torchvision-requirements.txt diff --git a/README.md b/README.md index ef982d26a..555cdaee9 100644 --- a/README.md +++ b/README.md @@ -45,9 +45,7 @@ pip install shark-turbine The above does install some unecessary cuda/cudnn packages for cpu use. To avoid this you can specify pytorch-cpu and install via: ``` -pip install --index-url https://download.pytorch.org/whl/cpu \ - -r core/pytorch-cpu-requirements.txt \ - -r core/torchvision-requirements.txt +pip install -r core/pytorch-cpu-requirements.txt pip install shark-turbine ``` diff --git a/core/pytorch-cpu-requirements.txt b/core/pytorch-cpu-requirements.txt index 92e78464b..c7f8e8ce0 100644 --- a/core/pytorch-cpu-requirements.txt +++ b/core/pytorch-cpu-requirements.txt @@ -1,3 +1,5 @@ --pre -torch==2.1.0 -mpmath==1.3.0 +--index-url https://download.pytorch.org/whl/test/cpu +torch==2.3.0 +torchaudio +torchvision diff --git a/core/requirements.txt b/core/requirements.txt index 128012cb7..09ee8e38e 100644 --- a/core/requirements.txt +++ b/core/requirements.txt @@ -5,5 +5,4 @@ -f https://openxla.github.io/iree/pip-release-links.html -r pytorch-cpu-requirements.txt --r torchvision-requirements.txt -r iree-requirements.txt diff --git a/core/torchvision-requirements.txt b/core/torchvision-requirements.txt deleted file mode 100644 index e38d8d008..000000000 --- a/core/torchvision-requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ ---pre -torchvision From 286e3c3c08d5df423e3fc770df6d9de3dd4c1b8d Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Thu, 21 Mar 2024 19:24:45 -0700 Subject: [PATCH 02/18] Adapt ExportedProgram to CompiledModule. --- core/pytorch-cpu-requirements.txt | 2 +- core/shark_turbine/aot/compiled_module.py | 61 +++++- .../support/procedural/exported_program.py | 188 ++++++++++++++++++ core/shark_turbine/dynamo/passes.py | 1 - core/shark_turbine/dynamo/utils.py | 99 --------- core/tests/aot/torch_export_test.py | 56 ++++++ 6 files changed, 304 insertions(+), 103 deletions(-) create mode 100644 core/shark_turbine/aot/support/procedural/exported_program.py delete mode 100644 core/shark_turbine/dynamo/utils.py create mode 100644 core/tests/aot/torch_export_test.py diff --git a/core/pytorch-cpu-requirements.txt b/core/pytorch-cpu-requirements.txt index c7f8e8ce0..20aa5da87 100644 --- a/core/pytorch-cpu-requirements.txt +++ b/core/pytorch-cpu-requirements.txt @@ -1,5 +1,5 @@ --pre --index-url https://download.pytorch.org/whl/test/cpu -torch==2.3.0 +torch>=2.3.0 torchaudio torchvision diff --git a/core/shark_turbine/aot/compiled_module.py b/core/shark_turbine/aot/compiled_module.py index f9b01e255..5653356b9 100644 --- a/core/shark_turbine/aot/compiled_module.py +++ b/core/shark_turbine/aot/compiled_module.py @@ -15,6 +15,8 @@ import weakref import sys +from torch.export import ExportedProgram + from . import builtins from ..support.ir_imports import ( @@ -35,6 +37,8 @@ current_ir_trace, ) +from .support.procedural.exported_program import import_exported_program + from .support.ir_utils import ( ModuleBuilder, ) @@ -130,7 +134,28 @@ def __repr__(self): return f"" -Exportable = Union[ExportProcDef, PyOnlyDef, GlobalsDef] +class ExportedProgramDef: + def __init__( + self, + ep: ExportedProgram, + *, + export_name: Optional[str] = None, + public: bool = False, + ): + self.export_name = export_name + self.exported_program = ep + self.public = public + + def copy(self) -> "ExportedProgramDef": + return ExportedProgramDef( + self.exported_program, export_name=self.export_name, public=self.public + ) + + def __repr__(self): + return f"" + + +Exportable = Union[ExportProcDef, ExportedProgramDef, PyOnlyDef, GlobalsDef] class CompiledModuleClassInfo: @@ -155,6 +180,15 @@ def export_procs(self) -> Generator[Tuple[str, ExportProcDef], None, None]: self.all_exports.items(), ) # type: ignore + @property + def exported_programs( + self, + ) -> Generator[Tuple[str, ExportedProgramDef], None, None]: + return filter( + lambda kv_tuple: isinstance(kv_tuple[1], ExportedProgramDef), + self.all_exports.items(), + ) # type: ignore + @property def py_only_defs(self) -> Generator[Tuple[str, PyOnlyDef], None, None]: return filter( @@ -175,6 +209,12 @@ def def_attribute(self, key, value): if isinstance(value, builtins.jittable): value = PyOnlyDef(value) + # Promote a torch ExportedProgram to an ExportedProgramDef. + if isinstance(value, ExportedProgram): + value = ExportedProgramDef( + value, export_name=key, public=not key.startswith("_") + ) + # Detect our own descriptors. if isinstance(value, GlobalsDef): logging.debug("DEFINE GLOBALS: %s = %r", key, value) @@ -186,11 +226,17 @@ def def_attribute(self, key, value): value.export_name = key self.add_export(key, value) return value - if isinstance(value, PyOnlyDef): logging.debug("DEFINE PY_ONLY: %s = %r", key, value) self.add_export(key, value) return value + if isinstance(value, ExportedProgramDef): + if value.export_name is None: + value = value.copy() + value.export_name = key + logging.debug("DEFINE EXPORTED_PROGRAM: %r", value.export_name) + self.add_export(key, value) + return value # Infer if it is an exported function. if callable(value) and inspect.isfunction(value): @@ -542,6 +588,17 @@ def __new__( for key, py_def in info.class_info.py_only_defs: info.shadow_dict[key] = py_def.py_value + # Instantiate exported programs. + # TODO: This should be done in two phases along with export_procs + # in order to enable dependence. + for key, ep_def in info.class_info.exported_programs: + info.shadow_dict[key] = import_exported_program( + module_builder, + ep_def.exported_program, + symbol_name=ep_def.export_name, + symbol_visibility="public" if ep_def.public else "private", + ) + # Instantiate procs. # TODO: This should be done in two phases, first binding the symbols # and then defining them, enabling dependence. diff --git a/core/shark_turbine/aot/support/procedural/exported_program.py b/core/shark_turbine/aot/support/procedural/exported_program.py new file mode 100644 index 000000000..4ec1081cb --- /dev/null +++ b/core/shark_turbine/aot/support/procedural/exported_program.py @@ -0,0 +1,188 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# Portions Copyright 2022 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Dict, List, Optional + +import torch + +from torch.utils._pytree import ( + tree_flatten, + tree_unflatten, + treespec_pprint, +) + +from iree.compiler.extras.fx_importer import ( + FxImporter, +) + +from ....support.ir_imports import ( + func_d, + FlatSymbolRefAttr, + FunctionType, + IrType, + Operation, + StringAttr, + TypeAttr, + Value, +) + +from ..ir_utils import ( + ModuleBuilder, +) + +from .base import ( + CallableIntrinsic, +) + +from .primitives import ( + IrImmediateTensor, + IrTensor, +) + +from .tracer import ( + IrTrace, +) + + +class ExportedProgramIntrinsic(CallableIntrinsic): + def __init__( + self, + entry_func_op: Operation, + entry_sig: torch.export.ModuleCallSignature, + user_output_dtypes: List[Optional[torch.dtype]], + ): + self.entry_func_op = entry_func_op + self.entry_sig = entry_sig + self.user_output_dtypes = user_output_dtypes + + @property + def function_type(self) -> FunctionType: + return TypeAttr(self.entry_func_op.attributes["function_type"]).value + + @property + def function_symbol(self) -> StringAttr: + return StringAttr(self.entry_func_op.attributes["sym_name"]) + + @property + def function_visibility(self) -> StringAttr: + return StringAttr(self.entry_func_op.attributes["sym_visibility"]) + + def resolve_call( + self, + proc_trace: IrTrace, + *py_args, + **py_kwargs, + ): + visibility = self.function_visibility + if visibility.value != "private": + raise ValueError( + f"Currently, only private ExportedPrograms can be called: " + f"{self.function_symbol} is {visibility}" + ) + + # Flatten and convert py args to torch IR values. + flat_py_args, args_tree = tree_flatten(((list(py_args),), py_kwargs)) + if args_tree != self.entry_sig.in_spec: + print(dir(args_tree)) + raise ValueError( + f"Mismatched arguments to exported program. \n" + f" Got: {treespec_pprint(args_tree)}\n" + f" Expected: {treespec_pprint(self.entry_sig.in_spec)} " + ) + function_type = self.function_type + flat_ir_args = [ + self._py_to_torch_ir(proc_trace, py_arg, torch_type) + for py_arg, torch_type in zip(flat_py_args, function_type.inputs) + ] + + # Call. + with proc_trace.ip, proc_trace.loc: + flat_ir_results = func_d.CallOp( + function_type.results, + FlatSymbolRefAttr.get(self.function_symbol.value), + flat_ir_args, + ).results + + # Convert torch IR values to python. + flat_py_results = [ + self._torch_ir_to_py(proc_trace, ir_value, dtype) + for ir_value, dtype in zip(flat_ir_results, self.user_output_dtypes) + ] + + return tree_unflatten(flat_py_results, self.entry_sig.out_spec) + + def _py_to_torch_ir( + self, proc_trace: IrTrace, py_value, torch_type: IrType + ) -> Value: + type_converter = proc_trace.module_builder.native_type_converter + if isinstance(py_value, IrTensor): + # TODO: Allow certain static info casts. + return type_converter.materialize_native_to_torch( + py_value.ir_value, torch_type + ) + else: + raise ValueError( + f"Unsupported type in arguments of call to ExportedProgram: " + f"{type(py_value)}: {py_value}" + ) + + def _torch_ir_to_py( + self, proc_trace: IrTrace, ir_value: Value, dtype: Optional[torch.dtype] + ): + type_converter = proc_trace.module_builder.native_type_converter + native_ir_value = type_converter.materialize_torch_to_native(ir_value) + if dtype is not None: + return IrImmediateTensor(native_ir_value, dtype) + else: + raise TypeError( + f"Unknown PyTorch->IREE value mapping for ExportedProgram output: " + f"{native_ir_value}" + ) + + +def import_exported_program( + module_builder: ModuleBuilder, + exported_program: torch.export.ExportedProgram, + symbol_name: str, + symbol_visibility: str, +) -> ExportedProgramIntrinsic: + fx_importer = FxImporter( + module_op=module_builder.module_op, + config_check=False, + py_attr_tracker=module_builder.fx_py_attr_tracker, + ) + entry_func_op = fx_importer.import_program( + exported_program, func_name=symbol_name, func_visibility=symbol_visibility + ) + + module_call_graph = exported_program.module_call_graph + assert len(module_call_graph) >= 1, "Expected at least one module call signature" + entry_module_call_entry = module_call_graph[0] + assert ( + entry_module_call_entry.fqn == "" + ), "Expected first module call entry to be unnamed" + + # We want additional torch-level metadata about any user outputs. + # This will help us create a true python fake without loss of information. + user_output_dtypes: list[Optional[torch.dtype]] = [] + node_map: Dict[str, torch.fx.Node] = { + n.name: n for n in exported_program.graph.nodes + } + for user_output in exported_program.graph_signature.user_outputs: + output_node = node_map[user_output] + tensor_meta = output_node.meta.get("tensor_meta") + fake_val = output_node.meta.get("val") + dtype = None + if tensor_meta is not None: + dtype = tensor_meta.dtype + elif fake_val is not None: + dtype = fake_val.dtype + user_output_dtypes.append(dtype) + + return ExportedProgramIntrinsic( + entry_func_op, entry_module_call_entry.signature, user_output_dtypes + ) diff --git a/core/shark_turbine/dynamo/passes.py b/core/shark_turbine/dynamo/passes.py index 5a9a7d16b..18220910f 100644 --- a/core/shark_turbine/dynamo/passes.py +++ b/core/shark_turbine/dynamo/passes.py @@ -1,7 +1,6 @@ import torch from torch.fx.experimental.proxy_tensor import make_fx from torch._decomp import get_decompositions -from shark_turbine.dynamo import utils from torch.func import functionalize from typing import List, Optional diff --git a/core/shark_turbine/dynamo/utils.py b/core/shark_turbine/dynamo/utils.py deleted file mode 100644 index 05035e803..000000000 --- a/core/shark_turbine/dynamo/utils.py +++ /dev/null @@ -1,99 +0,0 @@ -import torch -from torch._prims_common.wrappers import out_wrapper -from torch._prims_common import ( - DeviceLikeType, - TensorLikeType, -) -import torch._refs as _refs -from torch._decomp import get_decompositions, register_decomposition -from torch import Tensor -from typing import Dict, List, Tuple, Optional - - -if torch.__version__ < "2.2.0": - # Torch versions prior to 2.2.0 lacked some decompositions, which we - # add manually. - @register_decomposition(torch.ops.aten._scaled_dot_product_flash_attention.default) - def scaled_dot_product_flash_attention( - query, - key, - value, - dropout_p: float = 0.0, - is_causal: bool = False, - return_debug_mask: bool = False, - *, - scale: Optional[float] = None, - ) -> Tuple[Tensor, Tensor, Tensor, Tensor, int, int, Tensor, Tensor, Tensor]: - dtype = query.dtype - batchSize, num_head, qSize, headSize = ( - query.shape[0], - query.shape[1], - query.shape[2], - query.shape[3], - ) - - logsumexp = torch.empty( - [batchSize, qSize, num_head, headSize], dtype=torch.float - ) - cum_seq_q, cum_seq_k = torch.empty([], dtype=torch.long), torch.empty( - [], dtype=torch.long - ) - max_q, max_k = 0, 0 - philox_seed, philox_offset = torch.empty([], dtype=torch.long), torch.empty( - [], dtype=torch.long - ) - debug_attn_mask = torch.empty( - [], - dtype=query.dtype, - device="cpu", - requires_grad=query.requires_grad, - ) - output, _ = torch.ops.aten._scaled_dot_product_attention_math.default( - query, key, value, None, dropout_p, is_causal, None, scale=scale - ) - output = output.transpose(1, 2).contiguous( - memory_format=torch.contiguous_format - ) - return ( - output.transpose(1, 2), - logsumexp, - cum_seq_q, - cum_seq_k, - max_q, - max_k, - philox_seed, - philox_offset, - debug_attn_mask, - ) - - -# manually add decomposition to bypass the error that comes -# from VAE encode(inp).latent_dist.sample() failing to symbolically -# trace from torch fx. -# Expected Torch stable version: > 2.1.0 -# diffusers side issue: https://github.com/huggingface/diffusers/issues/6239 -# temporary Torch fix: https://github.com/pytorch/pytorch/issues/107170 -@register_decomposition(torch.ops.aten.randn.generator) -@out_wrapper() -def randn_generator( - *shape, - generator: Optional[torch.Generator] = None, - dtype: Optional[torch.dtype] = None, - device: Optional[DeviceLikeType] = None, - layout: Optional[torch.layout] = None, - requires_grad: bool = False, - pin_memory: bool = False, -) -> TensorLikeType: - # We should eventually support the generator overload. - # However, if someone passes in a None generator explicitly, - # we can jut fall back to randn.default - if generator is None: - return _refs.randn( - *shape, - dtype=dtype, - device=device, - layout=layout, - requires_grad=requires_grad, - pin_memory=pin_memory, - ) - return NotImplemented diff --git a/core/tests/aot/torch_export_test.py b/core/tests/aot/torch_export_test.py new file mode 100644 index 000000000..bfaa5b6a8 --- /dev/null +++ b/core/tests/aot/torch_export_test.py @@ -0,0 +1,56 @@ +# Copyright 2023 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import unittest + +import torch + +from iree.compiler.ir import ( + Context, +) + +from shark_turbine.aot import * +from shark_turbine.aot.builtins import * + + +class TorchExportTests(unittest.TestCase): + def testImportPhases(self): + class MyModule(torch.nn.Module): + def forward(self): + ... + + fxb = FxProgramsBuilder(MyModule()) + + @fxb.export_program( + args=([torch.empty([3, 2]), torch.empty([1, 2])],), + kwargs={"foobar": torch.empty([3, 1])}, + ) + def compute(module, inputs, *, foobar): + t1 = inputs[0] + t2 = inputs[1] + t3 = t1 + t2 + foobar + return [t3 * t3, foobar] + + class ExportedProcModule(CompiledModule): + _compute = compute + + def foobar( + self, + t1=AbstractTensor(3, 2), + t2=AbstractTensor(1, 2), + t3=AbstractTensor(3, 1), + ): + return self._compute(t1, t2, foobar=t3) + + inst = ExportedProcModule(context=Context(), import_to="import") + module_str = str(CompiledModule.get_mlir_module(inst)) + print(module_str) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main() From 9cd675f89a6cfee9309e27dbdc30235f5a946f70 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Mon, 25 Mar 2024 18:22:04 -0700 Subject: [PATCH 03/18] Adapt all APIs. --- core/examples/aot_mlp/mlp_export_dynamic.py | 7 +- core/shark_turbine/aot/builtins/jittable.py | 27 ++-- core/shark_turbine/aot/compiled_module.py | 2 +- core/shark_turbine/aot/exporter.py | 96 ++++++++---- .../support/procedural/exported_program.py | 97 +++++++++++- core/tests/aot/api_test.py | 50 +++++++ .../aot/compiled_exported_program_test.py | 138 ++++++++++++++++++ core/tests/aot/jittable_test.py | 4 +- core/tests/aot/torch_export_test.py | 56 ------- 9 files changed, 370 insertions(+), 107 deletions(-) create mode 100644 core/tests/aot/compiled_exported_program_test.py delete mode 100644 core/tests/aot/torch_export_test.py diff --git a/core/examples/aot_mlp/mlp_export_dynamic.py b/core/examples/aot_mlp/mlp_export_dynamic.py index 66ca38554..cd8636554 100644 --- a/core/examples/aot_mlp/mlp_export_dynamic.py +++ b/core/examples/aot_mlp/mlp_export_dynamic.py @@ -49,7 +49,12 @@ def main(self, x=aot.AbstractTensor(None, 97, 8, dtype=torch.float32)): ) -exported = aot.export(CompiledMLP) +batch = torch.export.Dim("batch") +exported = aot.export( + model, + args=(torch.empty([2, 97, 8], dtype=torch.float32),), + dynamic_shapes={"x": {0: batch}}, +) # Note that dynamic Torch IR is created below. exported.print_readable() diff --git a/core/shark_turbine/aot/builtins/jittable.py b/core/shark_turbine/aot/builtins/jittable.py index 12942b22c..380339acd 100644 --- a/core/shark_turbine/aot/builtins/jittable.py +++ b/core/shark_turbine/aot/builtins/jittable.py @@ -9,18 +9,14 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union +import warnings + import torch from torch._decomp import get_decompositions import torch._dynamo as dynamo -from torch.export import ( - Constraint, - dynamic_dim, -) from torch.fx import ( - Graph, GraphModule, ) -from torch.fx.passes.shape_prop import TensorMetadata from torch.utils._pytree import ( tree_flatten, tree_unflatten, @@ -148,7 +144,7 @@ def __init__( *, decompose_ops: Optional[List[Any]] = None, decomposition_table: Optional[Dict[Any, Callable[..., Any]]] = None, - constraints: Optional[List[Constraint]] = None, + constraints: Optional[List[Any]] = None, function_name: Optional[str] = None, passes: Sequence[str] = DEFAULT_PASSES, ): @@ -176,7 +172,7 @@ def resolve_call( self, proc_trace: IrTrace, *py_args, - constraints: Optional[List[Constraint]] = None, + constraints: Optional[List[Any]] = None, **py_kwargs, ): type_converter = proc_trace.module_builder.native_type_converter @@ -188,6 +184,17 @@ def resolve_call( if self.constraints is not None: constraints.extend(self.constraints) + export_kwargs = {} + if len(constraints) > 0: + warnings.warn( + "Compiling program with the old PyTorch constraints system " + "for dynamic shapes is deprecated and will break on PyTorch " + "nightlies after the 2.3 release cut (expect either a PyTorch " + "warning or excpetion to follow)", + DeprecationWarning, + ) + export_kwargs["constraints"] = constraints + # Convert procedural trace values to things that Dynamo can handle. flat_py_args, args_tree = tree_flatten((py_args, py_kwargs)) flat_pytorch_args = [] @@ -220,8 +227,8 @@ def flat_wrapped_f(*args): transformed_f, aten_graph=True, decomposition_table=self.decomposition_table, - constraints=constraints, assume_static_by_default=True, + **export_kwargs, ) logger.debug("Invoking dynamo trace") gm, guards = exported_f(*flat_pytorch_args) @@ -315,7 +322,7 @@ def flat_wrapped_f(*args): tree_py_results = tree_unflatten(flat_py_results, out_spec) return tree_py_results - def _split_py_arg(self, arg, constraints: List[Constraint]) -> Tuple[Value, Any]: + def _split_py_arg(self, arg, constraints: List[Any]) -> Tuple[Value, Any]: if isinstance(arg, IrTensor): meta_tensor, meta_constraints = arg._to_meta_tensor() constraints.extend(meta_constraints) diff --git a/core/shark_turbine/aot/compiled_module.py b/core/shark_turbine/aot/compiled_module.py index 5653356b9..72520a0a6 100644 --- a/core/shark_turbine/aot/compiled_module.py +++ b/core/shark_turbine/aot/compiled_module.py @@ -596,7 +596,7 @@ def __new__( module_builder, ep_def.exported_program, symbol_name=ep_def.export_name, - symbol_visibility="public" if ep_def.public else "private", + symbol_visibility=None if ep_def.public else "private", ) # Instantiate procs. diff --git a/core/shark_turbine/aot/exporter.py b/core/shark_turbine/aot/exporter.py index 2bb746df2..8e8025126 100644 --- a/core/shark_turbine/aot/exporter.py +++ b/core/shark_turbine/aot/exporter.py @@ -4,8 +4,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import Any, Optional, Sequence, Union -import functools +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import io from pathlib import Path import platform @@ -27,18 +26,14 @@ from .compiled_module import ( CompiledModule, CompiledModuleMeta, - ExportProcDef, ImportPhase, ) -from .support.procedural import ( - AbstractTypedef, -) _is_windows = platform.system() == "Windows" -ModuleLike = Union[torch.nn.Module, CompiledModuleMeta] +ModuleLike = Union[torch.nn.Module, CompiledModuleMeta, torch.export.ExportedProgram] SaveableTarget = Union[str, Path, None, Output] @@ -150,48 +145,89 @@ def compile( return None -# Decorator which explicitly exports a function. -# TODO: Make this a public API on CompiledModule. -# See https://github.com/nod-ai/SHARK-Turbine/issues/126 -def export_proc(f=None, *, signature: Sequence[AbstractTypedef]) -> Any: - if f is None: - return functools.partial(export_proc, signature=signature) - return ExportProcDef(f.__name__, f, signature=signature) - +def export( + mdl: ModuleLike, + *example_args: torch.Tensor, + args: Optional[tuple] = None, + kwargs: Optional[Dict[str, Any]] = None, + dynamic_shapes: Dict[str, Any] | Tuple[Any] | List[Any] | None = None, + external_params: bool = False, +) -> ExportOutput: + """One shot export of an nn.Module or CompiledModule. -def export(mdl: ModuleLike, *example_args: torch.Tensor) -> ExportOutput: - """One shot export of an nn.Module. + This function behaves differently based on the type of the `mdl` argument: - This is a very restrictive API vs the lower level `CompiledModule` - facility. It is suitable for one-shot modules, with a single - entrypoint and static example arguments where no additional - configuration is needed for mutable parameters/buffers or state - management. Dynamic shape constraints are also not presently - exposed via this API, but we expect to allow this in the future. + * nn.Module: The module is traced with torch.export.export passing it + `args`, `kwargs`, and `dynamic_shapes`. + * CompiledModule: The module is imported to IR. Additional arguments are + illegal in this case. + * torch.export.ExportedProgram: A pre-exported program can be passed and + it will be used to construct a single-entrypoint module. Args: mdl: The nn.Module to export. *example_args: Example tensors. + args: Example arguments to torch.export (if present, then *example_args + must be empty. + kwargs: Example keyword arguments. + dynamic_shapes: Dynamic shape specs to pass to torch.export. + external_params: Whether to declare parameters as external vs inlining + contents. Returns: An ExportOutput object that wraps the compilation and provides easy access. """ TransformedModule: Any - if isinstance(mdl, torch.nn.Module): + if isinstance(mdl, torch.export.ExportedProgram): + if ( + len(example_args) > 0 + or args is not None + or kwargs is not None + or dynamic_shapes is not None + ): + raise ValueError( + "If passing an ExportedProgram to aot.export, cannot also pass " + "args, example_args, kwargs, or dynamic_dims" + ) + + class Exported(CompiledModule, export_name=mdl.graph_module._get_name()): + params = export_global_tree( + dict(list(mdl.named_parameters())), external=external_params + ) + main = mdl + + TransformedModule = Exported + elif isinstance(mdl, torch.nn.Module): + # Normalize arguments for torch.export. + if args is None: + args = example_args + elif len(example_args) > 0: + raise ValueError( + "Cannot pass args= and positional example_args at the same time" + ) nn_module = mdl - signature = [abstractify(t) for t in example_args] + exported_program = torch.export.export( + nn_module, args=args, kwargs=kwargs, dynamic_shapes=dynamic_shapes + ) class Exported(CompiledModule, export_name=nn_module._get_name()): - params = export_parameters(nn_module) - - @export_proc(signature=signature) - def main(self, *args): - return jittable(nn_module.forward)(*args) + params = export_parameters(nn_module, external=external_params) + main = exported_program TransformedModule = Exported else: assert isinstance(mdl, CompiledModuleMeta) + if ( + len(example_args) > 0 + or args is not None + or kwargs is not None + or dynamic_shapes is not None + ): + raise ValueError( + "If passing a CompiledModule to aot.export, cannot also pass " + "args, example_args, kwargs, or dynamic_dims" + ) TransformedModule = mdl session = Session() diff --git a/core/shark_turbine/aot/support/procedural/exported_program.py b/core/shark_turbine/aot/support/procedural/exported_program.py index 4ec1081cb..12e325fd6 100644 --- a/core/shark_turbine/aot/support/procedural/exported_program.py +++ b/core/shark_turbine/aot/support/procedural/exported_program.py @@ -5,7 +5,9 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional + +import inspect import torch @@ -17,10 +19,15 @@ from iree.compiler.extras.fx_importer import ( FxImporter, + FxImporterHooks, + GraphNodeImporter, ) +from ....support.logging import aot_logger as logger + from ....support.ir_imports import ( func_d, + util_d, FlatSymbolRefAttr, FunctionType, IrType, @@ -148,13 +155,9 @@ def import_exported_program( module_builder: ModuleBuilder, exported_program: torch.export.ExportedProgram, symbol_name: str, - symbol_visibility: str, + symbol_visibility: Optional[str], ) -> ExportedProgramIntrinsic: - fx_importer = FxImporter( - module_op=module_builder.module_op, - config_check=False, - py_attr_tracker=module_builder.fx_py_attr_tracker, - ) + fx_importer = _create_fx_importer(module_builder) entry_func_op = fx_importer.import_program( exported_program, func_name=symbol_name, func_visibility=symbol_visibility ) @@ -186,3 +189,83 @@ def import_exported_program( return ExportedProgramIntrinsic( entry_func_op, entry_module_call_entry.signature, user_output_dtypes ) + + +class _Hooks(FxImporterHooks): + def __init__(self, module_builder: ModuleBuilder): + self.module_builder = module_builder + + def resolve_literal(self, gni: GraphNodeImporter, literal: Any) -> Optional[Value]: + module_builder = self.module_builder + + # We support resolution of tracked reference types. Currently this + # only includes Tensors. All others we let the importer do what it + # is going to do. + if not isinstance(literal, torch.Tensor): + return None + + # See if we know about it. + mapping = module_builder.global_ref_tracker.track(literal) + if mapping.is_empty: + # If it is unknown, just let the default importer take it on. + return None + + # Already materialized. + logger.debug("Resolved defined global for literal %r", mapping) + materialized_global: MaterializedGlobal = mapping.value # type: ignore + + # Emit a global load and conversion. + vtensor_type = gni._cc.tensor_to_vtensor_type(literal) + loaded_value = util_d.GlobalLoadOp( + materialized_global.ir_type, materialized_global.symbol_name + ).result + converted_value = Operation.create( + "torch_c.from_builtin_tensor", + results=[vtensor_type], + operands=[loaded_value], + ).result + return converted_value + + +# In https://github.com/llvm/torch-mlir/pull/3046, the FxImporter was +# extended to accept a "module_op" as an Operation (vs a Module). Switch for +# compatibility. +_fx_importer_accepts_module_op = ( + "module_op" in inspect.getfullargspec(FxImporter).kwonlyargs +) + + +def _create_fx_importer(module_builder: ModuleBuilder) -> FxImporter: + hooks = _Hooks(module_builder) + if _fx_importer_accepts_module_op: + # New path. + return FxImporter( + module_op=module_builder.module_op, + config_check=False, + py_attr_tracker=module_builder.fx_py_attr_tracker, + hooks=hooks, + ) + else: + # Legacy path. + class FakeModule: + def __init__(self, op): + self._op = module_builder.module_op + + @property + def context(self): + return self._op.context + + @property + def operation(self): + return self._op + + @property + def body(self): + return self._op.regions[0].blocks[0] + + return FxImporter( + module=FakeModule(module_builder.module_op), + config_check=False, + py_attr_tracker=module_builder.fx_py_attr_tracker, + hooks=hooks, + ) diff --git a/core/tests/aot/api_test.py b/core/tests/aot/api_test.py index ef13738ac..2bf6afabd 100644 --- a/core/tests/aot/api_test.py +++ b/core/tests/aot/api_test.py @@ -14,6 +14,7 @@ from shark_turbine.aot import * import torch +import torch.nn as nn class GeneralAPI(unittest.TestCase): @@ -71,6 +72,55 @@ def foobar(self): print(module_str) +class ExportAPI(unittest.TestCase): + def testStaticNNModule(self): + mdl = SimpleParams() + exported = export(mdl, args=(torch.empty([128, 20]),)) + exported.print_readable() + asm = str(exported.mlir_module) + self.assertIn("dense_resource", asm) + + def testDynamicNNModule(self): + mdl = SimpleParams() + batch = torch.export.Dim("batch") + exported = export( + mdl, args=(torch.empty([128, 20]),), dynamic_shapes={"x": {0: batch}} + ) + exported.print_readable() + asm = str(exported.mlir_module) + self.assertIn( + "func.func @main(%arg0: !torch.vtensor<[?,20],f32>) -> !torch.vtensor<[?,30],f32>", + asm, + ) + + def testExternalParamsNNModule(self): + mdl = SimpleParams() + exported = export(mdl, args=(torch.empty([128, 20]),), external_params=True) + exported.print_readable() + asm = str(exported.mlir_module) + self.assertNotIn("dense_resource", asm) + self.assertIn("util.global.load", asm) + + def testTorchExportedProgram(self): + mdl = SimpleParams() + prg = torch.export.export(mdl, args=(torch.empty([128, 20]),)) + exported = export(prg, external_params=True) + exported.print_readable() + asm = str(exported.mlir_module) + self.assertNotIn("dense_resource", asm) + self.assertIn("util.global private @_params.classifier.weight", asm) + self.assertIn("util.global private @_params.classifier.bias", asm) + + +class SimpleParams(nn.Module): + def __init__(self): + super().__init__() + self.classifier = nn.Linear(20, 30) + + def forward(self, x): + return self.classifier(x) + + if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main() diff --git a/core/tests/aot/compiled_exported_program_test.py b/core/tests/aot/compiled_exported_program_test.py new file mode 100644 index 000000000..b8a4d1ac3 --- /dev/null +++ b/core/tests/aot/compiled_exported_program_test.py @@ -0,0 +1,138 @@ +# Copyright 2023 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import unittest + +import torch +import torch.nn as nn + +from iree.compiler.ir import ( + Context, +) + +from shark_turbine.aot import * +from shark_turbine.aot.builtins import * + + +class TorchExportTests(unittest.TestCase): + def testImportPhases(self): + class MyModule(torch.nn.Module): + def forward(self): + ... + + fxb = FxProgramsBuilder(MyModule()) + + @fxb.export_program( + args=([torch.empty([3, 2]), torch.empty([1, 2])],), + kwargs={"foobar": torch.empty([3, 1])}, + ) + def compute(module, inputs, *, foobar): + t1 = inputs[0] + t2 = inputs[1] + t3 = t1 + t2 + foobar + return [t3 * t3, foobar] + + class ExportedProcModule(CompiledModule): + _compute = compute + + def foobar( + self, + t1=AbstractTensor(3, 2), + t2=AbstractTensor(1, 2), + t3=AbstractTensor(3, 1), + ): + return self._compute(t1, t2, foobar=t3) + + inst = ExportedProcModule(context=Context(), import_to="import") + module_str = str(CompiledModule.get_mlir_module(inst)) + print(module_str) + self.assertIn("func.func private @_compute", module_str) + self.assertIn("func.func @foobar", module_str) + + def testMultiPublic(self): + class MyModule(torch.nn.Module): + def forward(self): + ... + + fxb = FxProgramsBuilder(MyModule()) + + @fxb.export_program( + args=([torch.empty([3, 2]), torch.empty([1, 2])],), + kwargs={"foobar": torch.empty([3, 1])}, + ) + def _compute1(module, inputs, *, foobar): + t1 = inputs[0] + t2 = inputs[1] + t3 = t1 + t2 + foobar + return [t3 * t3, foobar] + + @fxb.export_program( + args=([torch.empty([5]), torch.empty([5])],), + kwargs={"foobar": torch.empty([5])}, + ) + def _compute2(module, inputs, *, foobar): + t1 = inputs[0] + t2 = inputs[1] + t3 = t1 + t2 + foobar + return [t3 * t3, foobar] + + class ExportedPublicModule(CompiledModule): + compute1 = _compute1 + compute2 = _compute2 + + inst = ExportedPublicModule(context=Context(), import_to="import") + module_str = str(CompiledModule.get_mlir_module(inst)) + print(module_str) + self.assertIn("func.func @compute1", module_str) + self.assertIn("func.func @compute2", module_str) + + def testParametersAsGlobals(self): + fxb = FxProgramsBuilder(SimpleParams()) + + @fxb.export_program( + args=(torch.empty([128, 20]),), + ) + def _compute1(module, x): + return module.forward(x) + + class ParamsAsGlobalsModule(CompiledModule): + params = export_parameters(fxb.root_module) + compute1 = _compute1 + compute2 = _compute1 + + inst = ParamsAsGlobalsModule(context=Context(), import_to="import") + module_str = str(CompiledModule.get_mlir_module(inst)) + print(module_str) + self.assertIn( + "util.global private @_params.classifier.weight {noinline}", module_str + ) + self.assertIn( + "util.global private @_params.classifier.bias {noinline}", module_str + ) + # Should only be two. + self.assertEqual(2, module_str.count("util.global private")) + # And two loads each loads. + self.assertEqual( + 2, module_str.count("util.global.load @_params.classifier.weight") + ) + self.assertEqual( + 2, module_str.count("util.global.load @_params.classifier.bias") + ) + + +class SimpleParams(nn.Module): + def __init__(self): + super().__init__() + self.classifier = nn.Linear(20, 30) + + def forward(self, x): + return self.classifier(x) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main() diff --git a/core/tests/aot/jittable_test.py b/core/tests/aot/jittable_test.py index 0b3cabfa8..6419c0bd4 100644 --- a/core/tests/aot/jittable_test.py +++ b/core/tests/aot/jittable_test.py @@ -73,7 +73,7 @@ def compute(*, a, b): print(module_str) def testDynamicDims(self): - class ProcArgsModule(CompiledModule): + class DynamicDimsModule(CompiledModule): def dynamic_dim(self, a=AbstractTensor(None, 2), b=AbstractTensor(None, 1)): return self.compute( a, @@ -87,7 +87,7 @@ def dynamic_dim(self, a=AbstractTensor(None, 2), b=AbstractTensor(None, 1)): def compute(a, b): return a * b - inst = ProcArgsModule(context=Context(), import_to=None) + inst = DynamicDimsModule(context=Context(), import_to=None) module_str = str(CompiledModule.get_mlir_module(inst)) print(module_str) diff --git a/core/tests/aot/torch_export_test.py b/core/tests/aot/torch_export_test.py deleted file mode 100644 index bfaa5b6a8..000000000 --- a/core/tests/aot/torch_export_test.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import logging -import unittest - -import torch - -from iree.compiler.ir import ( - Context, -) - -from shark_turbine.aot import * -from shark_turbine.aot.builtins import * - - -class TorchExportTests(unittest.TestCase): - def testImportPhases(self): - class MyModule(torch.nn.Module): - def forward(self): - ... - - fxb = FxProgramsBuilder(MyModule()) - - @fxb.export_program( - args=([torch.empty([3, 2]), torch.empty([1, 2])],), - kwargs={"foobar": torch.empty([3, 1])}, - ) - def compute(module, inputs, *, foobar): - t1 = inputs[0] - t2 = inputs[1] - t3 = t1 + t2 + foobar - return [t3 * t3, foobar] - - class ExportedProcModule(CompiledModule): - _compute = compute - - def foobar( - self, - t1=AbstractTensor(3, 2), - t2=AbstractTensor(1, 2), - t3=AbstractTensor(3, 1), - ): - return self._compute(t1, t2, foobar=t3) - - inst = ExportedProcModule(context=Context(), import_to="import") - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() From 2bbb15a3fd74faeabcfa30bec6e1f46cb58076fe Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Mon, 25 Mar 2024 18:24:49 -0700 Subject: [PATCH 04/18] Remove refs to torchvision --- .github/workflows/test.yml | 3 +-- .github/workflows/test_models.yml | 3 +-- .github/workflows/test_sdxl.yml | 3 +-- MANIFEST.in | 1 - models/turbine_models/custom_models/README.md | 5 ++--- 5 files changed, 5 insertions(+), 10 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index cf550a1b1..413ac77d2 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -40,8 +40,7 @@ jobs: # from non default locations first. Installing the PyTorch CPU # wheels saves multiple minutes and a lot of bandwidth on runner setup. pip install --index-url https://download.pytorch.org/whl/cpu \ - -r core/pytorch-cpu-requirements.txt \ - -r core/torchvision-requirements.txt + -r core/pytorch-cpu-requirements.txt pip install --upgrade \ -r core/requirements.txt \ -r mypy-requirements.txt diff --git a/.github/workflows/test_models.yml b/.github/workflows/test_models.yml index 18ba9ac73..08eec7033 100644 --- a/.github/workflows/test_models.yml +++ b/.github/workflows/test_models.yml @@ -39,8 +39,7 @@ jobs: # from non default locations first. Installing the PyTorch CPU # wheels saves multiple minutes and a lot of bandwidth on runner setup. pip install --index-url https://download.pytorch.org/whl/cpu \ - -r core/pytorch-cpu-requirements.txt \ - -r core/torchvision-requirements.txt + -r core/pytorch-cpu-requirements.txt pip install --upgrade -r core/requirements.txt pip install -e core[testing] pip install -e models diff --git a/.github/workflows/test_sdxl.yml b/.github/workflows/test_sdxl.yml index 5babfcbe1..5b60acc07 100644 --- a/.github/workflows/test_sdxl.yml +++ b/.github/workflows/test_sdxl.yml @@ -31,8 +31,7 @@ jobs: # from non default locations first. Installing the PyTorch CPU # wheels saves multiple minutes and a lot of bandwidth on runner setup. pip install --index-url https://download.pytorch.org/whl/cpu \ - -r core/pytorch-cpu-requirements.txt \ - -r core/torchvision-requirements.txt + -r core/pytorch-cpu-requirements.txt pip install --upgrade -r core/requirements.txt pip install -e core[testing,torch-cpu-nightly] pip install --upgrade -r models/requirements.txt diff --git a/MANIFEST.in b/MANIFEST.in index faa55e3f7..1ea7c0669 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,5 +1,4 @@ include README.md include requirements.txt include pytorch-cpu-requirements.txt -include torchvision-requirements.txt include version_info.json diff --git a/models/turbine_models/custom_models/README.md b/models/turbine_models/custom_models/README.md index 98aa347b1..d56214257 100644 --- a/models/turbine_models/custom_models/README.md +++ b/models/turbine_models/custom_models/README.md @@ -7,8 +7,7 @@ cd SHARK-Turbine python -m venv turbine_venv && source turbine_venv/bin/activate pip install --index-url https://download.pytorch.org/whl/cpu \ - -r core/pytorch-cpu-requirements.txt \ - -r core/torchvision-requirements.txt + -r core/pytorch-cpu-requirements.txt pip install --upgrade -r core/requirements.txt pip install -e core pip install -e models @@ -39,4 +38,4 @@ python models/turbine_models/custom_models/llm_runner.py --vmfb_path=/path/to/Ll 2) Interactive CLI chat mode. (just add a --chat_mode flag) ``` python models/turbine_models/custom_models/llm_runner.py --vmfb_path=/path/to/Llama_2_7b_chat_hf.vmfb --external_weight_path=Llama_2_7b_chat_hf_f16_int4.safetensors --device=vulkan hf_auth_token=your_hf_token --chat_mode -``` \ No newline at end of file +``` From 2eae1593fc982e59cebe9b681605a62201f706cc Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Mon, 25 Mar 2024 18:27:25 -0700 Subject: [PATCH 05/18] Make mypy happy. --- core/shark_turbine/aot/builtins/jittable.py | 2 +- core/shark_turbine/aot/compiled_module.py | 2 +- core/shark_turbine/aot/exporter.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/core/shark_turbine/aot/builtins/jittable.py b/core/shark_turbine/aot/builtins/jittable.py index 380339acd..5fcd3fcb2 100644 --- a/core/shark_turbine/aot/builtins/jittable.py +++ b/core/shark_turbine/aot/builtins/jittable.py @@ -228,7 +228,7 @@ def flat_wrapped_f(*args): aten_graph=True, decomposition_table=self.decomposition_table, assume_static_by_default=True, - **export_kwargs, + **export_kwargs, # type: ignore ) logger.debug("Invoking dynamo trace") gm, guards = exported_f(*flat_pytorch_args) diff --git a/core/shark_turbine/aot/compiled_module.py b/core/shark_turbine/aot/compiled_module.py index 72520a0a6..aa8e687c4 100644 --- a/core/shark_turbine/aot/compiled_module.py +++ b/core/shark_turbine/aot/compiled_module.py @@ -595,7 +595,7 @@ def __new__( info.shadow_dict[key] = import_exported_program( module_builder, ep_def.exported_program, - symbol_name=ep_def.export_name, + symbol_name=ep_def.export_name or "main", symbol_visibility=None if ep_def.public else "private", ) diff --git a/core/shark_turbine/aot/exporter.py b/core/shark_turbine/aot/exporter.py index 8e8025126..509f584a1 100644 --- a/core/shark_turbine/aot/exporter.py +++ b/core/shark_turbine/aot/exporter.py @@ -191,13 +191,13 @@ def export( "args, example_args, kwargs, or dynamic_dims" ) - class Exported(CompiledModule, export_name=mdl.graph_module._get_name()): + class EpExported(CompiledModule, export_name=mdl.graph_module._get_name()): params = export_global_tree( dict(list(mdl.named_parameters())), external=external_params ) main = mdl - TransformedModule = Exported + TransformedModule = EpExported elif isinstance(mdl, torch.nn.Module): # Normalize arguments for torch.export. if args is None: From f6d549ee2c65972cdc13cc3288139149250ed9bb Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Mon, 25 Mar 2024 18:35:47 -0700 Subject: [PATCH 06/18] Try to fix requirements --- .github/workflows/test.yml | 3 +-- core/pytorch-cpu-requirements.txt | 4 +--- core/pytorch-requirements.txt | 3 +++ core/requirements.txt | 2 +- .../aot/support/procedural/exported_program.py | 6 ++++++ 5 files changed, 12 insertions(+), 6 deletions(-) create mode 100644 core/pytorch-requirements.txt diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 413ac77d2..aa8d7a152 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -39,8 +39,7 @@ jobs: # Note: We install in three steps in order to satisfy requirements # from non default locations first. Installing the PyTorch CPU # wheels saves multiple minutes and a lot of bandwidth on runner setup. - pip install --index-url https://download.pytorch.org/whl/cpu \ - -r core/pytorch-cpu-requirements.txt + pip install -r core/pytorch-cpu-requirements.txt pip install --upgrade \ -r core/requirements.txt \ -r mypy-requirements.txt diff --git a/core/pytorch-cpu-requirements.txt b/core/pytorch-cpu-requirements.txt index 20aa5da87..e4fa5c795 100644 --- a/core/pytorch-cpu-requirements.txt +++ b/core/pytorch-cpu-requirements.txt @@ -1,5 +1,3 @@ --pre --index-url https://download.pytorch.org/whl/test/cpu -torch>=2.3.0 -torchaudio -torchvision +-r pytorch-requirements.txt diff --git a/core/pytorch-requirements.txt b/core/pytorch-requirements.txt new file mode 100644 index 000000000..78ad2577d --- /dev/null +++ b/core/pytorch-requirements.txt @@ -0,0 +1,3 @@ +torch>=2.3.0 +torchaudio +torchvision diff --git a/core/requirements.txt b/core/requirements.txt index 09ee8e38e..3265a2b99 100644 --- a/core/requirements.txt +++ b/core/requirements.txt @@ -4,5 +4,5 @@ # versions, not specific). -f https://openxla.github.io/iree/pip-release-links.html --r pytorch-cpu-requirements.txt +-r pytorch-requirements.txt -r iree-requirements.txt diff --git a/core/shark_turbine/aot/support/procedural/exported_program.py b/core/shark_turbine/aot/support/procedural/exported_program.py index 12e325fd6..9d74e2c26 100644 --- a/core/shark_turbine/aot/support/procedural/exported_program.py +++ b/core/shark_turbine/aot/support/procedural/exported_program.py @@ -17,6 +17,12 @@ treespec_pprint, ) +try: + from torch.utils._pytree import treespec_pprint +except ImportError: + # torch < 2.3 does not include this. + treespec_pprint = lambda x: repr(x) # type: ignore + from iree.compiler.extras.fx_importer import ( FxImporter, FxImporterHooks, From 5374ba1cc2f36d45125e45e1ac4f7e1d821c156f Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Mon, 25 Mar 2024 18:39:16 -0700 Subject: [PATCH 07/18] Fix again --- .github/workflows/test.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index aa8d7a152..b38c210e7 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -42,7 +42,8 @@ jobs: pip install -r core/pytorch-cpu-requirements.txt pip install --upgrade \ -r core/requirements.txt \ - -r mypy-requirements.txt + -r mypy-requirements.txt \ + -r serving/requirements.txt pip install -e core[testing] -e serving[testing] - name: Run core tests From 0e36850e10d50ae65988034feddd51f823e57f13 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Mon, 25 Mar 2024 18:41:55 -0700 Subject: [PATCH 08/18] Add requests --- serving/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/serving/requirements.txt b/serving/requirements.txt index 3c9503df4..3cb469b2c 100644 --- a/serving/requirements.txt +++ b/serving/requirements.txt @@ -1,2 +1,3 @@ fastapi>=0.109.2 uvicorn>=0.27.0 +requests From a9b06ab6a87f3a0ff0fdce1d834cfb593aa1dcbb Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Mon, 25 Mar 2024 18:49:13 -0700 Subject: [PATCH 09/18] More fixing --- core/shark_turbine/aot/support/procedural/exported_program.py | 1 - 1 file changed, 1 deletion(-) diff --git a/core/shark_turbine/aot/support/procedural/exported_program.py b/core/shark_turbine/aot/support/procedural/exported_program.py index 9d74e2c26..b26f20d2c 100644 --- a/core/shark_turbine/aot/support/procedural/exported_program.py +++ b/core/shark_turbine/aot/support/procedural/exported_program.py @@ -14,7 +14,6 @@ from torch.utils._pytree import ( tree_flatten, tree_unflatten, - treespec_pprint, ) try: From 1be25550192753901cf464f921cf26281525fe7c Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Tue, 26 Mar 2024 10:46:56 -0700 Subject: [PATCH 10/18] Bump iree to 20240326.843 --- core/iree-requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/iree-requirements.txt b/core/iree-requirements.txt index 23866262c..7cb60d71c 100644 --- a/core/iree-requirements.txt +++ b/core/iree-requirements.txt @@ -1,2 +1,2 @@ -iree-compiler==20240311.828 -iree-runtime==20240311.828 +iree-compiler==20240326.843 +iree-runtime==20240326.843 From 11dca55c5658eee99ed722bf98254098d180e594 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Tue, 26 Mar 2024 10:56:04 -0700 Subject: [PATCH 11/18] Address comments --- core/shark_turbine/aot/builtins/jittable.py | 2 +- .../aot/support/procedural/exported_program.py | 8 ++++++-- core/tests/aot/compiled_exported_program_test.py | 2 +- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/core/shark_turbine/aot/builtins/jittable.py b/core/shark_turbine/aot/builtins/jittable.py index 5fcd3fcb2..06d26dd35 100644 --- a/core/shark_turbine/aot/builtins/jittable.py +++ b/core/shark_turbine/aot/builtins/jittable.py @@ -190,7 +190,7 @@ def resolve_call( "Compiling program with the old PyTorch constraints system " "for dynamic shapes is deprecated and will break on PyTorch " "nightlies after the 2.3 release cut (expect either a PyTorch " - "warning or excpetion to follow)", + "warning or exception to follow)", DeprecationWarning, ) export_kwargs["constraints"] = constraints diff --git a/core/shark_turbine/aot/support/procedural/exported_program.py b/core/shark_turbine/aot/support/procedural/exported_program.py index b26f20d2c..4fd0c166c 100644 --- a/core/shark_turbine/aot/support/procedural/exported_program.py +++ b/core/shark_turbine/aot/support/procedural/exported_program.py @@ -96,10 +96,11 @@ def resolve_call( f"{self.function_symbol} is {visibility}" ) - # Flatten and convert py args to torch IR values. + # Flatten and convert py args to torch IR values by converting to + # the canonical tree structure for args + # (tuple of list of args, dict of kwargs). flat_py_args, args_tree = tree_flatten(((list(py_args),), py_kwargs)) if args_tree != self.entry_sig.in_spec: - print(dir(args_tree)) raise ValueError( f"Mismatched arguments to exported program. \n" f" Got: {treespec_pprint(args_tree)}\n" @@ -176,6 +177,9 @@ def import_exported_program( # We want additional torch-level metadata about any user outputs. # This will help us create a true python fake without loss of information. + # TODO: It is unclear how much switchiness is actually needed here as + # modern use is pretty constrained. Potentially streamline the body of + # the for loop once done with full test cases available. user_output_dtypes: list[Optional[torch.dtype]] = [] node_map: Dict[str, torch.fx.Node] = { n.name: n for n in exported_program.graph.nodes diff --git a/core/tests/aot/compiled_exported_program_test.py b/core/tests/aot/compiled_exported_program_test.py index b8a4d1ac3..0f79111c8 100644 --- a/core/tests/aot/compiled_exported_program_test.py +++ b/core/tests/aot/compiled_exported_program_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 Nod Labs, Inc +# Copyright 2024 Advanced Micro Devices, Inc. # # Licensed under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. From eac18580343d285d1cafcd443c26d8f71ef6b8ce Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Tue, 26 Mar 2024 17:18:56 -0700 Subject: [PATCH 12/18] Update to the new inplace support in torch and support functionalization of trace. --- core/shark_turbine/aot/exporter.py | 1 + core/shark_turbine/ops/iree.py | 24 ++++--------------- core/shark_turbine/runtime/op_reg/base.py | 19 ++++++++++++--- .../transforms/general/custom_op_expansion.py | 9 ++++++- core/tests/runtime/op_reg/kernel_aot_test.py | 1 + 5 files changed, 31 insertions(+), 23 deletions(-) diff --git a/core/shark_turbine/aot/exporter.py b/core/shark_turbine/aot/exporter.py index 509f584a1..92415075a 100644 --- a/core/shark_turbine/aot/exporter.py +++ b/core/shark_turbine/aot/exporter.py @@ -210,6 +210,7 @@ class EpExported(CompiledModule, export_name=mdl.graph_module._get_name()): exported_program = torch.export.export( nn_module, args=args, kwargs=kwargs, dynamic_shapes=dynamic_shapes ) + print("EP:", exported_program) class Exported(CompiledModule, export_name=nn_module._get_name()): params = export_parameters(nn_module, external=external_params) diff --git a/core/shark_turbine/ops/iree.py b/core/shark_turbine/ops/iree.py index 093c6c77e..0acba5c41 100644 --- a/core/shark_turbine/ops/iree.py +++ b/core/shark_turbine/ops/iree.py @@ -50,29 +50,15 @@ def _emit_tensor_trace(kb: KernelBuilder, key: str, ts: list[Value]): @CustomOp.register(library=IREE_LIBRARY) class trace_tensor(CustomOp): - signature = "trace_tensor(str trace_key, Tensor tensor) -> ()" + signature = "trace_tensor(str trace_key, Tensor(a!) tensor) -> ()" def select(self, ksel: KernelSelection): ksel.attr_str(0) - ksel.arg_tensor(1) + ksel.arg_tensor(1, inplace_tied=True) + print("TRACE_TENSOR_SELECT:", ksel) def generate(self, ksel: KernelSelection, kb: KernelBuilder): + print("TRACE_TENSOR_GENERATE:", ksel) key = cast(AttrArg, ksel.arg_descs[0]) _emit_tensor_trace(kb, cast(str, key.v), [kb.arg_bindings[1]]) - kb.yield_results() - - -@CustomOp.register(library=IREE_LIBRARY) -class trace_tensors(CustomOp): - signature = "trace_tensors(str trace_key, Tensor[] tensors) -> ()" - - def select(self, ksel: KernelSelection): - ksel.attr_str(0) - ksel.arg_tensor_list(1) - - def generate(self, ksel: KernelSelection, kb: KernelBuilder): - key = cast(AttrArg, ksel.arg_descs[0]) - ts = kb.arg_bindings[1] - if len(ts) >= 1: - _emit_tensor_trace(kb, cast(str, key.v), ts) - kb.yield_results() + kb.yield_results(kb.arg_bindings[1]) diff --git a/core/shark_turbine/runtime/op_reg/base.py b/core/shark_turbine/runtime/op_reg/base.py index e7fc20338..3e2b84992 100644 --- a/core/shark_turbine/runtime/op_reg/base.py +++ b/core/shark_turbine/runtime/op_reg/base.py @@ -239,6 +239,7 @@ class KernelSelection(ABC): __slots__ = [ "arg_descs", + "inplace_tied_arg_descs", "op", "result_descs", "variant", @@ -247,6 +248,7 @@ class KernelSelection(ABC): def __init__(self, op: CustomOp, arg_arity: int): self.op = op self.arg_descs = cast(list[Optional[ArgDescriptor]], arg_arity * [None]) + self.inplace_tied_arg_descs: list[ArgDescriptor] = [] self.result_descs: list[ArgDescriptor] = [] self.variant: str = "default" @@ -295,12 +297,16 @@ def spec_key(self) -> str: ) from e @abstractmethod - def arg_tensor(self, arg: int) -> "TensorArg": + def arg_tensor(self, arg: int, *, inplace_tied: bool = False) -> "TensorArg": """Declares an argument to allow any ranked tensor and to specialize for each rank and dtype. Returns the argument descriptor, which can be used to further inspect or constrain the selection. It will default to allowing all dimensions to be dynamic. + + If inplace_tied is True, then this argument participates in in-place + semantics. The kernel must yield the result-mutated after all normal + results in the order declared. """ ... @@ -354,7 +360,7 @@ def __init__(self, op: CustomOp, args: list[Any]): super().__init__(op, len(args)) self.args = args - def arg_tensor(self, arg: int) -> "TensorArg": + def arg_tensor(self, arg: int, *, inplace_tied: bool = False) -> "TensorArg": arg_descs = self.arg_descs arg_value = self.args[arg] assert arg_descs[arg] is None, f"Already constrained argument {arg}" @@ -362,6 +368,8 @@ def arg_tensor(self, arg: int) -> "TensorArg": arg_value, Tensor ), f"Argument type mismatch from Torch for {arg}: Expected tensor, got {type(arg_value)}" arg_descs[arg] = desc = TensorArg(arg_value) + if inplace_tied: + self.inplace_tied_arg_descs.append(desc) return desc def arg_tensor_list(self, arg: int) -> "TensorListArg": @@ -676,7 +684,7 @@ def __init__( # Assemble result types. result_types = [] - for d in ksel.result_descs: + for d in (*ksel.result_descs, *ksel.inplace_tied_arg_descs): if not d.is_list: if d.ir_arity == 1: result_types.append(IrType.parse(d.mlir_type_asm)) @@ -744,6 +752,11 @@ def create_module( def yield_results(self, *results: Value): """Yields results of the kernel computation.""" assert not self.yielded, "yield_results has already been called" + ksel = self.ksel + expected_count = len(ksel.result_descs) + len(ksel.inplace_tied_arg_descs) + assert ( + len(results) == expected_count + ), f"Mismatched yielded results and declared+inplace: Expected={expected_count}, Got={len(results)}" with self.ip, Location.unknown(): func_d.ReturnOp(results) self.yielded = True diff --git a/core/shark_turbine/transforms/general/custom_op_expansion.py b/core/shark_turbine/transforms/general/custom_op_expansion.py index dae04d905..0a191dc2a 100644 --- a/core/shark_turbine/transforms/general/custom_op_expansion.py +++ b/core/shark_turbine/transforms/general/custom_op_expansion.py @@ -124,7 +124,7 @@ def __init__( self.results = results self.type_converter = type_converter - def arg_tensor(self, arg: int) -> TensorArg: + def arg_tensor(self, arg: int, *, inplace_tied: bool = False) -> TensorArg: # This is annoying: We have to go from the Torch MLIR type system to the # original torch.tensor Python type system. We do this by way of the native # type converter because it has the mapping pathway we need. This is one of the @@ -154,6 +154,8 @@ def arg_tensor(self, arg: int) -> TensorArg: ) t = torch.empty(rtt.shape, dtype=dtype, device="meta") arg_descs[arg] = desc = TensorArg(t) + if inplace_tied: + self.inplace_tied_arg_descs.append(desc) return desc def arg_tensor_list(self, arg: int) -> TensorListArg: @@ -235,6 +237,11 @@ def __init__( def yield_results(self, *results: Value): """Yields results of the kernel computation.""" assert not self.yielded, "yield_results has already been called" + ksel = self.ksel + expected_count = len(ksel.result_descs) + len(ksel.inplace_tied_arg_descs) + assert ( + len(results) == expected_count + ), f"Mismatched yielded results and declared+inplace: Expected={expected_count}, Got={len(results)}" with self.ip, self.location: torch_op_results: list[Value] = list(self.torch_op.results) assert len(results) == len( diff --git a/core/tests/runtime/op_reg/kernel_aot_test.py b/core/tests/runtime/op_reg/kernel_aot_test.py index 48c7f59f1..0d31edbd3 100644 --- a/core/tests/runtime/op_reg/kernel_aot_test.py +++ b/core/tests/runtime/op_reg/kernel_aot_test.py @@ -49,6 +49,7 @@ def testTrace(self): print("CUSTOM OP CONVERTED:") module_asm = str(prog.mlir_module) + print(module_asm) self.assertIn('flow.tensor.trace "LAYER0"', module_asm) self.assertIn('flow.tensor.trace "LAYER1"', module_asm) self.assertIn('flow.tensor.trace "LAYER3"', module_asm) From f52b334f55bfc44063a10535591f41757f0b5b99 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Tue, 26 Mar 2024 17:20:47 -0700 Subject: [PATCH 13/18] Try to fix requirements --- .github/workflows/test_models.yml | 9 ++++----- core/pytorch-requirements.txt | 2 +- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/.github/workflows/test_models.yml b/.github/workflows/test_models.yml index 08eec7033..abdf8f17b 100644 --- a/.github/workflows/test_models.yml +++ b/.github/workflows/test_models.yml @@ -38,11 +38,10 @@ jobs: # Note: We install in three steps in order to satisfy requirements # from non default locations first. Installing the PyTorch CPU # wheels saves multiple minutes and a lot of bandwidth on runner setup. - pip install --index-url https://download.pytorch.org/whl/cpu \ - -r core/pytorch-cpu-requirements.txt - pip install --upgrade -r core/requirements.txt - pip install -e core[testing] - pip install -e models + pip install -r core/pytorch-cpu-requirements.txt + pip install --pre --upgrade -r core/requirements.txt + pip install --pre -e core[testing] + pip install --pre -e models - name: Show current free memory run: | diff --git a/core/pytorch-requirements.txt b/core/pytorch-requirements.txt index 78ad2577d..63fc21602 100644 --- a/core/pytorch-requirements.txt +++ b/core/pytorch-requirements.txt @@ -1,3 +1,3 @@ -torch>=2.3.0 +torch==2.3.0 torchaudio torchvision From 917a712a58d4c7e7ed54c4448af311d60c2d5901 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Tue, 26 Mar 2024 17:22:28 -0700 Subject: [PATCH 14/18] Remove stray print --- core/shark_turbine/aot/exporter.py | 1 - 1 file changed, 1 deletion(-) diff --git a/core/shark_turbine/aot/exporter.py b/core/shark_turbine/aot/exporter.py index 92415075a..509f584a1 100644 --- a/core/shark_turbine/aot/exporter.py +++ b/core/shark_turbine/aot/exporter.py @@ -210,7 +210,6 @@ class EpExported(CompiledModule, export_name=mdl.graph_module._get_name()): exported_program = torch.export.export( nn_module, args=args, kwargs=kwargs, dynamic_shapes=dynamic_shapes ) - print("EP:", exported_program) class Exported(CompiledModule, export_name=nn_module._get_name()): params = export_parameters(nn_module, external=external_params) From 7942ca7e81648ef6195ec5828d18d08c117c474c Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Tue, 26 Mar 2024 17:23:04 -0700 Subject: [PATCH 15/18] Remove prints --- core/shark_turbine/ops/iree.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/core/shark_turbine/ops/iree.py b/core/shark_turbine/ops/iree.py index 0acba5c41..e28826db8 100644 --- a/core/shark_turbine/ops/iree.py +++ b/core/shark_turbine/ops/iree.py @@ -55,10 +55,8 @@ class trace_tensor(CustomOp): def select(self, ksel: KernelSelection): ksel.attr_str(0) ksel.arg_tensor(1, inplace_tied=True) - print("TRACE_TENSOR_SELECT:", ksel) def generate(self, ksel: KernelSelection, kb: KernelBuilder): - print("TRACE_TENSOR_GENERATE:", ksel) key = cast(AttrArg, ksel.arg_descs[0]) _emit_tensor_trace(kb, cast(str, key.v), [kb.arg_bindings[1]]) kb.yield_results(kb.arg_bindings[1]) From 66ff40f0ca5471e2e5e17571921eb9764c429868 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Wed, 27 Mar 2024 11:01:52 -0700 Subject: [PATCH 16/18] Drop test for ops.iree.trace_tensors (plural - not supported) --- core/tests/ops/iree_test.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/core/tests/ops/iree_test.py b/core/tests/ops/iree_test.py index f10643026..b41647d65 100644 --- a/core/tests/ops/iree_test.py +++ b/core/tests/ops/iree_test.py @@ -17,12 +17,6 @@ def testTrace(self): t = torch.randn(3, 4) ops.iree.trace_tensor("TEST", t) - def testTraceList(self): - t1 = torch.randn(3, 4) - t2 = torch.randn(1, 8) - ops.iree.trace_tensors("TEST 2", [t1, t2]) - ops.iree.trace_tensors("TEST 1", [t1]) - if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) From aa63268a7c346cb8485e659ebe403855efcaf67a Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Wed, 27 Mar 2024 11:03:30 -0700 Subject: [PATCH 17/18] Bump IREE to 20240327.844 --- core/iree-requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/iree-requirements.txt b/core/iree-requirements.txt index 7cb60d71c..9d22d2559 100644 --- a/core/iree-requirements.txt +++ b/core/iree-requirements.txt @@ -1,2 +1,2 @@ -iree-compiler==20240326.843 -iree-runtime==20240326.843 +iree-compiler==20240327.844 +iree-runtime==20240327.844 From 2546a0773a2b10d2357e55dd0c53ba24c2e0483e Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Wed, 27 Mar 2024 11:09:05 -0700 Subject: [PATCH 18/18] Mark test as expectedFailure (#560) --- models/turbine_models/tests/stateless_llama_test.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/models/turbine_models/tests/stateless_llama_test.py b/models/turbine_models/tests/stateless_llama_test.py index be975cd8c..83bbbafc8 100644 --- a/models/turbine_models/tests/stateless_llama_test.py +++ b/models/turbine_models/tests/stateless_llama_test.py @@ -183,6 +183,9 @@ def test_streaming_vmfb_comparison(self): ) check_output_string(torch_str, turbine_str) + # See: https://github.com/nod-ai/SHARK-Turbine/issues/560 + # Developed issues related to the pytorch 2.3 upgrade. + @unittest.expectedFailure def test_rerotated_torch_comparison(self): torch_str = llm_runner.run_torch_llm( "Trelis/Llama-2-7b-chat-hf-function-calling-v2",