Skip to content

Commit 39eba35

Browse files
authored
implement swapped variants (cpu/cuda)
1 parent e3d2b20 commit 39eba35

File tree

7 files changed

+117
-45
lines changed

7 files changed

+117
-45
lines changed

ggml/include/ggml.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1100,23 +1100,37 @@ extern "C" {
11001100
// gated linear unit ops
11011101
// A: n columns, r rows,
11021102
// result is n / 2 columns, r rows,
1103+
// expects gate in second half of row, unless swapped is true
11031104
GGML_API struct ggml_tensor * ggml_glu(
11041105
struct ggml_context * ctx,
11051106
struct ggml_tensor * a,
1106-
enum ggml_glu_op op);
1107+
enum ggml_glu_op op,
1108+
bool swapped);
11071109

11081110
GGML_API struct ggml_tensor * ggml_reglu(
11091111
struct ggml_context * ctx,
11101112
struct ggml_tensor * a);
11111113

1114+
GGML_API struct ggml_tensor * ggml_reglu_swapped(
1115+
struct ggml_context * ctx,
1116+
struct ggml_tensor * a);
1117+
11121118
GGML_API struct ggml_tensor * ggml_geglu(
11131119
struct ggml_context * ctx,
11141120
struct ggml_tensor * a);
11151121

1122+
GGML_API struct ggml_tensor * ggml_geglu_swapped(
1123+
struct ggml_context * ctx,
1124+
struct ggml_tensor * a);
1125+
11161126
GGML_API struct ggml_tensor * ggml_swiglu(
11171127
struct ggml_context * ctx,
11181128
struct ggml_tensor * a);
11191129

