Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions backends/vulkan/runtime/graph/ops/impl/Arange.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>

#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
Expand Down Expand Up @@ -86,11 +87,11 @@ void add_arange_node(
kernel_name.reserve(kShaderNameReserve);
add_dtype_suffix(kernel_name, graph.dtype_of(out));

graph.execute_nodes().emplace_back(new DispatchNode(
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
graph.create_global_wg_size(out),
graph.create_local_wg_size(out),
default_pick_global_wg_size,
default_pick_local_wg_size,
// Inputs and Outputs
{{out, vkapi::kWrite}},
// Shader params buffers
Expand Down
7 changes: 4 additions & 3 deletions backends/vulkan/runtime/graph/ops/impl/BatchNorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
Expand Down Expand Up @@ -83,11 +84,11 @@ void add_native_batch_norm_node(
const int32_t num_texel_per_batch =
utils::div_up_4((dim_at<kChannel4D>(in_sizes)));

graph.execute_nodes().emplace_back(new DispatchNode(
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
graph.create_global_wg_size(out_ref),
graph.create_local_wg_size(out_ref),
default_pick_global_wg_size,
default_pick_local_wg_size,
{{out_ref, vkapi::kWrite},
{{in_ref, arg_weight, arg_bias, arg_mean, arg_var}, vkapi::kRead}},
{graph.logical_limits_ubo(out_ref),
Expand Down
138 changes: 116 additions & 22 deletions backends/vulkan/runtime/graph/ops/impl/Convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>

#include <executorch/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h>
Expand All @@ -19,6 +20,13 @@

namespace vkcompute {

enum class Conv2dMethod : uint8_t {
Depthwise,
Pointwise,
SlidingWindow,
Transposed,
};

void resize_conv2d_node(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
Expand Down Expand Up @@ -114,13 +122,6 @@ ValueRef prepack_biases(
return v;
}

enum class Conv2dMethod : uint8_t {
Depthwise,
Pointwise,
SlidingWindow,
Transposed,
};

vkapi::ShaderInfo get_conv2d_shader(
ComputeGraph& graph,
const ValueRef out,
Expand Down Expand Up @@ -327,6 +328,108 @@ utils::uvec3 create_conv2d_global_wg_size(
}
}

// Custom global workgroup size function for conv2d
utils::uvec3 conv2d_global_wg_size(
ComputeGraph* graph,
const vkapi::ShaderInfo& shader,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& resize_args) {
const ValueRef out = args.at(0).refs.at(0);
const ValueRef weight_data = resize_args.at(0);

// Determine method from shader name
Conv2dMethod method;
if (shader.kernel_name.find("conv2d_dw") != std::string::npos) {
method = Conv2dMethod::Depthwise;
} else if (
shader.kernel_name.find("conv2d_pw") != std::string::npos ||
(shader.kernel_name.find("conv2d") != std::string::npos &&
shader.kernel_name.find("conv_transpose2d") == std::string::npos)) {
// Check if it's pointwise by examining weight sizes
const auto& weight_sizes = graph->get_tref(weight_data)->sizes;
if (weight_sizes.at(2) == 1 && weight_sizes.at(3) == 1) {
method = Conv2dMethod::Pointwise;
} else {
method = Conv2dMethod::SlidingWindow;
}
} else if (shader.kernel_name.find("conv_transpose2d") != std::string::npos) {
method = Conv2dMethod::Transposed;
} else {
method = Conv2dMethod::SlidingWindow;
}

// Determine stride_equals_dilation from shader name
bool stride_equals_dilation =
shader.kernel_name.find("_sned") == std::string::npos;

utils::uvec3 wg_size = create_conv2d_global_wg_size(
*graph, method, out, weight_data, stride_equals_dilation);

if (method == Conv2dMethod::Depthwise || method == Conv2dMethod::Pointwise) {
wg_size = {wg_size[0] * wg_size[1], wg_size[2], 1};
}

return wg_size;
}

// Custom local workgroup size function for conv2d
utils::uvec3 conv2d_local_wg_size(
ComputeGraph* graph,
const vkapi::ShaderInfo& shader,
const utils::uvec3& global_workgroup_size,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& resize_args) {
(void)args;
(void)resize_args;

// Determine method from shader name
Conv2dMethod method;
if (shader.kernel_name.find("conv2d_dw") != std::string::npos) {
method = Conv2dMethod::Depthwise;
} else if (
shader.kernel_name.find("conv2d_pw") != std::string::npos ||
(shader.kernel_name.find("conv2d") != std::string::npos &&
shader.kernel_name.find("conv_transpose2d") == std::string::npos)) {
method = Conv2dMethod::Pointwise;
} else {
method = Conv2dMethod::SlidingWindow;
}

if (method == Conv2dMethod::Pointwise) {
uint32_t local_wg_size_y = 1;
if (global_workgroup_size[1] % 8 == 0) {
local_wg_size_y = 8;
} else if (global_workgroup_size[1] % 4 == 0) {
local_wg_size_y = 4;
} else if (global_workgroup_size[1] % 2 == 0) {
local_wg_size_y = 2;
}
return {64 / local_wg_size_y, local_wg_size_y, 1};
} else if (method == Conv2dMethod::Depthwise) {
return {64, 1, 1};
} else {
return graph->create_local_wg_size(global_workgroup_size);
}
}

// Custom global workgroup size function for conv1d
utils::uvec3 conv1d_global_wg_size(
ComputeGraph* graph,
const vkapi::ShaderInfo& shader,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& resize_args) {
(void)shader;
(void)resize_args;
const ValueRef out = args.at(0).refs.at(0);

return {// out length
graph->size_at<uint32_t>(-1, out),
// out channels
static_cast<uint32_t>(graph->size_at<int64_t>(-2, out)),
// out batches
utils::div_up_4(graph->size_at<uint32_t>(-3, out))};
}

void add_conv2d_node(
ComputeGraph& graph,
const ValueRef in,
Expand Down Expand Up @@ -486,11 +589,11 @@ void add_conv2d_node(
};
}

graph.execute_nodes().emplace_back(new DispatchNode(
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
graph,
shader,
wg_size,
local_wg_size,
conv2d_global_wg_size,
conv2d_local_wg_size,
// Inputs and Outputs
{{out, vkapi::kWrite}, {{in, arg_weight, arg_bias}, vkapi::kRead}},
// Shader params buffers
Expand Down Expand Up @@ -560,15 +663,6 @@ void add_conv1d_node(
const int32_t out_group_size =
static_cast<int64_t>(out_channels / groups_val);

const utils::uvec3 global_size = {
// out length
graph.size_at<uint32_t>(-1, out),
// out channels
static_cast<uint32_t>(out_channels),
// out batches
utils::div_up_4(graph.size_at<uint32_t>(-3, out))};
const utils::uvec3 local_size = graph.create_local_wg_size(global_size);

Kernel1dParams kernel_params = {
kernel_size,
stride_size,
Expand All @@ -587,11 +681,11 @@ void add_conv1d_node(

add_dtype_suffix(kernel_name, graph.dtype_of(out));

graph.execute_nodes().emplace_back(new DispatchNode(
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
global_size,
local_size,
conv1d_global_wg_size,
default_pick_local_wg_size,
// Inputs and Outputs
{{out, vkapi::kWrite}, {{in, arg_weight, arg_bias}, vkapi::kRead}},
// Shader params buffers
Expand Down
7 changes: 4 additions & 3 deletions backends/vulkan/runtime/graph/ops/impl/Copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
Expand Down Expand Up @@ -35,11 +36,11 @@ void add_copy_offset_node(

auto shader = VK_KERNEL_FROM_STR(kernel_name);

graph.execute_nodes().emplace_back(new DispatchNode(
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
graph.create_global_wg_size(out),
graph.create_local_wg_size(out),
default_pick_global_wg_size,
default_pick_local_wg_size,
// Inputs and Outputs
{
{out, vkapi::kWrite},
Expand Down
7 changes: 4 additions & 3 deletions backends/vulkan/runtime/graph/ops/impl/Embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
Expand Down Expand Up @@ -46,11 +47,11 @@ void add_embedding_node(
kernel_name.reserve(kShaderNameReserve);
add_dtype_suffix(kernel_name, graph.dtype_of(out));

graph.execute_nodes().emplace_back(new DispatchNode(
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
graph.create_global_wg_size(out),
graph.create_local_wg_size(out),
default_pick_global_wg_size,
default_pick_local_wg_size,
{{out, vkapi::kWrite}, {{in, weight}, vkapi::kRead}},
{
graph.sizes_ubo(out),
Expand Down
19 changes: 16 additions & 3 deletions backends/vulkan/runtime/graph/ops/impl/Flip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,26 @@

#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>

namespace vkcompute {

// Custom global workgroup size function for flip
utils::uvec3 flip_global_wg_size(
ComputeGraph* graph,
const vkapi::ShaderInfo& shader,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& resize_args) {
(void)shader;
(void)resize_args;
const ValueRef out = args.at(0).refs.at(0);
return graph->create_global_wg_size(out);
}

void check_flip_args(
ComputeGraph& graph,
const ValueRef in,
Expand Down Expand Up @@ -59,11 +72,11 @@ void add_flip_node(
kernel_name.reserve(kShaderNameReserve);
add_dtype_suffix(kernel_name, graph.dtype_of(out));

graph.execute_nodes().emplace_back(new DispatchNode(
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
graph.create_global_wg_size(out),
graph.create_local_wg_size(out),
flip_global_wg_size,
default_pick_local_wg_size,
// Inputs and Outputs
{
{out, vkapi::kWrite},
Expand Down
7 changes: 4 additions & 3 deletions backends/vulkan/runtime/graph/ops/impl/Full.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>

Expand Down Expand Up @@ -42,11 +43,11 @@ void add_full_node(

add_dtype_suffix(kernel_name, graph.dtype_of(out));

graph.execute_nodes().emplace_back(new DispatchNode(
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
graph.create_global_wg_size(out),
graph.create_local_wg_size(out),
default_pick_global_wg_size,
default_pick_local_wg_size,
// Inputs and Outputs
{{out, vkapi::kWrite}},
// Shader params buffers
Expand Down
7 changes: 4 additions & 3 deletions backends/vulkan/runtime/graph/ops/impl/GridPriors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>

#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
Expand Down Expand Up @@ -46,11 +47,11 @@ void add_grid_priors_node(
add_dtype_suffix(kernel_name, graph.dtype_of(out));

const GridPriorsParam param = {stride, offset};
graph.execute_nodes().emplace_back(new DispatchNode(
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
graph.create_global_wg_size(out),
graph.create_local_wg_size(out),
default_pick_global_wg_size,
default_pick_local_wg_size,
// Inputs and Outputs
{
{out, vkapi::kWrite},
Expand Down
Loading
Loading