Skip to content

Commit b5940d7

Browse files
ggerganovqnixsynapse
authored andcommitted
cuda : fix rope with partial rotation and non-cont src (ggml-org#14580)
* cuda : fix rope non-cont ggml-ci * cont : fix multi-rope + add test ggml-ci * sycl : try fix ggml-ci * cont : fix sycl + clean-up cuda ggml-ci
1 parent 2a755e0 commit b5940d7

File tree

3 files changed

+50
-51
lines changed

3 files changed

+50
-51
lines changed

ggml/src/ggml-cuda/rope.cu

Lines changed: 21 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -50,21 +50,19 @@ static __global__ void rope_norm(
5050

5151
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
5252

53-
if (i0 >= n_dims) {
54-
const int i = row_dst*ne0 + i0;
55-
56-
dst[i + 0] = x[i + 0];
57-
dst[i + 1] = x[i + 1];
58-
59-
return;
60-
}
61-
6253
const int row_x = row_dst % ne1;
6354
const int channel_x = row_dst / ne1;
6455

6556
const int idst = row_dst*ne0 + i0;
6657
const int ix = channel_x*s2 + row_x*s1 + i0;
6758

59+
if (i0 >= n_dims) {
60+
dst[idst + 0] = x[ix + 0];
61+
dst[idst + 1] = x[ix + 1];
62+
63+
return;
64+
}
65+
6866
const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
6967

7068
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
@@ -94,21 +92,19 @@ static __global__ void rope_neox(
9492

9593
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
9694

97-
if (i0 >= n_dims) {
98-
const int i = row_dst*ne0 + i0;
99-
100-
dst[i + 0] = x[i + 0];
101-
dst[i + 1] = x[i + 1];
102-
103-
return;
104-
}
105-
10695
const int row_x = row_dst % ne1;
10796
const int channel_x = row_dst / ne1;
10897

10998
const int idst = row_dst*ne0 + i0/2;
11099
const int ix = channel_x*s2 + row_x*s1 + i0/2;
111100

101+
if (i0 >= n_dims) {
102+
dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
103+
dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];
104+
105+
return;
106+
}
107+
112108
const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
113109

114110
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
@@ -138,21 +134,19 @@ static __global__ void rope_multi(
138134

139135
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
140136

141-
if (i0 >= n_dims) {
142-
const int i = row_dst*ne0 + i0;
143-
144-
dst[i + 0] = x[i + 0];
145-
dst[i + 1] = x[i + 1];
146-
147-
return;
148-
}
149-
150137
const int row_x = row_dst % ne1;
151138
const int channel_x = row_dst / ne1;
152139

153140
const int idst = row_dst*ne0 + i0/2;
154141
const int ix = channel_x*s2 + row_x*s1 + i0/2;
155142

143+
if (i0 >= n_dims) {
144+
dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
145+
dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];
146+
147+
return;
148+
}
149+
156150
const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
157151
const int sec_w = sections.v[1] + sections.v[0];
158152
const int sector = (i0 / 2) % sect_dims;

ggml/src/ggml-sycl/rope.cpp

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -47,18 +47,17 @@ static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const
4747

4848
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
4949

50-
if (i0 >= n_dims) {
51-
const int i = row * ne0 + i0;
52-
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
53-
return;
54-
}
55-
5650
const int row0 = row % ne1;
5751
const int channel0 = row / ne1;
5852

5953
const int i = row * ne0 + i0;
6054
const int i2 = channel0 * s2 + row0 * s1 + i0;
6155

56+
if (i0 >= n_dims) {
57+
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2);
58+
return;
59+
}
60+
6261
const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
6362

6463
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
@@ -88,18 +87,17 @@ static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const
8887

8988
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
9089

91-
if (i0 >= n_dims) {
92-
const int i = row * ne0 + i0;
93-
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
94-
return;
95-
}
96-
9790
const int row0 = row % ne1;
9891
const int channel0 = row / ne1;
9992

10093
const int i = row * ne0 + i0 / 2;
10194
const int i2 = channel0 * s2 + row0 * s1 + i0 / 2;
10295

96+
if (i0 >= n_dims) {
97+
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2 + i0 / 2);
98+
return;
99+
}
100+
103101
const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
104102

105103
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
@@ -129,17 +127,16 @@ static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const
129127
}
130128
const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2);
131129

132-
if (i0 >= n_dims) {
133-
const int i = row_dst*ne0 + i0;
134-
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
135-
return;
136-
}
137-
138130
const int row_x = row_dst % ne1;
139131
const int channel_x = row_dst / ne1;
140132
const int idst = (row_dst * ne0) + (i0 / 2);
141133
const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2);
142134

135+
if (i0 >= n_dims) {
136+
*reinterpret_cast<sycl::vec<T, 2> *>(dst + idst + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i0 / 2 + ix);
137+
return;
138+
}
139+
143140
const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
144141
const int sec_w = sections.v[1] + sections.v[0];
145142
const int sector = (i0 / 2) % sect_dims;

tests/test-backend-ops.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5323,12 +5323,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
53235323
for (bool fw : {true, false}) { // fw == forward
53245324
bool all = true;
53255325

5326-
for (float v : { 0, 1 }) {
5327-
for (float fs : { 1.0f, 1.4245f }) {
5328-
for (float ef : { 0.0f, 0.7465f }) {
5329-
for (float af : { 1.0f, 1.4245f }) {
5330-
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
5331-
for (bool ff : {false, true}) { // freq_factors
5326+
for (float fs : { 1.0f, 1.4245f }) {
5327+
for (float ef : { 0.0f, 0.7465f }) {
5328+
for (float af : { 1.0f, 1.4245f }) {
5329+
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
5330+
for (bool ff : {false, true}) { // freq_factors
5331+
for (float v : { 0, 1 }) {
53325332
test_cases.emplace_back(new test_rope(type, {128, 32, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 7B
53335333

53345334
if (all) {
@@ -5341,13 +5341,21 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
53415341
test_cases.emplace_back(new test_rope(type, { 64, 1, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B)
53425342
test_cases.emplace_back(new test_rope(type, { 64, 71, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B)
53435343
test_cases.emplace_back(new test_rope(type, { 64, 8, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B)
5344+
5345+
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 0, 512, fs, ef, af, ff, v, fw));
5346+
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 0, 512, fs, ef, af, ff, v, fw));
5347+
test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, 0, 512, fs, ef, af, ff, v, fw));
5348+
53445349
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 2, 512, fs, ef, af, ff, v, fw)); // neox (stablelm)
53455350
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 2, 512, fs, ef, af, ff, v, fw)); // neox (phi-2)
5351+
test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, 2, 512, fs, ef, af, ff, v, fw)); // neox (phi-2)
53465352
}
53475353

53485354
if (all) {
53495355
test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 2B)
53505356
test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 7B)
5357+
test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 20, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw));
5358+
test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 32, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw));
53515359
test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT)
53525360
}
53535361

0 commit comments

Comments
 (0)