Skip to content

Commit a022db0

Browse files
authored
[ET-VK][ez] Fix partitioner logic (#12196)
Summary: ## Changes * In partitioner, check if buffer storage can be used only after it has been determined that no valid texture layouts are available ## Context Currently, the logic in the vulkan partitioner is incorrect. 1. First, it checks what texture layouts may be used to represent the tensors involved in the computation 2. If no texture layouts are available, it checks if buffer support is available and the tensors are small enough to be within Vulkan buffer limits 3. Then, it checks if all valid texture layouts are supported by the op. This introduces a bug in situations where 3 fails, but 2 would pass. However, 2 is not checked due to the way the logic is structured. The fix is to switch the order of 2 and 3. Test Plan: ## Test Plan Manual verification + CI
1 parent 1315388 commit a022db0

File tree

1 file changed

+10
-12
lines changed

1 file changed

+10
-12
lines changed

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,17 @@ def op_node_is_compatible( # noqa: C901: Function is too complex
116116
arg, self.buffer_limit
117117
)
118118

119+
op_available_layouts = features.supported_memory_layouts(
120+
VkStorageType.TEXTURE_3D
121+
)
122+
123+
can_use_texture = any(
124+
layout in op_available_layouts for layout in valid_texture_layouts
125+
)
126+
119127
# If there are no valid texture memory layouts, then buffer storage must be
120128
# supported by the operator implementation.
121-
if len(valid_texture_layouts) == 0:
129+
if not can_use_texture:
122130
if not can_use_buffers:
123131
return (
124132
False,
@@ -131,17 +139,7 @@ def op_node_is_compatible( # noqa: C901: Function is too complex
131139
reason = "op requires buffers which is not supported by op impl"
132140
return compatible, reason
133141

134-
op_available_layouts = features.supported_memory_layouts(
135-
VkStorageType.TEXTURE_3D
136-
)
137-
138-
is_compatible = any(
139-
layout in op_available_layouts for layout in valid_texture_layouts
140-
)
141-
if not is_compatible:
142-
return False, "Required texutre memory layout not supported"
143-
144-
return is_compatible, "Op is compatible"
142+
return True, "Op is compatible"
145143

146144
def node_is_compatible(
147145
self, node: torch.fx.Node, features: Optional[OpFeatures] = None

0 commit comments

Comments
 (0)