diff --git a/examples/aot_mlp/mlp_export_simple.py b/examples/aot_mlp/mlp_export_simple.py index fed4795d4..5f63251fe 100644 --- a/examples/aot_mlp/mlp_export_simple.py +++ b/examples/aot_mlp/mlp_export_simple.py @@ -35,14 +35,14 @@ def forward(self, x: torch.Tensor): example_x = torch.empty(97, 8, dtype=torch.float32) exported = aot.export(model, example_x) exported.print_readable() -compiled_binary = exported.compile(save_to=None) +compiled_binary = exported.compile(save_to=None, target_backends=("rocm")) def infer(): import numpy as np import iree.runtime as rt - config = rt.Config("local-task") + config = rt.Config("rocm") vmm = rt.load_vm_module( rt.VmModule.wrap_buffer(config.vm_instance, compiled_binary.map_memory()), config, diff --git a/examples/eager_mlp/mlp_eager_simple.py b/examples/eager_mlp/mlp_eager_simple.py index 947d0da6e..578c5ef40 100644 --- a/examples/eager_mlp/mlp_eager_simple.py +++ b/examples/eager_mlp/mlp_eager_simple.py @@ -78,9 +78,10 @@ def infer(): custom_data_loader = MNISTDataLoader(config["batch_size"]) test_loader = custom_data_loader.get_test_loader() model = MLP() - test_opt = torch.compile(infer_iteration, backend="turbine_cpu") + test_opt = torch.compile(infer_iteration, backend="turbine_rocm") for i, (images, labels) in enumerate(test_loader): - test_opt(model, images) + outputs = test_opt(model, images) + print(f"Iter {i}: {outputs}") class ModelTests(unittest.TestCase): diff --git a/python/shark_turbine/dynamo/backends/cpu.py b/python/shark_turbine/dynamo/backends/cpu.py index 941aed94a..6fcc98f75 100644 --- a/python/shark_turbine/dynamo/backends/cpu.py +++ b/python/shark_turbine/dynamo/backends/cpu.py @@ -86,14 +86,7 @@ def _base_backend(gm: torch.fx.GraphModule, example_inputs): # Set up for runtime. device_state = _get_device_state() - # TODO: Switch to wrap_buffer once https://github.com/openxla/iree/issues/14926 - # is fixed. - # vmfb_module = VmModule.wrap_buffer( - # device_state.instance, - # output.map_memory(), - # destroy_callback=output.close, - # ) - vmfb_module = VmModule.copy_buffer( + vmfb_module = VmModule.wrap_buffer( device_state.instance, output.map_memory(), ) diff --git a/python/shark_turbine/dynamo/backends/rocm.py b/python/shark_turbine/dynamo/backends/rocm.py new file mode 100644 index 000000000..9f26cd9ca --- /dev/null +++ b/python/shark_turbine/dynamo/backends/rocm.py @@ -0,0 +1,108 @@ +# Copyright 2023 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import functools +import sys + +from ..device import ( + DeviceState, +) + +from ..executor import ( + SpecializedExecutable, +) + +from iree.compiler.api import ( + _initializeGlobalCL, + Invocation, + Session, + Source, + Output, +) + +from iree.compiler.ir import ( + Context, +) +from iree.compiler.passmanager import ( + PassManager, +) + +from iree.runtime import ( + VmModule, +) + +from ..importer import FxImporter + +import torch +from torch._dynamo.backends.common import aot_autograd +from ..passes import turbine_cpu_pass_pipeline + +_initializeGlobalCL("dynamo", "--iree-rocm-target-chip=gfx1100", "--iree-rocm-link-bc") + +DEFAULT_COMPILER_FLAGS = ( + # Enable asynchronous calling convention. + # TODO: Enable async execution mode. + # "--iree-execution-model=async-external", + "--iree-input-type=tm_tensor", +) + + +def _base_backend(gm: torch.fx.GraphModule, example_inputs): + # Set up the session, context and invocation. + # Note that we do this on one in-memory module in a few phases: + # 1. Build it from the FX graph. + # 2. Run torch MLIR passes to lower it to a suitable form for + # input. + # 3. Run IREE's main compiler. + # 4. Output to an mmap buffer. + session = Session() + session.set_flags(*DEFAULT_COMPILER_FLAGS) + session.set_flags("--iree-hal-target-backends=rocm") + context = session.context + importer = FxImporter(context=context) + module = importer.module + inv = session.invocation() + # TODO: Should capture diagnostics. + inv.enable_console_diagnostics() + inv.import_module(module.operation) + + # Apply decompositions. + gm = turbine_cpu_pass_pipeline(gm, example_inputs) + + # Import phase. + importer.import_graph_module(gm) + print(module, file=sys.stderr) + with context: + pm = PassManager.parse("builtin.module(torch-to-iree)") + pm.run(module.operation) + print(module, file=sys.stderr) + + # IREE compilation phase. + inv.execute() + + # Output phase. + output = Output.open_membuffer() + inv.output_vm_bytecode(output) + + # Set up for runtime. + device_state = _get_device_state() + vmfb_module = VmModule.wrap_buffer( + device_state.instance, + output.map_memory(), + ) + output.close() + + return SpecializedExecutable(vmfb_module, device_state) + + +backend = aot_autograd(fw_compiler=_base_backend) + + +# IREE runtime globals. For the CPU right now, there is no device selection, +# so it is easy. +@functools.lru_cache(maxsize=None) +def _get_device_state() -> DeviceState: + return DeviceState(driver="rocm") diff --git a/python/shark_turbine/dynamo/device.py b/python/shark_turbine/dynamo/device.py index 07181e6a3..dcdc728eb 100644 --- a/python/shark_turbine/dynamo/device.py +++ b/python/shark_turbine/dynamo/device.py @@ -9,6 +9,7 @@ from threading import local, Lock from iree.runtime import ( + _binding, asdevicearray, create_hal_module, HalBufferView, @@ -38,6 +39,9 @@ def get_vm_instance() -> VmInstance: if not _GLOBAL_VM_INSTANCE: with _CONFIG_LOCK: if not _GLOBAL_VM_INSTANCE: + # Using Dynamo in eager mode creates global garbage that is not + # freed before we unload extensions. Disable leak spew. + _binding.disable_leak_checker() _GLOBAL_VM_INSTANCE = VmInstance() return _GLOBAL_VM_INSTANCE diff --git a/setup.py b/setup.py index a5c2df138..c38c23404 100644 --- a/setup.py +++ b/setup.py @@ -99,6 +99,7 @@ def initialize_options(self): entry_points={ "torch_dynamo_backends": [ "turbine_cpu = shark_turbine.dynamo.backends.cpu:backend", + "turbine_rocm = shark_turbine.dynamo.backends.rocm:backend", ], }, install_requires=[