From 720b483e0dc105b07780b1c78d73fc77e53b47c6 Mon Sep 17 00:00:00 2001 From: Ervin Tasnadi Date: Sat, 21 Jun 2025 11:43:47 +0200 Subject: [PATCH 1/9] ggml/ggml-vulkan/test-backend-ops: adds CONV_2D for Vulkan * ggml-vulkan: adds f32 scalar shader to compute 2D convolution directly with gemm (no need for im2col), * test-backend-ops: adds test_case_ref to check the validity/performance of ops against reference implementations having different graphs, adds tests --- ggml/include/ggml-backend.h | 2 + ggml/src/ggml-backend.cpp | 49 ++ ggml/src/ggml-vulkan/ggml-vulkan.cpp | 181 +++++- .../ggml-vulkan/vulkan-shaders/conv2d_mm.comp | 244 ++++++++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 2 + tests/test-backend-ops.cpp | 530 +++++++++++++++++- 6 files changed, 1004 insertions(+), 4 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index a2977ea2e56d9..1bd7520178bce 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -340,6 +340,8 @@ extern "C" { // Compare the output of two backends GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor * test_node); + // Compare the output of two backends, graphs can be different and only the selected nodes will be compared + GGML_API bool ggml_backend_compare_graph_backend_node(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph1, struct ggml_cgraph * graph2, ggml_backend_eval_callback callback, void * user_data, char* op_name_out_1, char* op_name_out_2); // Tensor initialization GGML_API enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr); diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 788861a365fab..b3d42975b6fc9 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -1882,6 +1882,55 @@ bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t return true; } +bool ggml_backend_compare_graph_backend_node( + ggml_backend_t backend1, + ggml_backend_t backend2, + struct ggml_cgraph * graph1, + struct ggml_cgraph * graph2, + ggml_backend_eval_callback callback, void * user_data, char* op_name_out_1, char* op_name_out_2) { + + ggml_tensor * out1 = NULL; + ggml_tensor * out2 = NULL; + + struct ggml_cgraph * g1 = graph1; + struct ggml_cgraph * g2 = graph2; + + for (int i = 0; i < g1->n_nodes; i++) { + struct ggml_tensor * t1 = g1->nodes[i]; + struct ggml_cgraph g1v = ggml_graph_view(g1, i, i + 1); + ggml_backend_graph_compute(backend1, &g1v); + if (ggml_is_view_op(t1->op)) { + continue; + } + if(strcmp(t1 -> name, op_name_out_1) == 0){ + out1 = t1; + } + } + + for (int i = 0; i < g2->n_nodes; i++) { + struct ggml_tensor * t2 = g2->nodes[i]; + struct ggml_cgraph g2v = ggml_graph_view(g2, i, i + 1); + ggml_backend_graph_compute(backend2, &g2v); + if (ggml_is_view_op(t2->op)) { + continue; + } + if(strcmp(t2 -> name, op_name_out_2) == 0){ + out2 = t2; + } + } + + assert(out1 != NULL); + assert(out2 != NULL); + assert(ggml_are_same_layout(out1, out2)); + + // compare results, calculate rms etc + if (!callback(0, out1, out2, user_data)) { + return false; + } + + return true; +} + // CPU backend - buffer static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) { diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 416ee3bd3f70a..7cc5217e5a253 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -482,6 +482,7 @@ struct vk_device_struct { vk_pipeline pipeline_rwkv_wkv6_f32; vk_pipeline pipeline_rwkv_wkv7_f32; vk_pipeline pipeline_opt_step_adamw_f32; + vk_pipeline pipeline_conv2d_f32; vk_pipeline pipeline_conv2d_dw_whcn_f32; vk_pipeline pipeline_conv2d_dw_cwhn_f32; @@ -875,6 +876,38 @@ struct vk_op_rwkv_wkv7_push_constants { uint32_t H; }; +struct vk_op_conv2d_push_constants { + uint32_t Cout; + uint32_t Cin; + uint32_t N; + + uint32_t KW; + uint32_t KH; + uint32_t W; + uint32_t H; + uint32_t OW; + uint32_t OH; + + uint32_t s0; + uint32_t s1; + uint32_t p0; + uint32_t p1; + uint32_t d0; + uint32_t d1; + + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + + uint32_t nb1; + uint32_t nb2; + uint32_t nb3; +}; + struct vk_op_conv2d_dw_push_constants { uint32_t ne; uint32_t batches; @@ -976,16 +1009,33 @@ class vk_memory_logger { class vk_perf_logger { public: void print_timings() { + if(timings.empty()){ + return; + } std::cerr << "----------------\nVulkan Timings:" << std::endl; for (const auto& t : timings) { uint64_t total = 0; for (const auto& time : t.second) { total += time; } - std::cerr << t.first << ": " << t.second.size() << " x " << (total / t.second.size() / 1000.0) << " us" << std::endl; + std::cerr << t.first << ": " << t.second.size() << " x " << (total / t.second.size() / 1000.0) << " us"; + + // If we have as many flops entries as timing entries for the op, then compute and log the flops/S. + auto it = flops.find(t.first); + if(it != flops.end() && (it->second).size() == t.second.size()){ + uint64_t total_nflops = 0; + for(const auto& elem : it->second){ + total_nflops += elem; + } + std::cout << " (" << (double(total_nflops)/(1000.0*1000.0*1000.0)) / (double(total)/(1000.0*1000.0*1000.0)) << " GFLOPS/s)"; + } + + + std::cerr << std::endl; } timings.clear(); + flops.clear(); } void log_timing(const ggml_tensor * node, uint64_t time) { @@ -1004,12 +1054,33 @@ class vk_perf_logger { name += " m=" + std::to_string(m) + " n=" + std::to_string(n) + " k=" + std::to_string(k); } timings[name].push_back(time); + flops[name].push_back( m*n*(k+(k-1)) ); return; } + if(node->op == GGML_OP_CONV_2D){ + std::string name = ggml_op_name(node->op); + ggml_tensor * knl = node->src[0]; + uint64_t OW = node->ne[0]; + uint64_t OH = node->ne[1]; + uint64_t N = node->ne[3]; + uint64_t Cout = node->ne[2]; + uint64_t KW = knl->ne[0]; + uint64_t KH = knl->ne[1]; + uint64_t Cin = knl->ne[2]; + // KxCRS @ CRSxNPQ = KxNPQ -> M=K, K=CRS, N=NPQ + uint64_t size_M = Cout; + uint64_t size_K = Cin*KW*KH; + uint64_t size_N = N*OW*OH; + uint64_t n_flops = size_M*size_N*(size_K+(size_K-1)); + flops[name].push_back(n_flops); + timings[name].push_back(time); + return; + } timings[ggml_op_name(node->op)].push_back(time); } private: std::map> timings; + std::map> flops; }; struct ggml_backend_vk_context { @@ -2955,6 +3026,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3, sizeof(vk_op_conv2d_push_constants), {128 /* equal to BS_K in the shader */, 128 /* equal to BS_NPQ in the shader */, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); @@ -6803,6 +6876,16 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_leaky_relu_f32; } return nullptr; + case GGML_OP_CONV_2D: + if (src0->type == GGML_TYPE_F32 && + src1->type == GGML_TYPE_F32 && + dst->type == GGML_TYPE_F32 && + ggml_is_contiguous(src0) && + ggml_is_contiguous(src1) && + ggml_is_contiguous(dst)) { + return ctx->device->pipeline_conv2d_f32; + } + return nullptr; case GGML_OP_CONV_2D_DW: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { if (ggml_is_contiguous(src1)) { @@ -7125,6 +7208,30 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co const uint32_t OW = dst->ne[0]; elements = { N * OC * OH * OW, 1, 1}; } break; + case GGML_OP_CONV_2D: + { + // src0 - kernel: [KW, KH, Cin, Cout] + // src1 - input: [W, H, Cin, N] + // dst - result: [OW, OH, Cout, N] + + // Copied from ggml.c: int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d) + auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t { + return (ins + 2 * p - d * (ks - 1) - 1) / s + 1; + }; + // parallelize in {OW/BS_K, OH/BS_NPQ, 1} + int64_t W = src1->ne[0]; + int64_t H = src1->ne[1]; + int64_t KW = src0->ne[0]; + int64_t KH = src0->ne[1]; + int64_t Cout = src0->ne[3]; + int64_t N = src1->ne[3]; + int64_t OH = calc_conv_output_size(H, KH, dst->op_params[1], dst->op_params[3], dst->op_params[5]); + int64_t OW = calc_conv_output_size(W, KW, dst->op_params[0], dst->op_params[2], dst->op_params[4]); + int64_t NPQ = N*OW*OH; + + // Tile output matrix to (K/NB_K, NPQ/NB_NPQ, 1) workgroups + elements = {static_cast(Cout), static_cast(NPQ), 1}; + } break; case GGML_OP_ADD: case GGML_OP_SUB: case GGML_OP_DIV: @@ -7991,6 +8098,55 @@ static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, c }, dryrun); } +static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(nb00 == sizeof(float)); + GGML_ASSERT(nb10 == sizeof(float)); + GGML_ASSERT(nb0 == sizeof(float)); + + vk_op_conv2d_push_constants p{}; + p.Cout = static_cast(ne03); + p.Cin = static_cast(ne02); + p.N = static_cast(ne13); + + p.KW = static_cast(ne00); + p.KH = static_cast(ne01); + p.W = static_cast(ne10); + p.H = static_cast(ne11); + p.OW = static_cast(ne0); + p.OH = static_cast(ne1); + + p.s0 = static_cast(dst->op_params[0]); + p.s1 = static_cast(dst->op_params[1]); + p.p0 = static_cast(dst->op_params[2]); + p.p1 = static_cast(dst->op_params[3]); + p.d0 = static_cast(dst->op_params[4]); + p.d1 = static_cast(dst->op_params[5]); + + p.nb01 = static_cast(nb01/nb00); + p.nb02 = static_cast(nb02/nb00); + p.nb03 = static_cast(nb03/nb00); + + p.nb11 = static_cast(nb11/nb10); + p.nb12 = static_cast(nb12/nb10); + p.nb13 = static_cast(nb13/nb10); + + p.nb1 = static_cast(nb1 / nb0); + p.nb2 = static_cast(nb2 / nb0); + p.nb3 = static_cast(nb3 / nb0); + + GGML_ASSERT(ne03 == ne2); + GGML_ASSERT(ne02 == ne12); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D, std::move(p), dryrun); + +} + static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { vk_op_conv2d_dw_push_constants p{}; p.ne = ggml_nelements(dst); @@ -9053,6 +9209,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_POOL_2D: + case GGML_OP_CONV_2D: case GGML_OP_CONV_2D_DW: case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV7: @@ -9120,6 +9277,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_POOL_2D: + case GGML_OP_CONV_2D: case GGML_OP_CONV_2D_DW: case GGML_OP_LEAKY_RELU: { @@ -9326,6 +9484,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_POOL_2D: ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun); + break; + case GGML_OP_CONV_2D: + ggml_vk_conv_2d(ctx, compute_ctx, src0, src1, node, dryrun); + break; case GGML_OP_CONV_2D_DW: ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node, dryrun); @@ -9456,6 +9618,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_POOL_2D: + case GGML_OP_CONV_2D: case GGML_OP_CONV_2D_DW: case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV7: @@ -10617,6 +10780,14 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return true; case GGML_OP_CONV_TRANSPOSE_1D: return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32; + case GGML_OP_CONV_2D: + // Channel-contiguous format is not supported yet. + return (op->src[0]->type == GGML_TYPE_F32 && + op->src[1]->type == GGML_TYPE_F32 && + op->type == GGML_TYPE_F32 && + ggml_is_contiguous(op->src[0]) && + ggml_is_contiguous(op->src[1]) && + ggml_is_contiguous(op)); default: return false; } @@ -11175,6 +11346,14 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * const int32_t p1 = tensor->op_params[6]; tensor_clone = ggml_pool_2d(ggml_ctx, src_clone[0], op, k0, k1, s0, s1, p0, p1); + } else if (tensor->op == GGML_OP_CONV_2D) { + const int32_t s0 = tensor->op_params[0]; + const int32_t s1 = tensor->op_params[1]; + const int32_t p0 = tensor->op_params[2]; + const int32_t p1 = tensor->op_params[3]; + const int32_t d0 = tensor->op_params[4]; + const int32_t d1 = tensor->op_params[5]; + tensor_clone = ggml_conv_2d(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1); } else if (tensor->op == GGML_OP_LEAKY_RELU) { const float * op_params = (const float *)tensor->op_params; tensor_clone = ggml_leaky_relu(ggml_ctx, src_clone[0], op_params[0], false); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp new file mode 100644 index 0000000000000..0ff942d7e2993 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp @@ -0,0 +1,244 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable + +#include "types.comp" + +// shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j +layout (binding = 0) readonly buffer A {A_TYPE knl_data[];}; // src0 - kernel: [KW, KH, Cin, Cout] +layout (binding = 1) readonly buffer B {B_TYPE src_data[];}; // src1 - input: [W, H, Cin, N] -- channel_first format +layout (binding = 2) writeonly buffer D {D_TYPE dst_data[];}; // dst - result: [OW, OH, Cout, N] + +layout (push_constant) uniform parameter { + // I/O channels, batch size + uint32_t Cout; + uint32_t Cin; + uint32_t N; + + // Tensor spatial sizes: kernel, input, output + uint32_t KW; + uint32_t KH; + uint32_t W; + uint32_t H; + uint32_t OW; + uint32_t OH; + + // Parameters: stride, padding, dilation - 0=y, 1=x + uint32_t s0; + uint32_t s1; + uint32_t p0; + uint32_t p1; + uint32_t d0; + uint32_t d1; + + // Strides in elements + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + + uint32_t nb1; + uint32_t nb2; + uint32_t nb3; +} p; + +#define WG_SIZE 256 + +layout(local_size_x = WG_SIZE, local_size_y = 1, local_size_z = 1) in; + +uint32_t tid = gl_LocalInvocationID.x; +const uint32_t bs = gl_WorkGroupSize.x; + +uint splitWork(uint work_size, uint block_size){ + return (block_size + work_size -1) / block_size; +} + +uint32_t K = p.Cout; +uint32_t CRS = p.Cin*p.KH*p.KW; +uint32_t NPQ = p.N*p.OH*p.OW; + +uint32_t n_elems_out = K*NPQ; + +// Blocktile sizes +const uint32_t BS_K = 128; +const uint32_t BS_CRS = 16; +const uint32_t BS_NPQ = 128; + +// Number of blocktiles per input +uint32_t NB_CRS = splitWork(CRS, BS_CRS); + +const uint32_t Ash_stride = BS_CRS+1; +const uint32_t Bsh_stride = BS_NPQ+1; + +const uint32_t Ash_numel = BS_K*BS_CRS; +const uint32_t Bsh_numel = BS_CRS*BS_NPQ; + +const uint32_t Ash_len = BS_K*Ash_stride; +const uint32_t Bsh_len = BS_CRS*Bsh_stride; + +shared float Ash[Ash_len]; // K x CRS +shared float Bsh[Bsh_len]; // CRS x NPQ + +// Threadtile sizes +const uint32_t TS_K = 16; +const uint32_t TS_NPQ = BS_K*BS_NPQ / WG_SIZE / TS_K; + +// Number of threadtiles per blocktile +const uint32_t NT_K = BS_K / TS_K; +const uint32_t NT_NPQ = BS_NPQ / TS_NPQ; + +float regA[TS_K]; +float regB[TS_NPQ]; +float regC[TS_K][TS_NPQ]; + +/* +Compute +KxCRS @ CRSxNPQ = K x NPQ +K=Cout +C=Cin +R,S=KH,KW +P,Q=OH,OW +*/ + +uint32_t B_idx_K = gl_WorkGroupID.x; +uint32_t B_idx_NPQ = gl_WorkGroupID.y; + +uint32_t T_y = tid / NT_NPQ; +uint32_t T_x = tid % NT_NPQ; + +uint32_t Ar = tid / BS_CRS; +uint32_t Ac = tid % BS_CRS; +uint32_t ArpWg = WG_SIZE / BS_CRS; + +uint32_t Br = tid / BS_NPQ; +uint32_t Bc = tid % BS_NPQ; +uint32_t BrpWg = WG_SIZE / BS_NPQ; + +void initReg(){ + for(uint32_t T_ly = 0; T_ly < TS_K; T_ly++){ + for(uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++){ + regC[T_ly][T_lx] = 0.0; + } + } +} + +void outProdReg(){ + for(uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++){ + for(uint32_t T_ly = 0; T_ly < TS_K; T_ly++){ + regA[T_ly] = Ash[(T_y*TS_K + T_ly)*Ash_stride + CRS_lidx]; + } + for(uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++){ + regB[T_lx] = Bsh[CRS_lidx*Bsh_stride + T_x*TS_NPQ+T_lx]; + } + for(uint32_t T_ly = 0; T_ly < TS_K; T_ly++){ + for(uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++){ + regC[T_ly][T_lx] += regA[T_ly] * regB[T_lx]; + } + } + } +} + +// Generate different functions for computing the sides. + +#define NOOP() + +#define DEF_BOUNDARY_CONDITION_A_IF()\ +if(K_idx < K && CRS_idx < CRS){ + +#define DEF_BOUNDARY_CONDITION_A_ELSE()\ +}else{\ + Ash[B_ly * Ash_stride + B_lx] = 0.0;\ +} + +#define DEF_BOUNDARY_CONDITION_B_IF()\ +if(CRS_idx < CRS && NPQ_idx < NPQ){ + +#define DEF_BOUNDARY_CONDITION_B_ELSE()\ +}else{\ + Bsh[B_ly * Bsh_stride + B_lx] = 0.0;\ +} + +#define MAIN_LOOP(FUNC_NAME_SUFFIX, BOUNDARY_CONDITION_A_IF, BOUNDARY_CONDITION_A_ELSE, BOUNDARY_CONDITION_B_IF, BOUNDARY_CONDITION_B_ELSE)\ +void mainLoop ## FUNC_NAME_SUFFIX(){\ + initReg();\ + /* Advance block in CRS dim */\ + for(uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++){\ + /* Load kernel to A_block: (BS_K x BS_CRS)*/\ + for(uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg){\ + uint32_t B_ly = r_offset + Ar;\ + uint32_t B_lx = Ac;\ + uint32_t K_idx = B_idx_K*BS_K + B_ly; /* Global K_idx (row index of A)*/\ + uint32_t CRS_idx = B_idx_CRS*BS_CRS + B_lx; /* Global CRS_idx (column index of A)*/\ + BOUNDARY_CONDITION_A_IF()\ + uint32_t Cin_idx = CRS_idx / (p.KW*p.KH);\ + uint32_t KH_idx = (CRS_idx - Cin_idx*p.KW*p.KH) / p.KW;\ + uint32_t KW_idx = CRS_idx - Cin_idx*p.KW*p.KH - KH_idx*p.KW;\ + uint32_t knl_idx = KW_idx + KH_idx*p.nb01 + Cin_idx*p.nb02 + K_idx*p.nb03;\ + Ash[B_ly * Ash_stride + B_lx] = knl_data[knl_idx];\ + BOUNDARY_CONDITION_A_ELSE()\ + }\ + barrier();\ + /* Load input to B_block: (BS_CRS x BS_NPQ) */\ + for(uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg){\ + uint32_t B_ly = r_offset + Br; /* Row index of B block */\ + uint32_t B_lx = Bc; /* Column index of B block */\ + uint32_t CRS_idx = B_idx_CRS*BS_CRS + B_ly; /* Global CRS index (row index of B) */\ + uint32_t NPQ_idx = B_idx_NPQ*BS_NPQ + B_lx; /* Global NPQ index (column index of B) */\ + BOUNDARY_CONDITION_B_IF()\ + uint32_t Cin_idx = CRS_idx / (p.KW*p.KH);\ + uint32_t KH_idx = (CRS_idx - Cin_idx*p.KW*p.KH) / p.KW;\ + uint32_t KW_idx = CRS_idx - Cin_idx*p.KW*p.KH - KH_idx*p.KW;\ + uint32_t N_idx = NPQ_idx / (p.OH*p.OW);\ + uint32_t OH_idx = (NPQ_idx - N_idx*p.OH*p.OW) / p.OW;\ + uint32_t OW_idx = NPQ_idx - N_idx*p.OH*p.OW - OH_idx*p.OW;\ + uint32_t H_idx = OH_idx*p.s1 + KH_idx*p.d1 - p.p1;\ + uint32_t W_idx = OW_idx*p.s0 + KW_idx*p.d0 - p.p0;\ + if(H_idx >= 0 && H_idx < p.H && W_idx >= 0 && W_idx < p.W){\ + uint32_t src_idx = W_idx + H_idx*p.nb11 + Cin_idx*p.nb12 + N_idx*p.nb13;\ + Bsh[B_ly * Bsh_stride + B_lx] = src_data[src_idx];\ + }else{\ + Bsh[B_ly * Bsh_stride + B_lx] = 0.0;\ + }\ + BOUNDARY_CONDITION_B_ELSE()\ + }\ + barrier();\ + outProdReg();\ + barrier();\ + }\ + /* Save C* */\ + for(uint32_t T_ly = 0; T_ly < TS_K; T_ly++){\ + for(uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++){\ + uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly;\ + uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx;\ + if(K_idx < K && NPQ_idx < NPQ){\ + uint32_t N_idx = NPQ_idx / (p.OH*p.OW);\ + uint32_t OH_idx = (NPQ_idx - N_idx*p.OH*p.OW) / p.OW;\ + uint32_t OW_idx = NPQ_idx - N_idx*p.OH*p.OW - OH_idx*p.OW;\ + uint32_t dst_idx = OW_idx + OH_idx*p.nb1 + K_idx*p.nb2 + N_idx*p.nb3;\ + dst_data[dst_idx] = regC[T_ly][T_lx];\ + }\ + }\ + }\ +} + +// Generates mainLoopBoundaryCheck +MAIN_LOOP(BoundaryCheck, + DEF_BOUNDARY_CONDITION_A_IF, + DEF_BOUNDARY_CONDITION_A_ELSE, + DEF_BOUNDARY_CONDITION_B_IF, + DEF_BOUNDARY_CONDITION_B_ELSE) + +// Generates mainLoopNoBoundaryCheck +MAIN_LOOP(NoBoundaryCheck, + NOOP, NOOP, NOOP, NOOP) + +void main(){ + if(gl_WorkGroupID.x == gl_NumWorkGroups.x-1 || gl_WorkGroupID.y == gl_NumWorkGroups.y-1){ + mainLoopBoundaryCheck(); + }else{ + mainLoopNoBoundaryCheck(); + } +} \ No newline at end of file diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index d4a4e4c5290d8..e6432c356fb80 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -650,6 +650,8 @@ void process_shaders() { string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + string_to_spv("conv2d_f32", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}})); string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}})); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 81fe90b99323d..57ab18ad660e5 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1020,7 +1020,7 @@ struct test_case { return t; } - bool eval(ggml_backend_t backend1, ggml_backend_t backend2, const char * op_name, printer * output_printer) { + virtual bool eval(ggml_backend_t backend1, ggml_backend_t backend2, const char * op_name, printer * output_printer) { mode = MODE_TEST; ggml_init_params params = { @@ -1240,25 +1240,56 @@ struct test_case { // determine number of runs int n_runs; bool is_cpu = ggml_backend_dev_type(ggml_backend_get_device(backend)) == GGML_BACKEND_DEVICE_TYPE_CPU; + + // how many nodes are added by each op + uint32_t nodes_per_op = 1; + if(op_desc(out) == "CONV_2D_INDIRECT_IMPL"){ + nodes_per_op = 8; + } + if (op_flops(out) > 0) { // based on flops const uint64_t GFLOP = 1000 * 1000 * 1000; const uint64_t target_flops_cpu = 8ULL * GFLOP; const uint64_t target_flops_gpu = 100ULL * GFLOP; uint64_t target_flops = is_cpu ? target_flops_cpu : target_flops_gpu; - n_runs = std::min(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_flops / op_flops(out)) + 1; + n_runs = std::min((ggml_graph_size(gf) - ggml_graph_n_nodes(gf))/nodes_per_op, target_flops / op_flops(out)) + 1; } else { // based on memory size const size_t GB = 1ULL << 30; const size_t target_size_cpu = 8 * GB; const size_t target_size_gpu = 32 * GB; size_t target_size = is_cpu ? target_size_cpu : target_size_gpu; - n_runs = std::min(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_size / op_size(out)) + 1; + n_runs = std::min((ggml_graph_size(gf) - ggml_graph_n_nodes(gf))/nodes_per_op, target_size / op_size(out)) + 1; } // duplicate the op for (int i = 1; i < n_runs; i++) { ggml_graph_add_node(gf, out); + + if(op_desc(out) == "CONV_2D_INDIRECT_IMPL"){ + /* + TODO: add a permanent solution! E.g. return the list of tensors + needed to add for computing the op in build_graph(). + + Adds the full ggml_conv_2d() computation graph, not just the output! + * cont (out) + * cont (out->src[0]) + * permute (out->src[0]->...) + * reshape + * mul_mat + * reshape + * im2col + * reshape + */ + ggml_graph_add_node(gf, out->src[0]); // cont + ggml_graph_add_node(gf, out->src[0]->src[0]); // permute + ggml_graph_add_node(gf, out->src[0]->src[0]->src[0]); // reshape + ggml_graph_add_node(gf, out->src[0]->src[0]->src[0]->src[0]); // mul_mat + ggml_graph_add_node(gf, out->src[0]->src[0]->src[0]->src[0]->src[0]); // reshape + ggml_graph_add_node(gf, out->src[0]->src[0]->src[0]->src[0]->src[0]->src[0]); // im2col + ggml_graph_add_node(gf, out->src[0]->src[0]->src[0]->src[0]->src[1]); // reshape + } } // calculate memory @@ -1599,6 +1630,262 @@ struct test_case { } }; +// This can be useful to compare the output/performance of +// different graphs implementing the same op. +// Possible use cases: +// * no CPU implementation exists for the op, but the op +// can be built by combining elementary ops already having implementation +// and the user wants to compare the results. +// * comparing the performance of different implementations +// of the op: graph revwriting/operation fusion. E.g. basic attention +// compared to flash attention or conv compared with im2col->matmul. +struct test_case_ref : public test_case { +public: + ggml_cgraph * gf_ref = nullptr; + + // Output tensor names to compare + const char* output_node_name_ref; + const char* output_node_name; + + // Input tensor names in (actual graph, reference graph) + std::vector> input_names = { + {"input", "input"}, + {"kernel", "kernel"} + }; + + // Copies the inputs of the graph built using build_graph() to the reference graph + virtual void copy_data_to_ref(ggml_context * ctx, ggml_context * ctx_ref){ + std::map inputs; + std::map inputs_ref; + + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) { + for(auto e : input_names){ + if(e.first == t->name){ + inputs[e.first] = t; + } + } + } + + for (ggml_tensor * t = ggml_get_first_tensor(ctx_ref); t != nullptr; t = ggml_get_next_tensor(ctx_ref, t)) { + for(auto e : input_names){ + if(e.second == t->name){ + inputs_ref[e.second] = t; + } + } + } + + for(auto e : input_names){ + GGML_ASSERT(inputs.count(e.first) == 1); + GGML_ASSERT(inputs_ref.count(e.second) == 1); + std::vector buf(ggml_nbytes(inputs[e.first])); + ggml_backend_tensor_get(inputs[e.first], buf.data(), 0, ggml_nbytes(inputs[e.first])); + ggml_backend_tensor_set(inputs_ref[e.second], buf.data(), 0, buf.size()); + } + } + + // Graph of the reference op implementation + virtual ggml_tensor * build_graph_ref(ggml_context * ctx) = 0; + + // Compares the output of the actual graph to the output of the reference + bool eval(ggml_backend_t backend1, ggml_backend_t backend2, const char * op_name, printer * output_printer) override { + mode = MODE_TEST; + + ggml_init_params params = { + /* .mem_size = */ ggml_tensor_overhead()*128 + ggml_graph_overhead(), + /* .mem_base = */ NULL, + /* .no_alloc = */ true, + }; + ggml_context * ctx = ggml_init(params); + ggml_context * ctx_ref = ggml_init(params); + GGML_ASSERT(ctx); + GGML_ASSERT(ctx_ref); + + gf = ggml_new_graph(ctx); + gf_ref = ggml_new_graph(ctx_ref); + + // pre-graph sentinel + add_sentinel(ctx); + add_sentinel(ctx_ref); + + ggml_tensor * out = build_graph(ctx); + ggml_tensor * out_ref = build_graph_ref(ctx_ref); + + std::string current_op_name = op_desc(out); + + if (op_name != nullptr && op_desc(out) != op_name) { + //printf(" %s: skipping\n", op_desc(out).c_str()); + ggml_free(ctx); + return true; + } + + // check if the backends support the ops + bool supported = true; + ggml_backend* backend_tested = nullptr; + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + if (!ggml_backend_supports_op(backend1, t)) { + supported = false; + backend_tested = backend1; + break; + } + } + + if(supported){ + for (ggml_tensor * t = ggml_get_first_tensor(ctx_ref); t != NULL; t = ggml_get_next_tensor(ctx_ref, t)) { + if (!ggml_backend_supports_op(backend2, t)) { + supported = false; + backend_tested = backend2; + break; + } + } + } + + if (!supported) { + // Create test result for unsupported operation + test_result result(ggml_backend_name(backend_tested), current_op_name, vars(), "test", + false, false, "not supported"); + if (output_printer) { + output_printer->print_test_result(result); + } + + ggml_free(ctx); + return true; + } + + // post-graph sentinel + add_sentinel(ctx); + add_sentinel(ctx_ref); + + // allocate + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx, backend1); + + if (buf == NULL) { + printf("failed to allocate tensors [%s] ", ggml_backend_name(backend1)); + ggml_free(ctx); + return false; + } + + ggml_backend_buffer_t buf_ref = ggml_backend_alloc_ctx_tensors(ctx_ref, backend2); + if (buf_ref == NULL) { + printf("failed to allocate tensors [%s] ", ggml_backend_name(backend2)); + ggml_free(ctx_ref); + return false; + } + + // build graph + ggml_build_forward_expand(gf, out); + ggml_build_forward_expand(gf_ref, out_ref); + + // add sentinels as graph nodes so that they are checked in the callback + for (ggml_tensor * sentinel : sentinels) { + ggml_graph_add_node(gf, sentinel); + ggml_graph_add_node(gf_ref, sentinel); + } + + // randomize tensors + initialize_tensors(ctx); + copy_data_to_ref(ctx, ctx_ref); + + // compare + struct callback_userdata { + bool ok; + double max_err; + ggml_backend_t backend1; + ggml_backend_t backend2; + }; + + callback_userdata ud { + true, + max_nmse_err(), + backend1, + backend2 + }; + + auto callback = [](int index, ggml_tensor * t1, ggml_tensor * t2, void * user_data) -> bool { + callback_userdata * ud = (callback_userdata *) user_data; + const char * bn1 = ggml_backend_name(ud->backend1); + const char * bn2 = ggml_backend_name(ud->backend2); + + if (t1->op == GGML_OP_NONE) { + // sentinels must be unchanged + std::vector t1_data(ggml_nbytes(t1)); + std::vector t2_data(ggml_nbytes(t2)); + ggml_backend_tensor_get(t1, t1_data.data(), 0, ggml_nbytes(t1)); + ggml_backend_tensor_get(t2, t2_data.data(), 0, ggml_nbytes(t2)); + + if (memcmp(t1_data.data(), t2_data.data(), ggml_nbytes(t1)) != 0) { + printf("sentinel mismatch: %s ", t1->name); + ud->ok = false; + return true; + } + } + + std::vector f1 = tensor_to_float(t1); + std::vector f2 = tensor_to_float(t2); + + for (size_t i = 0; i < f1.size(); i++) { + // check for nans + if (std::isnan(f1[i]) || std::isnan(f2[i])) { + printf("[%s] NaN at index %zu (%s=%f %s=%f) ", ggml_op_desc(t1), i, bn1, f1[i], bn2, f2[i]); + ud->ok = false; + return true; + } + // check for infs: both must be inf of the same sign, or both must be finite + if (isinf_or_max(f1[i]) || isinf_or_max(f2[i])) { + if (isinf_or_max(f1[i]) && isinf_or_max(f2[i])) { + if (std::signbit(f1[i]) != std::signbit(f2[i])) { + printf("[%s] inf sign mismatch: %s=%f %s=%f ", ggml_op_desc(t1), bn1, f1[i], bn2, f2[i]); + ud->ok = false; + return true; + } + } else { + printf("[%s] inf mismatch: %s=%f %s=%f ", ggml_op_desc(t1), bn1, f1[i], bn2, f2[i]); + ud->ok = false; + return true; + } + } + } + + double err = nmse(f1.data(), f2.data(), f1.size()); + if (err > ud->max_err) { + printf("[%s] NMSE = %.9f > %.9f ", ggml_op_desc(t1), err, ud->max_err); + //for (int i = 0; i < (int) f1.size(); i++) { + // printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]); + //} + //printf("\n"); + //exit(1); + ud->ok = false; + } + return true; + + GGML_UNUSED(index); + }; + + + const bool cmp_ok = ggml_backend_compare_graph_backend_node(backend1, backend2, gf, gf_ref, callback, &ud, const_cast(output_node_name), const_cast(output_node_name_ref)); + + if (!cmp_ok) { + printf("compare failed "); + } + + ggml_backend_buffer_free(buf); + ggml_backend_buffer_free(buf_ref); + + ggml_free(ctx); + ggml_free(ctx_ref); + + // Create test result + bool test_passed = ud.ok && cmp_ok; + std::string error_msg = test_passed ? "" : (!cmp_ok ? "compare failed" : "test failed"); + test_result result(ggml_backend_name(backend1), current_op_name, vars(), "test", supported, test_passed, + error_msg); + + if (output_printer) { + output_printer->print_test_result(result); + } + + return test_passed; + } +}; // ################################### // ## Section 2: GGML Op Defintions ## @@ -3682,6 +3969,141 @@ struct test_im2col : public test_case { } }; +// Tests CONV_2D by comparing it to the IM2COL -> MUL_MM +// reference implementation. +struct test_conv_2d : public test_case_ref { + const std::array ne_input; + const std::array ne_kernel; + const int stride0; + const int stride1; + const int padding0; + const int padding1; + const int dilation0; + const int dilation1; + // Whether the inputs are contiguous in the channel dim or the width dim + const bool cwhn; + // If true, the direct CONV_2D will be used in the graph, otherwise it + // uses ggml_conv_2d: + // * if the program is called with -o CONV_2D_DIRECT_IMPL, the + // CONV_2D graph will be built, while + // * if the program is called with -o CONV_2D_INDIRECT_IMPL, the + // IM2COL -> MUL_MM graph will be built. + const bool direct_impl; + + virtual std::string op_desc(ggml_tensor * t) { + (void) t; + if(direct_impl){ + return std::string("CONV_2D_DIRECT_IMPL"); + }else{ + return std::string("CONV_2D_INDIRECT_IMPL"); + } + } + + std::string vars() override { + return VARS_TO_STR9(ne_input, ne_kernel, stride0, stride1, padding0, padding1, dilation0, dilation1, cwhn); + } + + uint64_t op_flops(ggml_tensor * t) override { + GGML_UNUSED(t); + // Just counting matmul costs: + // KxCRS @ CRSxNPQ = KxNPQ --> KxNPQx(CRS+CRS-1) flops + + // Copied from ggml.c: int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d) + auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t { + return (ins + 2 * p - d * (ks - 1) - 1) / s + 1; + }; + + int64_t W = ne_input[0]; + int64_t H = ne_input[1]; + int64_t KW = ne_kernel[0]; + int64_t KH = ne_kernel[1]; + int64_t Cin = ne_kernel[2]; + int64_t Cout = ne_kernel[3]; + int64_t N = ne_input[3]; + int64_t OH = calc_conv_output_size(H, KH, stride0, padding0, dilation0); + int64_t OW = calc_conv_output_size(W, KW, stride0, padding0, dilation0); + + int64_t K = Cout; + int64_t CRS = Cin*KH*KW; + int64_t NPQ = N*OH*OW; + + return K*NPQ*(2*CRS-1); + } + + + test_conv_2d(std::array ne_input = {64, 64, 16, 1}, + std::array ne_kernel = {3, 3, 1, 16}, + int stride0 = 1, int stride1 = 1, int padding0 = 0, int padding1 = 0, int dilation0 = 1, int dilation1 = 1, bool cwhn = false, bool direct_impl = true) + : ne_input(ne_input), ne_kernel(ne_kernel), stride0(stride0), stride1(stride1), padding0(padding0), padding1(padding1), dilation0(dilation0), dilation1(dilation1), cwhn(cwhn), direct_impl(direct_impl) { + output_node_name_ref = "out"; + output_node_name = "out"; + } + + ggml_tensor * build_graph_indirect(ggml_context * ctx) { + ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input.data()); + ggml_set_name(input, "input"); + + ggml_tensor * kernel = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_kernel.data()); + ggml_set_name(kernel, "kernel"); + + GGML_ASSERT(cwhn==false); + + if (cwhn) { + // change memory layout to channel-most-contiguous (CWHN), + // then permute it back so NE matches the original input + input = ggml_cont(ctx, ggml_permute(ctx, input, 1, 2, 0, 3)); + input = ggml_permute(ctx, input, 2, 0, 1, 3); + kernel = ggml_cont(ctx, ggml_permute(ctx, kernel, 2, 3, 1, 0)); + kernel = ggml_permute(ctx, kernel, 3, 2, 0, 1); + } + + ggml_tensor * conv2d_out = ggml_conv_2d( + ctx, kernel, input, + stride0, stride1, padding0, padding1, dilation0, dilation1); + + ggml_tensor * out = ggml_cont(ctx, conv2d_out); + + ggml_set_name(out, "out"); + return out; + } + + ggml_tensor * build_graph_direct(ggml_context * ctx) { + ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input.data()); + ggml_set_name(input, "input"); + + ggml_tensor * kernel = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_kernel.data()); + ggml_set_name(kernel, "kernel"); + + if (cwhn) { + // change memory layout to channel-most-contiguous (CWHN), + // then permute it back so NE matches the original input + input = ggml_cont(ctx, ggml_permute(ctx, input, 1, 2, 0, 3)); + input = ggml_permute(ctx, input, 2, 0, 1, 3); + kernel = ggml_cont(ctx, ggml_permute(ctx, kernel, 2, 3, 1, 0)); + kernel = ggml_permute(ctx, kernel, 3, 2, 0, 1); + } + + ggml_tensor * out = ggml_conv_2d_direct( + ctx, kernel, input, + stride0, stride1, padding0, padding1, dilation0, dilation1); + ggml_set_name(out, "out"); + return out; + } + + ggml_tensor * build_graph(ggml_context * ctx) override { + if(direct_impl){ + return build_graph_direct(ctx); + }else{ + return build_graph_indirect(ctx); + } + } + + // Reference always uses the indirect impl. + ggml_tensor * build_graph_ref(ggml_context * ctx) override { + return build_graph_indirect(ctx); + } +}; + // GGML_OP_CONV_2D_DW struct test_conv_2d_dw : public test_case { const std::array ne_input; @@ -4990,6 +5412,71 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 2560}, {3, 3, 1, 2560}, 1, 1, 1, 1, 1, 1, true)); test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2560}, {3, 3, 2, 2560}, 1, 1, 1, 1, 1, 1, true)); + // CONV_2D: + auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t { + return (ins + 2 * p - d * (ks - 1) - 1) / s + 1; + }; + + //uint32_t s0 = 3; + uint32_t s1 = 5; + uint32_t p0 = 5; + //uint32_t p1 = 2; + uint32_t d0 = 2; + uint32_t d1 = 4; + + for(uint32_t s0: {1, 3}){ + for(uint32_t p1: {2, 5}){ + for(uint32_t Cin : {1, 25}){ + for(uint32_t Cout : {1, 12}){ + for(uint32_t KH : {1, 2, 3, 11}){ + for(uint32_t KW : {1, 2, 3, 11}){ + for(uint32_t H : {1, 133}){ + for(uint32_t W : {1, 258}){ + if(calc_conv_output_size(W, KW, s0, p0, d0) > 0 && calc_conv_output_size(H, KH, s1, p1, d1) > 0){ + test_cases.emplace_back(new test_conv_2d({W, H, Cin, 2}, {KW, KH, Cin, Cout}, s0, s1, p0, p1, d0, d1, false, true)); + } + } + } + } + } + } + } + } + } + + uint32_t iwh_idx = 0; + uint32_t kwh_idx = 1; + uint32_t Cout_idx = 2; + uint32_t Cin_idx = 3; + uint32_t B_idx = 4; + std::vector> cases = { + //{IWH, KWH, Cout, Cin, B} + // K=CRS=NPQ=4096 conv2d matmul performance + {19, 4, 4096, 256, 16}, // --> fails + // K=128, CRS=128, NPQ=4096 + {19, 4, 128, 8, 16}, + // K=130, CRS=128, NPQ=4096 + {19, 4, 130, 8, 16}, + // Edge case: K x CRS is small + {19, 2, 4, 4, 16}, + // A ConvNet's first layer + {224, 3, 8, 3, 1}, + // A ConvNet's first layer with 2x2 convolution, and 1 channel + {224, 2, 8, 1, 1}, + // A ConvNet's first layer with 2x2 convolution, and 1 channel, several images in the batch + {224, 2, 8, 1, 8}, + // A middle layer of a ConvNet + {58, 3, 64, 32, 1}, + // A middle layer of a ConvNet, several images in the batch + {58, 3, 64, 32, 8}, + // A deep layer of a ConvNet, several images in the batch + {16, 3, 256, 128, 8} + }; + + for(auto act_case : cases){ + test_cases.emplace_back(new test_conv_2d({act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx]}, {act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx]}, 1, 1, 0, 0, 1, 1, false, true)); + } + // sycl backend will limit task global_range < MAX_INT // test cases for 2D im2col with large input W and H (occurs in stable-diffusion) // however these cases need to alloc more memory which may fail in some devices (Intel Arc770, etc.) @@ -5584,6 +6071,43 @@ static std::vector> make_test_cases_eval() { static std::vector> make_test_cases_perf() { std::vector> test_cases; + // Conv2d: K=CRS=NPQ=4096 matmul performance + uint32_t iwh_idx = 0; + uint32_t kwh_idx = 1; + uint32_t Cout_idx = 2; + uint32_t Cin_idx = 3; + uint32_t B_idx = 4; + std::vector> cases = { + //{IWH, KWH, Cout, Cin, B} + // K=CRS=NPQ=4096 conv2d matmul performance + {19, 4, 4096, 256, 16}, + // K=128, CRS=128, NPQ=4096 + {19, 4, 128, 8, 16}, + // K=130, CRS=128, NPQ=4096 + {19, 4, 130, 8, 16}, + // Edge case: K x CRS is small + {19, 2, 4, 4, 16}, + // A ConvNet's first layer + {224, 3, 8, 3, 1}, + // A ConvNet's first layer with 2x2 convolution, and 1 channel + {224, 2, 8, 1, 1}, + // A ConvNet's first layer with 2x2 convolution, and 1 channel, several images in the batch + {224, 2, 8, 1, 8}, + // A middle layer of a ConvNet + {58, 3, 64, 32, 1}, + // A middle layer of a ConvNet, several images in the batch + {58, 3, 64, 32, 8}, + // A deep layer of a ConvNet, several images in the batch + {16, 3, 512, 128, 8}, + }; + + for(auto act_case : cases){ + // Direct CONV_2D + test_cases.emplace_back(new test_conv_2d({act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx]}, {act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx]}, 1, 1, 0, 0, 1, 1, false, true)); + // Indirect CONV_2D (uses im2col + sgemm) + test_cases.emplace_back(new test_conv_2d({act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx]}, {act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx]}, 1, 1, 0, 0, 1, 1, false, false)); + } + test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 1, 1, 1})); test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 512, 1, 1})); From 0715985edfba0b40dde04254b5d80c8d00afb2a0 Mon Sep 17 00:00:00 2001 From: Ervin Tasnadi Date: Sat, 12 Jul 2025 03:13:51 +0200 Subject: [PATCH 2/9] * Performance fixes: minimized branch divergence, uses collectives to eliminate redundant calculation, macros removed. * Kernel shared memory size check * Updates test-backend-ops to support graphs for performance measurement. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 39 ++- .../ggml-vulkan/vulkan-shaders/conv2d_mm.comp | 236 ++++++++---------- 2 files changed, 140 insertions(+), 135 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 7cc5217e5a253..d54e8b569c0a4 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1011,29 +1011,35 @@ class vk_perf_logger { void print_timings() { if(timings.empty()){ return; - } + } + uint64_t total_all_op_times = 0; std::cerr << "----------------\nVulkan Timings:" << std::endl; for (const auto& t : timings) { - uint64_t total = 0; + uint64_t total_op_times = 0; for (const auto& time : t.second) { - total += time; + total_op_times += time; } - std::cerr << t.first << ": " << t.second.size() << " x " << (total / t.second.size() / 1000.0) << " us"; + std::cerr << t.first << ": " << t.second.size() << " x " << (total_op_times / t.second.size() / 1000.0) << " us"; // If we have as many flops entries as timing entries for the op, then compute and log the flops/S. auto it = flops.find(t.first); if(it != flops.end() && (it->second).size() == t.second.size()){ - uint64_t total_nflops = 0; + uint64_t total_op_flops = 0; for(const auto& elem : it->second){ - total_nflops += elem; + total_op_flops += elem; } - std::cout << " (" << (double(total_nflops)/(1000.0*1000.0*1000.0)) / (double(total)/(1000.0*1000.0*1000.0)) << " GFLOPS/s)"; + std::cerr << " (" << (double(total_op_flops)/(1000.0*1000.0*1000.0)) / (double(total_op_times)/(1000.0*1000.0*1000.0)) << " GFLOPS/s)"; } + total_all_op_times += total_op_times; std::cerr << std::endl; } + if(timings.size() > 0){ + std::cerr << "Total time: " << total_all_op_times/1000.0 << " us." << std::endl; + } + timings.clear(); flops.clear(); } @@ -1072,6 +1078,7 @@ class vk_perf_logger { uint64_t size_K = Cin*KW*KH; uint64_t size_N = N*OW*OH; uint64_t n_flops = size_M*size_N*(size_K+(size_K-1)); + name += " M=Cout=" + std::to_string(size_M) + ", K=Cin*KW*KH=" + std::to_string(size_K) + ", N=N*OW*OH=" + std::to_string(size_N); flops[name].push_back(n_flops); timings[name].push_back(time); return; @@ -3026,7 +3033,18 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3, sizeof(vk_op_conv2d_push_constants), {128 /* equal to BS_K in the shader */, 128 /* equal to BS_NPQ in the shader */, 1}, {}, 1); + // conv2d + uint32_t conv2d_WG_SIZE = 256; + uint32_t conv2d_BS_K = 128; + uint32_t conv2d_BS_CRS = 16; + uint32_t conv2d_BS_NPQ = 128; + uint32_t conv2d_TS_K = 8; + uint32_t conv2d_shmem_req = (conv2d_BS_K*(conv2d_BS_CRS+1) + conv2d_BS_CRS*(conv2d_BS_NPQ+1))*sizeof(float); + if(device->properties.limits.maxComputeSharedMemorySize < conv2d_shmem_req){ + conv2d_BS_CRS = 8; + conv2d_TS_K = 8; + } + ggml_vk_create_pipeline(device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3, sizeof(vk_op_conv2d_push_constants), {conv2d_BS_K, conv2d_BS_NPQ, 1}, {conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K}, 1); ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); @@ -10200,6 +10218,11 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false); if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) { total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]); + }else if(cgraph->nodes[i]->op == GGML_OP_CONV_2D){ + // Return CRSxNPQxsizeof(*) to account as many bytes as mul_mat has in im2col->mul_mat mode. + auto CRS_size = cgraph->nodes[i]->src[0]->ne[0]*cgraph->nodes[i]->src[0]->ne[1]*cgraph->nodes[i]->src[0]->ne[2]; + auto NPQ_size = cgraph->nodes[i]->ne[0]*cgraph->nodes[i]->ne[1]*cgraph->nodes[i]->ne[3]; + total_mat_mul_bytes += NPQ_size*CRS_size*ggml_type_size(cgraph->nodes[i]->type); } i += ctx->num_additional_fused_ops; ctx->num_additional_fused_ops = 0; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp index 0ff942d7e2993..3980040a3f2cb 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp @@ -1,9 +1,16 @@ #version 450 -#extension GL_EXT_control_flow_attributes : enable +#define USE_COLLECTIVES + +#ifdef USE_COLLECTIVES +#extension GL_KHR_shader_subgroup_shuffle: enable +#endif #include "types.comp" +// Make spec constant +#define SHMEM_PAD 0 + // shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j layout (binding = 0) readonly buffer A {A_TYPE knl_data[];}; // src0 - kernel: [KW, KH, Cin, Cout] layout (binding = 1) readonly buffer B {B_TYPE src_data[];}; // src1 - input: [W, H, Cin, N] -- channel_first format @@ -45,12 +52,16 @@ layout (push_constant) uniform parameter { uint32_t nb3; } p; -#define WG_SIZE 256 - -layout(local_size_x = WG_SIZE, local_size_y = 1, local_size_z = 1) in; +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +// Blocktile sizes +layout(constant_id = 1) const uint BS_K = 128; +layout(constant_id = 2) const uint BS_CRS = 16; +layout(constant_id = 3) const uint BS_NPQ = 128; +// Thread-tile sizes +layout(constant_id = 4) const uint TS_K = 8; uint32_t tid = gl_LocalInvocationID.x; -const uint32_t bs = gl_WorkGroupSize.x; +const uint32_t WG_SIZE = gl_WorkGroupSize.x; uint splitWork(uint work_size, uint block_size){ return (block_size + work_size -1) / block_size; @@ -62,16 +73,11 @@ uint32_t NPQ = p.N*p.OH*p.OW; uint32_t n_elems_out = K*NPQ; -// Blocktile sizes -const uint32_t BS_K = 128; -const uint32_t BS_CRS = 16; -const uint32_t BS_NPQ = 128; - // Number of blocktiles per input uint32_t NB_CRS = splitWork(CRS, BS_CRS); -const uint32_t Ash_stride = BS_CRS+1; -const uint32_t Bsh_stride = BS_NPQ+1; +const uint32_t Ash_stride = BS_CRS+SHMEM_PAD; +const uint32_t Bsh_stride = BS_NPQ+SHMEM_PAD; const uint32_t Ash_numel = BS_K*BS_CRS; const uint32_t Bsh_numel = BS_CRS*BS_NPQ; @@ -83,7 +89,6 @@ shared float Ash[Ash_len]; // K x CRS shared float Bsh[Bsh_len]; // CRS x NPQ // Threadtile sizes -const uint32_t TS_K = 16; const uint32_t TS_NPQ = BS_K*BS_NPQ / WG_SIZE / TS_K; // Number of threadtiles per blocktile @@ -111,134 +116,111 @@ uint32_t T_x = tid % NT_NPQ; uint32_t Ar = tid / BS_CRS; uint32_t Ac = tid % BS_CRS; -uint32_t ArpWg = WG_SIZE / BS_CRS; +const uint32_t ArpWg = WG_SIZE / BS_CRS; uint32_t Br = tid / BS_NPQ; uint32_t Bc = tid % BS_NPQ; -uint32_t BrpWg = WG_SIZE / BS_NPQ; +const uint32_t BrpWg = WG_SIZE / BS_NPQ; -void initReg(){ +void main(){\ for(uint32_t T_ly = 0; T_ly < TS_K; T_ly++){ for(uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++){ regC[T_ly][T_lx] = 0.0; } } -} - -void outProdReg(){ - for(uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++){ - for(uint32_t T_ly = 0; T_ly < TS_K; T_ly++){ - regA[T_ly] = Ash[(T_y*TS_K + T_ly)*Ash_stride + CRS_lidx]; + /* Advance block in CRS dim */\ + for(uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++){ + #ifdef USE_COLLECTIVES + uint32_t cached_CRS_idx = B_idx_CRS*BS_CRS + gl_SubgroupInvocationID; + uint32_t cached_Cin_idx = cached_CRS_idx / (p.KW*p.KH); + uint32_t cached_CRS_remainder = (cached_CRS_idx - cached_Cin_idx*p.KW*p.KH); + uint32_t cached_KH_idx = cached_CRS_remainder / p.KW; + uint32_t cached_KW_idx = cached_CRS_remainder - cached_KH_idx*p.KW; + + uint32_t CRS_idx_a = subgroupShuffle(cached_CRS_idx, Ac); + uint32_t Cin_idx_a = subgroupShuffle(cached_Cin_idx, Ac); + uint32_t KH_idx_a = subgroupShuffle(cached_KH_idx, Ac); + uint32_t KW_idx_a = subgroupShuffle(cached_KW_idx, Ac); + #else + uint32_t CRS_idx_a = B_idx_CRS*BS_CRS + Ac; // Global CRS_idx_a (column index of A) + uint32_t Cin_idx_a = CRS_idx_a / (p.KW*p.KH); + uint32_t CRS_remainder = CRS_idx_a - Cin_idx_a*p.KW*p.KH; + uint32_t KH_idx_a = CRS_remainder / p.KW; + uint32_t KW_idx_a = CRS_remainder - KH_idx_a*p.KW; + #endif + + /* Load kernel to A_block: (BS_K x BS_CRS)*/ + for(uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg){ + uint32_t B_ly = r_offset + Ar; + uint32_t B_lx = Ac; + uint32_t K_idx = B_idx_K*BS_K + B_ly; /* Global K_idx (row index of A)*/ + uint32_t knl_idx = min(KW_idx_a + KH_idx_a*p.nb01 + Cin_idx_a*p.nb02 + K_idx*p.nb03, K*CRS-1); + float val = knl_data[knl_idx]; + if(K_idx >= K || CRS_idx_a >= CRS){ + val = 0.0; + } + Ash[B_ly * Ash_stride + B_lx] = val; } - for(uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++){ - regB[T_lx] = Bsh[CRS_lidx*Bsh_stride + T_x*TS_NPQ+T_lx]; + /* Load input to B_block: (BS_CRS x BS_NPQ) */ + for(uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg){ + uint32_t B_ly = r_offset + Br; /* Row index of B block */ + uint32_t B_lx = Bc; + uint32_t NPQ_idx = B_idx_NPQ*BS_NPQ + B_lx; /* Global NPQ index (column index of B) */ + uint32_t N_idx = NPQ_idx / (p.OH*p.OW); + uint32_t NPQ_remainder = NPQ_idx - N_idx*p.OH*p.OW; + uint32_t OH_idx = NPQ_remainder / p.OW; + uint32_t OW_idx = NPQ_remainder - OH_idx*p.OW; + + #ifdef USE_COLLECTIVES + uint32_t CRS_idx_b = subgroupShuffle(cached_CRS_idx, r_offset + Br); + uint32_t Cin_idx_b = subgroupShuffle(cached_Cin_idx, r_offset + Br); + uint32_t KH_idx_b = subgroupShuffle(cached_KH_idx, r_offset + Br); + uint32_t KW_idx_b = subgroupShuffle(cached_KW_idx, r_offset + Br); + #else + uint32_t CRS_idx_b = B_idx_CRS*BS_CRS + B_ly; /* Global CRS index (row index of B) */ + uint32_t Cin_idx_b = CRS_idx_b / (p.KW*p.KH); + uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b*p.KW*p.KH; + uint32_t KH_idx_b = CRS_remainder / p.KW; + uint32_t KW_idx_b = CRS_remainder - KH_idx_b*p.KW; + #endif + + uint32_t H_idx = OH_idx*p.s1 + KH_idx_b*p.d1 - p.p1; + uint32_t W_idx = OW_idx*p.s0 + KW_idx_b*p.d0 - p.p0; + uint32_t src_idx = min(max(W_idx + H_idx*p.nb11 + Cin_idx_b*p.nb12 + N_idx*p.nb13, 0), p.Cin*p.N*p.W*p.H-1); + float val = src_data[src_idx]; + if(CRS_idx_b >= CRS || NPQ_idx >= NPQ || H_idx < 0 || H_idx >= p.H || W_idx < 0 || W_idx >= p.W){ + val = 0.0; + } + Bsh[B_ly * Bsh_stride + B_lx] = val; } - for(uint32_t T_ly = 0; T_ly < TS_K; T_ly++){ + barrier(); + for(uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++){ + for(uint32_t T_ly = 0; T_ly < TS_K; T_ly++){ + regA[T_ly] = Ash[(T_y*TS_K + T_ly)*Ash_stride + CRS_lidx]; + } for(uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++){ - regC[T_ly][T_lx] += regA[T_ly] * regB[T_lx]; + regB[T_lx] = Bsh[CRS_lidx*Bsh_stride + T_x*TS_NPQ+T_lx]; + } + for(uint32_t T_ly = 0; T_ly < TS_K; T_ly++){ + for(uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++){ + regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]); + } } } + barrier(); } -} - -// Generate different functions for computing the sides. - -#define NOOP() - -#define DEF_BOUNDARY_CONDITION_A_IF()\ -if(K_idx < K && CRS_idx < CRS){ - -#define DEF_BOUNDARY_CONDITION_A_ELSE()\ -}else{\ - Ash[B_ly * Ash_stride + B_lx] = 0.0;\ -} - -#define DEF_BOUNDARY_CONDITION_B_IF()\ -if(CRS_idx < CRS && NPQ_idx < NPQ){ - -#define DEF_BOUNDARY_CONDITION_B_ELSE()\ -}else{\ - Bsh[B_ly * Bsh_stride + B_lx] = 0.0;\ -} - -#define MAIN_LOOP(FUNC_NAME_SUFFIX, BOUNDARY_CONDITION_A_IF, BOUNDARY_CONDITION_A_ELSE, BOUNDARY_CONDITION_B_IF, BOUNDARY_CONDITION_B_ELSE)\ -void mainLoop ## FUNC_NAME_SUFFIX(){\ - initReg();\ - /* Advance block in CRS dim */\ - for(uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++){\ - /* Load kernel to A_block: (BS_K x BS_CRS)*/\ - for(uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg){\ - uint32_t B_ly = r_offset + Ar;\ - uint32_t B_lx = Ac;\ - uint32_t K_idx = B_idx_K*BS_K + B_ly; /* Global K_idx (row index of A)*/\ - uint32_t CRS_idx = B_idx_CRS*BS_CRS + B_lx; /* Global CRS_idx (column index of A)*/\ - BOUNDARY_CONDITION_A_IF()\ - uint32_t Cin_idx = CRS_idx / (p.KW*p.KH);\ - uint32_t KH_idx = (CRS_idx - Cin_idx*p.KW*p.KH) / p.KW;\ - uint32_t KW_idx = CRS_idx - Cin_idx*p.KW*p.KH - KH_idx*p.KW;\ - uint32_t knl_idx = KW_idx + KH_idx*p.nb01 + Cin_idx*p.nb02 + K_idx*p.nb03;\ - Ash[B_ly * Ash_stride + B_lx] = knl_data[knl_idx];\ - BOUNDARY_CONDITION_A_ELSE()\ - }\ - barrier();\ - /* Load input to B_block: (BS_CRS x BS_NPQ) */\ - for(uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg){\ - uint32_t B_ly = r_offset + Br; /* Row index of B block */\ - uint32_t B_lx = Bc; /* Column index of B block */\ - uint32_t CRS_idx = B_idx_CRS*BS_CRS + B_ly; /* Global CRS index (row index of B) */\ - uint32_t NPQ_idx = B_idx_NPQ*BS_NPQ + B_lx; /* Global NPQ index (column index of B) */\ - BOUNDARY_CONDITION_B_IF()\ - uint32_t Cin_idx = CRS_idx / (p.KW*p.KH);\ - uint32_t KH_idx = (CRS_idx - Cin_idx*p.KW*p.KH) / p.KW;\ - uint32_t KW_idx = CRS_idx - Cin_idx*p.KW*p.KH - KH_idx*p.KW;\ - uint32_t N_idx = NPQ_idx / (p.OH*p.OW);\ - uint32_t OH_idx = (NPQ_idx - N_idx*p.OH*p.OW) / p.OW;\ - uint32_t OW_idx = NPQ_idx - N_idx*p.OH*p.OW - OH_idx*p.OW;\ - uint32_t H_idx = OH_idx*p.s1 + KH_idx*p.d1 - p.p1;\ - uint32_t W_idx = OW_idx*p.s0 + KW_idx*p.d0 - p.p0;\ - if(H_idx >= 0 && H_idx < p.H && W_idx >= 0 && W_idx < p.W){\ - uint32_t src_idx = W_idx + H_idx*p.nb11 + Cin_idx*p.nb12 + N_idx*p.nb13;\ - Bsh[B_ly * Bsh_stride + B_lx] = src_data[src_idx];\ - }else{\ - Bsh[B_ly * Bsh_stride + B_lx] = 0.0;\ - }\ - BOUNDARY_CONDITION_B_ELSE()\ - }\ - barrier();\ - outProdReg();\ - barrier();\ - }\ - /* Save C* */\ - for(uint32_t T_ly = 0; T_ly < TS_K; T_ly++){\ - for(uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++){\ - uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly;\ - uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx;\ - if(K_idx < K && NPQ_idx < NPQ){\ - uint32_t N_idx = NPQ_idx / (p.OH*p.OW);\ - uint32_t OH_idx = (NPQ_idx - N_idx*p.OH*p.OW) / p.OW;\ - uint32_t OW_idx = NPQ_idx - N_idx*p.OH*p.OW - OH_idx*p.OW;\ - uint32_t dst_idx = OW_idx + OH_idx*p.nb1 + K_idx*p.nb2 + N_idx*p.nb3;\ - dst_data[dst_idx] = regC[T_ly][T_lx];\ - }\ - }\ - }\ -} - -// Generates mainLoopBoundaryCheck -MAIN_LOOP(BoundaryCheck, - DEF_BOUNDARY_CONDITION_A_IF, - DEF_BOUNDARY_CONDITION_A_ELSE, - DEF_BOUNDARY_CONDITION_B_IF, - DEF_BOUNDARY_CONDITION_B_ELSE) - -// Generates mainLoopNoBoundaryCheck -MAIN_LOOP(NoBoundaryCheck, - NOOP, NOOP, NOOP, NOOP) - -void main(){ - if(gl_WorkGroupID.x == gl_NumWorkGroups.x-1 || gl_WorkGroupID.y == gl_NumWorkGroups.y-1){ - mainLoopBoundaryCheck(); - }else{ - mainLoopNoBoundaryCheck(); + /* Save C* */ + for(uint32_t T_ly = 0; T_ly < TS_K; T_ly++){ + for(uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++){ + uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly; + uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx; + uint32_t N_idx = NPQ_idx / (p.OH*p.OW); + uint32_t OH_idx = (NPQ_idx - N_idx*p.OH*p.OW) / p.OW; + uint32_t OW_idx = NPQ_idx - N_idx*p.OH*p.OW - OH_idx*p.OW; + uint32_t dst_idx = OW_idx + OH_idx*p.nb1 + K_idx*p.nb2 + N_idx*p.nb3; + if(K_idx < K && NPQ_idx < NPQ){ + dst_data[dst_idx] = regC[T_ly][T_lx]; + } + } } } \ No newline at end of file From 7f9b6591c2fe890fd0c2588663421ff91a959420 Mon Sep 17 00:00:00 2001 From: Ervin Tasnadi Date: Mon, 14 Jul 2025 00:12:43 +0200 Subject: [PATCH 3/9] * Apple/Win32 compile errors fixed * Subgroup size used to determine tile size -> fixes llvmpipe errors. --- ggml/src/ggml-backend.cpp | 10 +- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 58 +++++++----- .../ggml-vulkan/vulkan-shaders/conv2d_mm.comp | 92 ++++++++++++------- .../vulkan-shaders/vulkan-shaders-gen.cpp | 2 +- tests/test-backend-ops.cpp | 35 +++---- 5 files changed, 120 insertions(+), 77 deletions(-) diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index b3d42975b6fc9..901da3711c9b6 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -1883,12 +1883,12 @@ bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t } bool ggml_backend_compare_graph_backend_node( - ggml_backend_t backend1, - ggml_backend_t backend2, - struct ggml_cgraph * graph1, - struct ggml_cgraph * graph2, + ggml_backend_t backend1, + ggml_backend_t backend2, + struct ggml_cgraph * graph1, + struct ggml_cgraph * graph2, ggml_backend_eval_callback callback, void * user_data, char* op_name_out_1, char* op_name_out_2) { - + ggml_tensor * out1 = NULL; ggml_tensor * out2 = NULL; diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index d54e8b569c0a4..8896664630e80 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -880,7 +880,7 @@ struct vk_op_conv2d_push_constants { uint32_t Cout; uint32_t Cin; uint32_t N; - + uint32_t KW; uint32_t KH; uint32_t W; @@ -1041,7 +1041,7 @@ class vk_perf_logger { } timings.clear(); - flops.clear(); + flops.clear(); } void log_timing(const ggml_tensor * node, uint64_t time) { @@ -1082,7 +1082,7 @@ class vk_perf_logger { flops[name].push_back(n_flops); timings[name].push_back(time); return; - } + } timings[ggml_op_name(node->op)].push_back(time); } private: @@ -2190,6 +2190,7 @@ static void ggml_vk_load_shaders(vk_device& device) { } compile_count++; } + compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint, parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size)); }; @@ -3037,14 +3038,27 @@ static void ggml_vk_load_shaders(vk_device& device) { uint32_t conv2d_WG_SIZE = 256; uint32_t conv2d_BS_K = 128; uint32_t conv2d_BS_CRS = 16; + // Enables subgroup ops for preventing the re-calculation of indices. + uint32_t use_collectives = 0; + // CRS block size should be capped at sugroup size for correctness when shuffle is used. + if(device->subgroup_shuffle){ + use_collectives = 1; + conv2d_BS_CRS = std::min(device->subgroup_size, conv2d_BS_CRS); + } uint32_t conv2d_BS_NPQ = 128; uint32_t conv2d_TS_K = 8; uint32_t conv2d_shmem_req = (conv2d_BS_K*(conv2d_BS_CRS+1) + conv2d_BS_CRS*(conv2d_BS_NPQ+1))*sizeof(float); if(device->properties.limits.maxComputeSharedMemorySize < conv2d_shmem_req){ conv2d_BS_CRS = 8; - conv2d_TS_K = 8; - } - ggml_vk_create_pipeline(device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3, sizeof(vk_op_conv2d_push_constants), {conv2d_BS_K, conv2d_BS_NPQ, 1}, {conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K}, 1); + if(device->subgroup_shuffle){ + conv2d_BS_CRS = std::min(device->subgroup_size, conv2d_BS_CRS); + } + } + if(device->subgroup_shuffle){ + ggml_vk_create_pipeline(device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3, sizeof(vk_op_conv2d_push_constants), {conv2d_BS_K, conv2d_BS_NPQ, 1}, {conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives}, 1, true, true); + }else{ + ggml_vk_create_pipeline(device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3, sizeof(vk_op_conv2d_push_constants), {conv2d_BS_K, conv2d_BS_NPQ, 1}, {conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives}, 1, true); + } ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); @@ -6895,11 +6909,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const } return nullptr; case GGML_OP_CONV_2D: - if (src0->type == GGML_TYPE_F32 && - src1->type == GGML_TYPE_F32 && - dst->type == GGML_TYPE_F32 && - ggml_is_contiguous(src0) && - ggml_is_contiguous(src1) && + if (src0->type == GGML_TYPE_F32 && + src1->type == GGML_TYPE_F32 && + dst->type == GGML_TYPE_F32 && + ggml_is_contiguous(src0) && + ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) { return ctx->device->pipeline_conv2d_f32; } @@ -7231,7 +7245,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co // src0 - kernel: [KW, KH, Cin, Cout] // src1 - input: [W, H, Cin, N] // dst - result: [OW, OH, Cout, N] - + // Copied from ggml.c: int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d) auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t { return (ins + 2 * p - d * (ks - 1) - 1) / s + 1; @@ -7246,9 +7260,9 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co int64_t OH = calc_conv_output_size(H, KH, dst->op_params[1], dst->op_params[3], dst->op_params[5]); int64_t OW = calc_conv_output_size(W, KW, dst->op_params[0], dst->op_params[2], dst->op_params[4]); int64_t NPQ = N*OW*OH; - + // Tile output matrix to (K/NB_K, NPQ/NB_NPQ, 1) workgroups - elements = {static_cast(Cout), static_cast(NPQ), 1}; + elements = {static_cast(Cout), static_cast(NPQ), 1}; } break; case GGML_OP_ADD: case GGML_OP_SUB: @@ -8131,14 +8145,14 @@ static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context& subctx, c p.Cout = static_cast(ne03); p.Cin = static_cast(ne02); p.N = static_cast(ne13); - + p.KW = static_cast(ne00); p.KH = static_cast(ne01); p.W = static_cast(ne10); p.H = static_cast(ne11); p.OW = static_cast(ne0); p.OH = static_cast(ne1); - + p.s0 = static_cast(dst->op_params[0]); p.s1 = static_cast(dst->op_params[1]); p.p0 = static_cast(dst->op_params[2]); @@ -8162,7 +8176,7 @@ static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context& subctx, c GGML_ASSERT(ne02 == ne12); ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D, std::move(p), dryrun); - + } static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { @@ -10805,11 +10819,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32; case GGML_OP_CONV_2D: // Channel-contiguous format is not supported yet. - return (op->src[0]->type == GGML_TYPE_F32 && - op->src[1]->type == GGML_TYPE_F32 && - op->type == GGML_TYPE_F32 && - ggml_is_contiguous(op->src[0]) && - ggml_is_contiguous(op->src[1]) && + return (op->src[0]->type == GGML_TYPE_F32 && + op->src[1]->type == GGML_TYPE_F32 && + op->type == GGML_TYPE_F32 && + ggml_is_contiguous(op->src[0]) && + ggml_is_contiguous(op->src[1]) && ggml_is_contiguous(op)); default: return false; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp index 3980040a3f2cb..2e2fb069410e5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp @@ -1,7 +1,5 @@ #version 450 -#define USE_COLLECTIVES - #ifdef USE_COLLECTIVES #extension GL_KHR_shader_subgroup_shuffle: enable #endif @@ -12,7 +10,7 @@ #define SHMEM_PAD 0 // shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j -layout (binding = 0) readonly buffer A {A_TYPE knl_data[];}; // src0 - kernel: [KW, KH, Cin, Cout] +layout (binding = 0) readonly buffer A {A_TYPE knl_data[];}; // src0 - kernel: [KW, KH, Cin, Cout] layout (binding = 1) readonly buffer B {B_TYPE src_data[];}; // src1 - input: [W, H, Cin, N] -- channel_first format layout (binding = 2) writeonly buffer D {D_TYPE dst_data[];}; // dst - result: [OW, OH, Cout, N] @@ -21,7 +19,7 @@ layout (push_constant) uniform parameter { uint32_t Cout; uint32_t Cin; uint32_t N; - + // Tensor spatial sizes: kernel, input, output uint32_t KW; uint32_t KH; @@ -59,6 +57,7 @@ layout(constant_id = 2) const uint BS_CRS = 16; layout(constant_id = 3) const uint BS_NPQ = 128; // Thread-tile sizes layout(constant_id = 4) const uint TS_K = 8; +layout(constant_id = 5) const uint use_collectives = 1; uint32_t tid = gl_LocalInvocationID.x; const uint32_t WG_SIZE = gl_WorkGroupSize.x; @@ -122,31 +121,48 @@ uint32_t Br = tid / BS_NPQ; uint32_t Bc = tid % BS_NPQ; const uint32_t BrpWg = WG_SIZE / BS_NPQ; -void main(){\ +void main(){ for(uint32_t T_ly = 0; T_ly < TS_K; T_ly++){ for(uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++){ regC[T_ly][T_lx] = 0.0; } } - /* Advance block in CRS dim */\ + /* Advance block in CRS dim */ for(uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++){ + uint32_t CRS_idx_a; + uint32_t Cin_idx_a; + uint32_t KH_idx_a; + uint32_t KW_idx_a; + #ifdef USE_COLLECTIVES - uint32_t cached_CRS_idx = B_idx_CRS*BS_CRS + gl_SubgroupInvocationID; - uint32_t cached_Cin_idx = cached_CRS_idx / (p.KW*p.KH); - uint32_t cached_CRS_remainder = (cached_CRS_idx - cached_Cin_idx*p.KW*p.KH); - uint32_t cached_KH_idx = cached_CRS_remainder / p.KW; - uint32_t cached_KW_idx = cached_CRS_remainder - cached_KH_idx*p.KW; - - uint32_t CRS_idx_a = subgroupShuffle(cached_CRS_idx, Ac); - uint32_t Cin_idx_a = subgroupShuffle(cached_Cin_idx, Ac); - uint32_t KH_idx_a = subgroupShuffle(cached_KH_idx, Ac); - uint32_t KW_idx_a = subgroupShuffle(cached_KW_idx, Ac); + uint32_t cached_CRS_idx; + uint32_t cached_Cin_idx; + uint32_t cached_KH_idx; + uint32_t cached_KW_idx; + if(use_collectives == 1){ + cached_CRS_idx = B_idx_CRS*BS_CRS + gl_SubgroupInvocationID; + cached_Cin_idx = cached_CRS_idx / (p.KW*p.KH); + uint32_t cached_CRS_remainder = (cached_CRS_idx - cached_Cin_idx*p.KW*p.KH); + cached_KH_idx = cached_CRS_remainder / p.KW; + cached_KW_idx = cached_CRS_remainder - cached_KH_idx*p.KW; + + CRS_idx_a = subgroupShuffle(cached_CRS_idx, Ac); + Cin_idx_a = subgroupShuffle(cached_Cin_idx, Ac); + KH_idx_a = subgroupShuffle(cached_KH_idx, Ac); + KW_idx_a = subgroupShuffle(cached_KW_idx, Ac); + }else{ + CRS_idx_a = B_idx_CRS*BS_CRS + Ac; // Global CRS_idx_a (column index of A) + Cin_idx_a = CRS_idx_a / (p.KW*p.KH); + uint32_t CRS_remainder = CRS_idx_a - Cin_idx_a*p.KW*p.KH; + KH_idx_a = CRS_remainder / p.KW; + KW_idx_a = CRS_remainder - KH_idx_a*p.KW; + } #else - uint32_t CRS_idx_a = B_idx_CRS*BS_CRS + Ac; // Global CRS_idx_a (column index of A) - uint32_t Cin_idx_a = CRS_idx_a / (p.KW*p.KH); - uint32_t CRS_remainder = CRS_idx_a - Cin_idx_a*p.KW*p.KH; - uint32_t KH_idx_a = CRS_remainder / p.KW; - uint32_t KW_idx_a = CRS_remainder - KH_idx_a*p.KW; + CRS_idx_a = B_idx_CRS*BS_CRS + Ac; // Global CRS_idx_a (column index of A) + Cin_idx_a = CRS_idx_a / (p.KW*p.KH); + CRS_remainder = CRS_idx_a - Cin_idx_a*p.KW*p.KH; + KH_idx_a = CRS_remainder / p.KW; + KW_idx_a = CRS_remainder - KH_idx_a*p.KW; #endif /* Load kernel to A_block: (BS_K x BS_CRS)*/ @@ -170,20 +186,32 @@ void main(){\ uint32_t NPQ_remainder = NPQ_idx - N_idx*p.OH*p.OW; uint32_t OH_idx = NPQ_remainder / p.OW; uint32_t OW_idx = NPQ_remainder - OH_idx*p.OW; - + + uint32_t CRS_idx_b; + uint32_t Cin_idx_b; + uint32_t KH_idx_b; + uint32_t KW_idx_b; #ifdef USE_COLLECTIVES - uint32_t CRS_idx_b = subgroupShuffle(cached_CRS_idx, r_offset + Br); - uint32_t Cin_idx_b = subgroupShuffle(cached_Cin_idx, r_offset + Br); - uint32_t KH_idx_b = subgroupShuffle(cached_KH_idx, r_offset + Br); - uint32_t KW_idx_b = subgroupShuffle(cached_KW_idx, r_offset + Br); + if(use_collectives == 1){ + CRS_idx_b = subgroupShuffle(cached_CRS_idx, r_offset + Br); + Cin_idx_b = subgroupShuffle(cached_Cin_idx, r_offset + Br); + KH_idx_b = subgroupShuffle(cached_KH_idx, r_offset + Br); + KW_idx_b = subgroupShuffle(cached_KW_idx, r_offset + Br); + }else{ + CRS_idx_b = B_idx_CRS*BS_CRS + B_ly; /* Global CRS index (row index of B) */ + Cin_idx_b = CRS_idx_b / (p.KW*p.KH); + uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b*p.KW*p.KH; + KH_idx_b = CRS_remainder / p.KW; + KW_idx_b = CRS_remainder - KH_idx_b*p.KW; + } #else - uint32_t CRS_idx_b = B_idx_CRS*BS_CRS + B_ly; /* Global CRS index (row index of B) */ - uint32_t Cin_idx_b = CRS_idx_b / (p.KW*p.KH); + CRS_idx_b = B_idx_CRS*BS_CRS + B_ly; /* Global CRS index (row index of B) */ + Cin_idx_b = CRS_idx_b / (p.KW*p.KH); uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b*p.KW*p.KH; - uint32_t KH_idx_b = CRS_remainder / p.KW; - uint32_t KW_idx_b = CRS_remainder - KH_idx_b*p.KW; + KH_idx_b = CRS_remainder / p.KW; + KW_idx_b = CRS_remainder - KH_idx_b*p.KW; #endif - + uint32_t H_idx = OH_idx*p.s1 + KH_idx_b*p.d1 - p.p1; uint32_t W_idx = OW_idx*p.s0 + KW_idx_b*p.d0 - p.p0; uint32_t src_idx = min(max(W_idx + H_idx*p.nb11 + Cin_idx_b*p.nb12 + N_idx*p.nb13, 0), p.Cin*p.N*p.W*p.H-1); @@ -223,4 +251,4 @@ void main(){\ } } } -} \ No newline at end of file +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index e6432c356fb80..30942fd00ebb4 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -650,7 +650,7 @@ void process_shaders() { string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); - string_to_spv("conv2d_f32", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("conv2d_f32", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}}); string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}})); string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}})); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 57ab18ad660e5..c35ce3e7ef161 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -37,6 +37,7 @@ #include #include #include +#include static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) { size_t nels = ggml_nelements(tensor); @@ -1266,12 +1267,12 @@ struct test_case { // duplicate the op for (int i = 1; i < n_runs; i++) { ggml_graph_add_node(gf, out); - + if(op_desc(out) == "CONV_2D_INDIRECT_IMPL"){ /* - TODO: add a permanent solution! E.g. return the list of tensors + TODO: add a permanent solution! E.g. return the list of tensors needed to add for computing the op in build_graph(). - + Adds the full ggml_conv_2d() computation graph, not just the output! * cont (out) * cont (out->src[0]) @@ -1630,13 +1631,13 @@ struct test_case { } }; -// This can be useful to compare the output/performance of +// This can be useful to compare the output/performance of // different graphs implementing the same op. // Possible use cases: -// * no CPU implementation exists for the op, but the op +// * no CPU implementation exists for the op, but the op // can be built by combining elementary ops already having implementation // and the user wants to compare the results. -// * comparing the performance of different implementations +// * comparing the performance of different implementations // of the op: graph revwriting/operation fusion. E.g. basic attention // compared to flash attention or conv compared with im2col->matmul. struct test_case_ref : public test_case { @@ -1652,7 +1653,7 @@ struct test_case_ref : public test_case { {"input", "input"}, {"kernel", "kernel"} }; - + // Copies the inputs of the graph built using build_graph() to the reference graph virtual void copy_data_to_ref(ggml_context * ctx, ggml_context * ctx_ref){ std::map inputs; @@ -1665,7 +1666,7 @@ struct test_case_ref : public test_case { } } } - + for (ggml_tensor * t = ggml_get_first_tensor(ctx_ref); t != nullptr; t = ggml_get_next_tensor(ctx_ref, t)) { for(auto e : input_names){ if(e.second == t->name){ @@ -1679,7 +1680,7 @@ struct test_case_ref : public test_case { GGML_ASSERT(inputs_ref.count(e.second) == 1); std::vector buf(ggml_nbytes(inputs[e.first])); ggml_backend_tensor_get(inputs[e.first], buf.data(), 0, ggml_nbytes(inputs[e.first])); - ggml_backend_tensor_set(inputs_ref[e.second], buf.data(), 0, buf.size()); + ggml_backend_tensor_set(inputs_ref[e.second], buf.data(), 0, buf.size()); } } @@ -1860,9 +1861,9 @@ struct test_case_ref : public test_case { GGML_UNUSED(index); }; - + const bool cmp_ok = ggml_backend_compare_graph_backend_node(backend1, backend2, gf, gf_ref, callback, &ud, const_cast(output_node_name), const_cast(output_node_name_ref)); - + if (!cmp_ok) { printf("compare failed "); } @@ -1884,7 +1885,7 @@ struct test_case_ref : public test_case { } return test_passed; - } + } }; // ################################### @@ -3984,13 +3985,13 @@ struct test_conv_2d : public test_case_ref { const bool cwhn; // If true, the direct CONV_2D will be used in the graph, otherwise it // uses ggml_conv_2d: - // * if the program is called with -o CONV_2D_DIRECT_IMPL, the + // * if the program is called with -o CONV_2D_DIRECT_IMPL, the // CONV_2D graph will be built, while // * if the program is called with -o CONV_2D_INDIRECT_IMPL, the // IM2COL -> MUL_MM graph will be built. const bool direct_impl; - virtual std::string op_desc(ggml_tensor * t) { + virtual std::string op_desc(ggml_tensor * t) override { (void) t; if(direct_impl){ return std::string("CONV_2D_DIRECT_IMPL"); @@ -4026,7 +4027,7 @@ struct test_conv_2d : public test_case_ref { int64_t K = Cout; int64_t CRS = Cin*KH*KW; int64_t NPQ = N*OH*OW; - + return K*NPQ*(2*CRS-1); } @@ -4062,7 +4063,7 @@ struct test_conv_2d : public test_case_ref { stride0, stride1, padding0, padding1, dilation0, dilation1); ggml_tensor * out = ggml_cont(ctx, conv2d_out); - + ggml_set_name(out, "out"); return out; } @@ -5475,7 +5476,7 @@ static std::vector> make_test_cases_eval() { for(auto act_case : cases){ test_cases.emplace_back(new test_conv_2d({act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx]}, {act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx]}, 1, 1, 0, 0, 1, 1, false, true)); - } + } // sycl backend will limit task global_range < MAX_INT // test cases for 2D im2col with large input W and H (occurs in stable-diffusion) From a09e8f59848d4d2fe09ecf18cf55602258a10b78 Mon Sep 17 00:00:00 2001 From: Ervin Tasnadi Date: Mon, 14 Jul 2025 13:22:32 +0200 Subject: [PATCH 4/9] Collectives disabled by default. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 8896664630e80..2eb7415c585e9 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -3041,7 +3041,7 @@ static void ggml_vk_load_shaders(vk_device& device) { // Enables subgroup ops for preventing the re-calculation of indices. uint32_t use_collectives = 0; // CRS block size should be capped at sugroup size for correctness when shuffle is used. - if(device->subgroup_shuffle){ + if(getenv("GGML_VK_USE_COLLECTIVES") != nullptr && device->subgroup_shuffle){ use_collectives = 1; conv2d_BS_CRS = std::min(device->subgroup_size, conv2d_BS_CRS); } @@ -3050,10 +3050,13 @@ static void ggml_vk_load_shaders(vk_device& device) { uint32_t conv2d_shmem_req = (conv2d_BS_K*(conv2d_BS_CRS+1) + conv2d_BS_CRS*(conv2d_BS_NPQ+1))*sizeof(float); if(device->properties.limits.maxComputeSharedMemorySize < conv2d_shmem_req){ conv2d_BS_CRS = 8; - if(device->subgroup_shuffle){ + if(getenv("GGML_VK_USE_COLLECTIVES") != nullptr && device->subgroup_shuffle){ conv2d_BS_CRS = std::min(device->subgroup_size, conv2d_BS_CRS); } } + + std::cerr << " --> BS_CRS=" << conv2d_BS_CRS << " use_collectives=" << use_collectives << std::endl; + if(device->subgroup_shuffle){ ggml_vk_create_pipeline(device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3, sizeof(vk_op_conv2d_push_constants), {conv2d_BS_K, conv2d_BS_NPQ, 1}, {conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives}, 1, true, true); }else{ From aa472583cd70c8c3c647672cf3ede656c768481c Mon Sep 17 00:00:00 2001 From: Ervin Tasnadi Date: Tue, 15 Jul 2025 14:49:46 +0200 Subject: [PATCH 5/9] Intel support is disabled as the performance is poor. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 38 +++++++++++++++------------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 2eb7415c585e9..e354708442564 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -3038,29 +3038,25 @@ static void ggml_vk_load_shaders(vk_device& device) { uint32_t conv2d_WG_SIZE = 256; uint32_t conv2d_BS_K = 128; uint32_t conv2d_BS_CRS = 16; - // Enables subgroup ops for preventing the re-calculation of indices. - uint32_t use_collectives = 0; - // CRS block size should be capped at sugroup size for correctness when shuffle is used. - if(getenv("GGML_VK_USE_COLLECTIVES") != nullptr && device->subgroup_shuffle){ + uint32_t use_collectives = 0; // Enables subgroup ops for preventing the re-calculation of indices. + if(device->subgroup_shuffle){ use_collectives = 1; - conv2d_BS_CRS = std::min(device->subgroup_size, conv2d_BS_CRS); + conv2d_BS_CRS = std::min(device->subgroup_size, conv2d_BS_CRS); // CRS block size should be capped at sugroup size for correctness when shuffle is used. } uint32_t conv2d_BS_NPQ = 128; uint32_t conv2d_TS_K = 8; uint32_t conv2d_shmem_req = (conv2d_BS_K*(conv2d_BS_CRS+1) + conv2d_BS_CRS*(conv2d_BS_NPQ+1))*sizeof(float); if(device->properties.limits.maxComputeSharedMemorySize < conv2d_shmem_req){ conv2d_BS_CRS = 8; - if(getenv("GGML_VK_USE_COLLECTIVES") != nullptr && device->subgroup_shuffle){ + if(device->subgroup_shuffle){ conv2d_BS_CRS = std::min(device->subgroup_size, conv2d_BS_CRS); } } - - std::cerr << " --> BS_CRS=" << conv2d_BS_CRS << " use_collectives=" << use_collectives << std::endl; - if(device->subgroup_shuffle){ + if(use_collectives){ ggml_vk_create_pipeline(device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3, sizeof(vk_op_conv2d_push_constants), {conv2d_BS_K, conv2d_BS_NPQ, 1}, {conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives}, 1, true, true); }else{ - ggml_vk_create_pipeline(device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3, sizeof(vk_op_conv2d_push_constants), {conv2d_BS_K, conv2d_BS_NPQ, 1}, {conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3, sizeof(vk_op_conv2d_push_constants), {conv2d_BS_K, conv2d_BS_NPQ, 1}, {conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives}, 1, true, false); } ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); @@ -10820,14 +10816,20 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return true; case GGML_OP_CONV_TRANSPOSE_1D: return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32; - case GGML_OP_CONV_2D: - // Channel-contiguous format is not supported yet. - return (op->src[0]->type == GGML_TYPE_F32 && - op->src[1]->type == GGML_TYPE_F32 && - op->type == GGML_TYPE_F32 && - ggml_is_contiguous(op->src[0]) && - ggml_is_contiguous(op->src[1]) && - ggml_is_contiguous(op)); + case GGML_OP_CONV_2D: + { + // Op is disabled for Intel + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + const vk_device& device = ggml_vk_get_device(ctx->device); + bool is_Intel = ggml_vk_get_device(ctx->device)->vendor_id == VK_VENDOR_ID_INTEL; + // Channel-contiguous format is not supported yet. + return (op->src[0]->type == GGML_TYPE_F32 && + op->src[1]->type == GGML_TYPE_F32 && + op->type == GGML_TYPE_F32 && + ggml_is_contiguous(op->src[0]) && + ggml_is_contiguous(op->src[1]) && + ggml_is_contiguous(op)) && !is_Intel; + } default: return false; } From a672803dba6769142f182498b1352d430af98e1f Mon Sep 17 00:00:00 2001 From: Ervin Tasnadi Date: Wed, 16 Jul 2025 12:05:44 +0200 Subject: [PATCH 6/9] Conv2d enabled for Intel with disabled collectives, disabled for Apple --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index e354708442564..3d808434affd5 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -3039,7 +3039,7 @@ static void ggml_vk_load_shaders(vk_device& device) { uint32_t conv2d_BS_K = 128; uint32_t conv2d_BS_CRS = 16; uint32_t use_collectives = 0; // Enables subgroup ops for preventing the re-calculation of indices. - if(device->subgroup_shuffle){ + if(device->subgroup_shuffle && device->vendor_id != VK_VENDOR_ID_INTEL){ // Do not enable collectives on Intel, see PR 14316 use_collectives = 1; conv2d_BS_CRS = std::min(device->subgroup_size, conv2d_BS_CRS); // CRS block size should be capped at sugroup size for correctness when shuffle is used. } @@ -3048,7 +3048,7 @@ static void ggml_vk_load_shaders(vk_device& device) { uint32_t conv2d_shmem_req = (conv2d_BS_K*(conv2d_BS_CRS+1) + conv2d_BS_CRS*(conv2d_BS_NPQ+1))*sizeof(float); if(device->properties.limits.maxComputeSharedMemorySize < conv2d_shmem_req){ conv2d_BS_CRS = 8; - if(device->subgroup_shuffle){ + if(use_collectives){ conv2d_BS_CRS = std::min(device->subgroup_size, conv2d_BS_CRS); } } @@ -10816,19 +10816,19 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return true; case GGML_OP_CONV_TRANSPOSE_1D: return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32; - case GGML_OP_CONV_2D: + case GGML_OP_CONV_2D: { - // Op is disabled for Intel + // Op is disabled for Apple because it segfaults at pipeline create time on MoltenVK ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; - const vk_device& device = ggml_vk_get_device(ctx->device); - bool is_Intel = ggml_vk_get_device(ctx->device)->vendor_id == VK_VENDOR_ID_INTEL; + const vk_device& device = ggml_vk_get_device(ctx->device); + bool is_Apple = ggml_vk_get_device(ctx->device)->vendor_id == VK_VENDOR_ID_APPLE; // Channel-contiguous format is not supported yet. return (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && - ggml_is_contiguous(op)) && !is_Intel; + ggml_is_contiguous(op)) && !is_Apple; } default: return false; From e50711f1d6e798ade602e6aea345d13d643dc802 Mon Sep 17 00:00:00 2001 From: Ervin Tasnadi Date: Wed, 16 Jul 2025 22:05:12 +0200 Subject: [PATCH 7/9] test-backend-ops modifications are reverted --- ggml/include/ggml-backend.h | 2 - ggml/src/ggml-backend.cpp | 49 ---- tests/test-backend-ops.cpp | 438 ++++-------------------------------- 3 files changed, 50 insertions(+), 439 deletions(-) diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index 1bd7520178bce..a2977ea2e56d9 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -340,8 +340,6 @@ extern "C" { // Compare the output of two backends GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor * test_node); - // Compare the output of two backends, graphs can be different and only the selected nodes will be compared - GGML_API bool ggml_backend_compare_graph_backend_node(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph1, struct ggml_cgraph * graph2, ggml_backend_eval_callback callback, void * user_data, char* op_name_out_1, char* op_name_out_2); // Tensor initialization GGML_API enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr); diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 901da3711c9b6..788861a365fab 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -1882,55 +1882,6 @@ bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t return true; } -bool ggml_backend_compare_graph_backend_node( - ggml_backend_t backend1, - ggml_backend_t backend2, - struct ggml_cgraph * graph1, - struct ggml_cgraph * graph2, - ggml_backend_eval_callback callback, void * user_data, char* op_name_out_1, char* op_name_out_2) { - - ggml_tensor * out1 = NULL; - ggml_tensor * out2 = NULL; - - struct ggml_cgraph * g1 = graph1; - struct ggml_cgraph * g2 = graph2; - - for (int i = 0; i < g1->n_nodes; i++) { - struct ggml_tensor * t1 = g1->nodes[i]; - struct ggml_cgraph g1v = ggml_graph_view(g1, i, i + 1); - ggml_backend_graph_compute(backend1, &g1v); - if (ggml_is_view_op(t1->op)) { - continue; - } - if(strcmp(t1 -> name, op_name_out_1) == 0){ - out1 = t1; - } - } - - for (int i = 0; i < g2->n_nodes; i++) { - struct ggml_tensor * t2 = g2->nodes[i]; - struct ggml_cgraph g2v = ggml_graph_view(g2, i, i + 1); - ggml_backend_graph_compute(backend2, &g2v); - if (ggml_is_view_op(t2->op)) { - continue; - } - if(strcmp(t2 -> name, op_name_out_2) == 0){ - out2 = t2; - } - } - - assert(out1 != NULL); - assert(out2 != NULL); - assert(ggml_are_same_layout(out1, out2)); - - // compare results, calculate rms etc - if (!callback(0, out1, out2, user_data)) { - return false; - } - - return true; -} - // CPU backend - buffer static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index c35ce3e7ef161..0853e482031e9 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -37,7 +37,6 @@ #include #include #include -#include static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) { size_t nels = ggml_nelements(tensor); @@ -1021,7 +1020,7 @@ struct test_case { return t; } - virtual bool eval(ggml_backend_t backend1, ggml_backend_t backend2, const char * op_name, printer * output_printer) { + bool eval(ggml_backend_t backend1, ggml_backend_t backend2, const char * op_name, printer * output_printer) { mode = MODE_TEST; ggml_init_params params = { @@ -1241,56 +1240,25 @@ struct test_case { // determine number of runs int n_runs; bool is_cpu = ggml_backend_dev_type(ggml_backend_get_device(backend)) == GGML_BACKEND_DEVICE_TYPE_CPU; - - // how many nodes are added by each op - uint32_t nodes_per_op = 1; - if(op_desc(out) == "CONV_2D_INDIRECT_IMPL"){ - nodes_per_op = 8; - } - if (op_flops(out) > 0) { // based on flops const uint64_t GFLOP = 1000 * 1000 * 1000; const uint64_t target_flops_cpu = 8ULL * GFLOP; const uint64_t target_flops_gpu = 100ULL * GFLOP; uint64_t target_flops = is_cpu ? target_flops_cpu : target_flops_gpu; - n_runs = std::min((ggml_graph_size(gf) - ggml_graph_n_nodes(gf))/nodes_per_op, target_flops / op_flops(out)) + 1; + n_runs = std::min(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_flops / op_flops(out)) + 1; } else { // based on memory size const size_t GB = 1ULL << 30; const size_t target_size_cpu = 8 * GB; const size_t target_size_gpu = 32 * GB; size_t target_size = is_cpu ? target_size_cpu : target_size_gpu; - n_runs = std::min((ggml_graph_size(gf) - ggml_graph_n_nodes(gf))/nodes_per_op, target_size / op_size(out)) + 1; + n_runs = std::min(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_size / op_size(out)) + 1; } // duplicate the op for (int i = 1; i < n_runs; i++) { ggml_graph_add_node(gf, out); - - if(op_desc(out) == "CONV_2D_INDIRECT_IMPL"){ - /* - TODO: add a permanent solution! E.g. return the list of tensors - needed to add for computing the op in build_graph(). - - Adds the full ggml_conv_2d() computation graph, not just the output! - * cont (out) - * cont (out->src[0]) - * permute (out->src[0]->...) - * reshape - * mul_mat - * reshape - * im2col - * reshape - */ - ggml_graph_add_node(gf, out->src[0]); // cont - ggml_graph_add_node(gf, out->src[0]->src[0]); // permute - ggml_graph_add_node(gf, out->src[0]->src[0]->src[0]); // reshape - ggml_graph_add_node(gf, out->src[0]->src[0]->src[0]->src[0]); // mul_mat - ggml_graph_add_node(gf, out->src[0]->src[0]->src[0]->src[0]->src[0]); // reshape - ggml_graph_add_node(gf, out->src[0]->src[0]->src[0]->src[0]->src[0]->src[0]); // im2col - ggml_graph_add_node(gf, out->src[0]->src[0]->src[0]->src[0]->src[1]); // reshape - } } // calculate memory @@ -1631,262 +1599,6 @@ struct test_case { } }; -// This can be useful to compare the output/performance of -// different graphs implementing the same op. -// Possible use cases: -// * no CPU implementation exists for the op, but the op -// can be built by combining elementary ops already having implementation -// and the user wants to compare the results. -// * comparing the performance of different implementations -// of the op: graph revwriting/operation fusion. E.g. basic attention -// compared to flash attention or conv compared with im2col->matmul. -struct test_case_ref : public test_case { -public: - ggml_cgraph * gf_ref = nullptr; - - // Output tensor names to compare - const char* output_node_name_ref; - const char* output_node_name; - - // Input tensor names in (actual graph, reference graph) - std::vector> input_names = { - {"input", "input"}, - {"kernel", "kernel"} - }; - - // Copies the inputs of the graph built using build_graph() to the reference graph - virtual void copy_data_to_ref(ggml_context * ctx, ggml_context * ctx_ref){ - std::map inputs; - std::map inputs_ref; - - for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) { - for(auto e : input_names){ - if(e.first == t->name){ - inputs[e.first] = t; - } - } - } - - for (ggml_tensor * t = ggml_get_first_tensor(ctx_ref); t != nullptr; t = ggml_get_next_tensor(ctx_ref, t)) { - for(auto e : input_names){ - if(e.second == t->name){ - inputs_ref[e.second] = t; - } - } - } - - for(auto e : input_names){ - GGML_ASSERT(inputs.count(e.first) == 1); - GGML_ASSERT(inputs_ref.count(e.second) == 1); - std::vector buf(ggml_nbytes(inputs[e.first])); - ggml_backend_tensor_get(inputs[e.first], buf.data(), 0, ggml_nbytes(inputs[e.first])); - ggml_backend_tensor_set(inputs_ref[e.second], buf.data(), 0, buf.size()); - } - } - - // Graph of the reference op implementation - virtual ggml_tensor * build_graph_ref(ggml_context * ctx) = 0; - - // Compares the output of the actual graph to the output of the reference - bool eval(ggml_backend_t backend1, ggml_backend_t backend2, const char * op_name, printer * output_printer) override { - mode = MODE_TEST; - - ggml_init_params params = { - /* .mem_size = */ ggml_tensor_overhead()*128 + ggml_graph_overhead(), - /* .mem_base = */ NULL, - /* .no_alloc = */ true, - }; - ggml_context * ctx = ggml_init(params); - ggml_context * ctx_ref = ggml_init(params); - GGML_ASSERT(ctx); - GGML_ASSERT(ctx_ref); - - gf = ggml_new_graph(ctx); - gf_ref = ggml_new_graph(ctx_ref); - - // pre-graph sentinel - add_sentinel(ctx); - add_sentinel(ctx_ref); - - ggml_tensor * out = build_graph(ctx); - ggml_tensor * out_ref = build_graph_ref(ctx_ref); - - std::string current_op_name = op_desc(out); - - if (op_name != nullptr && op_desc(out) != op_name) { - //printf(" %s: skipping\n", op_desc(out).c_str()); - ggml_free(ctx); - return true; - } - - // check if the backends support the ops - bool supported = true; - ggml_backend* backend_tested = nullptr; - for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { - if (!ggml_backend_supports_op(backend1, t)) { - supported = false; - backend_tested = backend1; - break; - } - } - - if(supported){ - for (ggml_tensor * t = ggml_get_first_tensor(ctx_ref); t != NULL; t = ggml_get_next_tensor(ctx_ref, t)) { - if (!ggml_backend_supports_op(backend2, t)) { - supported = false; - backend_tested = backend2; - break; - } - } - } - - if (!supported) { - // Create test result for unsupported operation - test_result result(ggml_backend_name(backend_tested), current_op_name, vars(), "test", - false, false, "not supported"); - if (output_printer) { - output_printer->print_test_result(result); - } - - ggml_free(ctx); - return true; - } - - // post-graph sentinel - add_sentinel(ctx); - add_sentinel(ctx_ref); - - // allocate - ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx, backend1); - - if (buf == NULL) { - printf("failed to allocate tensors [%s] ", ggml_backend_name(backend1)); - ggml_free(ctx); - return false; - } - - ggml_backend_buffer_t buf_ref = ggml_backend_alloc_ctx_tensors(ctx_ref, backend2); - if (buf_ref == NULL) { - printf("failed to allocate tensors [%s] ", ggml_backend_name(backend2)); - ggml_free(ctx_ref); - return false; - } - - // build graph - ggml_build_forward_expand(gf, out); - ggml_build_forward_expand(gf_ref, out_ref); - - // add sentinels as graph nodes so that they are checked in the callback - for (ggml_tensor * sentinel : sentinels) { - ggml_graph_add_node(gf, sentinel); - ggml_graph_add_node(gf_ref, sentinel); - } - - // randomize tensors - initialize_tensors(ctx); - copy_data_to_ref(ctx, ctx_ref); - - // compare - struct callback_userdata { - bool ok; - double max_err; - ggml_backend_t backend1; - ggml_backend_t backend2; - }; - - callback_userdata ud { - true, - max_nmse_err(), - backend1, - backend2 - }; - - auto callback = [](int index, ggml_tensor * t1, ggml_tensor * t2, void * user_data) -> bool { - callback_userdata * ud = (callback_userdata *) user_data; - const char * bn1 = ggml_backend_name(ud->backend1); - const char * bn2 = ggml_backend_name(ud->backend2); - - if (t1->op == GGML_OP_NONE) { - // sentinels must be unchanged - std::vector t1_data(ggml_nbytes(t1)); - std::vector t2_data(ggml_nbytes(t2)); - ggml_backend_tensor_get(t1, t1_data.data(), 0, ggml_nbytes(t1)); - ggml_backend_tensor_get(t2, t2_data.data(), 0, ggml_nbytes(t2)); - - if (memcmp(t1_data.data(), t2_data.data(), ggml_nbytes(t1)) != 0) { - printf("sentinel mismatch: %s ", t1->name); - ud->ok = false; - return true; - } - } - - std::vector f1 = tensor_to_float(t1); - std::vector f2 = tensor_to_float(t2); - - for (size_t i = 0; i < f1.size(); i++) { - // check for nans - if (std::isnan(f1[i]) || std::isnan(f2[i])) { - printf("[%s] NaN at index %zu (%s=%f %s=%f) ", ggml_op_desc(t1), i, bn1, f1[i], bn2, f2[i]); - ud->ok = false; - return true; - } - // check for infs: both must be inf of the same sign, or both must be finite - if (isinf_or_max(f1[i]) || isinf_or_max(f2[i])) { - if (isinf_or_max(f1[i]) && isinf_or_max(f2[i])) { - if (std::signbit(f1[i]) != std::signbit(f2[i])) { - printf("[%s] inf sign mismatch: %s=%f %s=%f ", ggml_op_desc(t1), bn1, f1[i], bn2, f2[i]); - ud->ok = false; - return true; - } - } else { - printf("[%s] inf mismatch: %s=%f %s=%f ", ggml_op_desc(t1), bn1, f1[i], bn2, f2[i]); - ud->ok = false; - return true; - } - } - } - - double err = nmse(f1.data(), f2.data(), f1.size()); - if (err > ud->max_err) { - printf("[%s] NMSE = %.9f > %.9f ", ggml_op_desc(t1), err, ud->max_err); - //for (int i = 0; i < (int) f1.size(); i++) { - // printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]); - //} - //printf("\n"); - //exit(1); - ud->ok = false; - } - return true; - - GGML_UNUSED(index); - }; - - - const bool cmp_ok = ggml_backend_compare_graph_backend_node(backend1, backend2, gf, gf_ref, callback, &ud, const_cast(output_node_name), const_cast(output_node_name_ref)); - - if (!cmp_ok) { - printf("compare failed "); - } - - ggml_backend_buffer_free(buf); - ggml_backend_buffer_free(buf_ref); - - ggml_free(ctx); - ggml_free(ctx_ref); - - // Create test result - bool test_passed = ud.ok && cmp_ok; - std::string error_msg = test_passed ? "" : (!cmp_ok ? "compare failed" : "test failed"); - test_result result(ggml_backend_name(backend1), current_op_name, vars(), "test", supported, test_passed, - error_msg); - - if (output_printer) { - output_printer->print_test_result(result); - } - - return test_passed; - } -}; // ################################### // ## Section 2: GGML Op Defintions ## @@ -3970,9 +3682,8 @@ struct test_im2col : public test_case { } }; -// Tests CONV_2D by comparing it to the IM2COL -> MUL_MM -// reference implementation. -struct test_conv_2d : public test_case_ref { +// CONV_2D +struct test_conv_2d : public test_case { const std::array ne_input; const std::array ne_kernel; const int stride0; @@ -3989,16 +3700,6 @@ struct test_conv_2d : public test_case_ref { // CONV_2D graph will be built, while // * if the program is called with -o CONV_2D_INDIRECT_IMPL, the // IM2COL -> MUL_MM graph will be built. - const bool direct_impl; - - virtual std::string op_desc(ggml_tensor * t) override { - (void) t; - if(direct_impl){ - return std::string("CONV_2D_DIRECT_IMPL"); - }else{ - return std::string("CONV_2D_INDIRECT_IMPL"); - } - } std::string vars() override { return VARS_TO_STR9(ne_input, ne_kernel, stride0, stride1, padding0, padding1, dilation0, dilation1, cwhn); @@ -4034,41 +3735,11 @@ struct test_conv_2d : public test_case_ref { test_conv_2d(std::array ne_input = {64, 64, 16, 1}, std::array ne_kernel = {3, 3, 1, 16}, - int stride0 = 1, int stride1 = 1, int padding0 = 0, int padding1 = 0, int dilation0 = 1, int dilation1 = 1, bool cwhn = false, bool direct_impl = true) - : ne_input(ne_input), ne_kernel(ne_kernel), stride0(stride0), stride1(stride1), padding0(padding0), padding1(padding1), dilation0(dilation0), dilation1(dilation1), cwhn(cwhn), direct_impl(direct_impl) { - output_node_name_ref = "out"; - output_node_name = "out"; - } - - ggml_tensor * build_graph_indirect(ggml_context * ctx) { - ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input.data()); - ggml_set_name(input, "input"); - - ggml_tensor * kernel = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_kernel.data()); - ggml_set_name(kernel, "kernel"); - - GGML_ASSERT(cwhn==false); - - if (cwhn) { - // change memory layout to channel-most-contiguous (CWHN), - // then permute it back so NE matches the original input - input = ggml_cont(ctx, ggml_permute(ctx, input, 1, 2, 0, 3)); - input = ggml_permute(ctx, input, 2, 0, 1, 3); - kernel = ggml_cont(ctx, ggml_permute(ctx, kernel, 2, 3, 1, 0)); - kernel = ggml_permute(ctx, kernel, 3, 2, 0, 1); + int stride0 = 1, int stride1 = 1, int padding0 = 0, int padding1 = 0, int dilation0 = 1, int dilation1 = 1, bool cwhn = false) + : ne_input(ne_input), ne_kernel(ne_kernel), stride0(stride0), stride1(stride1), padding0(padding0), padding1(padding1), dilation0(dilation0), dilation1(dilation1), cwhn(cwhn) { } - ggml_tensor * conv2d_out = ggml_conv_2d( - ctx, kernel, input, - stride0, stride1, padding0, padding1, dilation0, dilation1); - - ggml_tensor * out = ggml_cont(ctx, conv2d_out); - - ggml_set_name(out, "out"); - return out; - } - - ggml_tensor * build_graph_direct(ggml_context * ctx) { + ggml_tensor * build_graph(ggml_context * ctx) { ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input.data()); ggml_set_name(input, "input"); @@ -4090,19 +3761,6 @@ struct test_conv_2d : public test_case_ref { ggml_set_name(out, "out"); return out; } - - ggml_tensor * build_graph(ggml_context * ctx) override { - if(direct_impl){ - return build_graph_direct(ctx); - }else{ - return build_graph_indirect(ctx); - } - } - - // Reference always uses the indirect impl. - ggml_tensor * build_graph_ref(ggml_context * ctx) override { - return build_graph_indirect(ctx); - } }; // GGML_OP_CONV_2D_DW @@ -5413,6 +5071,44 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 2560}, {3, 3, 1, 2560}, 1, 1, 1, 1, 1, 1, true)); test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2560}, {3, 3, 2, 2560}, 1, 1, 1, 1, 1, 1, true)); + // Conv_2D test cases + #ifdef DETAILED_TESTS + // Probably we do not have enough time to execute these in the pipeline. + uint32_t iwh_idx = 0; + uint32_t kwh_idx = 1; + uint32_t Cout_idx = 2; + uint32_t Cin_idx = 3; + uint32_t B_idx = 4; + + std::vector> cases = { + //{IWH, KWH, Cout, Cin, B} + // K=CRS=NPQ=4096 conv2d matmul performance + {19, 4, 4096, 256, 16}, + // K=128, CRS=128, NPQ=4096 + {19, 4, 128, 8, 16}, + // K=130, CRS=128, NPQ=4096 + {19, 4, 130, 8, 16}, + // Edge case: K x CRS is small + {19, 2, 4, 4, 16}, + // A ConvNet's first layer + {224, 3, 8, 3, 1}, + // A ConvNet's first layer with 2x2 convolution, and 1 channel + {224, 2, 8, 1, 1}, + // A ConvNet's first layer with 2x2 convolution, and 1 channel, several images in the batch + {224, 2, 8, 1, 8}, + // A middle layer of a ConvNet + {58, 3, 64, 32, 1}, + // A middle layer of a ConvNet, several images in the batch + {58, 3, 64, 32, 8}, + // A deep layer of a ConvNet, several images in the batch + {16, 3, 256, 128, 8} + }; + + for(auto act_case : cases){ + test_cases.emplace_back(new test_conv_2d({act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx]}, {act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx]}, 1, 1, 0, 0, 1, 1, false)); + } + #endif + // CONV_2D: auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t { return (ins + 2 * p - d * (ks - 1) - 1) / s + 1; @@ -5432,9 +5128,9 @@ static std::vector> make_test_cases_eval() { for(uint32_t KH : {1, 2, 3, 11}){ for(uint32_t KW : {1, 2, 3, 11}){ for(uint32_t H : {1, 133}){ - for(uint32_t W : {1, 258}){ + for(uint32_t W : {1, 141}){ if(calc_conv_output_size(W, KW, s0, p0, d0) > 0 && calc_conv_output_size(H, KH, s1, p1, d1) > 0){ - test_cases.emplace_back(new test_conv_2d({W, H, Cin, 2}, {KW, KH, Cin, Cout}, s0, s1, p0, p1, d0, d1, false, true)); + test_cases.emplace_back(new test_conv_2d({W, H, Cin, 2}, {KW, KH, Cin, Cout}, s0, s1, p0, p1, d0, d1, false)); } } } @@ -5445,39 +5141,6 @@ static std::vector> make_test_cases_eval() { } } - uint32_t iwh_idx = 0; - uint32_t kwh_idx = 1; - uint32_t Cout_idx = 2; - uint32_t Cin_idx = 3; - uint32_t B_idx = 4; - std::vector> cases = { - //{IWH, KWH, Cout, Cin, B} - // K=CRS=NPQ=4096 conv2d matmul performance - {19, 4, 4096, 256, 16}, // --> fails - // K=128, CRS=128, NPQ=4096 - {19, 4, 128, 8, 16}, - // K=130, CRS=128, NPQ=4096 - {19, 4, 130, 8, 16}, - // Edge case: K x CRS is small - {19, 2, 4, 4, 16}, - // A ConvNet's first layer - {224, 3, 8, 3, 1}, - // A ConvNet's first layer with 2x2 convolution, and 1 channel - {224, 2, 8, 1, 1}, - // A ConvNet's first layer with 2x2 convolution, and 1 channel, several images in the batch - {224, 2, 8, 1, 8}, - // A middle layer of a ConvNet - {58, 3, 64, 32, 1}, - // A middle layer of a ConvNet, several images in the batch - {58, 3, 64, 32, 8}, - // A deep layer of a ConvNet, several images in the batch - {16, 3, 256, 128, 8} - }; - - for(auto act_case : cases){ - test_cases.emplace_back(new test_conv_2d({act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx]}, {act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx]}, 1, 1, 0, 0, 1, 1, false, true)); - } - // sycl backend will limit task global_range < MAX_INT // test cases for 2D im2col with large input W and H (occurs in stable-diffusion) // however these cases need to alloc more memory which may fail in some devices (Intel Arc770, etc.) @@ -6072,7 +5735,7 @@ static std::vector> make_test_cases_eval() { static std::vector> make_test_cases_perf() { std::vector> test_cases; - // Conv2d: K=CRS=NPQ=4096 matmul performance +// Conv2d: K=CRS=NPQ=4096 matmul performance uint32_t iwh_idx = 0; uint32_t kwh_idx = 1; uint32_t Cout_idx = 2; @@ -6104,9 +5767,8 @@ static std::vector> make_test_cases_perf() { for(auto act_case : cases){ // Direct CONV_2D - test_cases.emplace_back(new test_conv_2d({act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx]}, {act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx]}, 1, 1, 0, 0, 1, 1, false, true)); + test_cases.emplace_back(new test_conv_2d({act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx]}, {act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx]}, 1, 1, 0, 0, 1, 1, false)); // Indirect CONV_2D (uses im2col + sgemm) - test_cases.emplace_back(new test_conv_2d({act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx]}, {act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx]}, 1, 1, 0, 0, 1, 1, false, false)); } test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 1, 1, 1})); From f8295bfc76bb3774083cc55dd3a529a574fc68a3 Mon Sep 17 00:00:00 2001 From: Ervin Tasnadi Date: Wed, 16 Jul 2025 22:33:49 +0200 Subject: [PATCH 8/9] Trailing spaces and missing override fixed. --- tests/test-backend-ops.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 0853e482031e9..8c10ce7281cb7 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -3739,7 +3739,7 @@ struct test_conv_2d : public test_case { : ne_input(ne_input), ne_kernel(ne_kernel), stride0(stride0), stride1(stride1), padding0(padding0), padding1(padding1), dilation0(dilation0), dilation1(dilation1), cwhn(cwhn) { } - ggml_tensor * build_graph(ggml_context * ctx) { + ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input.data()); ggml_set_name(input, "input"); @@ -5078,7 +5078,7 @@ static std::vector> make_test_cases_eval() { uint32_t kwh_idx = 1; uint32_t Cout_idx = 2; uint32_t Cin_idx = 3; - uint32_t B_idx = 4; + uint32_t B_idx = 4; std::vector> cases = { //{IWH, KWH, Cout, Cin, B} @@ -5103,10 +5103,10 @@ static std::vector> make_test_cases_eval() { // A deep layer of a ConvNet, several images in the batch {16, 3, 256, 128, 8} }; - + for(auto act_case : cases){ test_cases.emplace_back(new test_conv_2d({act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx]}, {act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx]}, 1, 1, 0, 0, 1, 1, false)); - } + } #endif // CONV_2D: From 8ddaf2d5e65470f72c6231aa2209a1807f7adaf9 Mon Sep 17 00:00:00 2001 From: Ervin Tasnadi Date: Fri, 18 Jul 2025 02:50:58 +0200 Subject: [PATCH 9/9] Triggering pipeline relaunch. --- tests/test-backend-ops.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 8c10ce7281cb7..832888ee8fae0 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -5082,7 +5082,7 @@ static std::vector> make_test_cases_eval() { std::vector> cases = { //{IWH, KWH, Cout, Cin, B} - // K=CRS=NPQ=4096 conv2d matmul performance + // K=CRS=NPQ=4096 conv_2d matmul performance {19, 4, 4096, 256, 16}, // K=128, CRS=128, NPQ=4096 {19, 4, 128, 8, 16},