Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 28 additions & 13 deletions backends/vulkan/runtime/graph/ops/glsl/matmul_naive_buffer.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,23 @@ ${define_required_extensions(DTYPE)}

layout(std430) buffer;

${layout_declare_tensor(0, "w", "t_out", DTYPE, "buffer")}
${layout_declare_tensor(1, "r", "t_mat1", DTYPE, "buffer")}
${layout_declare_tensor(2, "r", "t_mat2", DTYPE, "buffer")}
${layout_declare_ubo(3, "ivec4", "out_sizes")}
${layout_declare_ubo(4, "ivec4", "out_strides")}
${layout_declare_ubo(5, "ivec4", "mat1_sizes")}
${layout_declare_ubo(6, "ivec4", "mat1_strides")}
${layout_declare_ubo(7, "ivec4", "mat2_sizes")}
${layout_declare_ubo(8, "ivec4", "mat2_strides")}
${layout_declare_ubo(9, "int", "out_numel")}
${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")}
${layout_declare_tensor(B, "r", "t_mat1", DTYPE, "buffer")}
${layout_declare_tensor(B, "r", "t_mat2", DTYPE, "buffer")}
${layout_declare_ubo(B, "ivec4", "out_sizes")}
${layout_declare_ubo(B, "ivec4", "out_strides")}
${layout_declare_ubo(B, "ivec4", "mat1_sizes")}
${layout_declare_ubo(B, "ivec4", "mat1_strides")}
${layout_declare_ubo(B, "ivec4", "mat2_sizes")}
${layout_declare_ubo(B, "ivec4", "mat2_strides")}
${layout_declare_ubo(B, "int", "out_numel")}

#include "indexing_utils.h"

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

${layout_declare_spec_const(C, "int", "mat2_is_transposed", "0")}

void main() {
const ivec4 out_bufix = ivec4(
gl_GlobalInvocationID.x,
Expand All @@ -44,15 +46,28 @@ void main() {

int mat1_bufi = tidx_to_bufi(
ivec4(0, out_bufix.y, out_bufix.z, out_bufix.w), mat1_strides);
int mat2_bufi = tidx_to_bufi(
ivec4(out_bufix.x, 0, out_bufix.z, out_bufix.w), mat2_strides);
int mat2_bufi;
if (mat2_is_transposed > 0) {
mat2_bufi = tidx_to_bufi(
ivec4(0, out_bufix.x, 0, 0), mat2_strides);
} else {
mat2_bufi = tidx_to_bufi(
ivec4(out_bufix.x, 0, out_bufix.z, out_bufix.w), mat2_strides);
}

int mat2_stride;
if (mat2_is_transposed > 0) {
mat2_stride = mat2_strides.x;
} else {
mat2_stride = mat2_strides.y;
}

T sum = T(0.0);
for (int i = 0; i < mat1_sizes.x; ++i) {
sum += t_mat1[mat1_bufi] * t_mat2[mat2_bufi];

mat1_bufi += mat1_strides.x;
mat2_bufi += mat2_strides.y;
mat2_bufi += mat2_stride;
}

const int out_bufi = tidx_to_bufi(out_bufix, out_strides);
Expand Down
3 changes: 3 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,9 +279,12 @@ void linear(ComputeGraph& graph, const std::vector<ValueRef>& args) {
ValueRef weight = prepack_standard(
graph, weight_data, graph.storage_type_of(out), utils::kWidthPacked);
ValueRef mat2_is_transposed = graph.add_scalar(true);

if (graph.val_is_none(bias)) {
return add_matmul_node(graph, input, weight, out, mat2_is_transposed);
} else {
// Buffer implementation does not yet support biases
VK_CHECK_COND(!graph.is_buffer_storage(out));
return add_addmm_node(
graph,
bias,
Expand Down
7 changes: 6 additions & 1 deletion backends/vulkan/runtime/graph/ops/impl/MatMul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ void add_matmul_naive_buffer_node(
graph.size_at<uint32_t>(-2, out),
graph.size_at<uint32_t>(-3, out) * graph.size_at<uint32_t>(-4, out)};

int mat2_is_transposed_val = (mat2_is_transposed != kDummyValueRef &&
graph.get_bool(mat2_is_transposed))
? 1
: 0;

graph.execute_nodes().emplace_back(new DispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
Expand All @@ -96,7 +101,7 @@ void add_matmul_naive_buffer_node(
graph.numel_ubo(out),
},
// Specialization Constants
{},
{mat2_is_transposed_val},
// Resizing Logic
resize_matmul_node,
{mat2_is_transposed}));
Expand Down
26 changes: 24 additions & 2 deletions backends/vulkan/test/op_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,7 @@ def get_addmm_inputs():
]


@register_test_suite("aten.linear.default")
def get_linear_inputs():
def get_linear_texture_inputs():
MKN_list = common_MKN_list

inputs_list = [((M, K), (N, K), None) for M, K, N in MKN_list]
Expand All @@ -142,9 +141,32 @@ def get_linear_inputs():
"utils::kWidthPacked",
"utils::kChannelsPacked",
]
test_suite.test_name_suffix = "texture"
return test_suite


def get_linear_buffer_inputs():
MKN_list = common_MKN_list

inputs_list = [((M, K), (N, K), None) for M, K, N in MKN_list]
inputs_list += [((3, M, K), (N, K), None) for M, K, N in MKN_list]

test_suite = VkTestSuite(inputs_list)
test_suite.dtypes = ["at::kFloat"]
test_suite.layouts = [
"utils::kWidthPacked",
"utils::kChannelsPacked",
]
test_suite.storage_types = ["utils::kBuffer"]
test_suite.test_name_suffix = "buffer"
return test_suite


@register_test_suite("aten.linear.default")
def get_linear_test_suites():
return [get_linear_texture_inputs(), get_linear_buffer_inputs()]


@register_test_suite("aten._weight_int8pack_mm.default")
def get_weight_int8pack_mm_inputs():
MKN_list = common_MKN_list
Expand Down
Loading