Skip to content

Commit 7c6487b

Browse files
committed
metal : extend ggml_soft_max_ext() to support n_seq dim
1 parent 401c13e commit 7c6487b

File tree

3 files changed

+7
-6
lines changed

3 files changed

+7
-6
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,8 @@ typedef struct {
454454
int64_t ne00;
455455
int64_t ne01;
456456
int64_t ne02;
457+
uint64_t nb11;
458+
uint64_t nb12;
457459
float scale;
458460
float max_bias;
459461
float m0;

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2562,10 +2562,7 @@ static bool ggml_metal_encode_node(
25622562
memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale));
25632563
memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias));
25642564

2565-
const int64_t nrows_x = ggml_nrows(src0);
2566-
const int64_t nrows_y = src0->ne[1];
2567-
2568-
const uint32_t n_head = nrows_x/nrows_y;
2565+
const uint32_t n_head = src0->ne[2];
25692566
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
25702567

25712568
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
@@ -2625,6 +2622,8 @@ static bool ggml_metal_encode_node(
26252622
/*.ne00 =*/ ne00,
26262623
/*.ne01 =*/ ne01,
26272624
/*.ne02 =*/ ne02,
2625+
/*.nb11 =*/ nb11,
2626+
/*.nb12 =*/ nb12,
26282627
/*.scale =*/ scale,
26292628
/*.max_bias =*/ max_bias,
26302629
/*.m0 =*/ m0,

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1263,7 +1263,7 @@ kernel void kernel_soft_max(
12631263
const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01);
12641264

12651265
device const float * psrc0 = (device const float *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00);
1266-
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00 : nullptr;
1266+
device const T * pmask = src1 != src0 ? (device const T *) (src1 + i01*args.nb11 + i03*args.nb12) : nullptr;
12671267
device float * pdst = (device float *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00);
12681268

12691269
float slope = 1.0f;
@@ -1359,7 +1359,7 @@ kernel void kernel_soft_max_4(
13591359
const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01);
13601360

13611361
device const float4 * psrc4 = (device const float4 *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4;
1362-
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00/4 : nullptr;
1362+
device const T * pmask = src1 != src0 ? (device const T *) (src1 + i01*args.nb11 + i03*args.nb12) : nullptr;
13631363
device float4 * pdst4 = (device float4 *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4;
13641364

13651365
float slope = 1.0f;

0 commit comments

Comments
 (0)