Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions backends/vulkan/partitioner/vulkan_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,15 @@

class VulkanSupportedOperators(OperatorSupportBase):
def __init__(
self, texture_limits: utils.ImageExtents, require_dynamic_shape: bool = False
self,
texture_limits: utils.ImageExtents,
buffer_limit: int,
require_dynamic_shape: bool = False,
) -> None:
super().__init__()
self.require_dynamic_shapes = require_dynamic_shape
self.texture_limits: utils.ImageExtents = texture_limits
self.buffer_limit = buffer_limit
self.require_dynamic_shapes = require_dynamic_shape

def op_node_is_compatible(
self, node: torch.fx.Node, features: Optional[OpFeatures] = None
Expand Down Expand Up @@ -83,6 +87,7 @@ def op_node_is_compatible(
node, self.texture_limits
)

can_use_buffers = utils.within_buffer_limit(node, self.buffer_limit)
for i, arg in enumerate(node.args):
if (
isinstance(arg, torch.fx.Node)
Expand All @@ -95,10 +100,19 @@ def op_node_is_compatible(
valid_texture_layouts = valid_texture_layouts.intersection(
arg_texture_layouts
)
can_use_buffers = can_use_buffers and utils.within_buffer_limit(
arg, self.buffer_limit
)

# If there are no valid texture memory layouts, then buffer storage must be
# supported by the operator implementation.
if len(valid_texture_layouts) == 0:
if not can_use_buffers:
return (
False,
f"op requires buffers that exceed the buffer limit ({self.buffer_limit})",
)

compatible = VkStorageType.BUFFER in features.supported_storage_types()
reason = "op is compatible"
if not compatible:
Expand Down Expand Up @@ -309,10 +323,12 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
texture_limits: utils.ImageExtents = self.options.get(
"texture_limits", utils.DEFAULT_TEXTURE_LIMITS
)
buffer_limit: int = self.options.get("buffer_limit", utils.DEFAULT_BUFFER_LIMIT)
capability_partitioner = CapabilityBasedPartitioner(
exported_program.graph_module,
VulkanSupportedOperators(
texture_limits,
buffer_limit,
require_dynamic_shape=self.options.get("require_dynamic_shapes", False),
),
allows_single_node_partition=True,
Expand Down
17 changes: 17 additions & 0 deletions backends/vulkan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def is_tensor_node(node: torch.fx.Node) -> bool:
ImageExtents = Tuple[int, int, int]

DEFAULT_TEXTURE_LIMITS = (16384, 16384, 2048)
DEFAULT_BUFFER_LIMIT = 134217728


class PackedDim(IntEnum):
Expand All @@ -113,6 +114,22 @@ class PackedDim(IntEnum):
}


def within_buffer_limit(node: torch.fx.Node, buffer_limit: int) -> int:
"""
Checks whether the tensors produced by the given node can fit within the device's
GPU buffer limit, which represents the maximum number of elements that can be stored
in a GPU buffer.
"""
assert is_tensor_node(node)

if isinstance(node.meta["val"], FakeTensor):
return node.meta["val"].numel() < buffer_limit
elif isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple):
return all(x.numel() < buffer_limit for x in node.meta["val"])
else:
raise RuntimeError(f"Cannot get numel for val of type {type(node.meta['val'])}")


def required_image_extents(sizes: torch.Size, layout: VkMemoryLayout) -> ImageExtents:
"""
Calculate the image extents that will be used to represent a tensor with the given sizes
Expand Down
Loading