Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions backends/vulkan/_passes/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,22 @@ runtime.python_library(
],
)

runtime.python_library(
name = "fuse_patterns",
srcs = ["fuse_patterns.py"],
visibility = [
"//executorch/backends/...",
],
deps = [
"//caffe2:torch",
"//executorch/backends/vulkan/patterns:vulkan_patterns",
"//executorch/exir:lib",
"//executorch/exir:pass_base",
"//executorch/exir/dialects:lib",
],
typing = True,
)

runtime.python_library(
name = "vulkan_passes",
srcs = [
Expand All @@ -128,6 +144,7 @@ runtime.python_library(
"//executorch/examples/...",
],
deps = [
":fuse_patterns",
":fuse_quantized_ops",
":insert_prepack_nodes",
":int4_weight_only_quantizer",
Expand Down
2 changes: 2 additions & 0 deletions backends/vulkan/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

# pyre-strict

from executorch.backends.vulkan._passes.fuse_patterns import FusePatternsPass
from executorch.backends.vulkan._passes.fuse_quantized_ops import (
FuseQuantizedOpsTransform,
)
Expand All @@ -29,6 +30,7 @@
from executorch.backends.vulkan._passes.tag_memory_meta_pass import TagMemoryMetaPass

__all__ = [
"FusePatternsPass",
"FuseQuantizedOpsTransform",
"insert_prepack_nodes",
"VkInt4WeightOnlyQuantizer",
Expand Down
30 changes: 30 additions & 0 deletions backends/vulkan/_passes/fuse_patterns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import executorch.backends.vulkan.patterns as vk_patterns

import torch

from executorch.exir import ExportedProgram
from executorch.exir.pass_base import ExportPass, PassResult


class FusePatternsPass(ExportPass):
def __init__(self, exported_program: ExportedProgram) -> None:
super().__init__()
self.program = exported_program

def call(self, graph_module: torch.fx.GraphModule):
total_replaced = vk_patterns.replace_all_fusable_subgraphs(
self.program, graph_module
)

if total_replaced > 0:
graph_module.recompile()
# Re-trace the graph
graph_module = super().call(graph_module).graph_module

return PassResult(graph_module, total_replaced > 0)
36 changes: 3 additions & 33 deletions backends/vulkan/custom_ops_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import executorch.backends.vulkan.patterns as vk_patterns
import torch.library

namespace = "et_vk"
Expand Down Expand Up @@ -325,42 +326,11 @@ def linear_qta8a_qga4w(
######################


# Note that this implementation is copied from executorch.examples.models.llama.rope
# but it is copied here to avoid introducing a dependency on the llama code.
def apply_rotary_emb_impl(
xq: torch.Tensor, xk: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
):
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
freqs_cis_ndim = freqs_cis.ndim
if freqs_cis_ndim == 3:
# freqs_cis: (seq_len, n_heads, head_dim // 2)
assert freqs_cis.shape == (x.shape[-3], x.shape[-2], x.shape[-1])
shape = [
d if (i == ndim - 3 or i == ndim - 2 or i == ndim - 1) else 1
for i, d in enumerate(x.shape)
]
else:
# freqs_cis: (seq_len, head_dim // 2)
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(shape)

xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)

freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)

xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos

xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)

return xq_out.type_as(xq), xk_out.type_as(xk)
pattern = vk_patterns.RotaryEmbeddingPattern()
return pattern.forward(xq, xk, freqs_cos, freqs_sin)


name = "apply_rotary_emb"
Expand Down
1 change: 1 addition & 0 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def update_features_impl(op: OpKey):
operator.gt,
operator.ge,
operator.le,
operator.eq,
# Guard and assert ops
torch.ops.aten._assert_scalar.default,
torch.ops.aten.sym_constrain_range_for_size.default,
Expand Down
1 change: 1 addition & 0 deletions backends/vulkan/partitioner/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ runtime.python_library(
"//executorch/backends/vulkan:op_registry",
"//executorch/backends/vulkan:utils_lib",
"//executorch/backends/vulkan:vulkan_preprocess",
"//executorch/backends/vulkan/patterns:vulkan_patterns",
"//executorch/exir:delegate",
"//executorch/exir:lib",
"//executorch/exir/backend:partitioner",
Expand Down
22 changes: 21 additions & 1 deletion backends/vulkan/partitioner/vulkan_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import logging
from typing import Any, Callable, Dict, final, List, Mapping, Optional, Set, Tuple

import executorch.backends.vulkan.patterns as vk_patterns
import executorch.backends.vulkan.utils as utils

import torch
Expand Down Expand Up @@ -37,9 +38,10 @@
from executorch.exir.dialects._ops import ops as exir_ops

from torch.export.exported_program import ExportedProgram
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner

from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.fx.passes.operator_support import OperatorSupportBase
from torch.fx.passes.utils.matcher_utils import InternalMatch

# pyre-ignore
ops_not_to_decompose = [
Expand All @@ -58,6 +60,7 @@ def __init__(
require_dynamic_shape: bool = False,
operator_blocklist: Optional[Set[OpKey]] = None,
operator_allowlist: Optional[Set[OpKey]] = None,
fusable_subgraphs: Optional[List[InternalMatch]] = None,
) -> None:
super().__init__()
self.texture_limits: utils.ImageExtents = texture_limits
Expand All @@ -67,6 +70,13 @@ def __init__(
operator_blocklist if operator_blocklist is not None else set()
)
self.operator_allowlist = operator_allowlist
self.fusable_subgraphs: List[InternalMatch] = (
fusable_subgraphs if fusable_subgraphs is not None else []
)
# Create a set of all nodes that are part of fusable subgraphs for quick lookup
self.fusable_nodes: Set[torch.fx.Node] = set()
for match in self.fusable_subgraphs:
self.fusable_nodes.update(match.nodes_map.values())

def op_node_is_compatible( # noqa: C901: Function is too complex
self, node: torch.fx.Node, features: Optional[OpFeatures] = None
Expand Down Expand Up @@ -204,6 +214,10 @@ def is_node_supported(
return r

def _is_node_supported(self, node: torch.fx.Node) -> bool:
# Check if this node is part of a fusable subgraph
if node.op == "call_function" and node in self.fusable_nodes:
return True

target = node.target
if node.target == torch.ops.higher_order.auto_functionalized:
first_arg = node.args[0]
Expand Down Expand Up @@ -330,6 +344,11 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
# subgraphs containing the nodes with the tags
partition_tags = {}

# Get all fusable subgraphs from fuse_patterns
fusable_subgraphs = vk_patterns.get_all_fusable_subgraphs(
exported_program.graph_module
)

texture_limits: utils.ImageExtents = self.options.get(
"texture_limits", utils.DEFAULT_TEXTURE_LIMITS
)
Expand All @@ -342,6 +361,7 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
require_dynamic_shape=self.options.get("require_dynamic_shapes", False),
operator_blocklist=self.operator_blocklist,
operator_allowlist=self.operator_allowlist,
fusable_subgraphs=fusable_subgraphs,
),
allows_single_node_partition=True,
)
Expand Down
24 changes: 24 additions & 0 deletions backends/vulkan/patterns/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

oncall("executorch")

runtime.python_library(
name = "vulkan_patterns",
srcs = [
"__init__.py",
"pattern_registry.py",
"rope.py",
],
visibility = [
"//executorch/backends/...",
"//executorch/examples/...",
],
deps = [
"//caffe2:torch",
"//executorch/exir:lib",
"//executorch/backends/transforms:utils",
"//executorch/backends/vulkan:utils_lib",
],
typing = True,
)
98 changes: 98 additions & 0 deletions backends/vulkan/patterns/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import List

import executorch.backends.vulkan.patterns.rope # noqa

import torch

from executorch.backends.vulkan.patterns.pattern_registry import (
CreateReplacementFn,
fusable_patterns,
GetGraphFn,
register_pattern_graph,
register_pattern_replacement,
)

from executorch.backends.vulkan.patterns.rope import RotaryEmbeddingPattern

from executorch.exir import ExportedProgram

from torch.fx.passes.utils.matcher_utils import InternalMatch, SubgraphMatcher


__all__ = [
"GetGraphFn",
"CreateReplacementFn",
"RotaryEmbeddingPattern",
"fusable_patterns",
"register_pattern_graph",
"register_pattern_replacement",
]


def all_fusable_graph_patterns() -> List[torch.fx.GraphModule]:
all_patterns = []
for entry in fusable_patterns.values():
if entry.get_graphs_fn is not None:
all_patterns.extend(entry.get_graphs_fn())

return all_patterns


def get_all_fusable_subgraphs(
graph_module: torch.fx.GraphModule,
) -> List[InternalMatch]:
fusable_subgraphs = []

fuse_patterns = all_fusable_graph_patterns()
for pattern in fuse_patterns:
sm = SubgraphMatcher(pattern.graph, ignore_literals=True)
matches = list(sm.match(graph_module.graph))
fusable_subgraphs.extend(matches)

return fusable_subgraphs


def create_replacement_for_pattern(
ep: ExportedProgram,
graph_module: torch.fx.GraphModule,
patterns: List[torch.fx.GraphModule],
create_replacement_func: CreateReplacementFn,
) -> int:
total_replaced = 0

for pattern in patterns:
sm = SubgraphMatcher(pattern.graph, ignore_literals=True)
matches = list(sm.match(graph_module.graph))

for partition_to_replace in matches:
create_replacement_func(ep, graph_module, partition_to_replace)
total_replaced += 1
# Remove dead code so they won't be matched again
graph_module.graph.eliminate_dead_code()

return total_replaced


def replace_all_fusable_subgraphs(
ep: ExportedProgram,
graph_module: torch.fx.GraphModule,
) -> int:
total_replaced = 0

for entry in fusable_patterns.values():
if entry.get_graphs_fn is not None and entry.create_replacement_fn is not None:
total_replaced += create_replacement_for_pattern(
ep,
graph_module,
entry.get_graphs_fn(),
# pyre-ignore[6]
entry.create_replacement_fn,
)

return total_replaced
Loading
Loading