Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 155 additions & 38 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,44 +8,69 @@

import operator

from typing import Callable, Dict, List, Optional, Union
from typing import Callable, Dict, Optional, Set, Union

import executorch.backends.vulkan.custom_ops_lib # noqa

import torch

from executorch.backends.vulkan.serialization.vulkan_graph_schema import VkMemoryLayout
from executorch.backends.vulkan.serialization.vulkan_graph_schema import (
VkMemoryLayout,
VkStorageType,
)

from executorch.backends.vulkan.utils import (
all_memory_layouts,
all_packed_dims,
PackedDim,
)
from executorch.exir.dialects._ops import ops as exir_ops

from executorch.exir.dialects.edge._ops import EdgeOpOverload
from torch._subclasses.fake_tensor import FakeTensor

######################
## OpFeatures class ##
######################


def allow_node(node: torch.fx.Node) -> bool:
return True


class TextureImplFeatures:
__slots__ = [
# Indicates if the compute shader is agnostic to the packed dimension
"uses_packed_dim",
# Indicates if the compute shader is agnostic to the texture axis mapping
"valid_packed_dims",
"uses_axis_map",
# Specifies a specific set of memory layouts that the shader supports. If it is
# and empty list, then the supported memory layouts can be inferred from the
# `uses_packed_dim` and `uses_axis_map` flags.
"supported_layouts",
]

def __init__(
self,
uses_packed_dim: bool = False,
uses_axis_map: bool = False,
supported_layouts: Optional[List[VkMemoryLayout]] = None,
valid_packed_dims: Optional[Set[PackedDim]] = None,
):
self.uses_packed_dim: bool = uses_packed_dim
self.uses_axis_map: bool = uses_axis_map
self.supported_layouts: Optional[List[VkMemoryLayout]] = supported_layouts
self.valid_packed_dims = set()
if valid_packed_dims is not None:
self.valid_packed_dims = valid_packed_dims

def valid_memory_layouts(self) -> Set[VkMemoryLayout]:
"""
Derive the set of memory layouts supported by the texture implementation based
on the valid packed dimensions.
"""
layouts = set()

if PackedDim.WIDTH in self.valid_packed_dims:
layouts.add(VkMemoryLayout.TENSOR_WIDTH_PACKED)

if PackedDim.HEIGHT in self.valid_packed_dims:
layouts.add(VkMemoryLayout.TENSOR_HEIGHT_PACKED)

if PackedDim.CHANNELS in self.valid_packed_dims:
layouts.add(VkMemoryLayout.TENSOR_CHANNELS_PACKED)

return layouts


class OpFeatures:
Expand All @@ -58,6 +83,9 @@ class OpFeatures:
# bool indicating if the operator has a resize function, which allows it to
# support dynamic shape tensors.
"resize_fn",
# Optimal
"optimal_storage",
"optimal_layout",
# bool indicating if the operator handles its own prepacking. If this is True,
# then the insert_prepack_nodes pass will not insert prepack nodes for the args
# of the op.
Expand All @@ -72,17 +100,90 @@ def __init__(
texture_impl: Optional[TextureImplFeatures] = None,
buffer_impl: bool = False,
resize_fn: bool = False,
optimal_storage: Optional[VkStorageType] = None,
optimal_layout: Optional[VkMemoryLayout] = None,
handles_own_prepacking: bool = False,
check_node_fn: Optional[Callable] = None,
):
self.texture_impl: Optional[TextureImplFeatures] = texture_impl
self.buffer_impl: bool = buffer_impl
self.resize_fn: bool = resize_fn
self.optimal_storage: Optional[VkStorageType] = optimal_storage
self.optimal_layout: Optional[VkMemoryLayout] = optimal_layout
self.handles_own_prepacking: bool = handles_own_prepacking
self.check_node_fn: Callable = allow_node
if check_node_fn is not None:
self.check_node_fn = check_node_fn

def propose_storage_type(self) -> Optional[VkStorageType]:
"""
Propose a storage type that should be used for this operator. A proposal can be
made if one of the following is true:
1. The operator specifies an optimal storage type
2. Only one storage type is supported.

If both storage types are supported and no optimal storage type is specified,
then None is returned to indicate that there is no preference in storage type.
"""
if self.optimal_storage is not None:
return self.optimal_storage

if self.texture_impl is not None and not self.buffer_impl:
return VkStorageType.TEXTURE_3D
elif self.buffer_impl and self.texture_impl is None:
return VkStorageType.BUFFER

return None

