From 64207122dbd838fa9d84a6b3f54768d45bf41fc0 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Fri, 10 Oct 2025 13:29:55 -0400 Subject: [PATCH 01/10] Update [ghstack-poisoned] --- backends/aoti/aoti_model_container.cpp | 6 ++++++ backends/aoti/aoti_model_container.h | 16 ++++++++++++++++ backends/aoti/common_shims.cpp | 5 +++++ backends/aoti/common_shims.h | 3 +++ 4 files changed, 30 insertions(+) diff --git a/backends/aoti/aoti_model_container.cpp b/backends/aoti/aoti_model_container.cpp index 03be835a0c3..d1764451ab6 100644 --- a/backends/aoti/aoti_model_container.cpp +++ b/backends/aoti/aoti_model_container.cpp @@ -25,6 +25,12 @@ AOTInductorModelContainerGetNumOutputsFunc AOTInductorModelContainerGetNumOutputs = nullptr; AOTInductorModelContainerRunFunc AOTInductorModelContainerRun = nullptr; +// Global function pointers needed by Metal backend +AOTInductorModelContainerGetInputNameFunc + AOTInductorModelContainerGetInputName = nullptr; +AOTInductorModelContainerGetNumConstantsFunc + AOTInductorModelContainerGetNumConstants = nullptr; + } // extern "C" } // namespace aoti diff --git a/backends/aoti/aoti_model_container.h b/backends/aoti/aoti_model_container.h index 9b185327172..88d936d21ba 100644 --- a/backends/aoti/aoti_model_container.h +++ b/backends/aoti/aoti_model_container.h @@ -70,6 +70,22 @@ extern AOTInductorModelContainerGetNumOutputsFunc AOTInductorModelContainerGetNumOutputs; extern AOTInductorModelContainerRunFunc AOTInductorModelContainerRun; +// Function pointer types needed by Metal backend +using AOTInductorModelContainerGetInputNameFunc = AOTIRuntimeError (*)( + AOTInductorModelContainerHandle container_handle, + size_t input_idx, + const char** input_name); + +using AOTInductorModelContainerGetNumConstantsFunc = AOTIRuntimeError (*)( + AOTInductorModelContainerHandle container_handle, + size_t* num_constants); + +// Global function pointers needed by Metal backend +extern AOTInductorModelContainerGetInputNameFunc + AOTInductorModelContainerGetInputName; +extern AOTInductorModelContainerGetNumConstantsFunc + AOTInductorModelContainerGetNumConstants; + } // extern "C" // AOTI Delegate Handle structure diff --git a/backends/aoti/common_shims.cpp b/backends/aoti/common_shims.cpp index abc83779443..7802444e97e 100644 --- a/backends/aoti/common_shims.cpp +++ b/backends/aoti/common_shims.cpp @@ -145,6 +145,11 @@ void cleanup_tensor_metadata() { internal::tensor_to_strides.clear(); } +// Needed by Metal backend +size_t aoti_torch_dtype_element_size(int32_t dtype) { + return dtype_to_element_size(dtype); +} + } // extern "C" } // namespace aoti diff --git a/backends/aoti/common_shims.h b/backends/aoti/common_shims.h index 5f54cd1c878..97fcea1085c 100644 --- a/backends/aoti/common_shims.h +++ b/backends/aoti/common_shims.h @@ -68,6 +68,9 @@ void aoti_torch_grad_mode_set_enabled(bool enabled); // Cleanup functions for clearing global state void cleanup_tensor_metadata(); +// Needed by Metal backend +size_t aoti_torch_dtype_element_size(int32_t dtype); + } // extern "C" } // namespace aoti From d036c0713348e2482aea4d21405d30a51b629f76 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Fri, 10 Oct 2025 16:08:15 -0400 Subject: [PATCH 02/10] Update [ghstack-poisoned] --- backends/apple/metal/metal_backend.py | 173 ++++++++++++++++++ backends/apple/metal/metal_partitioner.py | 77 ++++++++ .../metal/replace_slice_copy_with_slice.py | 118 ++++++++++++ backends/apple/metal/tests/__init__.py | 6 + .../apple/metal/tests/test_metal_backend.py | 80 ++++++++ .../metal/tests/test_metal_partitioner.py | 172 +++++++++++++++++ 6 files changed, 626 insertions(+) create mode 100644 backends/apple/metal/metal_backend.py create mode 100644 backends/apple/metal/metal_partitioner.py create mode 100644 backends/apple/metal/replace_slice_copy_with_slice.py create mode 100644 backends/apple/metal/tests/__init__.py create mode 100644 backends/apple/metal/tests/test_metal_backend.py create mode 100644 backends/apple/metal/tests/test_metal_partitioner.py diff --git a/backends/apple/metal/metal_backend.py b/backends/apple/metal/metal_backend.py new file mode 100644 index 00000000000..782aa522084 --- /dev/null +++ b/backends/apple/metal/metal_backend.py @@ -0,0 +1,173 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import os +import typing +from enum import Enum + +from typing import Any, Dict, final, List, Optional, Set + +import torch +from executorch.backends.apple.metal.replace_slice_copy_with_slice import ( + ReplaceSliceCopyWithSlicePass, +) +from executorch.exir._serialize._named_data_store import NamedDataStore +from executorch.exir._warnings import experimental +from executorch.exir.backend.backend_details import ( + BackendDetails, + ExportedProgram, + PreprocessResult, +) +from executorch.exir.backend.compile_spec_schema import CompileSpec +from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu +from torch.export.passes import move_to_device_pass + + +# exist fallback operators in et namespace; +supported_fallback_kernels: Dict[str, Any] = { + "aoti_torch_mps_addmm_out": None, + "aoti_torch_mps_convolution": None, + "aoti_torch_mps_mm_out": None, + "at::_ops::_scaled_dot_product_attention_math_for_mps::call": None, +} + +# required fallback kernels but not supported +missing_fallback_kernels: Set[str] = set() + + +class COMPILE_SPEC_KEYS(Enum): + METHOD_NAME = "method_name" + + +# context manager for non-fallback guarantee +# it will raise exception when generating fallback kernels during aoti compile +@contextlib.contextmanager +def collect_unsupported_fallback_kernels(): + original_generate_c_shim_extern_kernel_call = ( + CppWrapperCpu.generate_c_shim_extern_kernel_call + ) + + def generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels( + self, + kernel: str, + args: list[str], + device: str, + *, + debug_args: Optional[list[str]] = None, + debug_handle: Optional[int] = None, + ): + if kernel not in supported_fallback_kernels: + missing_fallback_kernels.add(kernel) + + original_generate_c_shim_extern_kernel_call( + self, kernel, args, device, debug_args=debug_args, debug_handle=debug_handle + ) + + CppWrapperCpu.generate_c_shim_extern_kernel_call = ( + generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels + ) + try: + yield + finally: + CppWrapperCpu.generate_c_shim_extern_kernel_call = ( + original_generate_c_shim_extern_kernel_call + ) + + +@final +@experimental( + "This API and all of Metal backend related functionality are experimental." +) +class MetalBackend(BackendDetails): + @staticmethod + def preprocess( + edge_program: ExportedProgram, + compile_specs: List[CompileSpec], + ) -> PreprocessResult: + print("entering the lowerable parts in MetalBackend.preprocess....") + # Move the edge_program from CPU to MPS for aoti compile + mps_edge_program = move_to_device_pass(edge_program, "mps") + + # replace slice_copy with slice + ReplaceSliceCopyWithSlicePass()(mps_edge_program.graph_module) + + edge_program_module = mps_edge_program.module() + + # Grab all input placeholders from the graph + user_input_names = mps_edge_program.graph_signature.user_inputs + user_input_placeholders = [] + for node in mps_edge_program.graph.nodes: + if node.op == "placeholder" and node.name in user_input_names: + user_input_placeholders.append(node.meta["val"]) + + # Base options for all devices + options: dict[str, typing.Any] = { + # Do not link against the full PyTorch/libtorch library + "aot_inductor.link_libtorch": False, + # Package model constants and other generated files directly in the shared object (.so) file + "aot_inductor.package_constants_in_so": True, + # Enable maximum automatic tuning for optimal performance + "max_autotune": True, + # "aot_inductor.debug_compile": True, + # "aot_inductor.force_mmap_weights": False, + } + + with collect_unsupported_fallback_kernels(): + so_path = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type] + if len(missing_fallback_kernels) > 0: + formatted_kernels = "\n - ".join(sorted(missing_fallback_kernels)) + raise RuntimeError( + f"Missing fallback kernels ({len(missing_fallback_kernels)} total):\n - {formatted_kernels}\n" + "Please add them to the AOTI backend." + ) + + # pyre-ignorep[6]: Incompatible parameter type + with open(so_path, "rb") as f: + so_data = f.read() + + named_data_store = NamedDataStore() + method_name = MetalBackend.method_name_from_compile_specs(compile_specs) + named_data_store.add_named_data( + method_name + "_so_blob", so_data, 1, "aoti_metal_blob" + ) + + # Clean up the generated so file; it has been packaged into the NamdeDataStore + # pyre-ignorep[6]: Incompatible parameter type + os.remove(so_path) + + return PreprocessResult( + processed_bytes=b"", + debug_handle_map={}, + data_store_output=named_data_store.get_named_data_store_output(), + ) + + @staticmethod + def generate_method_name_compile_spec( + method_name: str, + ) -> CompileSpec: + """ + Returns the compile spec representing the model compute precision, for additional details + please refer to the documentation for ``coremltools.precision``. + """ + return CompileSpec( + COMPILE_SPEC_KEYS.METHOD_NAME.value, + method_name.encode("utf-8"), + ) + + @staticmethod + def method_name_from_compile_specs( + compile_specs: List[CompileSpec], + ) -> str: + """ + Returns the method name from the compile specs. + """ + for spec in compile_specs: + if spec.key == COMPILE_SPEC_KEYS.METHOD_NAME.value: + return spec.value.decode("utf-8") + raise RuntimeError( + f"Could not find method name in compile specs: {compile_specs}" + ) diff --git a/backends/apple/metal/metal_partitioner.py b/backends/apple/metal/metal_partitioner.py new file mode 100644 index 00000000000..b103ac0f455 --- /dev/null +++ b/backends/apple/metal/metal_partitioner.py @@ -0,0 +1,77 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Dict, final, List, Optional, Tuple + +import torch +from executorch.backends.apple.metal.metal_backend import MetalBackend # usort: skip +from executorch.exir._warnings import experimental +from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.exir.backend.partitioner import ( + DelegationSpec, + Partitioner, + PartitionResult, +) +from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer +from torch.export.exported_program import ExportedProgram + + +@final +@experimental( + "This API and all of Metal backend related functionality are experimental." +) +class MetalPartitioner(Partitioner): + """ + Metal partitioner for AOTInductor backend integration. + + This partitioner creates a single partition containing all operators from the input graph. + It skips core ATen decomposition, allowing the Metal backend to handle decomposition using + AOTInductor's MPS-specific decomposition table. + + Only operators that cannot be handled by the aoti-mps library will be excluded from + the partition and fall back to ExecuTorch's default or custom handling. + """ + + def __init__(self, compile_spec: List[CompileSpec]) -> None: + self.delegation_spec = DelegationSpec(MetalBackend.__name__, compile_spec) + + def partition(self, exported_program: ExportedProgram) -> PartitionResult: + """ + Fully delegate the graph to AOTInductor by tagging all nodes as a single partition. + """ + + partition_tags: Dict[str, DelegationSpec] = {} + tag = "tag0" + + for node in exported_program.graph.nodes: + if node.op != "call_function": + continue + node.meta["delegation_tag"] = tag + + partition_tags[tag] = self.delegation_spec + + tag_constant_data(exported_program) + tag_mutated_buffer(exported_program) + + return PartitionResult( + tagged_exported_program=exported_program, partition_tags=partition_tags + ) + + def ops_to_not_decompose( + self, ep: ExportedProgram + ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: + """ + Return a list of operations that should not be decomposed and let the AOT compiler handle them. + Currently we skip ATen decompositon for all ops, and let the Metal backend handle them. + """ + do_not_decompose = set() + + for node in ep.graph.nodes: + if node.op == "call_function" and isinstance( + node.target, torch._ops.OpOverload + ): + do_not_decompose.add(node.target) + return list(do_not_decompose), None diff --git a/backends/apple/metal/replace_slice_copy_with_slice.py b/backends/apple/metal/replace_slice_copy_with_slice.py new file mode 100644 index 00000000000..4f16759af35 --- /dev/null +++ b/backends/apple/metal/replace_slice_copy_with_slice.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Dict, Iterable, Tuple + +import torch +from executorch.exir.dialects._ops import ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.pass_base import ExportPass, PassResult +from torch import fx + + +_SLICE_COPY_TARGETS: Tuple[torch._ops.OpOverload | EdgeOpOverload] = ( + torch.ops.aten.slice_copy.Tensor, + ops.edge.aten.slice_copy.Tensor, +) + +_SLICE_TARGETS: Dict[ + torch._ops.OpOverload | EdgeOpOverload, torch._ops.OpOverload | EdgeOpOverload +] = { + torch.ops.aten.slice_copy.Tensor: torch.ops.aten.slice.Tensor, + ops.edge.aten.slice_copy.Tensor: ops.edge.aten.slice.Tensor, +} + + +class ReplaceSliceCopyWithSlicePass(ExportPass): + """Replace non-mutated ``slice_copy`` results with ``slice`` views.""" + + def call(self, graph_module: fx.GraphModule) -> PassResult: + graph_changed = False + + for node in graph_module.graph.nodes: + if node.op != "call_function" or node.target not in _SLICE_COPY_TARGETS: + continue + + if self._has_blocking_user(node, node.users.keys()): + continue + + node.target = _SLICE_TARGETS[node.target] + graph_changed = True + + if graph_changed: + graph_module.graph.lint() + graph_module.recompile() + + return PassResult(graph_module, graph_changed) + + def _has_blocking_user(self, node: fx.Node, users: Iterable[fx.Node]) -> bool: + for user in users: + if self._is_mutating_user(node, user) or self._is_view_user(node, user): + return True + return False + + def _is_mutating_user(self, node: fx.Node, user: fx.Node) -> bool: + if user.op == "call_method": + # Treat in-place tensor methods conservatively as mutations only when the + # method name ends with ``_`` which is the PyTorch convention for mutation. + return isinstance(user.target, str) and user.target.endswith("_") + + if user.op != "call_function": + return False + + target = user.target + if not hasattr(target, "_schema"): + return False + + schema = target._schema # pyre-ignore[16] + # Positional arguments + for index, arg in enumerate(user.args): + if arg is node and self._argument_mutates(schema, index): + return True + + # Keyword arguments + for name, arg in user.kwargs.items(): + if arg is node and self._argument_mutates(schema, name): + return True + + return False + + def _is_view_user(self, node: fx.Node, user: fx.Node) -> bool: + if user.op == "call_method": + # Treat tensor methods conservatively and assume they may be view-producing. + return True + + if user.op != "call_function": + return False + + target = user.target + if getattr(target, "is_view", False): + for arg in user.args: + if arg is node: + return True + for arg in user.kwargs.values(): + if arg is node: + return True + + return False + + def _argument_mutates( + self, schema: torch._C.FunctionSchema, key: int | str + ) -> bool: + arguments = schema.arguments + if isinstance(key, int): + if key >= len(arguments): + return False + argument = arguments[key] + else: + argument = next((arg for arg in arguments if arg.name == key), None) + if argument is None: + return False + + alias_info = argument.alias_info + return bool(alias_info and alias_info.is_write) diff --git a/backends/apple/metal/tests/__init__.py b/backends/apple/metal/tests/__init__.py new file mode 100644 index 00000000000..fd6404c7f7b --- /dev/null +++ b/backends/apple/metal/tests/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + diff --git a/backends/apple/metal/tests/test_metal_backend.py b/backends/apple/metal/tests/test_metal_backend.py new file mode 100644 index 00000000000..26d2281c458 --- /dev/null +++ b/backends/apple/metal/tests/test_metal_backend.py @@ -0,0 +1,80 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +from executorch.backends.apple.metal.metal_backend import ( + COMPILE_SPEC_KEYS, + MetalBackend, +) +from executorch.exir.backend.compile_spec_schema import CompileSpec + + +class TestMetalBackend(unittest.TestCase): + """Test Metal backend utility functions.""" + + def test_generate_method_name_compile_spec(self): + """Test that compile spec is generated correctly with method name.""" + method_name = "forward" + compile_spec = MetalBackend.generate_method_name_compile_spec(method_name) + + # Verify compile spec structure + self.assertIsInstance(compile_spec, CompileSpec) + self.assertEqual(compile_spec.key, COMPILE_SPEC_KEYS.METHOD_NAME.value) + self.assertEqual(compile_spec.value, method_name.encode("utf-8")) + + def test_method_name_from_compile_specs(self): + """Test extracting method name from compile specs.""" + method_name = "forward" + compile_specs = [MetalBackend.generate_method_name_compile_spec(method_name)] + + # Extract method name + extracted_name = MetalBackend.method_name_from_compile_specs(compile_specs) + + self.assertEqual(extracted_name, method_name) + + def test_method_name_from_compile_specs_with_multiple_specs(self): + """Test extracting method name when there are multiple compile specs.""" + method_name = "forward" + compile_specs = [ + CompileSpec("other_key", b"other_value"), + MetalBackend.generate_method_name_compile_spec(method_name), + CompileSpec("another_key", b"another_value"), + ] + + # Extract method name + extracted_name = MetalBackend.method_name_from_compile_specs(compile_specs) + + self.assertEqual(extracted_name, method_name) + + def test_method_name_from_compile_specs_missing(self): + """Test that RuntimeError is raised when method name is missing.""" + compile_specs = [ + CompileSpec("other_key", b"other_value"), + ] + + # Should raise RuntimeError when method name is not found + with self.assertRaises(RuntimeError) as context: + MetalBackend.method_name_from_compile_specs(compile_specs) + + self.assertIn("Could not find method name", str(context.exception)) + + def test_compile_spec_roundtrip(self): + """Test that method name survives encode/decode roundtrip.""" + original_name = "my_custom_method" + + # Generate compile spec + compile_spec = MetalBackend.generate_method_name_compile_spec(original_name) + + # Extract from compile specs list + extracted_name = MetalBackend.method_name_from_compile_specs([compile_spec]) + + self.assertEqual(original_name, extracted_name) + + +if __name__ == "__main__": + unittest.main() + diff --git a/backends/apple/metal/tests/test_metal_partitioner.py b/backends/apple/metal/tests/test_metal_partitioner.py new file mode 100644 index 00000000000..97a073152f5 --- /dev/null +++ b/backends/apple/metal/tests/test_metal_partitioner.py @@ -0,0 +1,172 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from typing import Tuple + +import torch +from executorch.backends.apple.metal.metal_backend import MetalBackend +from executorch.backends.apple.metal.metal_partitioner import MetalPartitioner +from executorch.exir.backend.partitioner import PartitionResult +from torch.export import export + + +class TestMetalPartitioner(unittest.TestCase): + """ + Test Metal partitioner functionality. + + After Metal partitioning, there should be exactly one partitioned graph that contains + all operators from the input graph. This means all operators should be tagged with + the same delegation tag, indicating they will all be executed by the Metal backend. + """ + + def _get_partition_result( + self, module: torch.nn.Module, inputs: Tuple[torch.Tensor, ...] + ) -> PartitionResult: + """Helper method to get partition result for a given module.""" + # Export the model + exported_program = export(module, inputs, strict=True) + + # Create partitioner with compile specs + compile_specs = [MetalBackend.generate_method_name_compile_spec("forward")] + partitioner = MetalPartitioner(compile_specs) + + # Get partition result + partition_result = partitioner.partition(exported_program) + + # Verify partition result structure + self.assertIsNotNone(partition_result) + self.assertTrue(hasattr(partition_result, "tagged_exported_program")) + self.assertTrue(hasattr(partition_result, "partition_tags")) + + return partition_result + + def _check_fully_partitioned(self, partition_result: PartitionResult) -> bool: + """Check if the graph is fully partitioned (all operators have the same tag).""" + tagged_nodes = [] + untagged_ops = [] + + for node in partition_result.tagged_exported_program.graph.nodes: + if node.op == "call_function": + if hasattr(node, "meta") and "delegation_tag" in node.meta: + tagged_nodes.append(node) + else: + untagged_ops.append(node) + + # Check if we have any tagged nodes + if not tagged_nodes: + return False + + # Check if all tagged nodes have the same tag + first_tag = tagged_nodes[0].meta["delegation_tag"] + all_same_tag = all( + node.meta.get("delegation_tag") == first_tag for node in tagged_nodes + ) + + # Should have no untagged operations for full partitioning + fully_partitioned = len(untagged_ops) == 0 and all_same_tag + + return fully_partitioned + + def test_simple_add_partition(self): + """ + Test that Metal partitioner creates exactly one partition containing all operators. + Simple element-wise addition should result in a single graph with all ops tagged identically. + """ + + class AddModule(torch.nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x + y + + # Create test inputs + x = torch.randn(2, 3) + y = torch.randn(2, 3) + + # Get partition result + partition_result = self._get_partition_result(AddModule(), (x, y)) + + # Verify it's fully partitioned + self.assertTrue( + self._check_fully_partitioned(partition_result), + "Expected all operations to be in a single partition", + ) + + # Verify exactly one partition tag exists + self.assertEqual( + len(partition_result.partition_tags), + 1, + "Expected exactly one partition tag for fully delegated graph", + ) + + def test_linear_partition(self): + """ + Test Metal partitioner with a linear layer. + All matrix operations should be in a single partition. + """ + + class LinearModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 5) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + + # Create test input + x = torch.randn(2, 10) + + # Get partition result + partition_result = self._get_partition_result(LinearModule(), (x,)) + + # Verify it's fully partitioned + self.assertTrue( + self._check_fully_partitioned(partition_result), + "Expected all operations to be in a single partition", + ) + + def test_ops_to_not_decompose(self): + """ + Test that ops_to_not_decompose returns all call_function ops. + Metal backend should handle decomposition via AOTInductor. + """ + + class SimpleModule(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.relu(x + 1.0) + + # Create test input + x = torch.randn(2, 3) + + # Export the model + exported_program = export(SimpleModule(), (x,), strict=True) + + # Create partitioner + compile_specs = [MetalBackend.generate_method_name_compile_spec("forward")] + partitioner = MetalPartitioner(compile_specs) + + # Get ops to not decompose + ops_to_not_decompose, _ = partitioner.ops_to_not_decompose(exported_program) + + # Verify it returns a list + self.assertIsInstance(ops_to_not_decompose, list) + + # All call_function ops should be in the list + call_function_ops = [ + node.target + for node in exported_program.graph.nodes + if node.op == "call_function" and isinstance(node.target, torch._ops.OpOverload) + ] + + self.assertEqual( + set(ops_to_not_decompose), + set(call_function_ops), + "ops_to_not_decompose should contain all call_function ops", + ) + + +if __name__ == "__main__": + unittest.main() + From 1a22c5e02f3a4d57bb3abf6b68c1279ec4bf58e4 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Fri, 10 Oct 2025 16:21:06 -0400 Subject: [PATCH 03/10] Update [ghstack-poisoned] --- backends/apple/metal/tests/__init__.py | 1 - backends/apple/metal/tests/test_metal_backend.py | 1 - backends/apple/metal/tests/test_metal_partitioner.py | 4 ++-- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/backends/apple/metal/tests/__init__.py b/backends/apple/metal/tests/__init__.py index fd6404c7f7b..2e41cd717f6 100644 --- a/backends/apple/metal/tests/__init__.py +++ b/backends/apple/metal/tests/__init__.py @@ -3,4 +3,3 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - diff --git a/backends/apple/metal/tests/test_metal_backend.py b/backends/apple/metal/tests/test_metal_backend.py index 26d2281c458..5caf7a3adc6 100644 --- a/backends/apple/metal/tests/test_metal_backend.py +++ b/backends/apple/metal/tests/test_metal_backend.py @@ -77,4 +77,3 @@ def test_compile_spec_roundtrip(self): if __name__ == "__main__": unittest.main() - diff --git a/backends/apple/metal/tests/test_metal_partitioner.py b/backends/apple/metal/tests/test_metal_partitioner.py index 97a073152f5..1b29410ab6c 100644 --- a/backends/apple/metal/tests/test_metal_partitioner.py +++ b/backends/apple/metal/tests/test_metal_partitioner.py @@ -157,7 +157,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: call_function_ops = [ node.target for node in exported_program.graph.nodes - if node.op == "call_function" and isinstance(node.target, torch._ops.OpOverload) + if node.op == "call_function" + and isinstance(node.target, torch._ops.OpOverload) ] self.assertEqual( @@ -169,4 +170,3 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if __name__ == "__main__": unittest.main() - From d6f0bc952a57a6ec14c5118fd1fe4da91d2d3194 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Fri, 10 Oct 2025 17:01:30 -0400 Subject: [PATCH 04/10] Update [ghstack-poisoned] --- .../metal/runtime/shims/tensor_attribute.cpp | 38 ++++++++ .../metal/runtime/shims/tensor_attribute.h | 32 +++++++ backends/apple/metal/runtime/shims/types.h | 35 +++++++ backends/apple/metal/runtime/shims/utils.cpp | 93 +++++++++++++++++++ backends/apple/metal/runtime/shims/utils.h | 74 +++++++++++++++ 5 files changed, 272 insertions(+) create mode 100644 backends/apple/metal/runtime/shims/tensor_attribute.cpp create mode 100644 backends/apple/metal/runtime/shims/tensor_attribute.h create mode 100644 backends/apple/metal/runtime/shims/types.h create mode 100644 backends/apple/metal/runtime/shims/utils.cpp create mode 100644 backends/apple/metal/runtime/shims/utils.h diff --git a/backends/apple/metal/runtime/shims/tensor_attribute.cpp b/backends/apple/metal/runtime/shims/tensor_attribute.cpp new file mode 100644 index 00000000000..684e00ffe32 --- /dev/null +++ b/backends/apple/metal/runtime/shims/tensor_attribute.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal { + +extern "C" { + +// Metal-specific device type constant +__attribute__((__visibility__("default"))) int32_t +aoti_torch_device_type_mps() { + // Let's use 2 for MPS + return 2; +} + +// Override aoti_torch_get_device_type to return MPS device type +AOTITorchError aoti_torch_get_device_type( + AOTITensorHandle tensor, + int32_t* ret_device_type) { + *ret_device_type = aoti_torch_device_type_mps(); + return Error::Ok; +} + +} // extern "C" + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/apple/metal/runtime/shims/tensor_attribute.h b/backends/apple/metal/runtime/shims/tensor_attribute.h new file mode 100644 index 00000000000..8d2a3dde361 --- /dev/null +++ b/backends/apple/metal/runtime/shims/tensor_attribute.h @@ -0,0 +1,32 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace executorch { +namespace backends { +namespace metal { + +extern "C" { + +// Metal-specific device type function +int32_t aoti_torch_device_type_mps(); + +// Override aoti_torch_get_device_type to return MPS device type +AOTITorchError aoti_torch_get_device_type( + AOTITensorHandle tensor, + int32_t* ret_device_type); + +} // extern "C" + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/apple/metal/runtime/shims/types.h b/backends/apple/metal/runtime/shims/types.h new file mode 100644 index 00000000000..07d377d7499 --- /dev/null +++ b/backends/apple/metal/runtime/shims/types.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal { + +// Common using declarations for ExecutorTorch types +using executorch::runtime::Error; +using executorch::runtime::etensor::Tensor; + +extern "C" { + +// Common AOTI type aliases +// Note: AOTITensorHandle is aliased to Tensor* for ExecutorTorch compatibility +using AOTITensorHandle = Tensor*; +using AOTIRuntimeError = Error; +using AOTITorchError = Error; + +} // extern "C" + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/apple/metal/runtime/shims/utils.cpp b/backends/apple/metal/runtime/shims/utils.cpp new file mode 100644 index 00000000000..484158e9027 --- /dev/null +++ b/backends/apple/metal/runtime/shims/utils.cpp @@ -0,0 +1,93 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal { + +extern "C" { + +// Helper function to check if a dtype is supported in Metal backend +bool is_dtype_supported_in_et_metal(int32_t dtype) { + switch (dtype) { + case static_cast(SupportedDTypes::INT64): + case static_cast(SupportedDTypes::FLOAT32): + case static_cast(SupportedDTypes::BFLOAT16): + return true; + default: + return false; + } +} + +// Metal-specific dtype validation utility function +AOTITorchError validate_dtype(int32_t dtype) { + if (is_dtype_supported_in_et_metal(dtype)) { + return Error::Ok; + } + + ET_LOG( + Error, + "Unsupported dtype: %d. Supported dtypes: %d (int64), %d (float32), %d (bfloat16)", + dtype, + static_cast(SupportedDTypes::INT64), + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDTypes::BFLOAT16)); + return Error::InvalidArgument; +} + +} // extern "C" + +// Utility function to convert sizes pointer to vector +std::vector convert_sizes_to_vector( + int64_t ndim, + const int64_t* sizes_ptr) { + std::vector sizes(ndim); + for (int i = 0; i < ndim; i++) { + sizes[i] = static_cast(sizes_ptr[i]); + } + return sizes; +} + +// Utility function to convert strides pointer to vector or calculate from sizes +std::vector convert_strides_to_vector( + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr) { + std::vector strides(ndim); + + if (strides_ptr != nullptr) { + // Use provided strides. it is ok if provided strides here is not contiguous + // strides since it will be used internally in CUDA delegate. + for (int64_t i = 0; i < ndim; i++) { + strides[i] = static_cast(strides_ptr[i]); + } + } else { + // Calculate strides from sizes using ExecutorTorch's algorithm + if (ndim > 0) { + strides[ndim - 1] = static_cast( + 1); // Last dimension has stride 1 + for (int64_t i = ndim - 2; i >= 0; i--) { + if (sizes_ptr[i + 1] == 0) { + strides[i] = strides[i + 1]; // Copy stride when size is 0 + } else { + strides[i] = static_cast( + static_cast(strides[i + 1]) * sizes_ptr[i + 1]); + } + } + } + } + return strides; +} + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/apple/metal/runtime/shims/utils.h b/backends/apple/metal/runtime/shims/utils.h new file mode 100644 index 00000000000..5b9f9c5b3bb --- /dev/null +++ b/backends/apple/metal/runtime/shims/utils.h @@ -0,0 +1,74 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal { + +// Enum for supported data types in et-metal backend +enum class SupportedDTypes : int32_t { + // UINT8 = 0, // PyTorch's uint8 dtype code + // INT8 = 1, // PyTorch's int8 dtype code + // INT16 = 2, // PyTorch's int16 dtype code + // INT32 = 3, // PyTorch's int32 dtype code + INT64 = 4, // PyTorch's int64 dtype code + // FLOAT16 = 5, // PyTorch's float16 dtype code + FLOAT32 = 6, // PyTorch's float32 dtype code + // FLOAT64 = 7, // PyTorch's float64 dtype code + // BOOL = 11, // PyTorch's bool dtype code + BFLOAT16 = 15 // PyTorch's bfloat16 dtype code +}; + +extern "C" { + +// Helper function to check if a dtype is supported in Metal backend +bool is_dtype_supported_in_et_metal(int32_t dtype); + +// Metal-specific dtype validation utility function +AOTITorchError validate_dtype(int32_t dtype); + +} // extern "C" + +// Utility function to convert sizes pointer to vector +std::vector convert_sizes_to_vector( + int64_t ndim, + const int64_t* sizes_ptr); + +// Utility function to convert strides pointer to vector or calculate from sizes +std::vector convert_strides_to_vector( + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr); + +// Check if tensor is in contiguous memory format (NCHW for 4D tensors) +// Contiguous format means strides decrease from left to right: +// For NCHW: strides = [C*H*W, H*W, W, 1] +inline bool is_contiguous_tensor( + std::vector sizes, + std::vector strides) { + int64_t ndim = static_cast(strides.size()); + int64_t expected_stride = 1; + for (int64_t i = ndim - 1; i >= 0; i--) { + if (strides[i] != expected_stride) { + return false; + } + expected_stride *= sizes[i]; + } + return true; +} + +} // namespace metal +} // namespace backends +} // namespace executorch From 7e11615aa43033df7f5988d5c1adb6d923c78297 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Fri, 10 Oct 2025 17:01:34 -0400 Subject: [PATCH 05/10] Update [ghstack-poisoned] --- backends/apple/metal/runtime/shims/et_metal.h | 271 ++++++ .../apple/metal/runtime/shims/et_metal.mm | 872 ++++++++++++++++++ 2 files changed, 1143 insertions(+) create mode 100644 backends/apple/metal/runtime/shims/et_metal.h create mode 100644 backends/apple/metal/runtime/shims/et_metal.mm diff --git a/backends/apple/metal/runtime/shims/et_metal.h b/backends/apple/metal/runtime/shims/et_metal.h new file mode 100644 index 00000000000..c18ad513a3a --- /dev/null +++ b/backends/apple/metal/runtime/shims/et_metal.h @@ -0,0 +1,271 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#ifdef __OBJC__ +#import +#import +#include +// Forward declarations for MetalPerformanceShadersGraph types +@class MPSGraph; +@class MPSCommandBuffer; +// Metal type definitions for Objective-C compilation +typedef id MTLDevice_t; +typedef id MTLCommandQueue_t; +typedef id MTLCommandBuffer_t; +typedef id MTLComputeCommandEncoder_t; +typedef id MTLComputePipelineState_t; +typedef id MTLFunction_t; +typedef id MTLLibrary_t; +typedef id MTLBuffer_t; +typedef dispatch_queue_t dispatch_queue_t; +typedef MPSGraph* MPSGraph_t; +typedef MPSCommandBuffer* MPSCommandBuffer_t; +typedef NSDictionary* NSDictionary_t; +#else +// Forward declarations for C++ compilation +typedef void* MTLDevice_t; +typedef void* MTLCommandQueue_t; +typedef void* MTLCommandBuffer_t; +typedef void* MTLComputeCommandEncoder_t; +typedef void* MTLComputePipelineState_t; +typedef void* MTLFunction_t; +typedef void* MTLLibrary_t; +typedef void* MTLBuffer_t; +typedef void* dispatch_queue_t; +typedef void* MPSGraph_t; +typedef void* MPSCommandBuffer_t; +typedef void* NSDictionary_t; +#endif + +#include +#include +#include +#include +#include + +namespace executorch::runtime::etensor { +class Tensor; +} + +namespace executorch { +namespace backends { +namespace metal { + +// Forward declarations +class ETMetalKernelFunction; +class ETMetalStream; + +// ======================= +// SyncType - Metal synchronization options +// ======================= +enum class SyncType { + NONE, // no commit to command buffer + COMMIT, // commit and flush the command buffer + COMMIT_AND_WAIT, // flush and wait for command buffer execution to finish + COMMIT_AND_CONTINUE, // commit and continue with a new underlying command + // buffer + COMMIT_ADAPTIVE, // commit adaptively based on available memory +}; + +// ======================= +// ETMetalShaderLibrary - ExecuTorch Metal shader library management +// ======================= +class ETMetalShaderLibrary { + public: + ETMetalShaderLibrary(const std::string& source); + ~ETMetalShaderLibrary(); + + std::shared_ptr getKernelFunction( + const std::string& name); + + private: + void compileLibrary(); + std::pair getLibraryPipelineState( + const std::string& functionName); + + friend class ETMetalKernelFunction; + + std::string shaderSource_; + MTLLibrary_t library_; + std::unordered_map< + std::string, + std::pair> + pipelineStates_; +}; + +// ======================= +// ETMetalKernelFunction - ExecuTorch Metal kernel function execution +// ======================= +class ETMetalKernelFunction { + public: + ETMetalKernelFunction(MTLComputePipelineState_t cps, MTLFunction_t func); + ~ETMetalKernelFunction(); + + void startEncoding(); + void setArg(unsigned idx, const executorch::runtime::etensor::Tensor& tensor); + void setArg(unsigned idx, int64_t val); + + void dispatchSingle(uint64_t length); + void dispatchSingleWithGroupSize(uint64_t length, uint64_t group_size); + void dispatchArray(const uint64_t* length, size_t length_size); + void dispatchArrayWithGroupSize( + const uint64_t* length, + size_t length_size, + const uint64_t* group_size, + size_t group_size_size); + + void runCommandBlock(std::function f); + + private: + MTLComputePipelineState_t cps_; + MTLFunction_t func_; + MTLComputeCommandEncoder_t encoder_; +}; + +// ======================= +// ETMetalStream - Metal command buffer and synchronization management +// ======================= +class ETMetalStream { + public: + ETMetalStream(); + ~ETMetalStream(); + + // Get the default stream (singleton) + static ETMetalStream* getDefaultStream(); + + // Device and queue access + MTLDevice_t device() const { + return device_; + } + MTLCommandQueue_t commandQueue() const { + return commandQueue_; + } + dispatch_queue_t queue() const { + return serialQueue_; + } + + // Synchronization methods + void synchronize(SyncType syncType = SyncType::COMMIT_AND_WAIT); + void synchronize(); // Overload for backward compatibility + bool isEmpty() const; + + // Command buffer management with lazy creation + MPSCommandBuffer_t commandBuffer(); + MTLComputeCommandEncoder_t commandEncoder(); + + void endKernelCoalescing(); + + // MPSGraph execution + void executeMPSGraph( + MPSGraph_t mpsGraph, + NSDictionary_t feeds, + NSDictionary_t results, + SyncType syncType = SyncType::COMMIT_ADAPTIVE); + + // Command buffer lifecycle management + void commitCommandBuffer(MTLCommandBuffer_t commandBuffer); + void flush(); + + // Memory operations + void fill( + MTLBuffer_t buffer, + uint8_t value, + size_t length, + size_t offset, + SyncType syncType = SyncType::NONE); + void copy( + MTLBuffer_t srcBuffer, + MTLBuffer_t dstBuffer, + size_t length, + size_t srcOffset, + size_t dstOffset, + SyncType syncType = SyncType::NONE); + + private: + // Private synchronization methods + void commit(); + void commitAndWait(); + void commitAndContinue(); + + private: + // Private members + MTLDevice_t device_; + MTLCommandQueue_t commandQueue_; + MPSCommandBuffer_t commandBuffer_; + MPSCommandBuffer_t prevCommandBuffer_; // For commit-and-continue pattern + MTLComputeCommandEncoder_t commandEncoder_; + dispatch_queue_t serialQueue_; // For thread safety + + // Configuration + bool enableCommitAndContinue_; + + // Singleton instance + static ETMetalStream* defaultStream_; +}; + +// ======================= +// Global storage management functions +// ======================= +void storeFunctionHandle( + ETMetalKernelFunction* raw_function, + std::shared_ptr function_shared_ptr); +void storeLibraryHandle( + ETMetalShaderLibrary* raw_library, + std::unique_ptr library); +bool removeFunctionHandle(ETMetalKernelFunction* raw_function); +bool removeLibraryHandle(ETMetalShaderLibrary* raw_library); + +// ======================= +// Global stream access functions +// ======================= +ETMetalStream* getCurrentMetalStream(); +void setCurrentMetalStream(ETMetalStream* stream); + +// ======================= +// Metal stream synchronization functions (C++ interface with exceptions) +// ======================= +void synchronize_metal_stream(); +void synchronize_metal_stream_with_type(int sync_type); + +// ======================= +// Metal helper functions (C interface) +// ======================= +#ifdef __cplusplus +extern "C" { +#endif + +// Memory management functions for Metal +void* metal_allocate_buffer(long bytes); +bool metal_is_device_pointer(void* ptr); +int metal_copy_memory( + void* dst, + const void* src, + size_t nbytes, + bool src_is_device, + bool dst_is_device); +void metal_cleanup_resources(); + +// Helper functions to access Metal objects +MTLDevice_t get_metal_device(); +MTLCommandQueue_t get_metal_command_queue(); + +#ifdef __cplusplus +} + +// C++ only - expose the Metal buffer mapping +#ifdef __OBJC__ +extern std::unordered_map ptr_to_mtl_buffer; +#endif + +#endif + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/apple/metal/runtime/shims/et_metal.mm b/backends/apple/metal/runtime/shims/et_metal.mm new file mode 100644 index 00000000000..5afcf761d56 --- /dev/null +++ b/backends/apple/metal/runtime/shims/et_metal.mm @@ -0,0 +1,872 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#import +#import +#import +#import +#include +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal { + +// ======================= +// Exception-Safe Dispatch Function (similar to PyTorch MPS) +// ======================= + +void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)()) { + __block std::optional block_exception; + dispatch_sync(queue, ^() { + try { + block(); + } catch (...) { + block_exception = std::current_exception(); + } + }); + if (block_exception) { + std::rethrow_exception(*block_exception); + } +} + +// ======================= +// Global Variables and Storage +// ================ + + +// Global Metal buffer mapping - accessible for MPS shim +std::unordered_map> ptr_to_mtl_buffer; + +// Global storage to keep shared_ptr alive while raw pointers are used +static std::unordered_map> function_storage; +static std::unordered_map> library_storage; + +// Static singleton instance for default stream +ETMetalStream* ETMetalStream::defaultStream_ = nullptr; + +// Thread-local current stream +static thread_local ETMetalStream* currentStream_ = nullptr; + +// ======================= +// Metal Helper Functions (C Interface) +// ======================= + +extern "C" { + +void* metal_allocate_buffer(long bytes) { + ETMetalStream* stream = getCurrentMetalStream(); + id device = stream->device(); + if (!device) { + ET_LOG(Error, "Failed to get Metal device from stream"); + return nullptr; + } + + @autoreleasepool { + id buffer = [device newBufferWithLength:bytes options:MTLResourceStorageModeShared]; + if (!buffer) { + ET_LOG(Error, "Failed to allocate %ld bytes on Metal device", bytes); + return nullptr; + } + + void* ptr = [buffer contents]; + ptr_to_mtl_buffer[ptr] = buffer; + + ET_LOG(Debug, "Allocated %ld bytes on Metal device", bytes); + return ptr; + } +} + +void metal_cleanup_resources() { + if (!ptr_to_mtl_buffer.empty()) { + @autoreleasepool { + for (auto& pair : ptr_to_mtl_buffer) { + pair.second = nil; + } + ptr_to_mtl_buffer.clear(); + } + } +} + +bool metal_is_device_pointer(void* ptr) { + return ptr_to_mtl_buffer.find(ptr) != ptr_to_mtl_buffer.end(); +} + +int metal_copy_memory(void* dst, const void* src, size_t nbytes, bool src_is_device, bool dst_is_device) { + if (!src || !dst || nbytes == 0) { + ET_LOG(Error, "Metal copy: Invalid parameters"); + return -1; + } + + @autoreleasepool { + // Case 1: Device-to-device copy - use GPU blit encoder (most efficient) + if (src_is_device && dst_is_device) { + auto src_it = ptr_to_mtl_buffer.find(const_cast(src)); + auto dst_it = ptr_to_mtl_buffer.find(dst); + + if (src_it != ptr_to_mtl_buffer.end() && dst_it != ptr_to_mtl_buffer.end()) { + id srcBuffer = src_it->second; + id dstBuffer = dst_it->second; + + // Calculate offsets relative to buffer base + size_t srcOffset = static_cast(src) - static_cast([srcBuffer contents]); + size_t dstOffset = static_cast(dst) - static_cast([dstBuffer contents]); + + // Use Metal's blit encoder for GPU-accelerated copy + ETMetalStream* stream = getCurrentMetalStream(); + stream->copy(srcBuffer, dstBuffer, nbytes, srcOffset, dstOffset, SyncType::NONE); + + ET_LOG(Debug, "Metal device-to-device copy (GPU blit): %zu bytes", nbytes); + return 0; + } + + ET_LOG(Error, "Metal copy: Device pointers not found in buffer map"); + return -1; + } + + // Case 2: Host-to-device or device-to-host - use memcpy with shared memory + // Since Metal uses shared storage mode, CPU and GPU access the same memory + std::memcpy(dst, src, nbytes); + + // Synchronize only if we need to ensure GPU operations complete before CPU reads + // (device-to-host case where GPU may have written data) + if (src_is_device && !dst_is_device) { + // Ensure any pending GPU writes to source complete before CPU reads + ETMetalStream* stream = getCurrentMetalStream(); + stream->synchronize(SyncType::COMMIT_AND_WAIT); + } + + ET_LOG(Debug, "Metal memory copy (memcpy): %zu bytes, src_device=%d, dst_device=%d", + nbytes, src_is_device, dst_is_device); + } + + return 0; +} + +id get_metal_device() { + // Use stream-based device access + ETMetalStream* stream = getCurrentMetalStream(); + return stream->device(); +} + +id get_metal_command_queue() { + // Use stream-based queue access + ETMetalStream* stream = getCurrentMetalStream(); + return stream->commandQueue(); +} + +} // extern "C" + +// ======================= +// ETMetalShaderLibrary Implementation +// ======================= + +ETMetalShaderLibrary::ETMetalShaderLibrary(const std::string& source) : shaderSource_(source) { + compileLibrary(); +} + +ETMetalShaderLibrary::~ETMetalShaderLibrary() { + @autoreleasepool { + if (library_) { + [library_ release]; + library_ = nil; + } + + for (auto& pair : pipelineStates_) { + [pair.second.first release]; + [pair.second.second release]; + } + pipelineStates_.clear(); + } +} + +void ETMetalShaderLibrary::compileLibrary() { + @autoreleasepool { + id device = get_metal_device(); + if (!device) { + ET_LOG(Error, "ETMetalShaderLibrary: Failed to get Metal device"); + return; + } + + NSString* sourceString = [NSString stringWithUTF8String:shaderSource_.c_str()]; + NSError* error = nil; + + library_ = [device newLibraryWithSource:sourceString options:nil error:&error]; + if (!library_ || error) { + ET_LOG(Error, "ETMetalShaderLibrary: Failed to compile shader library: %s", + error ? [[error localizedDescription] UTF8String] : "unknown error"); + return; + } + + [library_ retain]; + ET_LOG(Debug, "ETMetalShaderLibrary: Successfully compiled shader library"); + } +} + +std::pair, id> ETMetalShaderLibrary::getLibraryPipelineState(const std::string& functionName) { + auto it = pipelineStates_.find(functionName); + if (it != pipelineStates_.end()) { + return it->second; + } + + @autoreleasepool { + if (!library_) { + ET_LOG(Error, "ETMetalShaderLibrary: Library not compiled"); + return {nil, nil}; + } + + id device = get_metal_device(); + if (!device) { + ET_LOG(Error, "ETMetalShaderLibrary: Failed to get Metal device"); + return {nil, nil}; + } + + NSString* funcName = [NSString stringWithUTF8String:functionName.c_str()]; + id function = [library_ newFunctionWithName:funcName]; + if (!function) { + ET_LOG(Error, "ETMetalShaderLibrary: Failed to get function '%s'", functionName.c_str()); + return {nil, nil}; + } + + NSError* error = nil; + id pipelineState = [device newComputePipelineStateWithFunction:function error:&error]; + if (!pipelineState || error) { + ET_LOG(Error, "ETMetalShaderLibrary: Failed to create pipeline state for '%s': %s", + functionName.c_str(), error ? [[error localizedDescription] UTF8String] : "unknown error"); + [function release]; + return {nil, nil}; + } + + [pipelineState retain]; + [function retain]; + pipelineStates_[functionName] = {pipelineState, function}; + + ET_LOG(Debug, "ETMetalShaderLibrary: Created pipeline state for function '%s'", functionName.c_str()); + return {pipelineState, function}; + } +} + +std::shared_ptr ETMetalShaderLibrary::getKernelFunction(const std::string& name) { + auto pipelineStatePair = getLibraryPipelineState(name); + if (!pipelineStatePair.first || !pipelineStatePair.second) { + ET_LOG(Error, "ETMetalShaderLibrary::getKernelFunction: Failed to get pipeline state for '%s'", name.c_str()); + return nullptr; + } + + return std::make_shared(pipelineStatePair.first, pipelineStatePair.second); +} + +// ======================= +// ETMetalKernelFunction Implementation +// ======================= + +ETMetalKernelFunction::ETMetalKernelFunction(id cps, id func) + : cps_(cps), func_(func), encoder_(nil) { + if (cps_) [cps_ retain]; + if (func_) [func_ retain]; +} + +ETMetalKernelFunction::~ETMetalKernelFunction() { + @autoreleasepool { + // Don't release encoder_ here - the stream owns it + // Only clean up our own references + if (cps_) { + [cps_ release]; + cps_ = nil; + } + if (func_) { + [func_ release]; + func_ = nil; + } + + encoder_ = nil; // Clear reference without releasing + } +} + +void ETMetalKernelFunction::startEncoding() { + @autoreleasepool { + // Don't retain/release the encoder - just get reference from stream + ETMetalStream* stream = getCurrentMetalStream(); + encoder_ = stream->commandEncoder(); // Use stream's managed encoder + if (!encoder_) { + ET_LOG(Error, "ETMetalKernelFunction: Failed to get encoder from stream"); + return; + } + + // Don't retain - stream owns the encoder + [encoder_ setComputePipelineState:cps_]; + + ET_LOG(Debug, "ETMetalKernelFunction: Started encoding with stream-managed encoder"); + } +} + +void ETMetalKernelFunction::setArg(unsigned idx, const executorch::runtime::etensor::Tensor& tensor) { + if (!encoder_) { + ET_LOG(Error, "ETMetalKernelFunction::setArg: No active encoder"); + return; + } + + void* data_ptr = tensor.mutable_data_ptr(); + size_t totalSize = tensor.numel() * tensor.element_size(); + + auto it = ptr_to_mtl_buffer.find(data_ptr); + if (it != ptr_to_mtl_buffer.end()) { + // Use existing Metal buffer + id mtlBuffer = it->second; + [encoder_ setBuffer:mtlBuffer offset:0 atIndex:idx]; + ET_LOG(Debug, "ETMetalKernelFunction::setArg: Set Metal buffer at index %u (size: %zu)", idx, totalSize); + } else { + // Handle CPU tensor data + if (totalSize <= 4096) { + // Use setBytes for small data (more efficient) + [encoder_ setBytes:data_ptr length:totalSize atIndex:idx]; + ET_LOG(Debug, "ETMetalKernelFunction::setArg: Set CPU tensor via setBytes at index %u (size: %zu)", idx, totalSize); + } else { + // Create temporary buffer for large data (should be rare) + @autoreleasepool { + id device = get_metal_device(); + if (device) { + id tempBuffer = [device newBufferWithBytes:data_ptr + length:totalSize + options:MTLResourceStorageModeShared]; + if (tempBuffer) { + [encoder_ setBuffer:tempBuffer offset:0 atIndex:idx]; + ET_LOG(Debug, "ETMetalKernelFunction::setArg: Set large CPU tensor via temporary buffer at index %u (size: %zu)", idx, totalSize); + } else { + ET_LOG(Error, "ETMetalKernelFunction::setArg: Failed to create temporary buffer for index %u", idx); + } + } else { + ET_LOG(Error, "ETMetalKernelFunction::setArg: No Metal device available for index %u", idx); + } + } + } + } +} + +void ETMetalKernelFunction::setArg(unsigned idx, int64_t val) { + if (!encoder_) { + ET_LOG(Error, "ETMetalKernelFunction::setArg: No active encoder"); + return; + } + + [encoder_ setBytes:&val length:sizeof(int64_t) atIndex:idx]; + ET_LOG(Debug, "ETMetalKernelFunction::setArg: Set int64_t value %lld at index %u", val, idx); +} + +void ETMetalKernelFunction::dispatchSingle(uint64_t length) { + if (!encoder_) { + ET_LOG(Error, "ETMetalKernelFunction::dispatchSingle: No active encoder"); + return; + } + + const auto maxThreadsPerGroup = static_cast([cps_ maxTotalThreadsPerThreadgroup]); + uint64_t actualGroupSize = std::min(maxThreadsPerGroup, length); + + auto size = MTLSizeMake(length, 1, 1); + auto threadGroupSize = MTLSizeMake(actualGroupSize, 1, 1); + + [encoder_ dispatchThreads:size threadsPerThreadgroup:threadGroupSize]; + ET_LOG(Debug, "ETMetalKernelFunction::dispatchSingle: Dispatched with length %llu, group size %llu", length, actualGroupSize); + +} + +void ETMetalKernelFunction::dispatchSingleWithGroupSize(uint64_t length, uint64_t group_size) { + if (!encoder_) { + ET_LOG(Error, "ETMetalKernelFunction::dispatchSingleWithGroupSize: No active encoder"); + return; + } + + const auto maxThreadsPerGroup = static_cast([cps_ maxTotalThreadsPerThreadgroup]); + uint64_t actualGroupSize = group_size > 0 ? std::min(group_size, maxThreadsPerGroup) : std::min(maxThreadsPerGroup, length); + + auto size = MTLSizeMake(length, 1, 1); + auto threadGroupSize = MTLSizeMake(actualGroupSize, 1, 1); + + [encoder_ dispatchThreads:size threadsPerThreadgroup:threadGroupSize]; + ET_LOG(Debug, "ETMetalKernelFunction::dispatchSingleWithGroupSize: Dispatched with length %llu, group size %llu", length, actualGroupSize); + +} + +void ETMetalKernelFunction::dispatchArray(const uint64_t* length, size_t length_size) { + if (!encoder_) { + ET_LOG(Error, "ETMetalKernelFunction::dispatchArray: No active encoder"); + return; + } + + if (!length || length_size == 0) { + ET_LOG(Error, "ETMetalKernelFunction::dispatchArray: Invalid length array"); + return; + } + + const auto maxThreadsPerGroup = static_cast([cps_ maxTotalThreadsPerThreadgroup]); + + MTLSize size, threadGroupSize; + + if (length_size == 1) { + size = MTLSizeMake(length[0], 1, 1); + uint64_t actualGroupSize = std::min(maxThreadsPerGroup, length[0]); + threadGroupSize = MTLSizeMake(actualGroupSize, 1, 1); + } else if (length_size == 2) { + size = MTLSizeMake(length[0], length[1], 1); + uint64_t groupX = std::min(static_cast(32), length[0]); + uint64_t groupY = maxThreadsPerGroup / groupX; + threadGroupSize = MTLSizeMake(groupX, groupY, 1); + } else { + size = MTLSizeMake(length[0], length[1], length_size > 2 ? length[2] : 1); + uint64_t groupX = std::min(static_cast(8), length[0]); + uint64_t groupY = std::min(static_cast(8), length[1]); + uint64_t groupZ = maxThreadsPerGroup / (groupX * groupY); + threadGroupSize = MTLSizeMake(groupX, groupY, groupZ); + } + + [encoder_ dispatchThreads:size threadsPerThreadgroup:threadGroupSize]; + ET_LOG(Debug, "ETMetalKernelFunction::dispatchArray: Dispatched %zuD with size [%lu, %lu, %lu], group [%lu, %lu, %lu]", + length_size, size.width, size.height, size.depth, + threadGroupSize.width, threadGroupSize.height, threadGroupSize.depth); + +} + +void ETMetalKernelFunction::dispatchArrayWithGroupSize(const uint64_t* length, size_t length_size, + const uint64_t* group_size, size_t group_size_size) { + if (!encoder_) { + ET_LOG(Error, "ETMetalKernelFunction::dispatchArrayWithGroupSize: No active encoder"); + return; + } + + if (!length || length_size == 0) { + ET_LOG(Error, "ETMetalKernelFunction::dispatchArrayWithGroupSize: Invalid length array"); + return; + } + + const auto maxThreadsPerGroup = static_cast([cps_ maxTotalThreadsPerThreadgroup]); + + MTLSize size, threadGroupSize; + + if (length_size == 1) { + size = MTLSizeMake(length[0], 1, 1); + uint64_t actualGroupSize = maxThreadsPerGroup; + if (group_size && group_size_size > 0) { + actualGroupSize = std::min(maxThreadsPerGroup, group_size[0]); + } + threadGroupSize = MTLSizeMake(actualGroupSize, 1, 1); + } else if (length_size == 2) { + size = MTLSizeMake(length[0], length[1], 1); + uint64_t groupX = std::min(static_cast(32), length[0]); + uint64_t groupY = maxThreadsPerGroup / groupX; + if (group_size && group_size_size >= 2) { + groupX = std::min(static_cast(group_size[0]), length[0]); + groupY = std::min(static_cast(group_size[1]), length[1]); + } + threadGroupSize = MTLSizeMake(groupX, groupY, 1); + } else { + size = MTLSizeMake(length[0], length[1], length_size > 2 ? length[2] : 1); + uint64_t groupX = std::min(static_cast(8), length[0]); + uint64_t groupY = std::min(static_cast(8), length[1]); + uint64_t groupZ = maxThreadsPerGroup / (groupX * groupY); + if (group_size && group_size_size >= 3) { + groupX = std::min(static_cast(group_size[0]), length[0]); + groupY = std::min(static_cast(group_size[1]), length[1]); + groupZ = std::min(static_cast(group_size[2]), length_size > 2 ? length[2] : 1); + } + threadGroupSize = MTLSizeMake(groupX, groupY, groupZ); + } + + [encoder_ dispatchThreads:size threadsPerThreadgroup:threadGroupSize]; + ET_LOG(Debug, "ETMetalKernelFunction::dispatchArrayWithGroupSize: Dispatched %zuD with size [%lu, %lu, %lu], group [%lu, %lu, %lu]", + length_size, size.width, size.height, size.depth, + threadGroupSize.width, threadGroupSize.height, threadGroupSize.depth); + +} + +void ETMetalKernelFunction::runCommandBlock(std::function f) { + // Use dispatch_sync with the stream's serial queue for thread safety and synchronization + // This matches PyTorch's approach: dispatch_sync_with_rethrow(getCurrentMPSStream()->queue(), ...) + ETMetalStream* stream = getCurrentMetalStream(); + dispatch_sync_with_rethrow(stream->queue(), ^() { + @autoreleasepool { + f(); + } + }); + + ET_LOG(Debug, "ETMetalKernelFunction::runCommandBlock: Executed command block with dispatch_sync"); +} + +// ======================= +// ETMetalStream Implementation +// ======================= + +ETMetalStream::ETMetalStream() + : device_(nil), commandQueue_(nil), commandBuffer_(nil), prevCommandBuffer_(nil), + commandEncoder_(nil), serialQueue_(nullptr), enableCommitAndContinue_(true) { + @autoreleasepool { + // Create device and command queue + device_ = MTLCreateSystemDefaultDevice(); + if (!device_) { + ET_LOG(Error, "ETMetalStream: Failed to create Metal device"); + return; + } + [device_ retain]; + + commandQueue_ = [device_ newCommandQueue]; + if (!commandQueue_) { + ET_LOG(Error, "ETMetalStream: Failed to create Metal command queue"); + return; + } + [commandQueue_ retain]; + + // Create serial queue for thread safety + serialQueue_ = dispatch_queue_create("metal gpu stream", nullptr); + + ET_LOG(Debug, "ETMetalStream: Created stream with device %p, queue %p", device_, commandQueue_); + } +} + +ETMetalStream::~ETMetalStream() { + @autoreleasepool { + // Synchronize before cleanup + synchronize(SyncType::COMMIT_AND_WAIT); + + // Clean up command encoder + if (commandEncoder_) { + [commandEncoder_ release]; + commandEncoder_ = nil; + } + + // Clean up command buffers + if (commandBuffer_) { + [commandBuffer_ release]; + commandBuffer_ = nil; + } + if (prevCommandBuffer_) { + [prevCommandBuffer_ release]; + prevCommandBuffer_ = nil; + } + + // Clean up command queue and device + if (commandQueue_) { + [commandQueue_ release]; + commandQueue_ = nil; + } + if (device_) { + [device_ release]; + device_ = nil; + } + + // Clean up serial queue + if (serialQueue_) { + dispatch_release(serialQueue_); + serialQueue_ = nullptr; + } + + ET_LOG(Debug, "ETMetalStream: Destroyed stream"); + } +} + +ETMetalStream* ETMetalStream::getDefaultStream() { + if (!defaultStream_) { + defaultStream_ = new ETMetalStream(); + } + return defaultStream_; +} + +// Lazy command buffer creation (use MPSCommandBuffer like PyTorch) +MPSCommandBuffer* ETMetalStream::commandBuffer() { + if (!commandBuffer_) { + if (!commandQueue_) { + ET_LOG(Error, "ETMetalStream::commandBuffer: No command queue available"); + return nil; + } + + commandBuffer_ = [MPSCommandBuffer commandBufferFromCommandQueue:commandQueue_]; + if (!commandBuffer_) { + ET_LOG(Error, "ETMetalStream::commandBuffer: Failed to create command buffer"); + return nil; + } + [commandBuffer_ retain]; + + ET_LOG(Debug, "ETMetalStream::commandBuffer: Created lazy command buffer %p", commandBuffer_); + } + + return commandBuffer_; +} + +// Lazy command encoder creation +id ETMetalStream::commandEncoder() { + if (!commandEncoder_) { + MPSCommandBuffer* cmdBuffer = commandBuffer(); + if (!cmdBuffer) { + ET_LOG(Error, "ETMetalStream::commandEncoder: Failed to get command buffer"); + return nil; + } + + commandEncoder_ = [cmdBuffer computeCommandEncoder]; + if (!commandEncoder_) { + ET_LOG(Error, "ETMetalStream::commandEncoder: Failed to create command encoder"); + return nil; + } + [commandEncoder_ retain]; + + ET_LOG(Debug, "ETMetalStream::commandEncoder: Created lazy command encoder %p", commandEncoder_); + } + + return commandEncoder_; +} + +// Synchronization with SyncType - matches PyTorch's approach (no dispatch_sync here) +void ETMetalStream::synchronize(SyncType syncType) { + endKernelCoalescing(); + + switch (syncType) { + case SyncType::NONE: + // Do nothing - no commit + break; + case SyncType::COMMIT: + commit(); + break; + case SyncType::COMMIT_AND_WAIT: + commitAndWait(); + break; + case SyncType::COMMIT_AND_CONTINUE: + if (enableCommitAndContinue_) { + commitAndContinue(); + } else { + ET_LOG(Error, "ETMetalStream::synchronize: CommitAndContinue requested but disabled"); + commit(); + } + break; + case SyncType::COMMIT_ADAPTIVE: + // Simple adaptive policy - could be enhanced with memory pressure detection + // TODO: Could add memory pressure detection like PyTorch does + commit(); + break; + } + + ET_LOG(Debug, "ETMetalStream::synchronize: Completed with SyncType %d", static_cast(syncType)); +} + +// Encoder coalescing management +void ETMetalStream::endKernelCoalescing() { + if (commandEncoder_) { + [commandEncoder_ endEncoding]; + [commandEncoder_ release]; + commandEncoder_ = nil; + ET_LOG(Debug, "ETMetalStream::endKernelCoalescing: Ended encoder coalescing"); + } +} + +// Commit methods +void ETMetalStream::commit() { + if (enableCommitAndContinue_ && commandBuffer_) { + // Use commit-and-continue for better performance + commitAndContinue(); + } else { + flush(); + } +} + +void ETMetalStream::commitAndWait() { + // Handle previous command buffer first + if (prevCommandBuffer_) { + [prevCommandBuffer_ waitUntilCompleted]; + [prevCommandBuffer_ release]; + prevCommandBuffer_ = nil; + } + + // Handle current command buffer + if (commandBuffer_) { + [commandBuffer_ commit]; + [commandBuffer_ waitUntilCompleted]; + [commandBuffer_ release]; + commandBuffer_ = nil; + } + + ET_LOG(Debug, "ETMetalStream::commitAndWait: Committed and waited for completion"); +} + +void ETMetalStream::commitAndContinue() { + if (!commandBuffer_) { + ET_LOG(Error, "ETMetalStream::commitAndContinue: No command buffer to commit"); + return; + } + + // Commit buffer and allow immediate reuse for better performance + [commandBuffer_ commit]; + ET_LOG(Debug, "ETMetalStream::commitAndContinue: Committed buffer %p with continue", commandBuffer_); + + // The buffer handles synchronization internally for commit-and-continue +} + +void ETMetalStream::flush() { + if (commandBuffer_) { + [commandBuffer_ commit]; + + if (!enableCommitAndContinue_) { + // Keep the command buffer for later waiting if commit-and-continue is disabled + prevCommandBuffer_ = commandBuffer_; + } else { + [commandBuffer_ release]; + } + commandBuffer_ = nil; + + ET_LOG(Debug, "ETMetalStream::flush: Flushed command buffer"); + } +} + +// Memory operations +void ETMetalStream::fill(id buffer, uint8_t value, size_t length, size_t offset, SyncType syncType) { + if (length == 0) { + return; + } + + dispatch_sync(serialQueue_, ^{ + @autoreleasepool { + endKernelCoalescing(); + id blitEncoder = [commandBuffer() blitCommandEncoder]; + + [blitEncoder fillBuffer:buffer range:NSMakeRange(offset, length) value:value]; + [blitEncoder endEncoding]; + synchronize(syncType); + + ET_LOG(Debug, "ETMetalStream::fill: Filled buffer with value %u, length %zu, offset %zu", value, length, offset); + } + }); +} + +void ETMetalStream::copy(id srcBuffer, id dstBuffer, size_t length, + size_t srcOffset, size_t dstOffset, SyncType syncType) { + dispatch_sync(serialQueue_, ^{ + @autoreleasepool { + endKernelCoalescing(); + id blitEncoder = [commandBuffer() blitCommandEncoder]; + + // Handle large copies in chunks + constexpr size_t max_copy_size = 0x80000000; // 2GB + size_t bytes_copied = 0; + size_t bytes_remaining = length; + + while (bytes_remaining > 0) { + NSUInteger bytes_to_copy = std::min(max_copy_size, bytes_remaining); + [blitEncoder copyFromBuffer:srcBuffer + sourceOffset:(NSUInteger)srcOffset + bytes_copied + toBuffer:dstBuffer + destinationOffset:(NSUInteger)dstOffset + bytes_copied + size:bytes_to_copy]; + bytes_copied += bytes_to_copy; + bytes_remaining -= bytes_to_copy; + } + + [blitEncoder endEncoding]; + synchronize(syncType); + + ET_LOG(Debug, "ETMetalStream::copy: Copied %zu bytes from offset %zu to offset %zu", length, srcOffset, dstOffset); + } + }); +} + + +void ETMetalStream::synchronize() { + synchronize(SyncType::COMMIT_AND_WAIT); +} + +bool ETMetalStream::isEmpty() const { + return !commandBuffer_ && !commandEncoder_; +} + +void ETMetalStream::executeMPSGraph(MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results, SyncType syncType) { + // Use dispatch_sync_with_rethrow exactly like PyTorch does for MPSGraph execution + dispatch_sync_with_rethrow(serialQueue_, ^() { + @autoreleasepool { + endKernelCoalescing(); + + [mpsGraph encodeToCommandBuffer:commandBuffer() + feeds:feeds + targetOperations:nil + resultsDictionary:results + executionDescriptor:nil]; + + //synchronize(syncType); + } + }); +} + +// ======================= +// Global Storage Management Functions +// ======================= + +void storeFunctionHandle(ETMetalKernelFunction* raw_function, std::shared_ptr function_shared_ptr) { + function_storage[raw_function] = function_shared_ptr; +} + +void storeLibraryHandle(ETMetalShaderLibrary* raw_library, std::unique_ptr library) { + library_storage[raw_library] = std::move(library); +} + +bool removeFunctionHandle(ETMetalKernelFunction* raw_function) { + auto it = function_storage.find(raw_function); + if (it != function_storage.end()) { + function_storage.erase(it); + return true; + } + return false; +} + +bool removeLibraryHandle(ETMetalShaderLibrary* raw_library) { + auto it = library_storage.find(raw_library); + if (it != library_storage.end()) { + library_storage.erase(it); + return true; + } + return false; +} + +// ======================= +// Global Stream Access Functions +// ======================= + +ETMetalStream* getCurrentMetalStream() { + if (!currentStream_) { + currentStream_ = ETMetalStream::getDefaultStream(); + } + return currentStream_; +} + +void setCurrentMetalStream(ETMetalStream* stream) { + currentStream_ = stream; +} + +// ======================= +// Metal Stream Synchronization Functions +// ======================= + +void synchronize_metal_stream() { + @autoreleasepool { + // Use the ETMetalStream for proper synchronization + ETMetalStream* stream = getCurrentMetalStream(); + stream->synchronize(SyncType::COMMIT_AND_WAIT); + + ET_LOG(Debug, "synchronize_metal_stream: Stream synchronized with COMMIT_AND_WAIT"); + } +} + +void synchronize_metal_stream_with_type(int sync_type) { + @autoreleasepool { + ETMetalStream* stream = getCurrentMetalStream(); + SyncType syncTypeEnum = static_cast(sync_type); + stream->synchronize(syncTypeEnum); + + ET_LOG(Debug, "synchronize_metal_stream_with_type: Stream synchronized with SyncType %d", sync_type); + } +} + +} // namespace metal +} // namespace backends +} // namespace executorch From ca5f1e52300560ba3dab33ed247afdb0ea36a30a Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Sat, 11 Oct 2025 15:47:38 -0400 Subject: [PATCH 06/10] Update [ghstack-poisoned] --- backends/apple/metal/runtime/shims/utils.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/backends/apple/metal/runtime/shims/utils.cpp b/backends/apple/metal/runtime/shims/utils.cpp index 484158e9027..bc8c0483e9d 100644 --- a/backends/apple/metal/runtime/shims/utils.cpp +++ b/backends/apple/metal/runtime/shims/utils.cpp @@ -65,13 +65,12 @@ std::vector convert_strides_to_vector( std::vector strides(ndim); if (strides_ptr != nullptr) { - // Use provided strides. it is ok if provided strides here is not contiguous - // strides since it will be used internally in CUDA delegate. + // Use provided strides. for (int64_t i = 0; i < ndim; i++) { strides[i] = static_cast(strides_ptr[i]); } } else { - // Calculate strides from sizes using ExecutorTorch's algorithm + // Calculate strides from sizes. if (ndim > 0) { strides[ndim - 1] = static_cast( 1); // Last dimension has stride 1 From f46adc5149a168fec69d18bfebe7fcd39be37ce5 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Tue, 14 Oct 2025 21:36:07 -0400 Subject: [PATCH 07/10] Update [ghstack-poisoned] --- backends/cuda/runtime/shims/memory.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/backends/cuda/runtime/shims/memory.cpp b/backends/cuda/runtime/shims/memory.cpp index 6fe315ba8ee..fe8ccf07281 100644 --- a/backends/cuda/runtime/shims/memory.cpp +++ b/backends/cuda/runtime/shims/memory.cpp @@ -27,6 +27,8 @@ using executorch::backends::aoti::aoti_torch_get_device_index; using executorch::backends::aoti::aoti_torch_get_dtype; using executorch::backends::aoti::aoti_torch_get_sizes; using executorch::backends::aoti::aoti_torch_get_strides; +using executorch::backends::aoti::convert_sizes_to_vector; +using executorch::backends::aoti::convert_strides_to_vector; using executorch::backends::aoti::dtype_to_element_size; using executorch::backends::aoti::dtype_to_scalar_type; using executorch::backends::aoti::validate_storage_offset; From 71f87b691d65e222957b714e40e7d2e657075cb8 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Wed, 15 Oct 2025 15:32:04 -0400 Subject: [PATCH 08/10] Update [ghstack-poisoned] --- backends/apple/metal/runtime/shims/et_metal.h | 107 ++++++++++++++++++ .../apple/metal/runtime/shims/et_metal.mm | 22 +++- 2 files changed, 127 insertions(+), 2 deletions(-) diff --git a/backends/apple/metal/runtime/shims/et_metal.h b/backends/apple/metal/runtime/shims/et_metal.h index c18ad513a3a..a1c8c684131 100644 --- a/backends/apple/metal/runtime/shims/et_metal.h +++ b/backends/apple/metal/runtime/shims/et_metal.h @@ -77,6 +77,35 @@ enum class SyncType { // ======================= // ETMetalShaderLibrary - ExecuTorch Metal shader library management // ======================= + +/** + * @class ETMetalShaderLibrary + * @brief Manages Metal shader library compilation and kernel function retrieval. + * + * This class provides a high-level interface for compiling Metal shading language + * source code into a Metal library and creating compute pipeline states for + * kernel functions. It handles the creation and caching of Metal compute pipeline + * states and functions, which should be reused across multiple kernel dispatches. + * + * The class automatically compiles the provided shader source code upon construction + * and maintains an internal cache of compute pipeline states for different kernel + * functions to avoid redundant compilation. + * + * Example usage: + * @code + * std::string shaderSource = R"( + * #include + * using namespace metal; + * kernel void my_kernel(device float* data [[buffer(0)]], + * uint tid [[thread_position_in_grid]]) { + * data[tid] = data[tid] * 2.0; + * } + * )"; + * + * ETMetalShaderLibrary library(shaderSource); + * auto kernelFunction = library.getKernelFunction("my_kernel"); + * @endcode + */ class ETMetalShaderLibrary { public: ETMetalShaderLibrary(const std::string& source); @@ -103,6 +132,45 @@ class ETMetalShaderLibrary { // ======================= // ETMetalKernelFunction - ExecuTorch Metal kernel function execution // ======================= + +/** + * @class ETMetalKernelFunction + * @brief Represents a Metal compute kernel function ready for execution. + * + * This class encapsulates a Metal compute pipeline state and function, providing + * a high-level interface for setting kernel arguments and dispatching compute + * work to the GPU. It handles the encoding of compute commands and manages the + * interaction with Metal's compute command encoder. + * + * The class supports different dispatch patterns: + * - Single-dimension dispatch for linear workloads + * - Multi-dimensional dispatch for grid-based workloads + * - Custom thread group sizes for performance optimization + * + * Kernel arguments can be set using tensors (which will be mapped to Metal buffers) + * or scalar values. The class handles the encoding of these arguments + * into the compute command encoder. + * + * Example usage: + * @code + * // Get kernel function from library + * auto kernelFunction = library.getKernelFunction("vector_add"); + * + * // Start encoding commands + * kernelFunction->startEncoding(); + * + * // Set tensor arguments + * kernelFunction->setArg(0, inputTensorA); + * kernelFunction->setArg(1, inputTensorB); + * kernelFunction->setArg(2, outputTensor); + * + * // Set scalar argument + * kernelFunction->setArg(3, static_cast(numElements)); + * + * // Dispatch for linear workload + * kernelFunction->dispatchSingle(numElements); + * @endcode + */ class ETMetalKernelFunction { public: ETMetalKernelFunction(MTLComputePipelineState_t cps, MTLFunction_t func); @@ -132,6 +200,45 @@ class ETMetalKernelFunction { // ======================= // ETMetalStream - Metal command buffer and synchronization management // ======================= + +/** + * @class ETMetalStream + * @brief Manages Metal compute command streams and provides GPU synchronization. + * + * This class serves as the central management hub for Metal GPU operations, providing + * a stream-based abstraction similar to CUDA streams. It handles command buffer lifecycle, + * compute command encoder management, and various synchronization patterns required for + * efficient GPU computation. + * + * Key features: + * - Lazy command buffer and encoder creation for optimal resource usage + * - Thread-safe operations using serial dispatch queues + * - Multiple synchronization modes (COMMIT, COMMIT_AND_WAIT, COMMIT_AND_CONTINUE) + * - Kernel coalescing to batch multiple operations efficiently + * - MPSGraph integration for high-level neural network operations + * - Memory operations (copy, fill) with GPU acceleration via blit encoders + * + * The stream follows PyTorch's MPS stream design patterns, providing similar semantics + * for command buffer management and synchronization. + * + * Example usage: + * @code + * // Get current stream (typically the default stream) + * ETMetalStream* stream = getCurrentMetalStream(); + * + * // Execute kernel operations (handled automatically) + * auto kernelFunction = library.getKernelFunction("my_kernel"); + * kernelFunction->startEncoding(); + * kernelFunction->setArg(0, inputTensor); + * kernelFunction->dispatchSingle(numElements); + * + * // Synchronize to ensure completion + * stream->synchronize(SyncType::COMMIT_AND_WAIT); + * + * // Copy between GPU buffers using blit encoder + * stream->copy(srcBuffer, dstBuffer, numBytes, 0, 0, SyncType::COMMIT); + * @endcode + */ class ETMetalStream { public: ETMetalStream(); diff --git a/backends/apple/metal/runtime/shims/et_metal.mm b/backends/apple/metal/runtime/shims/et_metal.mm index 5afcf761d56..f76146ab783 100644 --- a/backends/apple/metal/runtime/shims/et_metal.mm +++ b/backends/apple/metal/runtime/shims/et_metal.mm @@ -743,6 +743,26 @@ int metal_copy_memory(void* dst, const void* src, size_t nbytes, bool src_is_dev void ETMetalStream::copy(id srcBuffer, id dstBuffer, size_t length, size_t srcOffset, size_t dstOffset, SyncType syncType) { + + if (length == 0) { + return; + + // Check that offsets are within buffer bounds before copying + if (!srcBuffer || !dstBuffer) { + ET_LOG(Error, "ETMetalStream::copy: Source or destination buffer is nil"); + return; + } + NSUInteger srcBufferLength = [srcBuffer length]; + NSUInteger dstBufferLength = [dstBuffer length]; + if (srcOffset + length > srcBufferLength) { + ET_LOG(Error, "ETMetalStream::copy: Source offset (%zu) + length (%zu) exceeds source buffer size (%zu)", srcOffset, length, srcBufferLength); + return; + } + if (dstOffset + length > dstBufferLength) { + ET_LOG(Error, "ETMetalStream::copy: Destination offset (%zu) + length (%zu) exceeds destination buffer size (%zu)", dstOffset, length, dstBufferLength); + return; + } + dispatch_sync(serialQueue_, ^{ @autoreleasepool { endKernelCoalescing(); @@ -792,8 +812,6 @@ int metal_copy_memory(void* dst, const void* src, size_t nbytes, bool src_is_dev targetOperations:nil resultsDictionary:results executionDescriptor:nil]; - - //synchronize(syncType); } }); } From 95a70247fded8a432c69a08d64e9d891a2c8a2f4 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Wed, 15 Oct 2025 15:40:46 -0400 Subject: [PATCH 09/10] Update [ghstack-poisoned] --- backends/apple/metal/runtime/shims/et_metal.h | 50 ++++++++++--------- 1 file changed, 27 insertions(+), 23 deletions(-) diff --git a/backends/apple/metal/runtime/shims/et_metal.h b/backends/apple/metal/runtime/shims/et_metal.h index a1c8c684131..75f79e5139c 100644 --- a/backends/apple/metal/runtime/shims/et_metal.h +++ b/backends/apple/metal/runtime/shims/et_metal.h @@ -80,16 +80,18 @@ enum class SyncType { /** * @class ETMetalShaderLibrary - * @brief Manages Metal shader library compilation and kernel function retrieval. + * @brief Manages Metal shader library compilation and kernel function + * retrieval. * - * This class provides a high-level interface for compiling Metal shading language - * source code into a Metal library and creating compute pipeline states for - * kernel functions. It handles the creation and caching of Metal compute pipeline - * states and functions, which should be reused across multiple kernel dispatches. + * This class provides a high-level interface for compiling Metal shading + * language source code into a Metal library and creating compute pipeline + * states for kernel functions. It handles the creation and caching of Metal + * compute pipeline states and functions, which should be reused across multiple + * kernel dispatches. * - * The class automatically compiles the provided shader source code upon construction - * and maintains an internal cache of compute pipeline states for different kernel - * functions to avoid redundant compilation. + * The class automatically compiles the provided shader source code upon + * construction and maintains an internal cache of compute pipeline states for + * different kernel functions to avoid redundant compilation. * * Example usage: * @code @@ -137,18 +139,18 @@ class ETMetalShaderLibrary { * @class ETMetalKernelFunction * @brief Represents a Metal compute kernel function ready for execution. * - * This class encapsulates a Metal compute pipeline state and function, providing - * a high-level interface for setting kernel arguments and dispatching compute - * work to the GPU. It handles the encoding of compute commands and manages the - * interaction with Metal's compute command encoder. + * This class encapsulates a Metal compute pipeline state and function, + * providing a high-level interface for setting kernel arguments and dispatching + * compute work to the GPU. It handles the encoding of compute commands and + * manages the interaction with Metal's compute command encoder. * * The class supports different dispatch patterns: * - Single-dimension dispatch for linear workloads * - Multi-dimensional dispatch for grid-based workloads * - Custom thread group sizes for performance optimization * - * Kernel arguments can be set using tensors (which will be mapped to Metal buffers) - * or scalar values. The class handles the encoding of these arguments + * Kernel arguments can be set using tensors (which will be mapped to Metal + * buffers) or scalar values. The class handles the encoding of these arguments * into the compute command encoder. * * Example usage: @@ -203,23 +205,25 @@ class ETMetalKernelFunction { /** * @class ETMetalStream - * @brief Manages Metal compute command streams and provides GPU synchronization. + * @brief Manages Metal compute command streams and provides GPU + * synchronization. * - * This class serves as the central management hub for Metal GPU operations, providing - * a stream-based abstraction similar to CUDA streams. It handles command buffer lifecycle, - * compute command encoder management, and various synchronization patterns required for - * efficient GPU computation. + * This class serves as the central management hub for Metal GPU operations, + * providing a stream-based abstraction similar to CUDA streams. It handles + * command buffer lifecycle, compute command encoder management, and various + * synchronization patterns required for efficient GPU computation. * * Key features: * - Lazy command buffer and encoder creation for optimal resource usage * - Thread-safe operations using serial dispatch queues - * - Multiple synchronization modes (COMMIT, COMMIT_AND_WAIT, COMMIT_AND_CONTINUE) + * - Multiple synchronization modes (COMMIT, COMMIT_AND_WAIT, + * COMMIT_AND_CONTINUE, etc.) * - Kernel coalescing to batch multiple operations efficiently - * - MPSGraph integration for high-level neural network operations + * - MPSGraph integration for executing fall back operations (mm, conv, sdpa) * - Memory operations (copy, fill) with GPU acceleration via blit encoders * - * The stream follows PyTorch's MPS stream design patterns, providing similar semantics - * for command buffer management and synchronization. + * The stream follows PyTorch's MPS stream design patterns, providing similar + * semantics for command buffer management and synchronization. * * Example usage: * @code From d37e7efb19229e57e5dd314c5c06345bb68f4c59 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Wed, 15 Oct 2025 15:51:31 -0400 Subject: [PATCH 10/10] Update [ghstack-poisoned] --- backends/apple/metal/runtime/shims/et_metal.mm | 1 + 1 file changed, 1 insertion(+) diff --git a/backends/apple/metal/runtime/shims/et_metal.mm b/backends/apple/metal/runtime/shims/et_metal.mm index f76146ab783..fdca0a28cf3 100644 --- a/backends/apple/metal/runtime/shims/et_metal.mm +++ b/backends/apple/metal/runtime/shims/et_metal.mm @@ -746,6 +746,7 @@ int metal_copy_memory(void* dst, const void* src, size_t nbytes, bool src_is_dev if (length == 0) { return; + } // Check that offsets are within buffer bounds before copying if (!srcBuffer || !dstBuffer) {