File tree Expand file tree Collapse file tree 3 files changed +7
-6
lines changed Expand file tree Collapse file tree 3 files changed +7
-6
lines changed Original file line number Diff line number Diff line change @@ -454,6 +454,8 @@ typedef struct {
454
454
int64_t ne00 ;
455
455
int64_t ne01 ;
456
456
int64_t ne02 ;
457
+ uint64_t nb11 ;
458
+ uint64_t nb12 ;
457
459
float scale ;
458
460
float max_bias ;
459
461
float m0 ;
Original file line number Diff line number Diff line change @@ -2562,10 +2562,7 @@ static bool ggml_metal_encode_node(
2562
2562
memcpy (&scale, ((const int32_t *) dst->op_params ) + 0 , sizeof (scale));
2563
2563
memcpy (&max_bias, ((const int32_t *) dst->op_params ) + 1 , sizeof (max_bias));
2564
2564
2565
- const int64_t nrows_x = ggml_nrows (src0);
2566
- const int64_t nrows_y = src0->ne [1 ];
2567
-
2568
- const uint32_t n_head = nrows_x/nrows_y;
2565
+ const uint32_t n_head = src0->ne [2 ];
2569
2566
const uint32_t n_head_log2 = 1u << (uint32_t ) floorf (log2f ((float ) n_head));
2570
2567
2571
2568
const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
@@ -2625,6 +2622,8 @@ static bool ggml_metal_encode_node(
2625
2622
/* .ne00 =*/ ne00,
2626
2623
/* .ne01 =*/ ne01,
2627
2624
/* .ne02 =*/ ne02,
2625
+ /* .nb11 =*/ nb11,
2626
+ /* .nb12 =*/ nb12,
2628
2627
/* .scale =*/ scale,
2629
2628
/* .max_bias =*/ max_bias,
2630
2629
/* .m0 =*/ m0,
Original file line number Diff line number Diff line change @@ -1263,7 +1263,7 @@ kernel void kernel_soft_max(
1263
1263
const int64_t i01 = (tgpig - i03*args.ne02 *args.ne01 - i02*args.ne01 );
1264
1264
1265
1265
device const float * psrc0 = (device const float *) src0 + (i03*args.ne02 *args.ne01 *args.ne00 + i02*args.ne01 *args.ne00 + i01*args.ne00 );
1266
- device const T * pmask = src1 != src0 ? (device const T *) src1 + i01 *args.ne00 : nullptr ;
1266
+ device const T * pmask = src1 != src0 ? (device const T *) ( src1 + i01*args. nb11 + i03 *args.nb12 ) : nullptr ;
1267
1267
device float * pdst = (device float *) dst + (i03*args.ne02 *args.ne01 *args.ne00 + i02*args.ne01 *args.ne00 + i01*args.ne00 );
1268
1268
1269
1269
float slope = 1 .0f ;
@@ -1359,7 +1359,7 @@ kernel void kernel_soft_max_4(
1359
1359
const int64_t i01 = (tgpig - i03*args.ne02 *args.ne01 - i02*args.ne01 );
1360
1360
1361
1361
device const float4 * psrc4 = (device const float4 *) src0 + (i03*args.ne02 *args.ne01 *args.ne00 + i02*args.ne01 *args.ne00 + i01*args.ne00 )/4 ;
1362
- device const T * pmask = src1 != src0 ? (device const T *) src1 + i01 *args.ne00 / 4 : nullptr ;
1362
+ device const T * pmask = src1 != src0 ? (device const T *) ( src1 + i01*args. nb11 + i03 *args.nb12 ) : nullptr ;
1363
1363
device float4 * pdst4 = (device float4 *) dst + (i03*args.ne02 *args.ne01 *args.ne00 + i02*args.ne01 *args.ne00 + i01*args.ne00 )/4 ;
1364
1364
1365
1365
float slope = 1 .0f ;
You can’t perform that action at this time.
0 commit comments