@@ -500,10 +500,9 @@ void main() {
500
500
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
501
501
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
502
502
503
- const uint ib = idx / 128; // 2 values per idx
504
- const uint ib32 = (idx % 128) / 16; // 0..7
505
- const uint ib8 = (idx % 128) / 4;
506
- const int i8 = 2 * int(idx % 4);
503
+ const uint ib = idx / 32; // 8 values per idx
504
+ const uint ib32 = (idx % 32) / 4; // 0..7
505
+ const uint ib8 = idx % 32;
507
506
508
507
const float d = float(data_a[ib].d);
509
508
const uint qh = data_a[ib].qh[ib32];
@@ -512,22 +511,16 @@ void main() {
512
511
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
513
512
const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]);
514
513
515
- const ivec2 gvec = ivec2(
516
- bitfieldExtract(grid, 2 * (i8), 2),
517
- bitfieldExtract(grid, 2 * (i8 + 1), 2)
518
- );
519
- const vec2 v = dl * (vec2(gvec) + delta);
520
-
521
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
522
- buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
514
+ [[unroll]] for (int k = 0; k < 8; ++k) {
515
+ buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta));
516
+ }
523
517
#elif defined(DATA_A_IQ1_M)
524
518
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
525
519
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
526
520
527
- const uint ib = idx / 128 ; // 2 values per idx
528
- const uint ib8 = ( idx % 128) / 4 ;
521
+ const uint ib = idx / 32 ; // 8 values per idx
522
+ const uint ib8 = idx % 32 ;
529
523
const uint ib16 = ib8 / 2;
530
- const int i8 = 2 * int(idx % 4);
531
524
532
525
const uint16_t[4] scales = data_a[ib].scales;
533
526
const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
@@ -538,21 +531,17 @@ void main() {
538
531
const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1);
539
532
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
540
533
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
541
- const ivec2 gvec = ivec2(
542
- bitfieldExtract(grid, 2 * (i8), 2),
543
- bitfieldExtract(grid, 2 * (i8 + 1), 2)
544
- );
545
- const vec2 v = dl * (vec2(gvec) + delta);
546
534
547
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
548
- buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
535
+ [[unroll]] for (int k = 0; k < 8; ++k) {
536
+ buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta));
537
+ }
549
538
#elif defined(DATA_A_IQ2_XXS)
550
539
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
551
540
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
552
541
553
- const uint ib = idx / 128 ; // 2 values per idx
554
- const uint ib32 = (idx % 128 ) / 16 ; // 0..7
555
- const uint ib8 = ( idx / 4) % 4;
542
+ const uint ib = idx / 32 ; // 8 values per idx
543
+ const uint ib32 = (idx % 32 ) / 4 ; // 0..7
544
+ const uint ib8 = idx % 4;
556
545
557
546
const float d = float(data_a[ib].d);
558
547
const uint qs = data_a[ib].qs[8 * ib32 + ib8];
@@ -562,63 +551,81 @@ void main() {
562
551
data_a[ib].qs[8*ib32 + 6],
563
552
data_a[ib].qs[8*ib32 + 7]
564
553
));
565
- const float db = d * 0.25 * (0.5 + (signs >> 28));
554
+ const FLOAT_TYPE db = FLOAT_TYPE( d * 0.25 * (0.5 + (signs >> 28) ));
566
555
const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7);
567
- const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
568
- const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
569
- const uint grid = iq2xxs_grid[qs][(idx % 4) / 2] >> (16 * (idx & 1));
570
- const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
571
-
572
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
573
- buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
556
+ const uint sign = sign7 | (bitCount(sign7) << 7);
557
+ const uvec2 grid = iq2xxs_grid[qs];
558
+ const vec4 grid0 = vec4(unpack8(grid.x));
559
+ const vec4 grid1 = vec4(unpack8(grid.y));
560
+
561
+ buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
562
+ buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
563
+ buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
564
+ buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
565
+ buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
566
+ buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
567
+ buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
568
+ buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
574
569
#elif defined(DATA_A_IQ2_XS)
575
570
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
576
571
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
577
572
578
- const uint ib = idx / 128 ; // 2 values per idx
579
- const uint ib32 = (idx % 128 ) / 16; // 0..7
580
- const uint ib8 = ( idx / 4) % 4; // 0..3
573
+ const uint ib = idx / 32 ; // 8 values per idx
574
+ const uint ib32 = (idx % 32 ) / 4; // 0..7
575
+ const uint ib8 = idx % 4; // 0..3
581
576
582
577
const float d = float(data_a[ib].d);
583
578
const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
584
- const float db = d * 0.25 * (0.5 + scale);
579
+ const FLOAT_TYPE db = FLOAT_TYPE( d * 0.25 * (0.5 + scale) );
585
580
const uint qs = data_a[ib].qs[4 * ib32 + ib8];
586
581
const uint sign7 = qs >> 9;
587
- const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
588
- const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
589
- const uint grid = iq2xs_grid[qs & 511][(idx % 4) / 2] >> (16 * (idx & 1));
590
- const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
591
-
592
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
593
- buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
582
+ const uint sign = sign7 | (bitCount(sign7) << 7);
583
+ const uvec2 grid = iq2xs_grid[qs & 511];
584
+ const vec4 grid0 = vec4(unpack8(grid.x));
585
+ const vec4 grid1 = vec4(unpack8(grid.y));
586
+
587
+ buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
588
+ buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
589
+ buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
590
+ buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
591
+ buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
592
+ buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
593
+ buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
594
+ buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
594
595
#elif defined(DATA_A_IQ2_S)
595
596
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
596
597
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
597
598
598
- const uint ib = idx / 128 ; // 2 values per idx
599
- const uint ib8 = ( idx % 128) / 4 ; // 0..31
600
- const uint ib32 = ib8 / 4; // 0..7
599
+ const uint ib = idx / 32 ; // 8 values per idx
600
+ const uint ib8 = idx % 32 ; // 0..31
601
+ const uint ib32 = ib8 / 4; // 0..7
601
602
602
603
const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
603
604
const uint qs = data_a[ib].qs[ib8];
604
605
const uint qh = data_a[ib].qh[ib32];
605
606
const uint qhshift = 2 * (ib8 % 4);
606
- const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8] >> (2 * (idx % 4)) ;
607
+ const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8];
607
608
608
609
const float d = float(data_a[ib].d);
609
- const float db = d * 0.25 * (0.5 + scale);
610
- const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
611
- const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1];
612
- const vec2 v = db * vec2(sign01) * vec2(unpack8(uint32_t(grid)).xy); // vec4 used due to #12147
613
-
614
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
615
- buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
610
+ const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale));
611
+ const uvec2 grid = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)];
612
+ const vec4 grid0 = vec4(unpack8(grid.x));
613
+ const vec4 grid1 = vec4(unpack8(grid.y));
614
+
615
+ buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
616
+ buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
617
+ buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
618
+ buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
619
+ buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
620
+ buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
621
+ buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
622
+ buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
616
623
#elif defined(DATA_A_IQ3_XXS)
617
624
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
618
625
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
619
626
620
- const uint ib = idx / 128 ; // 2 values per idx
621
- const uint iqs = ( idx % 128) / 2 ; // 0..63
627
+ const uint ib = idx / 64 ; // 4 values per idx
628
+ const uint iqs = idx % 64 ; // 0..63
622
629
const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values
623
630
624
631
const float d = float(data_a[ib].d);
@@ -631,33 +638,36 @@ void main() {
631
638
));
632
639
const float db = d * 0.5 * (0.5 + (signs >> 28));
633
640
const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);
634
- const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
635
- const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
636
- const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1));
637
- const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
638
-
639
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
640
- buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
641
+ const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (4 * (idx % 2));
642
+ const uint grid = iq3xxs_grid[qs];
643
+ const vec4 v = db * vec4(unpack8(grid));
644
+
645
+ buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x);
646
+ buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y);
647
+ buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z);
648
+ buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w);
641
649
#elif defined(DATA_A_IQ3_S)
642
650
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
643
651
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
644
652
645
- const uint ib = idx / 128 ; // 2 values per idx
646
- const uint iqs = ( idx % 128) / 2 ; // 0..63
653
+ const uint ib = idx / 64 ; // 4 values per idx
654
+ const uint iqs = idx % 64 ; // 0..63
647
655
const uint iqh = iqs / 8;
648
656
649
657
const float d = float(data_a[ib].d);
650
658
const uint qs = data_a[ib].qs[iqs];
651
659
const uint qh = data_a[ib].qh[iqh];
652
- const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (2 * (idx % 4 )));
660
+ const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (4 * (idx % 2 )));
653
661
const uint scale = data_a[ib].scales[iqs / 16];
654
662
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign)));
655
663
const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
656
- const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2)) ;
657
- const vec2 v = db * vec2(sign01) * vec2( unpack8(grid).xy); // vec4 used due to #12147
664
+ const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)];
665
+ const vec4 v = db * vec4( unpack8(grid));
658
666
659
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
660
- buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
667
+ buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x);
668
+ buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y);
669
+ buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z);
670
+ buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w);
661
671
#elif defined(DATA_A_IQ4_XS)
662
672
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
663
673
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
0 commit comments