Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 29 additions & 13 deletions backends/vulkan/_passes/tag_memory_meta_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def propose_node_storage( # noqa: C901
# pyre-ignore
features = get_op_features(node.target)
valid_storage_types = features.supported_storage_types()
storage = features.propose_storage_type()
storage = features.propose_output_storage_type()
if storage is not None:
return storage

Expand Down Expand Up @@ -180,7 +180,7 @@ def propose_node_layout(
# pyre-ignore
features = get_op_features(node.target)
valid_layouts = features.supported_memory_layouts(storage)
layout = features.propose_memory_layout(storage)
layout = features.propose_output_memory_layout(storage)
if layout is not None:
return layout

Expand Down Expand Up @@ -251,33 +251,49 @@ def set_or_transition_arg_node(
) -> bool:
assert isinstance(arg, torch.fx.Node)

storage = utils.get_node_storage_type(node)
assert storage is not None
layout = utils.get_node_memory_layout(node)
assert layout is not None
# Determine the desired storage and layout for this input
desired_storage = None
desired_layout = None

# Check if the operator has input-specific preferences
if has_impl(node.target):
features = get_op_features(node.target)
desired_storage = features.propose_input_storage_type(i)
if desired_storage is not None:
desired_layout = features.propose_input_memory_layout(
i, desired_storage
)

# Fallback to output preferences if no input-specific preferences
if desired_storage is None:
desired_storage = utils.get_node_storage_type(node)
assert desired_storage is not None
if desired_layout is None:
desired_layout = utils.get_node_memory_layout(node)
assert desired_layout is not None

arg_storage = utils.get_node_storage_type(arg)
arg_layout = utils.get_node_memory_layout(arg)

if arg_storage is None:
utils.set_node_spec_attr(arg, "vk_storage_type", storage)
arg_storage = storage
utils.set_node_spec_attr(arg, "vk_storage_type", desired_storage)
arg_storage = desired_storage
if arg_layout is None:
utils.set_node_spec_attr(arg, "vk_memory_layout", layout)
arg_layout = layout
utils.set_node_spec_attr(arg, "vk_memory_layout", desired_layout)
arg_layout = desired_layout

if arg_storage == storage and arg_layout == layout:
if arg_storage == desired_storage and arg_layout == desired_layout:
return False

if not dirty:
logger.info(
f"[Vulkan Delegate] Inserting transition(s) for {node.format_node()}:"
)

insert_transition_node(graph_module, node, arg, storage, layout)
insert_transition_node(graph_module, node, arg, desired_storage, desired_layout)

logger.info(
f" args {i} ({arg}): ({arg_storage}, {arg_layout}) -> ({storage}, {layout})"
f" args {i} ({arg}): ({arg_storage}, {arg_layout}) -> ({desired_storage}, {desired_layout})"
)

return True
Expand Down
152 changes: 121 additions & 31 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,12 @@ class OpFeatures:
# bool indicating if the operator has a resize function, which allows it to
# support dynamic shape tensors.
"resize_fn",
# Optimal
"optimal_storage",
"optimal_layout",
# Output-specific optimal storage and layout specifications
"optimal_output_storage",
"optimal_output_layout",
# Input-specific optimal storage and layout specifications
"optimal_input_storage",
"optimal_input_layout",
# bool indicating if the operator handles its own prepacking. If this is True,
# then the insert_prepack_nodes pass will not insert prepack nodes for the args
# of the op.
Expand All @@ -103,17 +106,25 @@ def __init__(
texture_impl: Optional[TextureImplFeatures] = None,
buffer_impl: bool = False,
resize_fn: bool = False,
optimal_storage: Optional[VkStorageType] = None,
optimal_layout: Optional[VkMemoryLayout] = None,
optimal_output_storage: Optional[VkStorageType] = None,
optimal_output_layout: Optional[VkMemoryLayout] = None,
optimal_input_storage: Optional[Union[VkStorageType, list]] = None,
optimal_input_layout: Optional[Union[VkMemoryLayout, list]] = None,
handles_own_prepacking: bool = False,
skip_limits_check: Optional[Set[int]] = None,
check_node_fn: Optional[Callable] = None,
):
self.texture_impl: Optional[TextureImplFeatures] = texture_impl
self.buffer_impl: bool = buffer_impl
self.resize_fn: bool = resize_fn
self.optimal_storage: Optional[VkStorageType] = optimal_storage
self.optimal_layout: Optional[VkMemoryLayout] = optimal_layout
self.optimal_output_storage: Optional[VkStorageType] = optimal_output_storage
self.optimal_output_layout: Optional[VkMemoryLayout] = optimal_output_layout
self.optimal_input_storage: Optional[Union[VkStorageType, list]] = (
optimal_input_storage
)
self.optimal_input_layout: Optional[Union[VkMemoryLayout, list]] = (
optimal_input_layout
)
self.handles_own_prepacking: bool = handles_own_prepacking

self.skip_limits_check: Set[int] = set()
Expand All @@ -124,7 +135,7 @@ def __init__(
if check_node_fn is not None:
self.check_node_fn = check_node_fn

def propose_storage_type(self) -> Optional[VkStorageType]:
def propose_output_storage_type(self) -> Optional[VkStorageType]:
"""
Propose a storage type that should be used for this operator. A proposal can be
made if one of the following is true:
Expand All @@ -134,8 +145,8 @@ def propose_storage_type(self) -> Optional[VkStorageType]:
If both storage types are supported and no optimal storage type is specified,
then None is returned to indicate that there is no preference in storage type.
"""
if self.optimal_storage is not None:
return self.optimal_storage
if self.optimal_output_storage is not None:
return self.optimal_output_storage

if self.texture_impl is not None and not self.buffer_impl:
return VkStorageType.TEXTURE_3D
Expand All @@ -156,7 +167,9 @@ def supported_storage_types(self) -> Set[VkStorageType]:

return storage_types

def propose_memory_layout(self, storage: VkStorageType) -> Optional[VkMemoryLayout]:
def propose_output_memory_layout(
self, storage: VkStorageType
) -> Optional[VkMemoryLayout]:
"""
Given a storage type as a precondition, propose a memory layout that should be
used for this operator. A proposal can be made if one of the following is true:
Expand All @@ -167,8 +180,8 @@ def propose_memory_layout(self, storage: VkStorageType) -> Optional[VkMemoryLayo
specified then return None to indicate that the "best" memory layout for the
operator is ambiguous.
"""
if self.optimal_layout is not None:
return self.optimal_layout
if self.optimal_output_layout is not None:
return self.optimal_output_layout

if storage == VkStorageType.TEXTURE_3D:
assert self.texture_impl is not None
Expand All @@ -189,6 +202,51 @@ def supported_memory_layouts(self, storage: VkStorageType) -> Set[VkMemoryLayout
else:
return all_memory_layouts

def propose_input_storage_type(self, input_index: int) -> Optional[VkStorageType]:
"""
Propose a storage type for a specific input tensor by index.

Args:
input_index: Index of the input tensor in the operator's input list

Returns:
Optimal storage type for the input, or None if no preference
"""
if self.optimal_input_storage is not None:
if isinstance(self.optimal_input_storage, list):
if input_index < len(self.optimal_input_storage):
return self.optimal_input_storage[input_index]
else:
# Single storage type applies to all inputs
return self.optimal_input_storage

# Fallback to output preference
return self.propose_output_storage_type()

def propose_input_memory_layout(
self, input_index: int, storage: VkStorageType
) -> Optional[VkMemoryLayout]:
"""
Propose a memory layout for a specific input tensor by index and storage type.

Args:
input_index: Index of the input tensor in the operator's input list
storage: Storage type for the input tensor

Returns:
Optimal memory layout for the input, or None if no preference
"""
if self.optimal_input_layout is not None:
if isinstance(self.optimal_input_layout, list):
if input_index < len(self.optimal_input_layout):
return self.optimal_input_layout[input_index]
else:
# Single layout applies to all inputs
return self.optimal_input_layout

# Fallback to output preference
return self.propose_output_memory_layout(storage)


#######################
## Operator Registry ##
Expand Down Expand Up @@ -253,8 +311,6 @@ def register_ephemeral_op(features: OpFeatures):
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
exir_ops.edge.quantized_decomposed.quantize_per_token.default,
exir_ops.edge.quantized_decomposed.dequantize_per_token.default,
exir_ops.edge.quantized_decomposed.choose_qparams.tensor,
exir_ops.edge.quantized_decomposed.choose_qparams_per_token_asymmetric.default,
]
)
def register_quantization_op(features: OpFeatures):
Expand All @@ -268,7 +324,37 @@ def register_quantization_op(features: OpFeatures):
)
features.buffer_impl = True
features.resize_fn = True
features.optimal_storage = VkStorageType.BUFFER
features.optimal_output_storage = VkStorageType.TEXTURE_3D
# Input can be TEXTURE_3D, but scales and zeros must be BUFFER
# For most quantization ops: input[0] = tensor, input[1] = scale, input[2] = zero_point
features.optimal_input_storage = [
VkStorageType.TEXTURE_3D, # input tensor
VkStorageType.BUFFER, # scale
VkStorageType.BUFFER, # zero_point
]
return features


@update_features(
[
exir_ops.edge.quantized_decomposed.choose_qparams.tensor,
exir_ops.edge.quantized_decomposed.choose_qparams_per_token_asymmetric.default,
]
)
def register_choose_qparams_op(features: OpFeatures):
# choose_qparams operators only have input tensor, no scales/zeros
# Optimal storage should be TEXTURE_3D for both input and output since it's faster
features.texture_impl = TextureImplFeatures(
uses_axis_map=True,
valid_packed_dims={
PackedDim.WIDTH,
},
)
features.buffer_impl = True
features.resize_fn = True
features.optimal_output_storage = VkStorageType.TEXTURE_3D
# Only input is the input tensor, should be TEXTURE_3D for optimal performance
features.optimal_input_storage = VkStorageType.TEXTURE_3D
return features


Expand All @@ -285,8 +371,10 @@ def register_affine_quantization_op(features: OpFeatures):
)
features.buffer_impl = True
features.resize_fn = True
features.optimal_storage = VkStorageType.TEXTURE_3D
features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED
features.optimal_output_storage = VkStorageType.TEXTURE_3D
features.optimal_output_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED
# All inputs should be TEXTURE_3D for quantize_affine and dequantize_affine
features.optimal_input_storage = VkStorageType.TEXTURE_3D
features.handles_own_prepacking = True

