Skip to content

Commit 830e554

Browse files
committed
Merge branch 'master' into compilade/mamba2
2 parents f8c7cae + edc4a29 commit 830e554

17 files changed

+1158
-486
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
499499
GGML_METAL_KERNEL_TYPE_COS,
500500
GGML_METAL_KERNEL_TYPE_NEG,
501501
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
502+
GGML_METAL_KERNEL_TYPE_MEAN,
502503
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
503504
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
504505
GGML_METAL_KERNEL_TYPE_ARGMAX,
@@ -1456,6 +1457,7 @@ @implementation GGMLMetalClass
14561457
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
14571458
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
14581459
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
1460+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
14591461
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
14601462
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
14611463
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
@@ -1655,6 +1657,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
16551657
case GGML_OP_LOG:
16561658
return false; // TODO: implement
16571659
case GGML_OP_SUM_ROWS:
1660+
case GGML_OP_MEAN:
16581661
case GGML_OP_SOFT_MAX:
16591662
case GGML_OP_GROUP_NORM:
16601663
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
@@ -2402,11 +2405,30 @@ static bool ggml_metal_encode_node(
24022405
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
24032406
} break;
24042407
case GGML_OP_SUM_ROWS:
2408+
case GGML_OP_MEAN:
24052409
{
24062410
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
24072411

2408-
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
2412+
id<MTLComputePipelineState> pipeline = nil;
2413+
2414+
switch (dst->op) {
2415+
case GGML_OP_SUM_ROWS:
2416+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
2417+
break;
2418+
case GGML_OP_MEAN:
2419+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MEAN].pipeline;
2420+
break;
2421+
default:
2422+
GGML_ABORT("fatal error");
2423+
}
2424+
2425+
int nth = 32; // SIMD width
2426+
2427+
while (nth < ne00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
2428+
nth *= 2;
2429+
}
24092430

2431+
nth = MIN(nth, ne00);
24102432

24112433
ggml_metal_kargs_sum_rows args = {
24122434
/*.ne00 =*/ ne00,
@@ -2436,11 +2458,12 @@ static bool ggml_metal_encode_node(
24362458
};
24372459

24382460
[encoder setComputePipelineState:pipeline];
2439-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2440-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2441-
[encoder setBytes:&args length:sizeof(args) atIndex:2];
2461+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
2462+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2463+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2464+
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
24422465

2443-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2466+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
24442467
} break;
24452468
case GGML_OP_SOFT_MAX:
24462469
{

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -993,31 +993,61 @@ kernel void kernel_neg(
993993
dst[tpig] = -src0[tpig];
994994
}
995995

996+
template <bool norm>
996997
kernel void kernel_sum_rows(
998+
constant ggml_metal_kargs_sum_rows & args,
997999
device const float * src0,
9981000
device float * dst,
999-
constant ggml_metal_kargs_sum_rows & args,
1000-
uint3 tpig[[thread_position_in_grid]]) {
1001-
int64_t i3 = tpig.z;
1002-
int64_t i2 = tpig.y;
1003-
int64_t i1 = tpig.x;
1001+
threadgroup float * shmem_f32 [[threadgroup(0)]],
1002+
uint3 tgpig[[threadgroup_position_in_grid]],
1003+
ushort3 tpitg[[thread_position_in_threadgroup]],
1004+
ushort sgitg[[simdgroup_index_in_threadgroup]],
1005+
ushort tiisg[[thread_index_in_simdgroup]],
1006+
ushort3 ntg[[threads_per_threadgroup]]) {
1007+
int64_t i3 = tgpig.z;
1008+
int64_t i2 = tgpig.y;
1009+
int64_t i1 = tgpig.x;
10041010

10051011
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
10061012
return;
10071013
}
10081014

1015+
if (sgitg == 0) {
1016+
shmem_f32[tiisg] = 0.0f;
1017+
}
1018+
10091019
device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
10101020
device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
10111021

1012-
float row_sum = 0;
1022+
float sumf = 0;
10131023

1014-
for (int64_t i0 = 0; i0 < args.ne00; i0++) {
1015-
row_sum += src_row[i0];
1024+
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
1025+
sumf += src_row[i0];
10161026
}
10171027

1018-
dst_row[0] = row_sum;
1028+
sumf = simd_sum(sumf);
1029+
1030+
threadgroup_barrier(mem_flags::mem_threadgroup);
1031+
1032+
if (tiisg == 0) {
1033+
shmem_f32[sgitg] = sumf;
1034+
}
1035+
1036+
threadgroup_barrier(mem_flags::mem_threadgroup);
1037+
1038+
sumf = shmem_f32[tiisg];
1039+
sumf = simd_sum(sumf);
1040+
1041+
if (tpitg.x == 0) {
1042+
dst_row[0] = norm ? sumf / args.ne00 : sumf;
1043+
}
10191044
}
10201045

1046+
typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
1047+
1048+
template [[host_name("kernel_sum_rows")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
1049+
template [[host_name("kernel_mean")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
1050+
10211051
template<typename T>
10221052
kernel void kernel_soft_max(
10231053
device const char * src0,

src/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@ add_library(llama
2222
llama-io.cpp
2323
llama-kv-cache-unified.cpp
2424
llama-kv-cache-unified-iswa.cpp
25-
llama-kv-cache-recurrent.cpp
2625
llama-memory.cpp
26+
llama-memory-hybrid.cpp
27+
llama-memory-recurrent.cpp
2728
llama-mmap.cpp
2829
llama-model-loader.cpp
2930
llama-model-saver.cpp

src/llama-arch.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
148148
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
149149
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
150150
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
151+
{ LLM_KV_ATTENTION_LAYER_INDICES, "%s.attention.layer_indices" },
151152

152153
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
153154
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
@@ -1835,3 +1836,26 @@ llm_arch llm_arch_from_string(const std::string & name) {
18351836
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) {
18361837
return LLM_TENSOR_INFOS.at(tensor);
18371838
}
1839+
1840+
bool llm_arch_is_recurrent(const llm_arch & arch) {
1841+
switch (arch) {
1842+
case LLM_ARCH_MAMBA:
1843+
case LLM_ARCH_MAMBA2:
1844+
case LLM_ARCH_RWKV6:
1845+
case LLM_ARCH_RWKV6QWEN2:
1846+
case LLM_ARCH_RWKV7:
1847+
case LLM_ARCH_ARWKV7:
1848+
return true;
1849+
default:
1850+
return false;
1851+
}
1852+
}
1853+
1854+
bool llm_arch_is_hybrid(const llm_arch & arch) {
1855+
// TODO: There are currently no hybrid models! Once there are, this will be
1856+
// the place to identify them
1857+
switch (arch) {
1858+
default:
1859+
return false;
1860+
}
1861+
}

src/llama-arch.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ enum llm_kv {
152152
LLM_KV_ATTENTION_SCALE,
153153
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
154154
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
155+
LLM_KV_ATTENTION_LAYER_INDICES,
155156

156157
LLM_KV_ROPE_DIMENSION_COUNT,
157158
LLM_KV_ROPE_DIMENSION_SECTIONS,
@@ -442,3 +443,6 @@ const char * llm_arch_name(llm_arch arch);
442443
llm_arch llm_arch_from_string(const std::string & name);
443444

444445
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor);
446+
447+
bool llm_arch_is_recurrent(const llm_arch & arch);
448+
bool llm_arch_is_hybrid (const llm_arch & arch);

0 commit comments

Comments
 (0)