def supported_storage_types(self) -> Set[VkStorageType]:
"""
Return the set of storage types supported by this operator.
"""
storage_types = set()
if self.texture_impl is not None:
storage_types.add(VkStorageType.TEXTURE_3D)
if self.buffer_impl:
storage_types.add(VkStorageType.BUFFER)

return storage_types

def propose_memory_layout(self, storage: VkStorageType) -> Optional[VkMemoryLayout]:
"""
Given a storage type as a precondition, propose a memory layout that should be
used for this operator. A proposal can be made if one of the following is true:
1. The operator specifies an optimal memory layout
2. Only one memory layout is supported.

If multiple memory layouts are supported and no optimal memory layout is
specified then return None to indicate that the "best" memory layout for the
operator is ambiguous.
"""
if self.optimal_layout is not None:
return self.optimal_layout

if storage == VkStorageType.TEXTURE_3D:
assert self.texture_impl is not None
possible_layouts = self.texture_impl.valid_memory_layouts()
if len(possible_layouts) == 1:
return next(iter(possible_layouts))

return None

def supported_memory_layouts(self, storage: VkStorageType) -> Set[VkMemoryLayout]:
"""
Return the set of memory layouts supported by this operator for a given storage
type.
"""
if storage == VkStorageType.TEXTURE_3D:
assert self.texture_impl is not None
return self.texture_impl.valid_memory_layouts()
else:
return all_memory_layouts


#######################
## Operator Registry ##
#######################

OpKey = Union[str, torch._ops.OpOverload, EdgeOpOverload]

Expand Down Expand Up @@ -122,8 +223,8 @@ def update_features_impl(op: OpKey):
)
def register_ephemeral_op(features: OpFeatures):
features.texture_impl = TextureImplFeatures(
uses_packed_dim=True,
uses_axis_map=True,
valid_packed_dims=all_packed_dims,
)
features.buffer_impl = True
features.resize_fn = True
Expand All @@ -143,8 +244,8 @@ def register_ephemeral_op(features: OpFeatures):
)
def register_binary_op(features: OpFeatures):
features.texture_impl = TextureImplFeatures(
uses_packed_dim=True,
uses_axis_map=True,
valid_packed_dims=all_packed_dims,
)
features.resize_fn = True
return features
Expand All @@ -170,8 +271,8 @@ def register_binary_op(features: OpFeatures):
)
def register_unary_op(features: OpFeatures):
features.texture_impl = TextureImplFeatures(
uses_packed_dim=True,
uses_axis_map=True,
valid_packed_dims=all_packed_dims,
)
features.buffer_impl = True
features.resize_fn = True
Expand All @@ -181,8 +282,8 @@ def register_unary_op(features: OpFeatures):
@update_features(exir_ops.edge.aten._to_copy.default)
def register_to_copy_op(features: OpFeatures):
features.texture_impl = TextureImplFeatures(
uses_packed_dim=True,
uses_axis_map=True,
valid_packed_dims=all_packed_dims,
)
features.resize_fn = True

Expand Down Expand Up @@ -220,40 +321,43 @@ def check_to_copy_node(node: torch.fx.Node) -> bool:
)
def register_mm_op(features: OpFeatures):
features.texture_impl = TextureImplFeatures(
uses_packed_dim=False,
uses_axis_map=True,
supported_layouts=[
VkMemoryLayout.TENSOR_WIDTH_PACKED,
VkMemoryLayout.TENSOR_CHANNELS_PACKED,
],
valid_packed_dims={
PackedDim.WIDTH,
PackedDim.CHANNELS,
},
)
features.buffer_impl = True
features.resize_fn = True
features.optimal_storage = VkStorageType.TEXTURE_3D
features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED
features.handles_own_prepacking = True
return features


@update_features(exir_ops.edge.aten._weight_int8pack_mm.default)
def register_int8_mm_op(features: OpFeatures):
features.texture_impl = TextureImplFeatures(
uses_packed_dim=False,
uses_axis_map=False,
supported_layouts=[VkMemoryLayout.TENSOR_WIDTH_PACKED],
valid_packed_dims={PackedDim.WIDTH},
)
features.buffer_impl = True
features.resize_fn = True
features.optimal_storage = VkStorageType.TEXTURE_3D
features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED
features.handles_own_prepacking = True
return features


@update_features(exir_ops.edge.et_vk.linear_weight_int4.default)
def register_int4_mm_op(features: OpFeatures):
features.texture_impl = TextureImplFeatures(
uses_packed_dim=False,
uses_axis_map=False,
supported_layouts=[VkMemoryLayout.TENSOR_WIDTH_PACKED],
valid_packed_dims={PackedDim.WIDTH},
)
features.resize_fn = True
features.optimal_storage = VkStorageType.TEXTURE_3D
features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED
features.handles_own_prepacking = True
return features

