Skip to content

Commit e753b9a

Browse files
netrunnereveremyoudompheng
authored andcommitted
vulkan: increase LOAD_VEC_A to 8 (IQ1/IQ2) or 4 (IQ3) (llama/14485)
Commit taken from remyoudompheng's PR ggml-org/llama.cpp#12260 Co-authored-by: Rémy Oudompheng <remyoudompheng@gmail.com>
1 parent 9d0c408 commit e753b9a

File tree

2 files changed

+83
-73
lines changed

2 files changed

+83
-73
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp

Lines changed: 81 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -500,10 +500,9 @@ void main() {
500500
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
501501
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
502502

503-
const uint ib = idx / 128; // 2 values per idx
504-
const uint ib32 = (idx % 128) / 16; // 0..7
505-
const uint ib8 = (idx % 128) / 4;
506-
const int i8 = 2 * int(idx % 4);
503+
const uint ib = idx / 32; // 8 values per idx
504+
const uint ib32 = (idx % 32) / 4; // 0..7
505+
const uint ib8 = idx % 32;
507506

508507
const float d = float(data_a[ib].d);
509508
const uint qh = data_a[ib].qh[ib32];
@@ -512,22 +511,16 @@ void main() {
512511
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
513512
const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]);
514513

515-
const ivec2 gvec = ivec2(
516-
bitfieldExtract(grid, 2 * (i8), 2),
517-
bitfieldExtract(grid, 2 * (i8 + 1), 2)
518-
);
519-
const vec2 v = dl * (vec2(gvec) + delta);
520-
521-
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
522-
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
514+
[[unroll]] for (int k = 0; k < 8; ++k) {
515+
buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta));
516+
}
523517
#elif defined(DATA_A_IQ1_M)
524518
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
525519
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
526520

527-
const uint ib = idx / 128; // 2 values per idx
528-
const uint ib8 = (idx % 128) / 4;
521+
const uint ib = idx / 32; // 8 values per idx
522+
const uint ib8 = idx % 32;
529523
const uint ib16 = ib8 / 2;
530-
const int i8 = 2 * int(idx % 4);
531524

532525
const uint16_t[4] scales = data_a[ib].scales;
533526
const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
@@ -538,21 +531,17 @@ void main() {
538531
const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1);
539532
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
540533
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
541-
const ivec2 gvec = ivec2(
542-
bitfieldExtract(grid, 2 * (i8), 2),
543-
bitfieldExtract(grid, 2 * (i8 + 1), 2)
544-
);
545-
const vec2 v = dl * (vec2(gvec) + delta);
546534

547-
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
548-
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
535+
[[unroll]] for (int k = 0; k < 8; ++k) {
536+
buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta));
537+
}
549538
#elif defined(DATA_A_IQ2_XXS)
550539
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
551540
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
552541

553-
const uint ib = idx / 128; // 2 values per idx
554-
const uint ib32 = (idx % 128) / 16; // 0..7
555-
const uint ib8 = (idx / 4) % 4;
542+
const uint ib = idx / 32; // 8 values per idx
543+
const uint ib32 = (idx % 32) / 4; // 0..7
544+
const uint ib8 = idx % 4;
556545

557546
const float d = float(data_a[ib].d);
558547
const uint qs = data_a[ib].qs[8 * ib32 + ib8];
@@ -562,63 +551,81 @@ void main() {
562551
data_a[ib].qs[8*ib32 + 6],
563552
data_a[ib].qs[8*ib32 + 7]
564553
));
565-
const float db = d * 0.25 * (0.5 + (signs >> 28));
554+
const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + (signs >> 28)));
566555
const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7);
567-
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
568-
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
569-
const uint grid = iq2xxs_grid[qs][(idx % 4) / 2] >> (16 * (idx & 1));
570-
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
571-
572-
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
573-
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
556+
const uint sign = sign7 | (bitCount(sign7) << 7);
557+
const uvec2 grid = iq2xxs_grid[qs];
558+
const vec4 grid0 = vec4(unpack8(grid.x));
559+
const vec4 grid1 = vec4(unpack8(grid.y));
560+
561+
buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
562+
buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
563+
buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
564+
buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
565+
buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
566+
buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
567+
buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
568+
buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
574569
#elif defined(DATA_A_IQ2_XS)
575570
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
576571
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
577572

578-
const uint ib = idx / 128; // 2 values per idx
579-
const uint ib32 = (idx % 128) / 16; // 0..7
580-
const uint ib8 = (idx / 4) % 4; // 0..3
573+
const uint ib = idx / 32; // 8 values per idx
574+
const uint ib32 = (idx % 32) / 4; // 0..7
575+
const uint ib8 = idx % 4; // 0..3
581576

