@@ -732,28 +732,42 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
732
732
733
733
#if defined(__POWER9_VECTOR__ )
734
734
const vector float v85 = vec_splats (8.5f );
735
+ const vector signed int v15 = vec_splats (15 );
735
736
for (int i = 0 ; i < nb ; i ++ ) {
736
- float amax = 0.0f ; // absolute max
737
+ float max = 0.0f ;
738
+ float min = 0.0f ;
737
739
738
740
vector float srcv [8 ];
739
- vector float asrcv [8 ];
740
- vector float amaxv [8 ];
741
+ vector float maxv [8 ];
742
+ vector float minv [8 ];
741
743
742
744
for (int l = 0 ; l < 8 ; l ++ ) srcv [l ] = * (vector float * )(x + i * 32 + 4 * l );
743
- for (int l = 0 ; l < 8 ; l ++ ) asrcv [l ] = vec_abs (srcv [l ]);
745
+ // for (int l = 0; l < 8; l++) asrcv[l] = vec_abs(srcv[l]);
744
746
745
- for (int l = 0 ; l < 4 ; l ++ ) amaxv [2 * l ] = vec_max (asrcv [2 * l ], asrcv [2 * l + 1 ]);
746
- //for (int l = 0; l < 2; l++) amaxv [4*l] = vec_max(amaxv [4*l], amaxv [4*l+2]);
747
- amaxv [0 ] = vec_max (amaxv [0 ], amaxv [2 ]);
748
- amaxv [4 ] = vec_max (amaxv [4 ], amaxv [6 ]);
749
- //for (int l = 0; l < 1; l++) amaxv [8*l] = vec_max(amaxv [8*l], amaxv [8*l+4]);
750
- amaxv [0 ] = vec_max (amaxv [0 ], amaxv [4 ]);
747
+ for (int l = 0 ; l < 4 ; l ++ ) maxv [2 * l ] = vec_max (asrcv [2 * l ], asrcv [2 * l + 1 ]);
748
+ //for (int l = 0; l < 2; l++) maxv [4*l] = vec_max(maxv [4*l], maxv [4*l+2]);
749
+ maxv [0 ] = vec_max (maxv [0 ], maxv [2 ]);
750
+ maxv [4 ] = vec_max (maxv [4 ], maxv [6 ]);
751
+ //for (int l = 0; l < 1; l++) maxv [8*l] = vec_max(maxv [8*l], maxv [8*l+4]);
752
+ maxv [0 ] = vec_max (maxv [0 ], maxv [4 ]);
751
753
752
- amax = MAX (
753
- MAX (vec_extract (amaxv [0 ], 0 ), vec_extract (amaxv [0 ], 1 )),
754
- MAX (vec_extract (amaxv [0 ], 2 ), vec_extract (amaxv [0 ], 3 )));
754
+ for (int l = 0 ; l < 4 ; l ++ ) minv [2 * l ] = vec_min (asrcv [2 * l ], asrcv [2 * l + 1 ]);
755
+ //for (int l = 0; l < 2; l++) minv[4*l] = vec_min(minv[4*l], minv[4*l+2]);
756
+ minv [0 ] = vec_min (minv [0 ], minv [2 ]);
757
+ minv [4 ] = vec_min (minv [4 ], minv [6 ]);
758
+ //for (int l = 0; l < 1; l++) minv[8*l] = vec_min(minv[8*l], minv[8*l+4]);
759
+ minv [0 ] = vec_min (minv [0 ], minv [4 ]);
755
760
756
- const float d = amax / ((1 << 3 ) - 1 );
761
+
762
+ max = MAX (
763
+ MAX (vec_extract (maxv [0 ], 0 ), vec_extract (maxv [0 ], 1 )),
764
+ MAX (vec_extract (maxv [0 ], 2 ), vec_extract (maxv [0 ], 3 )));
765
+ min = MIN (
766
+ MIN (vec_extract (minv [0 ], 0 ), vec_extract (minv [0 ], 1 )),
767
+ MIN (vec_extract (minv [0 ], 2 ), vec_extract (minv [0 ], 3 )));
768
+
769
+ const float magnitude = max >= fabsf (min ) ? max : min ;
770
+ const float d = magnitude / -8 ;
757
771
const float id = d ? 1.0 /d : 0.0 ;
758
772
759
773
y [i ].d = d ;
@@ -763,9 +777,10 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
763
777
for (int l = 0 ; l < 8 ; l ++ ) {
764
778
const vector float vf = vec_madd (srcv [l ], vid , v85 );
765
779
const vector signed int vi = vec_signed (vf );
780
+ const vector signed int vc = vec_min (vi , v15 );
766
781
767
- pb [2 * l + 0 ] = vec_extract (vi , 0 ) | (vec_extract (vi , 1 ) << 4 );
768
- pb [2 * l + 1 ] = vec_extract (vi , 2 ) | (vec_extract (vi , 3 ) << 4 );
782
+ pb [2 * l + 0 ] = vec_extract (vc , 0 ) | (vec_extract (vc , 1 ) << 4 );
783
+ pb [2 * l + 1 ] = vec_extract (vc , 2 ) | (vec_extract (vc , 3 ) << 4 );
769
784
}
770
785
}
771
786
#elif __ARM_NEON
0 commit comments