return features
Expand All @@ -308,7 +396,9 @@ def register_choose_qparams_affine_op(features: OpFeatures):
)
features.buffer_impl = True
features.resize_fn = True
features.optimal_storage = VkStorageType.BUFFER
features.optimal_output_storage = VkStorageType.BUFFER
# All inputs should be BUFFER since texture implementation is not available yet
features.optimal_input_storage = VkStorageType.BUFFER

return features

Expand Down Expand Up @@ -449,8 +539,8 @@ def register_mm_op(features: OpFeatures):
)
features.buffer_impl = True
features.resize_fn = True
features.optimal_storage = VkStorageType.TEXTURE_3D
features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED
features.optimal_output_storage = VkStorageType.TEXTURE_3D
features.optimal_output_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED
features.handles_own_prepacking = True
return features

Expand All @@ -468,8 +558,8 @@ def register_int8_mm_op(features: OpFeatures):
)
features.buffer_impl = True
features.resize_fn = True
features.optimal_storage = VkStorageType.TEXTURE_3D
features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED
features.optimal_output_storage = VkStorageType.TEXTURE_3D
features.optimal_output_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED
features.handles_own_prepacking = True
return features

Expand All @@ -487,8 +577,8 @@ def register_int4_mm_op(features: OpFeatures):
valid_packed_dims={PackedDim.WIDTH},
)
features.resize_fn = True
features.optimal_storage = VkStorageType.TEXTURE_3D
features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED
features.optimal_output_storage = VkStorageType.TEXTURE_3D
features.optimal_output_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED
features.handles_own_prepacking = True
features.skip_limits_check = {1}
return features
Expand Down Expand Up @@ -562,8 +652,8 @@ def register_convolution_op(features: OpFeatures):
valid_packed_dims={PackedDim.CHANNELS},
)
features.resize_fn = True
features.optimal_storage = VkStorageType.TEXTURE_3D
features.optimal_layout = VkMemoryLayout.TENSOR_CHANNELS_PACKED
features.optimal_output_storage = VkStorageType.TEXTURE_3D
features.optimal_output_layout = VkMemoryLayout.TENSOR_CHANNELS_PACKED
features.handles_own_prepacking = True
features.skip_limits_check = {1, 2}
return features
Expand All @@ -575,8 +665,8 @@ def register_sdpa_with_kv_cache_op(features: OpFeatures):
valid_packed_dims={PackedDim.WIDTH},
)
features.resize_fn = True
features.optimal_storage = VkStorageType.TEXTURE_3D
features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED
features.optimal_output_storage = VkStorageType.TEXTURE_3D
features.optimal_output_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED
features.handles_own_prepacking = True
return features

Expand Down Expand Up @@ -740,13 +830,13 @@ def register_native_group_norm(features: OpFeatures):
)
features.handles_own_prepacking = True

features.optimal_storage = [
features.optimal_output_storage = [
VkStorageType.TEXTURE_3D,
VkStorageType.BUFFER,
VkStorageType.BUFFER,
]

features.optimal_layout = [
features.optimal_output_layout = [
VkMemoryLayout.TENSOR_CHANNELS_PACKED,
VkMemoryLayout.TENSOR_WIDTH_PACKED,
VkMemoryLayout.TENSOR_WIDTH_PACKED,
Expand Down
Loading
Loading