Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions backends/aoti/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,64 @@ inline bool is_tensor_contiguous(

} // extern "C"

// Utility function to convert sizes pointer to vector
inline std::vector<executorch::aten::SizesType> convert_sizes_to_vector(
int64_t ndim,
const int64_t* sizes_ptr) {
std::vector<executorch::aten::SizesType> sizes(ndim);
for (int i = 0; i < ndim; i++) {
sizes[i] = static_cast<executorch::aten::SizesType>(sizes_ptr[i]);
}
return sizes;
}

// Utility function to convert strides pointer to vector or calculate from sizes
inline std::vector<executorch::aten::StridesType> convert_strides_to_vector(
int64_t ndim,
const int64_t* sizes_ptr,
const int64_t* strides_ptr) {
std::vector<executorch::aten::StridesType> strides(ndim);

if (strides_ptr != nullptr) {
// Use provided strides.
for (int64_t i = 0; i < ndim; i++) {
strides[i] = static_cast<executorch::aten::StridesType>(strides_ptr[i]);
}
} else {
// Calculate strides from sizes.
if (ndim > 0) {
strides[ndim - 1] = static_cast<executorch::aten::StridesType>(
1); // Last dimension has stride 1
for (int64_t i = ndim - 2; i >= 0; i--) {
if (sizes_ptr[i + 1] == 0) {
strides[i] = strides[i + 1]; // Copy stride when size is 0
} else {
strides[i] = static_cast<executorch::aten::StridesType>(
static_cast<int64_t>(strides[i + 1]) * sizes_ptr[i + 1]);
}
}
}
}
return strides;
}

// Check if tensor is in contiguous memory format (NCHW for 4D tensors)
// Contiguous format means strides decrease from left to right:
// For NCHW: strides = [C*H*W, H*W, W, 1]
inline bool is_contiguous_tensor(
std::vector<executorch::aten::SizesType>& sizes,
std::vector<executorch::aten::StridesType>& strides) {
int64_t ndim = static_cast<int64_t>(strides.size());
int64_t expected_stride = 1;
for (int64_t i = ndim - 1; i >= 0; i--) {
if (strides[i] != expected_stride) {
return false;
}
expected_stride *= sizes[i];
}
return true;
}

} // namespace aoti
} // namespace backends
} // namespace executorch
5 changes: 5 additions & 0 deletions backends/apple/metal/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Metal Backend

⚠️ **EXPERIMENTAL BACKEND**

This backend is currently in experimental development and may not be fully functional or stable. Use with caution.
172 changes: 172 additions & 0 deletions backends/apple/metal/metal_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import contextlib
import os
import typing
from enum import Enum

from typing import Any, Dict, final, List, Optional, Set

import torch
from executorch.backends.apple.metal.replace_slice_copy_with_slice import (
ReplaceSliceCopyWithSlicePass,
)
from executorch.exir._serialize._named_data_store import NamedDataStore
from executorch.exir._warnings import experimental
from executorch.exir.backend.backend_details import (
BackendDetails,
ExportedProgram,
PreprocessResult,
)
from executorch.exir.backend.compile_spec_schema import CompileSpec
from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu
from torch.export.passes import move_to_device_pass


# exist fallback operators in et namespace;
supported_fallback_kernels: Dict[str, Any] = {
"aoti_torch_mps_addmm_out": None,
"aoti_torch_mps_convolution": None,
"aoti_torch_mps_mm_out": None,
"at::_ops::_scaled_dot_product_attention_math_for_mps::call": None,
}

# required fallback kernels but not supported
missing_fallback_kernels: Set[str] = set()


class COMPILE_SPEC_KEYS(Enum):
METHOD_NAME = "method_name"


# context manager for non-fallback guarantee
# it will raise exception when generating fallback kernels during aoti compile
@contextlib.contextmanager
def collect_unsupported_fallback_kernels():
original_generate_c_shim_extern_kernel_call = (
CppWrapperCpu.generate_c_shim_extern_kernel_call
)

def generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels(
self,
kernel: str,
args: list[str],
device: str,
*,
debug_args: Optional[list[str]] = None,
debug_handle: Optional[int] = None,
):
if kernel not in supported_fallback_kernels:
missing_fallback_kernels.add(kernel)

original_generate_c_shim_extern_kernel_call(
self, kernel, args, device, debug_args=debug_args, debug_handle=debug_handle
)

CppWrapperCpu.generate_c_shim_extern_kernel_call = (
generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels
)
try:
yield
finally:
CppWrapperCpu.generate_c_shim_extern_kernel_call = (
original_generate_c_shim_extern_kernel_call
)


@final
@experimental(
"This API and all of Metal backend related functionality are experimental."
)
class MetalBackend(BackendDetails):
@staticmethod
def preprocess(
edge_program: ExportedProgram,
compile_specs: List[CompileSpec],
) -> PreprocessResult:
print("entering the lowerable parts in MetalBackend.preprocess....")
# Move the edge_program from CPU to MPS for aoti compile
mps_edge_program = move_to_device_pass(edge_program, "mps")