582577
const float d = float(data_a[ib].d);
583578
const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
584-
const float db = d * 0.25 * (0.5 + scale);
579+
const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale));
585580
const uint qs = data_a[ib].qs[4 * ib32 + ib8];
586581
const uint sign7 = qs >> 9;
587-
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
588-
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
589-
const uint grid = iq2xs_grid[qs & 511][(idx % 4) / 2] >> (16 * (idx & 1));
590-
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
591-
592-
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
593-
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
582+
const uint sign = sign7 | (bitCount(sign7) << 7);
583+
const uvec2 grid = iq2xs_grid[qs & 511];
584+
const vec4 grid0 = vec4(unpack8(grid.x));
585+
const vec4 grid1 = vec4(unpack8(grid.y));
586+
587+
buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
588+
buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
589+
buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
590+
buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
591+
buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
592+
buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
593+
buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
594+
buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
594595
#elif defined(DATA_A_IQ2_S)
595596
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
596597
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
597598

598-
const uint ib = idx / 128; // 2 values per idx
599-
const uint ib8 = (idx % 128) / 4; // 0..31
600-
const uint ib32 = ib8 / 4; // 0..7
599+
const uint ib = idx / 32; // 8 values per idx
600+
const uint ib8 = idx % 32; // 0..31
601+
const uint ib32 = ib8 / 4; // 0..7
601602

602603
const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
603604
const uint qs = data_a[ib].qs[ib8];
604605
const uint qh = data_a[ib].qh[ib32];
605606
const uint qhshift = 2 * (ib8 % 4);
606-
const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8] >> (2 * (idx % 4));
607+
const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8];
607608

608609
const float d = float(data_a[ib].d);
609-
const float db = d * 0.25 * (0.5 + scale);
610-
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
611-
const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1];
612-
const vec2 v = db * vec2(sign01) * vec2(unpack8(uint32_t(grid)).xy); // vec4 used due to #12147
613-
614-
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
615-
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
610+
const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale));
611+
const uvec2 grid = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)];
612+
const vec4 grid0 = vec4(unpack8(grid.x));
613+
const vec4 grid1 = vec4(unpack8(grid.y));
614+
615+
buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
616+
buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
617+
buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
618+
buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
619+
buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
620+
buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
621+
buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
622+
buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
616623
#elif defined(DATA_A_IQ3_XXS)
617624
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
618625
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
619626

620-
const uint ib = idx / 128; // 2 values per idx
621-
const uint iqs = (idx % 128) / 2; // 0..63
627+
const uint ib = idx / 64; // 4 values per idx
628+
const uint iqs = idx % 64; // 0..63
622629
const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values
623630

624631
const float d = float(data_a[ib].d);
@@ -631,33 +638,36 @@ void main() {
631638
));
632639
const float db = d * 0.5 * (0.5 + (signs >> 28));
633640
const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);
634-
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
635-
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
636-
const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1));
637-
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
638-
639-
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
640-
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
641+
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (4 * (idx % 2));
642+
const uint grid = iq3xxs_grid[qs];
643+
const vec4 v = db * vec4(unpack8(grid));
644+
645+
buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x);
646+
buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y);
647+
buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z);
648+
buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w);
641649
#elif defined(DATA_A_IQ3_S)
642650
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
643651
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
644652

645-
const uint ib = idx / 128; // 2 values per idx
646-
const uint iqs = (idx % 128) / 2; // 0..63
653+
const uint ib = idx / 64; // 4 values per idx
654+
const uint iqs = idx % 64; // 0..63
647655
const uint iqh = iqs / 8;
648656

649657
const float d = float(data_a[ib].d);
650658
const uint qs = data_a[ib].qs[iqs];
651659
const uint qh = data_a[ib].qh[iqh];
652-
const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (2 * (idx % 4)));
660+
const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (4 * (idx % 2)));
653661
const uint scale = data_a[ib].scales[iqs / 16];
654662
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign)));
655663
const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
656-
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2));
657-
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
664+
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)];
665+
const vec4 v = db * vec4(unpack8(grid));
658666

659-
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
660-
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
667+
buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x);
668+
buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y);
669+
buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z);
670+
buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w);
661671
#elif defined(DATA_A_IQ4_XS)
662672
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
663673
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -360,9 +360,9 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
360360

361361
for (const auto& tname : type_names) {
362362
std::string load_vec_quant = "2";
363-
if ((tname == "q4_0") || (tname == "q4_1"))
363+
if ((tname == "q4_0") || (tname == "q4_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s"))
364364
load_vec_quant = "8";
365-
else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq4_nl"))
365+
else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl"))
366366
load_vec_quant = "4";
367367

368368
if (tname == "bf16") {

0 commit comments

Comments
 (0)