Skip to content

Commit 48f1061

Browse files
committed
metal : opt add
ggml-ci
1 parent 1881217 commit 48f1061

File tree

2 files changed

+15
-8
lines changed

2 files changed

+15
-8
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2095,6 +2095,9 @@ static int ggml_metal_encode_node(
20952095
GGML_ASSERT(src0t == GGML_TYPE_F32);
20962096
GGML_ASSERT(src1t == GGML_TYPE_F32);
20972097

2098+
GGML_ASSERT(ggml_is_contiguous_rows(src0));
2099+
GGML_ASSERT(ggml_is_contiguous_rows(src1));
2100+
20982101
const size_t offs = 0;
20992102

21002103
bool bcast_row = false;
@@ -2234,7 +2237,11 @@ static int ggml_metal_encode_node(
22342237

22352238
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
22362239
} else {
2237-
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
2240+
int nth = 32;
2241+
2242+
while (16*nth < ne0 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
2243+
nth *= 2;
2244+
}
22382245

22392246
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
22402247
}

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -849,25 +849,25 @@ kernel void kernel_add_fuse_impl(
849849
const int i12 = i02%args.ne12;
850850
const int i11 = i01%args.ne11;
851851

852-
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
853-
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
852+
device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs);
853+
device float * dst_ptr = (device float *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs);
854854

855-
device const char * src1_ptr[F];
855+
device const float * src1_ptr[F];
856856
for (short j = 0; j < F; ++j) {
857-
src1_ptr[j] = src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
857+
src1_ptr[j] = (device const float *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
858858
}
859859

860860
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
861861
const int i10 = i0%args.ne10;
862862

863-
float res = *((device float *)(src0_ptr + i0*args.nb00));
863+
float res = src0_ptr[i0];
864864

865865
#pragma unroll
866866
for (short j = 0; j < F; ++j) {
867-
res += *((device float *)(src1_ptr[j] + i10*args.nb10));
867+
res += src1_ptr[j][i10];
868868
}
869869

870-
*((device float *)(dst_ptr + i0*args.nb0)) = res;
870+
dst_ptr[i0] = res;
871871
}
872872
}
873873

0 commit comments

Comments
 (0)