Skip to content

Commit b8eeb87

Browse files
authored
vulkan : fix rope with partial rotation and non-cont src (#14582)
1 parent 17a1f0d commit b8eeb87

File tree

3 files changed

+21
-27
lines changed

3 files changed

+21
-27
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,19 @@ void main() {
1414

1515
const uint row_dst = gl_GlobalInvocationID.x;
1616

17-
if (i0 >= p.n_dims) {
18-
const uint i = row_dst*ne0 + i0;
19-
20-
data_d[i + 0] = data_a[i + 0];
21-
data_d[i + 1] = data_a[i + 1];
22-
23-
return;
24-
}
25-
2617
const uint row_x = row_dst % ne1;
2718
const uint channel_x = row_dst / ne1;
2819

2920
const uint idst = row_dst*ne0 + i0/2;
3021
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
3122

23+
if (i0 >= p.n_dims) {
24+
data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0];
25+
data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1];
26+
27+
return;
28+
}
29+
3230
const int sect_dims = p.sections[0] + p.sections[1] + p.sections[2] + p.sections[3];
3331
const int sec_w = p.sections[1] + p.sections[0];
3432
const uint sector = (i0 / 2) % sect_dims;

ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,19 @@ void main() {
1313

1414
const uint row_dst = gl_GlobalInvocationID.x;
1515

16-
if (i0 >= p.n_dims) {
17-
const uint i = row_dst*ne0 + i0;
18-
19-
data_d[i + 0] = data_a[i + 0];
20-
data_d[i + 1] = data_a[i + 1];
21-
22-
return;
23-
}
24-
2516
const uint row_x = row_dst % ne1;
2617
const uint channel_x = row_dst / ne1;
2718

2819
const uint idst = row_dst*ne0 + i0/2;
2920
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
3021

22+
if (i0 >= p.n_dims) {
23+
data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0];
24+
data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1];
25+
26+
return;
27+
}
28+
3129
const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f);
3230

3331
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;

ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,19 @@ void main() {
1313

1414
const uint row_dst = gl_GlobalInvocationID.x;
1515

16-
if (i0 >= p.n_dims) {
17-
const uint i = row_dst*ne0 + i0;
18-
19-
data_d[i + 0] = data_a[i + 0];
20-
data_d[i + 1] = data_a[i + 1];
21-
22-
return;
23-
}
24-
2516
const uint row_x = row_dst % ne1;
2617
const uint channel_x = row_dst / ne1;
2718

2819
const uint idst = row_dst*ne0 + i0;
2920
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0;
3021

22+
if (i0 >= p.n_dims) {
23+
data_d[idst + 0] = data_a[ix + 0];
24+
data_d[idst + 1] = data_a[ix + 1];
25+
26+
return;
27+
}
28+
3129
const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f);
3230

3331
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;

0 commit comments

Comments
 (0)