# replace slice_copy with slice
ReplaceSliceCopyWithSlicePass()(mps_edge_program.graph_module)

edge_program_module = mps_edge_program.module()

# Grab all input placeholders from the graph
user_input_names = mps_edge_program.graph_signature.user_inputs
user_input_placeholders = []
for node in mps_edge_program.graph.nodes:
if node.op == "placeholder" and node.name in user_input_names:
user_input_placeholders.append(node.meta["val"])

# Base options for all devices
options: dict[str, typing.Any] = {
# Do not link against the full PyTorch/libtorch library
"aot_inductor.link_libtorch": False,
# Package model constants and other generated files directly in the shared object (.so) file
"aot_inductor.package_constants_in_so": True,
# Enable maximum automatic tuning for optimal performance
"max_autotune": True,
# "aot_inductor.debug_compile": True,
# "aot_inductor.force_mmap_weights": False,
}

with collect_unsupported_fallback_kernels():
so_path = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type]
if len(missing_fallback_kernels) > 0:
formatted_kernels = "\n - ".join(sorted(missing_fallback_kernels))
raise RuntimeError(
f"Missing fallback kernels ({len(missing_fallback_kernels)} total):\n - {formatted_kernels}\n"
"Please add them to the AOTI backend."
)

# pyre-ignorep[6]: Incompatible parameter type
with open(so_path, "rb") as f:
so_data = f.read()

named_data_store = NamedDataStore()
method_name = MetalBackend.method_name_from_compile_specs(compile_specs)
named_data_store.add_named_data(
method_name + "_so_blob", so_data, 1, "aoti_metal_blob"
)

# Clean up the generated so file; it has been packaged into the NamdeDataStore
# pyre-ignorep[6]: Incompatible parameter type
os.remove(so_path)

return PreprocessResult(
processed_bytes=b"",
debug_handle_map={},
data_store_output=named_data_store.get_named_data_store_output(),
)

@staticmethod
def generate_method_name_compile_spec(
method_name: str,
) -> CompileSpec:
"""
Generates a CompileSpec for the given method name.
"""
return CompileSpec(
COMPILE_SPEC_KEYS.METHOD_NAME.value,
method_name.encode("utf-8"),
)

@staticmethod
def method_name_from_compile_specs(
compile_specs: List[CompileSpec],
) -> str:
"""
Returns the method name from the compile specs.
"""
for spec in compile_specs:
if spec.key == COMPILE_SPEC_KEYS.METHOD_NAME.value:
return spec.value.decode("utf-8")
raise RuntimeError(
f"Could not find method name in compile specs: {compile_specs}"
)
77 changes: 77 additions & 0 deletions backends/apple/metal/metal_partitioner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Callable, Dict, final, List, Optional, Tuple

import torch
from executorch.backends.apple.metal.metal_backend import MetalBackend # usort: skip
from executorch.exir._warnings import experimental
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.backend.partitioner import (
DelegationSpec,
Partitioner,
PartitionResult,
)
from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer
from torch.export.exported_program import ExportedProgram


@final
@experimental(
"This API and all of Metal backend related functionality are experimental."
)
class MetalPartitioner(Partitioner):
"""
Metal partitioner for AOTInductor backend integration.

This partitioner creates a single partition containing all operators from the input graph.
It skips core ATen decomposition, allowing the Metal backend to handle decomposition using
AOTInductor's MPS-specific decomposition table.

Only operators that cannot be handled by the aoti-mps library will be excluded from
the partition and fall back to ExecuTorch's default or custom handling.
"""

def __init__(self, compile_spec: List[CompileSpec]) -> None:
self.delegation_spec = DelegationSpec(MetalBackend.__name__, compile_spec)

def partition(self, exported_program: ExportedProgram) -> PartitionResult:
"""
Fully delegate the graph to AOTInductor by tagging all nodes as a single partition.
"""

partition_tags: Dict[str, DelegationSpec] = {}
tag = "tag0"

for node in exported_program.graph.nodes:
if node.op != "call_function":
continue
node.meta["delegation_tag"] = tag

partition_tags[tag] = self.delegation_spec

tag_constant_data(exported_program)
tag_mutated_buffer(exported_program)

return PartitionResult(
tagged_exported_program=exported_program, partition_tags=partition_tags
)

def ops_to_not_decompose(
self, ep: ExportedProgram
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
"""
Return a list of operations that should not be decomposed and let the AOT compiler handle them.
Currently we skip ATen decompositon for all ops, and let the Metal backend handle them.
"""
do_not_decompose = set()

for node in ep.graph.nodes:
if node.op == "call_function" and isinstance(
node.target, torch._ops.OpOverload
):
do_not_decompose.add(node.target)
return list(do_not_decompose), None
Loading
Loading