Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,17 +491,9 @@ def register_view_ops():
# for both texture and buffer storage types.
@update_features(exir_ops.edge.aten.cat.default)
def register_cat_op():
def check_cat_node(node: torch.fx.Node) -> bool:
inputs = node.args[0]
if isinstance(inputs, (list, tuple)) and len(inputs) <= 3:
return True

return False

return OpFeatures(
inputs_storage=utils.ANY_STORAGE,
supports_resize=True,
are_node_inputs_supported_fn=check_cat_node,
)


Expand Down
9 changes: 8 additions & 1 deletion backends/vulkan/runtime/api/containers/Tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,7 @@ void vTensorStorage::transition(
vkapi::MemoryAccessFlags prev_access = last_access_.access;

const bool prev_written = (prev_access & vkapi::MemoryAccessType::WRITE) != 0;
const bool cur_written = (cur_access & vkapi::MemoryAccessType::WRITE) != 0;

VkImageLayout cur_layout = VK_IMAGE_LAYOUT_UNDEFINED;
VkImageLayout new_layout = VK_IMAGE_LAYOUT_UNDEFINED;
Expand All @@ -528,7 +529,13 @@ void vTensorStorage::transition(
layout_changed = cur_layout != new_layout;
}

if (prev_written || layout_changed) {
// RAW: need to make sure current read sees previous writes
// WAW: need to make sure the current write occurs after previous write so
// the final value is correct.
// WAR: need to make sure previous read does not read the value from the
// current write.
// RAR: no need for synchronization
if (prev_written || cur_written || layout_changed) {
VkPipelineStageFlags src_stage = vkapi::vk_stage(prev_stage);
if (0u == src_stage) {
src_stage = VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT;
Expand Down
2 changes: 1 addition & 1 deletion backends/vulkan/runtime/graph/ops/ExecuteNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class ExecuteNode {
friend class ComputeGraph;

public:
using ResizeFunction = const std::function<void(
using ResizeFunction = std::function<void(
ComputeGraph*,
const std::vector<ArgGroup>&,
const std::vector<ValueRef>&)>;
Expand Down
61 changes: 44 additions & 17 deletions backends/vulkan/runtime/graph/ops/glsl/concat_buffer.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,21 @@ layout(std430) buffer;

#include "indexing_utils.h"

${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")}
${layout_declare_tensor(B, "rw", "t_out", DTYPE, "buffer")}

$for i in range(NUM_INPUTS):
${layout_declare_tensor(B, "r", "t_in" + str(i + 1), DTYPE, "buffer")}
${layout_declare_tensor(B, "r", "t_inp" + str(i), DTYPE, "buffer")}

${layout_declare_tensor(B, "r", "t_concat_offset", "int", "buffer")}

${layout_declare_ubo(B, "int", "concat_dim")}

${layout_declare_ubo(B, "ivec4", "out_sizes")}
${layout_declare_ubo(B, "ivec4", "out_strides")}

$for i in range(NUM_INPUTS):
${layout_declare_ubo(B, "ivec4", "in" + str(i+1) + "_sizes")}
${layout_declare_ubo(B, "ivec4", "in" + str(i+1) + "_strides")}
${layout_declare_ubo(B, "ivec4", "inp" + str(i) + "_sizes")}
${layout_declare_ubo(B, "ivec4", "inp" + str(i) + "_strides")}

${layout_declare_ubo(B, "int", "out_numel")}

Expand All @@ -42,28 +44,53 @@ const lowp ivec4 out_dim_order = unhash_dim_order(out_layout);

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

#define NUM_INPUTS ${NUM_INPUTS}

#include "concat_utils.glslh"

/*
* This shader template concatenates up to NUM_INPUT input tensors to the
* output tensor along the concat_dim. Elements from the input tensor will
* be inserted along the output's concat_dim starting at concat_offset.
*/
void main() {
const int out_bufi = ivec3(gl_GlobalInvocationID).x;
if (out_bufi >= out_numel) {
const int tid = ivec3(gl_GlobalInvocationID).x;

// The 1-3 input tensors are interpreted as one concatenated tensor ("volume")
// along the concat_dim for the purposes of tensor indexing. Each thread is
// responsible for reading one item from this volume and writing it to the
// appropriate output location.
ivec4 inp_volume_sizes = out_sizes;
inp_volume_sizes[concat_dim] = total_concat_dim_numel();

// Account for 0 size input tensors
if (any(lessThanEqual(inp_volume_sizes, ivec4(0)))) {
return;
}

ivec4 inp_volume_tidx = nchwi_to_tidx(tid, inp_volume_sizes);

// bounds check
if (any(greaterThanEqual(inp_volume_tidx, inp_volume_sizes))) {
return;
}

// Convert buffer linear index to 4-D tensor index for output
const ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides, out_dim_order);
int concat_offset = t_concat_offset[0];

ivec4 out_tidx = inp_volume_tidx;
out_tidx[concat_dim] += concat_offset;

// Determine which input tensor to read from
ivec4 in_tidx = out_tidx;
const uint out_bufi = tidx_to_bufi(out_tidx, out_strides);

// Go through the list of input tensors, and find which input this output
// element should be read from.
$for i in range(NUM_INPUTS):
// Check if the index at the concat dim is within bounds of the input tensor
// If so, read from that input tensor and write to output
if (in_tidx[concat_dim] < in${i+1}_sizes[concat_dim]) {
int in_bufi = tidx_to_bufi(in_tidx, in${i+1}_strides);
t_out[out_bufi] = t_in${i+1}[in_bufi];
if (inp_volume_tidx[concat_dim] < inp${i}_sizes[concat_dim]) {
int inp_bufi = tidx_to_bufi(inp_volume_tidx, inp${i}_strides);
t_out[out_bufi] = t_inp${i}[inp_bufi];
return;
}
// otherwise, decrement the index at the concat dim
else {
in_tidx[concat_dim] -= in${i+1}_sizes[concat_dim];
inp_volume_tidx[concat_dim] -= inp${i}_sizes[concat_dim];
}
}
193 changes: 120 additions & 73 deletions backends/vulkan/runtime/graph/ops/glsl/concat_texture.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,18 @@ layout(std430) buffer;

#include "indexing_utils.h"

${layout_declare_tensor(B, "w", "t_out", DTYPE, "texture3d")}
${layout_declare_tensor(B, "rw", "t_out", DTYPE, "texture3d")}

$for i in range(NUM_INPUTS):
${layout_declare_tensor(B, "r", "t_in" + str(i + 1), DTYPE, "texture3d")}
${layout_declare_tensor(B, "r", "t_inp" + str(i), DTYPE, "texture3d")}

${layout_declare_tensor(B, "r", "t_concat_offset", "int", "buffer")}

${layout_declare_ubo(B, "int", "concat_dim")}

$in_metadata = ""
$for i in range(NUM_INPUTS):
$in_metadata += "ivec4 in" + str(i + 1) + "_sizes;\n"
$in_metadata += "ivec4 inp" + str(i) + "_sizes;\n"

layout(push_constant) uniform restrict Block {
ivec4 out_sizes;
Expand All @@ -40,90 +42,135 @@ const lowp ivec4 out_axis_map = unhash_axis_map(out_layout);
const lowp int out_packed_dim = unhash_packed_dim(out_layout);

$for i in range(NUM_INPUTS):
${layout_declare_spec_const(C, "int", "in" + str(i+1) + "_layout", "DEFAULT_LAYOUT")}
const lowp ivec4 in${i+1}_axis_map = unhash_axis_map(in${i+1}_layout);
const lowp int in${i+1}_packed_dim = unhash_packed_dim(in${i+1}_layout);
${layout_declare_spec_const(C, "int", "inp" + str(i) + "_layout", "DEFAULT_LAYOUT")}
const lowp ivec4 inp${i}_axis_map = unhash_axis_map(inp${i}_layout);
const lowp int inp${i}_packed_dim = unhash_packed_dim(inp${i}_layout);

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

// Check if we can use the fast path (no texel merging required)
bool can_use_fast_path() {
// Fast path is possible when:
// 1. The concat dimension is not the packed dimension, or
// 2. The concat dimension is the packed dimension but both input tensors have dimensions
// that are multiples of 4 along the packed dimension
if (concat_dim != out_packed_dim) {
return true;
}

// Check if all input tensors have dimensions that are multiples of 4 along the packed dimension
bool all_concat_dim_size_multiple_of_4 = true;
$for i in range(NUM_INPUTS):
all_concat_dim_size_multiple_of_4 =
all_concat_dim_size_multiple_of_4 &&
(in${i+1}_sizes[concat_dim] % 4 == 0);
#define NUM_INPUTS ${NUM_INPUTS}

return all_concat_dim_size_multiple_of_4;
}
#include "concat_utils.glslh"

/*
* This shader template concatenates up to NUM_INPUT input tensors to the
* output tensor along the concat_dim. Elements from the input tensor will
* be inserted along the output's concat_dim starting at concat_offset.
*
* Each thread is responsible for writing out one output texel. The data
* required for the output texel may be read from multiple input texels of one
* input tensor.
*/
void main() {
const ivec3 lpos = ivec3(gl_GlobalInvocationID);
ivec4 out_tidx = lpos_to_tidx(lpos, out_sizes, out_axis_map.w, out_packed_dim);

if (any(greaterThanEqual(out_tidx, out_sizes))) {
const int tid = ivec3(gl_GlobalInvocationID).x;

// Sum of the sizes of all input tensors along the concat_dim
const int concat_numel = total_concat_dim_numel();

// The 1-3 input tensors are interpreted as one concatenated tensor ("volume")
// along the concat_dim for the purposes of tensor indexing. Each thread is
// responsible for writing out 4 elements along the packed dim of the output
// tensor by reading the source data from the input tensor(s).
ivec4 inp_volume_sizes = out_sizes;
inp_volume_sizes[concat_dim] = total_concat_dim_numel();

// Reconstruct inp_volume_texel_sizes from Concat.cpp
ivec4 inp_volume_texel_sizes = inp_volume_sizes;
inp_volume_texel_sizes[out_packed_dim] = DIV_UP_4(
inp_volume_texel_sizes[out_packed_dim]
) + 1;

// tensor index of the first element that will be read from the input volume
ivec4 inp_volume_start_tidx = nchwi_to_tidx(tid, inp_volume_texel_sizes);
inp_volume_start_tidx[out_packed_dim] = MUL_4(
inp_volume_start_tidx[out_packed_dim]
);

int concat_offset = t_concat_offset[0];

// tensor index of the first element that will be written to the output tensor
ivec4 out_write_start_tidx = inp_volume_start_tidx;
out_write_start_tidx[concat_dim] += concat_offset;

// To write to the the desired output element, we will need to load the texel
// to which the element belongs. Calculate the tensor index of the first
// element of that texel.
ivec4 out_read_start_tidx = out_write_start_tidx;
out_read_start_tidx[out_packed_dim] = ALIGN_DOWN_4(
out_write_start_tidx[out_packed_dim]);

// bounds check
if (any(greaterThanEqual(out_read_start_tidx, out_sizes))) {
return;
}

if (can_use_fast_path()) {
// Fast path: No texel merging required
ivec4 in_tidx = out_tidx;
ivec3 out_pos = tidx_to_pos(
out_read_start_tidx,
out_sizes,
out_axis_map,
out_packed_dim
);

$for i in range(NUM_INPUTS):
// For each input tensor, check if the tensor index is within bounds. If
// so, read the texel from the input tensor and write it to the output
if (in_tidx[concat_dim] < in${i+1}_sizes[concat_dim]) {
const ivec3 in_pos = tidx_to_pos(in_tidx, in${i+1}_sizes, in${i+1}_axis_map, in${i+1}_packed_dim);
const VEC4_T in_texel = load_texel(t_in${i+1}, in_pos);
write_texel_lpos(t_out, lpos, in_texel, out_axis_map);
return;
}
// Otherwise, adjust the index along the concat dimension and try the next
// input tensor.
else {
in_tidx[concat_dim] -= in${i+1}_sizes[concat_dim];
}
}
else {
// Slow path: Texel merging required
VEC4_T out_texel = VEC4_T(0);
VEC4_T out_texel = imageLoad(t_out, out_pos);

// Process each element in the output texel individually
for (int texel_i = 0; texel_i < 4; ++texel_i) {
ivec4 curr_out_tidx = out_tidx;
curr_out_tidx[out_packed_dim] += texel_i;
VEC4_T test_texel = VEC4_T(-1.0);

// Skip if we're out of bounds
if (curr_out_tidx[out_packed_dim] >= out_sizes[out_packed_dim]) {
continue;
}
for (int comp = 0; comp < 4; ++comp) {
ivec4 out_tidx = out_read_start_tidx;
out_tidx[out_packed_dim] += comp;

ivec4 in_tidx = curr_out_tidx;
$for i in range(NUM_INPUTS):
// For each input tensor, check if the tensor index is within bounds. If
// so, read the corresponding texel element from the input tensor and
// write it to the output texel.
if (in_tidx[concat_dim] < in${i+1}_sizes[concat_dim]) {
const ivec4 in_posi = tidx_to_posi(in_tidx, in${i+1}_sizes, in${i+1}_axis_map, in${i+1}_packed_dim);
out_texel[texel_i] = load_texel(t_in${i+1}, in_posi.xyz)[in_posi.w];
continue;
}
// Otherwise, adjust the index along the concat dimension and try the
// next input tensor.
else {
in_tidx[concat_dim] -= in${i+1}_sizes[concat_dim];
}

// It's possible that the current texel element has been written to as part
// of the previous input batch; if so, then don't overwrite this texel
// element
if (out_tidx[concat_dim] < concat_offset) {
test_texel[comp] = -5.0;
continue;
}

write_texel_lpos(t_out, lpos, out_texel, out_axis_map);
// Calculate the tidx of the input volume that corresponds to this output
// element
ivec4 inp_volume_tidx = out_tidx;
inp_volume_tidx[concat_dim] -= concat_offset;

// go through the list of input tensors, and figure out which input this
// output element should be read from.
$for i in range(NUM_INPUTS):
if (inp_volume_tidx[concat_dim] < inp${i}_sizes[concat_dim]) {
// Special fast path case if, for the first output texel element, the
// corresponding input element is at the start of the texel it belongs
// to. In this case, the input texel can be written as-is to the output
// texel. Also require that The entire input texel is valid and does not
// contain any padding elements.
if (comp == 0 &&
out_tidx[out_packed_dim] % 4 == 0 &&
inp_volume_tidx[inp${i}_packed_dim] % 4 == 0 &&
inp_volume_tidx[inp${i}_packed_dim] + 3 < inp${i}_sizes[inp${i}_packed_dim]) {
const ivec3 in_pos = tidx_to_pos(
inp_volume_tidx,
inp${i}_sizes,
inp${i}_axis_map,
inp${i}_packed_dim);

out_texel = texelFetch(t_inp${i}, in_pos, 0);
break;
}

// Otherwise, locate the specific input element required
const ivec4 in_posi = tidx_to_posi(
inp_volume_tidx,
inp${i}_sizes,
inp${i}_axis_map,
inp${i}_packed_dim);

out_texel[comp] = texelFetch(t_inp${i}, in_posi.xyz, 0)[in_posi.w];
test_texel[comp] = out_texel[comp];
continue;
}
else {
inp_volume_tidx[concat_dim] -= inp${i}_sizes[concat_dim];
}
}

imageStore(t_out, out_pos, out_texel);
}
Loading
Loading