From a856a5665d5cb8b683065166d06015c5ce1334d8 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 18 Jul 2025 13:36:27 +0300 Subject: [PATCH] tests : add non-cont K,V FA tests ggml-ci --- tests/test-backend-ops.cpp | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index a3d68fba046cf..a0ab5b9257e8c 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4258,26 +4258,32 @@ struct test_flash_attn_ext : public test_case { const int64_t hsk_padded = GGML_PAD(hsk, ggml_blck_size(type_KV)); const int64_t hsv_padded = GGML_PAD(hsv, ggml_blck_size(type_KV)); - auto const &create_permuted = [&](ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) -> ggml_tensor * { + auto const &create_permuted = [&](ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, bool is_view) -> ggml_tensor * { int64_t ne[4] = {ne0, ne1, ne2, ne3}; int64_t ne_perm[4]; for (int i = 0; i < 4; ++i) { ne_perm[permute[i]] = ne[i]; } - ggml_tensor * t = ggml_new_tensor_4d(ctx, type, ne_perm[0], ne_perm[1], ne_perm[2], ne_perm[3]); + ggml_tensor * t; + if (is_view) { + ggml_tensor * t0 = ggml_new_tensor_4d(ctx, type, ne_perm[0], 2*ne_perm[1], ne_perm[2], ne_perm[3]); + t = ggml_view_4d(ctx, t0, ne_perm[0], ne_perm[1], ne_perm[2], ne_perm[3], t0->nb[1], t0->nb[2], t0->nb[3], 0); + } else { + t = ggml_new_tensor_4d(ctx, type, ne_perm[0], ne_perm[1], ne_perm[2], ne_perm[3]); + } if (permute != std::array{0, 1, 2, 3}) { t = ggml_permute(ctx, t, permute[0], permute[1], permute[2], permute[3]); } return t; }; - ggml_tensor * q = create_permuted(GGML_TYPE_F32, hsk_padded, nb, nh*nr23[0], nr23[1]); + ggml_tensor * q = create_permuted(GGML_TYPE_F32, hsk_padded, nb, nh*nr23[0], nr23[1], false); ggml_set_name(q, "q"); - ggml_tensor * k = create_permuted(type_KV, hsk_padded, kv, nh, nr23[1]); + ggml_tensor * k = create_permuted(type_KV, hsk_padded, kv, nh, nr23[1], true); // the K tensor is usually a view of the K cache ggml_set_name(k, "k"); - ggml_tensor * v = create_permuted(type_KV, hsv_padded, kv, nh, nr23[1]); + ggml_tensor * v = create_permuted(type_KV, hsv_padded, kv, nh, nr23[1], true); // the V tensor is usually a view of the V cache ggml_set_name(v, "v"); ggml_tensor * m = nullptr;