@@ -7509,20 +7509,23 @@ static void ggml_compute_forward_rope_f32(
7509
7509
// row index used to determine which thread to use
7510
7510
int ir = 0 ;
7511
7511
7512
- const float theta_scale = powf (10000.0 , (( float ) -2 ) /n_dims );
7512
+ const float theta_scale = powf (10000.0 , -2.0f /n_dims );
7513
7513
7514
7514
for (int64_t i3 = 0 ; i3 < ne3 ; i3 ++ ) {
7515
7515
for (int64_t i2 = (mode == 0 ? 0 : n_past ); i2 < ne2 ; i2 ++ ) {
7516
7516
const int p = (mode == 0 ? n_past + i2 : i2 );
7517
7517
for (int64_t i1 = 0 ; i1 < ne1 ; i1 ++ ) {
7518
7518
if (ir ++ < ir0 ) continue ;
7519
7519
if (ir > ir1 ) break ;
7520
+
7520
7521
float theta = (float )p ;
7522
+
7521
7523
for (int i0 = 0 ; i0 < n_dims ; i0 += 2 ) {
7522
7524
const float cos_theta = cosf (theta );
7523
7525
const float sin_theta = sinf (theta );
7524
7526
7525
7527
theta *= theta_scale ;
7528
+
7526
7529
const float * const src = (float * )((char * ) src0 -> data + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0 );
7527
7530
float * dst_data = (float * )((char * ) dst -> data + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0 );
7528
7531
@@ -7583,20 +7586,23 @@ static void ggml_compute_forward_rope_f16(
7583
7586
// row index used to determine which thread to use
7584
7587
int ir = 0 ;
7585
7588
7586
- const float theta_scale = powf (10000.0 , (( float ) -2 ) /n_dims );
7589
+ const float theta_scale = powf (10000.0 , -2.0f /n_dims );
7587
7590
7588
7591
for (int64_t i3 = 0 ; i3 < ne3 ; i3 ++ ) {
7589
7592
for (int64_t i2 = (mode == 0 ? 0 : n_past ); i2 < ne2 ; i2 ++ ) {
7590
7593
const int p = (mode == 0 ? n_past + i2 : i2 );
7591
7594
for (int64_t i1 = 0 ; i1 < ne1 ; i1 ++ ) {
7592
7595
if (ir ++ < ir0 ) continue ;
7593
7596
if (ir > ir1 ) break ;
7597
+
7594
7598
float theta = (float )p ;
7599
+
7595
7600
for (int i0 = 0 ; i0 < n_dims ; i0 += 2 ) {
7596
7601
const float cos_theta = cosf (theta );
7597
7602
const float sin_theta = sinf (theta );
7598
7603
7599
7604
theta *= theta_scale ;
7605
+
7600
7606
const ggml_fp16_t * const src = (ggml_fp16_t * )((char * ) src0 -> data + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0 );
7601
7607
ggml_fp16_t * dst_data = (ggml_fp16_t * )((char * ) dst -> data + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0 );
7602
7608
0 commit comments