Expand All @@ -266,7 +370,7 @@ def register_int4_mm_op(features: OpFeatures):
)
def register_softmax_op(features: OpFeatures):
features.texture_impl = TextureImplFeatures(
uses_packed_dim=True,
valid_packed_dims=all_packed_dims,
)
features.resize_fn = True
return features
Expand All @@ -282,7 +386,7 @@ def register_softmax_op(features: OpFeatures):
)
def register_reduce_op(features: OpFeatures):
features.texture_impl = TextureImplFeatures(
uses_packed_dim=True,
valid_packed_dims=all_packed_dims,
)
features.resize_fn = True

Expand All @@ -309,7 +413,7 @@ def check_reduce_node(node: torch.fx.Node) -> bool:
)
def register_2d_pool_op(features: OpFeatures):
features.texture_impl = TextureImplFeatures(
supported_layouts=[VkMemoryLayout.TENSOR_CHANNELS_PACKED],
valid_packed_dims={PackedDim.CHANNELS},
)
features.resize_fn = True
return features
Expand All @@ -323,27 +427,31 @@ def register_2d_pool_op(features: OpFeatures):
)
def register_convolution_op(features: OpFeatures):
features.texture_impl = TextureImplFeatures(
supported_layouts=[VkMemoryLayout.TENSOR_CHANNELS_PACKED],
valid_packed_dims={PackedDim.CHANNELS},
)
features.resize_fn = True
features.optimal_storage = VkStorageType.TEXTURE_3D
features.optimal_layout = VkMemoryLayout.TENSOR_CHANNELS_PACKED
features.handles_own_prepacking = True
return features


@update_features("llama::sdpa_with_kv_cache")
def register_sdpa_op(features: OpFeatures):
features.texture_impl = TextureImplFeatures(
supported_layouts=[VkMemoryLayout.TENSOR_WIDTH_PACKED],
valid_packed_dims={PackedDim.WIDTH},
)
features.resize_fn = True
features.optimal_storage = VkStorageType.TEXTURE_3D
features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED
features.handles_own_prepacking = True
return features


@update_features(exir_ops.edge.et_vk.apply_rotary_emb.default)
def register_rotary_emb_op(features: OpFeatures):
features.texture_impl = TextureImplFeatures(
supported_layouts=[VkMemoryLayout.TENSOR_WIDTH_PACKED],
valid_packed_dims={PackedDim.WIDTH},
)
features.resize_fn = True
return features
Expand All @@ -352,7 +460,7 @@ def register_rotary_emb_op(features: OpFeatures):
@update_features(exir_ops.edge.aten.view_copy.default)
def register_view_op(features: OpFeatures):
features.texture_impl = TextureImplFeatures(
uses_packed_dim=True,
valid_packed_dims=all_packed_dims,
)
features.resize_fn = True
return features
Expand Down Expand Up @@ -393,7 +501,7 @@ def register_view_op(features: OpFeatures):
)
def register_ported_op(features: OpFeatures):
features.texture_impl = TextureImplFeatures(
supported_layouts=[VkMemoryLayout.TENSOR_CHANNELS_PACKED],
valid_packed_dims={PackedDim.CHANNELS},
)
return features

Expand All @@ -408,15 +516,24 @@ def register_ported_op(features: OpFeatures):
)
def register_ported_ops_with_prepacking(features: OpFeatures):
features.texture_impl = TextureImplFeatures(
supported_layouts=[VkMemoryLayout.TENSOR_CHANNELS_PACKED],
valid_packed_dims={PackedDim.CHANNELS},
)
features.handles_own_prepacking = True
return features


##
## Utility Functions
##
#######################
## Utility functions ##
#######################


def has_impl(target: OpKey) -> bool:
if not isinstance(target, str):
if target not in vulkan_supported_ops:
return target.name() in vulkan_supported_ops
return target in vulkan_supported_ops
else:
return target in vulkan_supported_ops


def get_op_features(target: OpKey) -> OpFeatures:
Expand Down
1 change: 1 addition & 0 deletions backends/vulkan/partitioner/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ runtime.python_library(
],
deps = [
"//executorch/backends/vulkan:op_registry",
"//executorch/backends/vulkan:utils_lib",
"//executorch/backends/vulkan:vulkan_preprocess",
"//executorch/exir:delegate",
"//executorch/exir:lib",
Expand Down
Loading
Loading