1130+
GGML_API struct ggml_tensor * ggml_swiglu_swapped(
1131+
struct ggml_context * ctx,
1132+
struct ggml_tensor * a);
1133+
11201134
// normalize along rows
11211135
GGML_API struct ggml_tensor * ggml_norm(
11221136
struct ggml_context * ctx,

ggml/src/ggml-cpu/ops.cpp

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3214,6 +3214,8 @@ static void ggml_compute_forward_reglu_f32(
32143214
GGML_ASSERT(dst->ne[0] == nc);
32153215
GGML_ASSERT(ggml_nrows(dst) == nr);
32163216

3217+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3218+
32173219
// rows per thread
32183220
const int dr = (nr + nth - 1)/nth;
32193221

@@ -3224,7 +3226,8 @@ static void ggml_compute_forward_reglu_f32(
32243226
for (int i1 = ir0; i1 < ir1; i1++) {
32253227
ggml_vec_reglu_f32(nc,
32263228
(float *) ((char *) dst->data + i1*( dst->nb[1])),
3227-
(float *) ((char *) src0->data + i1*(src0->nb[1])));
3229+
(float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0),
3230+
(float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc));
32283231

32293232
#ifndef NDEBUG
32303233
for (int k = 0; k < nc; k++) {
@@ -3255,6 +3258,8 @@ static void ggml_compute_forward_reglu_f16(
32553258
GGML_ASSERT(dst->ne[0] == nc);
32563259
GGML_ASSERT(ggml_nrows(dst) == nr);
32573260

3261+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3262+
32583263
// rows per thread
32593264
const int dr = (nr + nth - 1)/nth;
32603265

@@ -3265,7 +3270,8 @@ static void ggml_compute_forward_reglu_f16(
32653270
for (int i1 = ir0; i1 < ir1; i1++) {
32663271
ggml_vec_reglu_f16(nc,
32673272
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
3268-
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
3273+
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0),
3274+
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc));
32693275

32703276
#ifndef NDEBUG
32713277
for (int k = 0; k < nc; k++) {
@@ -3321,6 +3327,8 @@ static void ggml_compute_forward_geglu_f32(
33213327
GGML_ASSERT(dst->ne[0] == nc);
33223328
GGML_ASSERT(ggml_nrows(dst) == nr);
33233329

3330+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3331+
33243332
// rows per thread
33253333
const int dr = (nr + nth - 1)/nth;
33263334

@@ -3331,7 +3339,8 @@ static void ggml_compute_forward_geglu_f32(
33313339
for (int i1 = ir0; i1 < ir1; i1++) {
33323340
ggml_vec_geglu_f32(nc,
33333341
(float *) ((char *) dst->data + i1*( dst->nb[1])),
3334-
(float *) ((char *) src0->data + i1*(src0->nb[1])));
3342+
(float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0),
3343+
(float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc));
33353344

33363345
#ifndef NDEBUG
33373346
for (int k = 0; k < nc; k++) {
@@ -3362,6 +3371,8 @@ static void ggml_compute_forward_geglu_f16(
33623371
GGML_ASSERT(dst->ne[0] == nc);
33633372
GGML_ASSERT(ggml_nrows(dst) == nr);
33643373

3374+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3375+
33653376
// rows per thread
33663377
const int dr = (nr + nth - 1)/nth;
33673378

@@ -3372,7 +3383,8 @@ static void ggml_compute_forward_geglu_f16(
33723383
for (int i1 = ir0; i1 < ir1; i1++) {
33733384
ggml_vec_geglu_f16(nc,
33743385
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
3375-
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
3386+
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0),
3387+
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc));
33763388

33773389
#ifndef NDEBUG
33783390
for (int k = 0; k < nc; k++) {
@@ -3428,6 +3440,8 @@ static void ggml_compute_forward_swiglu_f32(
34283440
GGML_ASSERT(dst->ne[0] == nc);
34293441
GGML_ASSERT(ggml_nrows(dst) == nr);
34303442

3443+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3444+
34313445
// rows per thread
34323446
const int dr = (nr + nth - 1)/nth;
34333447

@@ -3438,7 +3452,8 @@ static void ggml_compute_forward_swiglu_f32(
34383452
for (int i1 = ir0; i1 < ir1; i1++) {
34393453
ggml_vec_swiglu_f32(nc,
34403454
(float *) ((char *) dst->data + i1*( dst->nb[1])),
3441-
(float *) ((char *) src0->data + i1*(src0->nb[1])));
3455+
(float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0),
3456+
(float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc));
34423457

34433458
#ifndef NDEBUG
34443459
for (int k = 0; k < nc; k++) {
@@ -3469,6 +3484,8 @@ static void ggml_compute_forward_swiglu_f16(
34693484
GGML_ASSERT(dst->ne[0] == nc);
34703485
GGML_ASSERT(ggml_nrows(dst) == nr);
34713486

3487+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3488+
34723489
// rows per thread
34733490
const int dr = (nr + nth - 1)/nth;
34743491

@@ -3479,7 +3496,8 @@ static void ggml_compute_forward_swiglu_f16(
34793496
for (int i1 = ir0; i1 < ir1; i1++) {
34803497
ggml_vec_swiglu_f16(nc,
34813498
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
3482-
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
3499+
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0),
3500+
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc));
34833501

34843502
#ifndef NDEBUG
34853503
for (int k = 0; k < nc; k++) {

ggml/src/ggml-cpu/vec.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -254,27 +254,27 @@ void ggml_vec_silu_f32(const int n, float * y, const float * x) {
254254
}
255255
}
256256

257-
void ggml_vec_swiglu_f32(const int n, float * y, const float * x) {
257+
void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float * g) {
258258
int i = 0;
259259
#if defined(__AVX512F__) && defined(__AVX512DQ__)
260260
for (; i + 15 < n; i += 16) {
261-
_mm512_storeu_ps(y + i, _mm512_mul_ps(ggml_v_silu(_mm512_loadu_ps(x + i)), _mm512_loadu_ps(x + i + n)));
261+
_mm512_storeu_ps(y + i, _mm512_mul_ps(ggml_v_silu(_mm512_loadu_ps(x + i)), _mm512_loadu_ps(g + i)));
262262
}
263263
#elif defined(__AVX2__) && defined(__FMA__)
264264
for (; i + 7 < n; i += 8) {
265-
_mm256_storeu_ps(y + i, _mm256_mul_ps(ggml_v_silu(_mm256_loadu_ps(x + i)), _mm256_loadu_ps(x + i + n)));
265+
_mm256_storeu_ps(y + i, _mm256_mul_ps(ggml_v_silu(_mm256_loadu_ps(x + i)), _mm256_loadu_ps(g + i)));
266266
}
267267
#elif defined(__SSE2__)
268268
for (; i + 3 < n; i += 4) {
269-
_mm_storeu_ps(y + i, _mm_mul_ps(ggml_v_silu(_mm_loadu_ps(x + i)), _mm_loadu_ps(x + i + n)));
269+
_mm_storeu_ps(y + i, _mm_mul_ps(ggml_v_silu(_mm_loadu_ps(x + i)), _mm_loadu_ps(g + i)));
270270
}
271271
#elif defined(__ARM_NEON) && defined(__aarch64__)
272272
for (; i + 3 < n; i += 4) {
273-
vst1q_f32(y + i, vmulq_f32(ggml_v_silu(vld1q_f32(x + i)), vld1q_f32(x + i + n)));
273+
vst1q_f32(y + i, vmulq_f32(ggml_v_silu(vld1q_f32(x + i)), vld1q_f32(g + i)));
274274
}
275275
#endif
276276
for (; i < n; ++i) {
277-
y[i] = ggml_silu_f32(x[i]) * x[i + n];
277+
y[i] = ggml_silu_f32(x[i]) * g[i];
278278
}
279279
}
280280

ggml/src/ggml-cpu/vec.h

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -905,57 +905,57 @@ inline static void ggml_vec_silu_backward_f16(const int n, ggml_fp16_t * dx, con
905905
}
906906
}
907907

908-
inline static void ggml_vec_reglu_f32 (const int n, float * y, const float * x) {
908+
inline static void ggml_vec_reglu_f32 (const int n, float * y, const float * x, const float * g) {
909909
for (int i = 0; i < n; ++i) {
910-
y[i] = (x[i] > 0.f) ? x[i] * x[i + n] : 0.f;
910+
y[i] = (x[i] > 0.f) ? x[i] * g[i] : 0.f;
911911
}
912912
}
913913

914-
inline static void ggml_vec_reglu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
914+
inline static void ggml_vec_reglu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
915915
for (int i = 0; i < n; ++i) {
916916
float v = GGML_FP16_TO_FP32(x[i]);
917-
y[i] = GGML_FP32_TO_FP16((v > 0.f) ? v * GGML_FP16_TO_FP32(x[i + n]) : 0.f);
917+
y[i] = GGML_FP32_TO_FP16((v > 0.f) ? v * GGML_FP16_TO_FP32(g[i]) : 0.f);
918918
}
919919
}
920920

921921
#ifdef GGML_GELU_FP16
922-
inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x) {
922+
inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x, const float * g) {
923923
uint16_t t;
924924
for (int i = 0; i < n; ++i) {
925925
if (x[i] <= -10.0f) {
926926
y[i] = 0.0f;
927927
} else if (x[i] >= 10.0f) {
928-
y[i] = x[i] * x[i + n];
928+
y[i] = x[i] * g[i];
929929
} else {
930930
ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
931931
memcpy(&t, &fp16, sizeof(uint16_t));
932-
y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t]) * x[i + n];
932+
y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t]) * g[i];
933933
}
934934
}
935935
}
936936
#else
937-
inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x) {
937+
inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x, const float * g) {
938938
for (int i = 0; i < n; ++i) {
939-
y[i] = ggml_gelu_f32(x[i]) * x[i + n];
939+
y[i] = ggml_gelu_f32(x[i]) * g[i];
940940
}
941941
}
942942
#endif
943943

944-
inline static void ggml_vec_geglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
944+
inline static void ggml_vec_geglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
945945
const uint16_t * i16 = (const uint16_t *) x;
946946
for (int i = 0; i < n; ++i) {
947-
float g = GGML_FP16_TO_FP32(x[i + n]);
948-
y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(ggml_table_gelu_f16[i16[i]]) * g);
947+
float v = GGML_FP16_TO_FP32(g[i]);
948+
y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(ggml_table_gelu_f16[i16[i]]) * v);
949949
}
950950
}
951951

952-
void ggml_vec_swiglu_f32(const int n, float * y, const float * x);
952+
void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float * g);
953953

954-
inline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
954+
inline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
955955
for (int i = 0; i < n; ++i) {
956956
float v = GGML_FP16_TO_FP32(x[i]);
957-
float g = GGML_FP16_TO_FP32(x[i + n]);
958-
y[i] = GGML_FP32_TO_FP16((v/(1.0f + expf(-v))) * g);
957+
float w = GGML_FP16_TO_FP32(g[i]);
958+
y[i] = GGML_FP32_TO_FP16((v/(1.0f + expf(-v))) * w);
959959
}
960960
}
961961

ggml/src/ggml-cuda/unary.cu

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
199199
/* gated ops */
200200

201201
template <float (*op)(float), typename T>
202-
static __global__ void unary_gated_op_kernel(const T * x, T * dst, const int64_t k, const int64_t n, const int64_t o) {
202+
static __global__ void unary_gated_op_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o) {
203203
const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x;
204204

205205
if (i >= k) {
@@ -208,13 +208,13 @@ static __global__ void unary_gated_op_kernel(const T * x, T * dst, const int64_t
208208

209209
// perform base op on first half of row and multiply with gate in second half
210210
const int64_t j = (i / n) * o + (i % n);
211-
dst[i] = (T)(op((float)x[j]) * (float)x[j + n]);
211+
dst[i] = (T)(op((float)x[j]) * (float)g[j]);
212212
}
213213

214214
template <float (*op)(float), typename T>
215-
static void unary_gated_cuda(const T * x, T * dst, const int64_t k, const int64_t n, const int64_t o, cudaStream_t stream) {
215+
static void unary_gated_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o, cudaStream_t stream) {
216216
const int64_t num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE;
217-
unary_gated_op_kernel<op><<<num_blocks, CUDA_GLU_BLOCK_SIZE, 0, stream>>>(x, dst, k, n, o);
217+
unary_gated_op_kernel<op><<<num_blocks, CUDA_GLU_BLOCK_SIZE, 0, stream>>>(x, g, dst, k, n, o);
218218
}
219219

220220
template <float (*op)(float)>
@@ -235,10 +235,26 @@ void ggml_cuda_op_unary_gated(ggml_backend_cuda_context & ctx, ggml_tensor * dst
235235
GGML_ASSERT(dst->ne[0] == nc);
236236
GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0));
237237

238+
const int32_t swapped = ((const int32_t *) dst->op_params)[1];
239+
238240
if (src0->type == GGML_TYPE_F16) {
239-
unary_gated_cuda<op>((const half *)src0_d, (half *)dst_d, ggml_nelements(dst), nc, src0->nb[1] / sizeof(half), stream);
241+
unary_gated_cuda<op>(
242+
(const half *)src0_d + (swapped ? nc : 0),
243+
(const half *)src0_d + (swapped ? 0 : nc),
244+
(half *)dst_d,
245+
ggml_nelements(dst),
246+
nc,
247+
src0->nb[1] / sizeof(half),
248+
stream);
240249
} else {
241-
unary_gated_cuda<op>((const float *)src0_d, (float *)dst_d, ggml_nelements(dst), nc, src0->nb[1] / sizeof(float), stream);
250+
unary_gated_cuda<op>(
251+
(const float *)src0_d + (swapped ? nc : 0),
252+
(const float *)src0_d + (swapped ? 0 : nc),
253+
(float *)dst_d,
254+
ggml_nelements(dst),
255+
nc,
256+
src0->nb[1] / sizeof(float),
257+
stream);
242258
}
243259
}
244260

ggml/src/ggml.c

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2647,13 +2647,15 @@ struct ggml_tensor * ggml_exp_inplace(
26472647
struct ggml_tensor * ggml_glu(
26482648
struct ggml_context * ctx,
26492649
struct ggml_tensor * a,
2650-
enum ggml_glu_op op) {
2650+
enum ggml_glu_op op,
2651+
bool swapped) {
26512652
GGML_ASSERT(ggml_is_contiguous_1(a));
26522653

26532654
int64_t ne[GGML_MAX_DIMS] = { a->ne[0] / 2 }; for (int i = 1; i < GGML_MAX_DIMS; i++) ne[i] = a->ne[i];
26542655
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, ne, NULL, 0);
26552656

26562657
ggml_set_op_params_i32(result, 0, (int32_t) op);
2658+
ggml_set_op_params_i32(result, 1, (int32_t) swapped);
26572659

26582660
result->op = GGML_OP_GLU;
26592661
result->src[0] = a;
@@ -2666,23 +2668,41 @@ struct ggml_tensor * ggml_glu(
26662668
struct ggml_tensor * ggml_reglu(
26672669
struct ggml_context * ctx,
26682670
struct ggml_tensor * a) {
2669-
return ggml_glu(ctx, a, GGML_GLU_OP_REGLU);
2671+
return ggml_glu(ctx, a, GGML_GLU_OP_REGLU, false);
2672+
}
2673+
2674+
struct ggml_tensor * ggml_reglu_swapped(
2675+
struct ggml_context * ctx,
2676+
struct ggml_tensor * a) {
2677+
return ggml_glu(ctx, a, GGML_GLU_OP_REGLU, true);
26702678
}
26712679

26722680
// ggml_geglu
26732681

26742682
struct ggml_tensor * ggml_geglu(
26752683
struct ggml_context * ctx,
26762684
struct ggml_tensor * a) {
2677-
return ggml_glu(ctx, a, GGML_GLU_OP_GEGLU);
2685+
return ggml_glu(ctx, a, GGML_GLU_OP_GEGLU, false);
2686+
}
2687+
2688+
struct ggml_tensor * ggml_geglu_swapped(
2689+
struct ggml_context * ctx,
2690+
struct ggml_tensor * a) {
2691+
return ggml_glu(ctx, a, GGML_GLU_OP_GEGLU, true);
26782692
}
26792693

26802694
// ggml_swiglu
26812695

26822696
struct ggml_tensor * ggml_swiglu(
26832697
struct ggml_context * ctx,
26842698
struct ggml_tensor * a) {
2685-
return ggml_glu(ctx, a, GGML_GLU_OP_SWIGLU);
2699+
return ggml_glu(ctx, a, GGML_GLU_OP_SWIGLU, false);
2700+
}
2701+
2702+
struct ggml_tensor * ggml_swiglu_swapped(
2703+
struct ggml_context * ctx,
2704+
struct ggml_tensor * a) {
2705+
return ggml_glu(ctx, a, GGML_GLU_OP_SWIGLU, true);
26862706
}
26872707

26882708
// ggml_norm

0 commit comments

Comments
 (0)