Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions backends/vulkan/runtime/VulkanBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -390,18 +390,20 @@ bool maybe_resize_input(
const size_t input_i,
executorch::aten::Tensor& et_tensor) {
ValueRef in_tensor_ref = graph->inputs()[input_i].value;
vTensorPtr in_tensor = graph->get_tensor(in_tensor_ref);

const std::vector<int64_t> in_tensor_vk_sizes =
graph->sizes_of(in_tensor_ref);

ET_CHECK_MSG(
et_tensor.dim() == in_tensor->sizes().size(),
et_tensor.dim() == in_tensor_vk_sizes.size(),
"Cannot resize input tensor: old ndim %zu does not match new ndim %zu",
static_cast<size_t>(in_tensor->sizes().size()),
static_cast<size_t>(in_tensor_vk_sizes.size()),
static_cast<size_t>(et_tensor.dim()));

bool should_resize = false;
std::vector<int64_t> new_sizes(et_tensor.dim());
for (size_t i = 0; i < et_tensor.dim(); i++) {
if (in_tensor->sizes()[i] != et_tensor.sizes()[i]) {
if (in_tensor_vk_sizes[i] != et_tensor.sizes()[i]) {
should_resize = true;
}
new_sizes.at(i) = et_tensor.sizes()[i];
Expand All @@ -411,10 +413,11 @@ bool maybe_resize_input(
graph->resize_input(input_i, new_sizes);
}

const size_t in_tensor_vk_numel = graph->numel_of(in_tensor_ref);
ET_CHECK_MSG(
in_tensor->numel() == et_tensor.numel(),
in_tensor_vk_numel == et_tensor.numel(),
"Vulkan tensor numel %zu does not match ET tensor numel %zu",
static_cast<size_t>(in_tensor->numel()),
static_cast<size_t>(in_tensor_vk_numel),
static_cast<size_t>(et_tensor.numel()));

return should_resize;
Expand Down Expand Up @@ -445,12 +448,14 @@ void maybe_resize_output(
const size_t output_i,
executorch::aten::Tensor& et_tensor) {
ValueRef out_tensor_ref = graph->outputs()[output_i].value;
vTensorPtr out_tensor = graph->get_tensor(out_tensor_ref);

const std::vector<int64_t> out_tensor_vk_sizes =
graph->sizes_of(out_tensor_ref);

executorch::aten::SizesType new_output_size[kTensorDimensionLimit];
size_t ndim = out_tensor->sizes().size();
size_t ndim = out_tensor_vk_sizes.size();
for (int i = 0; i < ndim; ++i) {
new_output_size[i] = out_tensor->sizes()[i];
new_output_size[i] = out_tensor_vk_sizes[i];
}

executorch::aten::ArrayRef<executorch::aten::SizesType> output_size{
Expand Down
43 changes: 43 additions & 0 deletions backends/vulkan/runtime/graph/ComputeGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,38 @@ utils::uvec3 ComputeGraph::create_local_wg_size(const ValueRef idx) {
return create_local_wg_size(create_global_wg_size(idx));
}

void ComputeGraph::bind_tensor_to_descriptor_set(
const ValueRef ref,
vkapi::PipelineBarrier& pipeline_barrier,
const vkapi::MemoryAccessFlags access_type,
vkapi::DescriptorSet& descriptor_set,
const uint32_t idx) {
vTensorPtr tensor = get_tensor(ref);
if (tensor->buffer()) {
vkapi::VulkanBuffer& buffer = tensor->buffer(
pipeline_barrier, vkapi::PipelineStage::COMPUTE, access_type);
descriptor_set.bind(idx, buffer);
} else {
vkapi::VulkanImage& image = tensor->image(
pipeline_barrier, vkapi::PipelineStage::COMPUTE, access_type);
descriptor_set.bind(idx, image);
}
}

void ComputeGraph::bind_value_to_descriptor_set(
const ValueRef ref,
vkapi::PipelineBarrier& pipeline_barrier,
const vkapi::MemoryAccessFlags access_type,
vkapi::DescriptorSet& descriptor_set,
const uint32_t idx) {
if (val_is_tensor(ref)) {
bind_tensor_to_descriptor_set(
ref, pipeline_barrier, access_type, descriptor_set, idx);
} else if (val_is_staging(ref)) {
descriptor_set.bind(idx, get_staging(ref)->buffer());
}
}

void ComputeGraph::copy_into_staging(
const ValueRef idx,
const void* data,
Expand Down Expand Up @@ -891,6 +923,17 @@ void ComputeGraph::execute() {
execute_count_++;
}

void ComputeGraph::virtual_clone(const ValueRef dst, const ValueRef src) {
get_tensor(dst)->virtual_clone(*get_tensor(src));
}

void ComputeGraph::virtual_transpose(
const ValueRef tensor,
const int64_t dim0,
const int64_t dim1) {
get_tensor(tensor)->virtual_transpose(dim0, dim1);
}

void ComputeGraph::resize_input(
const int64_t idx,
const std::vector<int64_t>& new_sizes) {
Expand Down
31 changes: 31 additions & 0 deletions backends/vulkan/runtime/graph/ComputeGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,10 @@ class ComputeGraph final {
return values_.at(idx).toConstTensor().numel();
}

inline size_t staging_buffer_numel_of(const ValueRef idx) const {
return values_.at(idx).toConstTensor().staging_buffer_numel();
}

inline utils::StorageType storage_type_of(const ValueRef idx) const {
return values_.at(idx).toConstTensor().storage_type();
}
Expand Down Expand Up @@ -832,6 +836,20 @@ class ComputeGraph final {
*/
utils::uvec3 create_local_wg_size(const ValueRef idx);

void bind_tensor_to_descriptor_set(
const ValueRef ref,
vkapi::PipelineBarrier& pipeline_barrier,
const vkapi::MemoryAccessFlags accessType,
vkapi::DescriptorSet& descriptor_set,
const uint32_t idx);

void bind_value_to_descriptor_set(
const ValueRef ref,
vkapi::PipelineBarrier& pipeline_barrier,
const vkapi::MemoryAccessFlags access_type,
vkapi::DescriptorSet& descriptor_set,
const uint32_t idx);

//
// Input/Output
//
Expand Down Expand Up @@ -890,14 +908,27 @@ class ComputeGraph final {

void execute();

//
// Tensor View
//

void virtual_clone(const ValueRef dst, const ValueRef src);

void virtual_transpose(
const ValueRef tensor,
const int64_t dim0,
const int64_t dim1);

//
// Dynamic Shape support
//

void resize_input(const int64_t idx, const std::vector<int64_t>& new_sizes);

void virtual_resize(
const ValueRef idx,
const std::vector<int64_t>& new_sizes);

void propagate_resize();

//
Expand Down
14 changes: 6 additions & 8 deletions backends/vulkan/runtime/graph/ops/BlitNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,9 @@ BlitNode::BlitNode(
}

void BlitNode::encode(ComputeGraph* graph) {
auto src_tensor = graph->get_tensor(src_);
auto dst_tensor = graph->get_tensor(dst_);
VK_CHECK_COND(
src_tensor->storage_type() != utils::kBuffer &&
dst_tensor->storage_type() != utils::kBuffer,
graph->storage_type_of(src_) != utils::kBuffer &&
graph->storage_type_of(dst_) != utils::kBuffer,
"BlitNode: Only texture backed tensors are supported.");

api::Context* const context = graph->context();
Expand All @@ -41,18 +39,18 @@ void BlitNode::encode(ComputeGraph* graph) {
// Hack to get timing data for non shader op
std::string kernel_name("Blit_");
kernel_name.reserve(32);
kernel_name += vkapi::to_string(src_tensor->dtype());
kernel_name += vkapi::to_string(graph->dtype_of(src_));
kernel_name += "_to_";
kernel_name += vkapi::to_string(dst_tensor->dtype());
kernel_name += vkapi::to_string(graph->dtype_of(dst_));

context->report_shader_dispatch_start(
kernel_name, utils::uvec3(), utils::WorkgroupSize(), node_id_);

context->register_blit(
pipeline_barrier,
src_tensor->image(
graph->get_tensor(src_)->image(
pipeline_barrier, vkapi::PipelineStage::TRANSFER, vkapi::kRead),
dst_tensor->image(
graph->get_tensor(dst_)->image(
pipeline_barrier, vkapi::PipelineStage::TRANSFER, vkapi::kWrite));

context->report_shader_dispatch_end();
Expand Down
26 changes: 12 additions & 14 deletions backends/vulkan/runtime/graph/ops/PrepackNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@ namespace vkcompute {

vkapi::ShaderInfo get_noop_shader(ComputeGraph& graph, const ValueRef packed) {
std::string noop_shader_name("no_op");
vTensorPtr t_packed = graph.get_tensor(packed);
add_dtype_suffix(noop_shader_name, *t_packed);
add_storage_type_suffix(noop_shader_name, *t_packed);
add_dtype_suffix(noop_shader_name, graph.dtype_of(packed));
add_storage_type_suffix(noop_shader_name, graph.storage_type_of(packed));
return VK_KERNEL_FROM_STR(noop_shader_name);
}

Expand Down Expand Up @@ -48,13 +47,13 @@ PrepackNode::PrepackNode(
}

api::StagingBuffer PrepackNode::create_staging_buffer(ComputeGraph* graph) {
vTensorPtr packed = graph->get_tensor(packed_);

// If no TensorRef is provided, create a staging buffer of zeros according to
// the vkapi::vTensor metadata.
// If no TensorRef is provided, create a staging buffer of zeros based on the
// Tensor metadata.
if (graph->val_is_none(tref_)) {
size_t numel = utils::multiply_integers(packed->sizes());
api::StagingBuffer staging(graph->context(), packed->dtype(), numel);
const std::vector<int64_t> packed_sizes = graph->sizes_of(packed_);
size_t numel = utils::multiply_integers(packed_sizes);
api::StagingBuffer staging(
graph->context(), graph->dtype_of(packed_), numel);
staging.set_staging_zeros();
return staging;
}
Expand All @@ -80,7 +79,6 @@ void PrepackNode::encode(ComputeGraph* graph) {

context->check_device_capabilities(shader_);

vTensorPtr packed = graph->get_tensor(packed_);
api::StagingBuffer staging = create_staging_buffer(graph);

std::unique_lock<std::mutex> cmd_lock = context->dispatch_lock();
Expand All @@ -101,8 +99,8 @@ void PrepackNode::encode(ComputeGraph* graph) {
shader_, local_workgroup_size_, spec_vars_, push_constants_offset);

uint32_t idx = 0;
bind_tensor_to_descriptor_set(
*packed,
graph->bind_tensor_to_descriptor_set(
packed_,
pipeline_barrier,
vkapi::MemoryAccessType::WRITE,
descriptor_set,
Expand All @@ -128,8 +126,8 @@ void PrepackNode::encode(ComputeGraph* graph) {
vkapi::DescriptorSet descriptor_set = context->get_descriptor_set(
noop_shader_, utils::WorkgroupSize(1, 1, 1));

bind_tensor_to_descriptor_set(
*packed,
graph->bind_tensor_to_descriptor_set(
packed_,
pipeline_barrier,
vkapi::MemoryAccessType::READ,
descriptor_set,
Expand Down
22 changes: 10 additions & 12 deletions backends/vulkan/runtime/graph/ops/impl/Arange.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,22 @@ void resize_arange_node(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& extra_args) {
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
const ValueRef out = args.at(0).refs.at(0);

int start_val = 0;
int step_val = 1;
if (!graph->val_is_none(extra_args[0])) {
start_val = graph->extract_scalar<int64_t>(extra_args[0]);
if (!graph->val_is_none(extra_args.at(0))) {
start_val = graph->extract_scalar<int64_t>(extra_args.at(0));
}
int end_val = graph->extract_scalar<int64_t>(extra_args[1]);
if (!graph->val_is_none(extra_args[2])) {
step_val = graph->extract_scalar<int64_t>(extra_args[2]);
const int end_val = graph->extract_scalar<int64_t>(extra_args.at(1));
if (!graph->val_is_none(extra_args.at(2))) {
step_val = graph->extract_scalar<int64_t>(extra_args.at(2));
}

std::vector<int64_t> out_sizes = {
const std::vector<int64_t> out_sizes = {
utils::div_up(end_val - start_val, step_val)};

out->virtual_resize(out_sizes);
graph->virtual_resize(out, out_sizes);
}

void check_arange_input(
Expand Down Expand Up @@ -82,11 +82,9 @@ void add_arange_node(
}
}

vTensorPtr t_out = graph.get_tensor(out);

std::string kernel_name("arange");
kernel_name.reserve(kShaderNameReserve);
add_dtype_suffix(kernel_name, *t_out);
add_dtype_suffix(kernel_name, graph.dtype_of(out));

graph.execute_nodes().emplace_back(new DispatchNode(
graph,
Expand All @@ -96,7 +94,7 @@ void add_arange_node(
// Inputs and Outputs
{{out, vkapi::kWrite}},
// Shader params buffers
{t_out->sizes_ubo(),
{graph.sizes_ubo(out),
graph.create_params_buffer(start_val),
graph.create_params_buffer(step_val)},
// Push Constants
Expand Down
34 changes: 16 additions & 18 deletions backends/vulkan/runtime/graph/ops/impl/BatchNorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,44 +46,42 @@ void add_native_batch_norm_node(
ValueRef var_ref,
ValueRef eps_ref,
ValueRef out_tuple_ref) {
std::vector<int64_t> in_sizes = graph.get_tensor(in_ref)->sizes();
std::vector<int64_t> out_sizes = graph.get_tensor(in_ref)->sizes();
const std::vector<int64_t> in_sizes = graph.sizes_of(in_ref);
const std::vector<int64_t> out_sizes = graph.sizes_of(in_ref);

VK_CHECK_COND(in_sizes.size() == 4, "BatchNorm only support 4d tensor");
VK_CHECK_COND(out_sizes.size() == 4, "BatchNorm only support 4d tensor");

// Only the first element of the return value is propagated. The remaining 2
// elements are zero-size dummy tensor.
ValueRef out_ref = graph.get_value_list(out_tuple_ref)->at(0);
const ValueRef out_ref = graph.get_value_list(out_tuple_ref)->at(0);

utils::StorageType stype = graph.storage_type_of(out_ref);
const utils::StorageType stype = graph.storage_type_of(out_ref);

int64_t num_channels = dim_at<kChannel4D>(in_sizes);
const int64_t num_channels = dim_at<kChannel4D>(in_sizes);

ValueRef arg_weight =
const ValueRef arg_weight =
check_and_prepack_arg(graph, weight_ref, stype, num_channels, "weight");
ValueRef arg_bias =
const ValueRef arg_bias =
check_and_prepack_arg(graph, bias_ref, stype, num_channels, "bias");
ValueRef arg_mean =
const ValueRef arg_mean =
check_and_prepack_arg(graph, mean_ref, stype, num_channels, "mean");
ValueRef arg_var =
const ValueRef arg_var =
check_and_prepack_arg(graph, var_ref, stype, num_channels, "var");
float epsilon = graph.extract_scalar<float>(eps_ref);

vTensorPtr t_in = graph.get_tensor(in_ref);
const float epsilon = graph.extract_scalar<float>(eps_ref);

VK_CHECK_COND(!graph.val_is_tref(out_ref), "Output should not be tref");
vTensorPtr t_out = graph.get_tensor(out_ref);

const std::vector<int64_t> out_tensor_sizes = graph.sizes_of(out_ref);
VK_CHECK_COND(
dim_at<kChannel4D>(t_out->sizes()) == num_channels,
dim_at<kChannel4D>(out_tensor_sizes) == num_channels,
"out channel must match in channel");

std::string kernel_name = "batchnorm";
add_dtype_suffix(kernel_name, *t_out);
add_dtype_suffix(kernel_name, graph.dtype_of(out_ref));

int32_t num_texel_per_batch =
utils::div_up_4((dim_at<kChannel4D>(t_in->sizes())));
const int32_t num_texel_per_batch =
utils::div_up_4((dim_at<kChannel4D>(in_sizes)));

graph.execute_nodes().emplace_back(new DispatchNode(
graph,
Expand All @@ -92,7 +90,7 @@ void add_native_batch_norm_node(
graph.create_local_wg_size(out_ref),
{{out_ref, vkapi::kWrite},
{{in_ref, arg_weight, arg_bias, arg_mean, arg_var}, vkapi::kRead}},
{t_out->logical_limits_ubo(),
{graph.logical_limits_ubo(out_ref),
graph.create_params_buffer(epsilon),
graph.create_params_buffer(num_texel_per_batch)},
// Push Constants
Expand Down
Loading
Loading