@@ -40,6 +40,17 @@ struct ggml_kleidiai_context {
40
40
ggml_kleidiai_kernels * kernels;
41
41
} static ctx = { CPU_FEATURE_NONE, NULL };
42
42
43
+ static const char * cpu_feature_to_string (cpu_feature f) {
44
+ switch (f) {
45
+ case CPU_FEATURE_NONE: return " NONE" ;
46
+ case CPU_FEATURE_DOTPROD: return " DOTPROD" ;
47
+ case CPU_FEATURE_I8MM: return " I8MM" ;
48
+ case CPU_FEATURE_SVE: return " SVE" ;
49
+ case CPU_FEATURE_SME: return " SME" ;
50
+ default : return " UNKNOWN" ;
51
+ }
52
+ }
53
+
43
54
static void init_kleidiai_context (void ) {
44
55
45
56
ggml_critical_section_start ();
@@ -62,6 +73,11 @@ static void init_kleidiai_context(void) {
62
73
ctx.features |= ggml_cpu_has_sme () ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
63
74
}
64
75
ctx.kernels = ggml_kleidiai_select_kernels_q4_0 (ctx.features );
76
+ #ifndef NDEBUG
77
+ if (ctx.kernels ) {
78
+ GGML_LOG_DEBUG (" kleidiai: using kernel with CPU feature %s\n " , cpu_feature_to_string (ctx.kernels ->required_cpu ));
79
+ }
80
+ #endif
65
81
}
66
82
ggml_critical_section_end ();
67
83
}
@@ -102,6 +118,9 @@ static void transpose_f32kxn_f16nxk(size_t n, size_t k, float * dst, const uint1
102
118
103
119
class tensor_traits : public ggml ::cpu::tensor_traits {
104
120
bool work_size (int /* n_threads */ , const struct ggml_tensor * op, size_t & size) override {
121
+ if (op->op != GGML_OP_MUL_MAT) {
122
+ return false ;
123
+ }
105
124
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels (ctx.features , op);
106
125
GGML_ASSERT (kernels);
107
126
kernel_info * kernel = op->src [1 ]->ne [1 ] == 1 ? &kernels->gemv : &kernels->gemm ;
@@ -135,6 +154,10 @@ class tensor_traits : public ggml::cpu::tensor_traits {
135
154
} else if (dst->src [0 ]->type == GGML_TYPE_F16) {
136
155
return compute_forward_kv_cache (params, dst);
137
156
}
157
+ } else if (dst->op == GGML_OP_GET_ROWS) {
158
+ if (dst->src [0 ]->type == GGML_TYPE_Q4_0) {
159
+ return compute_forward_get_rows (params, dst);
160
+ }
138
161
}
139
162
return false ;
140
163
}
@@ -270,6 +293,8 @@ class tensor_traits : public ggml::cpu::tensor_traits {
270
293
}
271
294
272
295
bool compute_forward_q4_0 (struct ggml_compute_params * params, struct ggml_tensor * dst) {
296
+ GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_Q4_0);
297
+
273
298
const ggml_tensor * src0 = dst->src [0 ];
274
299
const ggml_tensor * src1 = dst->src [1 ];
275
300
@@ -342,26 +367,62 @@ class tensor_traits : public ggml::cpu::tensor_traits {
342
367
return true ;
343
368
}
344
369
370
+ bool compute_forward_get_rows (struct ggml_compute_params * params, struct ggml_tensor * dst) {
371
+ GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_Q4_0);
372
+ GGML_ASSERT (ctx.kernels );
373
+
374
+ const ggml_tensor * src0 = dst->src [0 ];
375
+ const ggml_tensor * src1 = dst->src [1 ];
376
+
377
+ GGML_TENSOR_BINARY_OP_LOCALS
378
+
379
+ rhs_packing_info * rhs_info = &ctx.kernels ->rhs_info ;
380
+ kernel_info * kernel = &ctx.kernels ->gemm ;
381
+
382
+ const int64_t nc = ne00;
383
+ const int64_t nr = ggml_nelements (src1);
384
+
385
+ const size_t block_rows = kernel->get_nr ();
386
+ const size_t kr = kernel->get_kr ();
387
+
388
+ const size_t num_bytes_multiplier = sizeof (uint16_t );
389
+ const size_t packed_stride = rhs_info->packed_stride (nc, block_rows, kr, QK4_0);
390
+
391
+ const int ith = params->ith ;
392
+ const int nth = params->nth ;
393
+
394
+ const int dr = (nr + nth - 1 ) / nth;
395
+ const int ir0 = dr * ith;
396
+ const int ir1 = MIN (ir0 + dr, nr);
397
+
398
+ for (int64_t i = ir0; i < ir1; ++i) {
399
+ GGML_ASSERT (src1->type == GGML_TYPE_I32);
400
+ int64_t row_idx = ((const int32_t *)src1->data )[i];
401
+ GGML_ASSERT (row_idx >= 0 && row_idx < src0->ne [1 ]);
402
+
403
+ float *out = (float *)((char *)dst->data + i * nb1);
404
+ rhs_info->to_float (src0->data , row_idx, nc, out, block_rows, packed_stride, kr, QK4_0, num_bytes_multiplier);
405
+ }
406
+
407
+ return true ;
408
+ }
409
+
345
410
public:
346
411
int repack (struct ggml_tensor * tensor, const void * data, size_t data_size) {
412
+ GGML_ASSERT (tensor->type == GGML_TYPE_Q4_0);
347
413
GGML_ASSERT (ctx.kernels );
348
414
const size_t n = tensor->ne [1 ];
349
415
const size_t k = tensor->ne [0 ];
350
416
size_t nr = ctx.kernels ->gemm .get_nr ();
351
417
size_t kr = ctx.kernels ->gemm .get_kr ();
352
418
size_t sr = ctx.kernels ->gemm .get_sr ();
353
419
354
- #ifndef NDEBUG
355
- const size_t repacked_size = variant_call<size_t >(ctx.kernels ->rhs_info .packed_size , n, k, nr, kr, QK4_0);
356
- GGML_ASSERT (repacked_size <= data_size && " repacked size larger than the packed size!" );
357
- #endif
358
420
struct kai_rhs_pack_qs4cxs1s0_param params;
359
421
params.lhs_zero_point = 1 ;
360
422
params.rhs_zero_point = 8 ;
361
423
variant_call<void >(ctx.kernels ->rhs_info .pack_func , 1 , n, k, nr, kr, sr, QK4_0, (const uint8_t *)data, nullptr , tensor->data , 0 , ¶ms);
362
424
363
425
return 0 ;
364
-
365
426
GGML_UNUSED (data_size);
366
427
}
367
428
};
@@ -375,8 +436,8 @@ static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struc
375
436
static enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
376
437
tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits (buffer, tensor);
377
438
378
- GGML_UNUSED (buffer);
379
439
return GGML_STATUS_SUCCESS;
440
+ GGML_UNUSED (buffer);
380
441
}
381
442
382
443
static void ggml_backend_cpu_kleidiai_buffer_set_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
@@ -418,18 +479,35 @@ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_b
418
479
GGML_UNUSED (buft);
419
480
}
420
481
482
+ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size (ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
483
+ GGML_ASSERT (tensor->type == GGML_TYPE_Q4_0);
484
+ GGML_ASSERT (ctx.kernels );
485
+
486
+ const size_t n = tensor->ne [1 ];
487
+ const size_t k = tensor->ne [0 ];
488
+ const size_t nr = ctx.kernels ->gemm .get_nr ();
489
+ const size_t kr = ctx.kernels ->gemm .get_kr ();
490
+
491
+ return variant_call<size_t >(ctx.kernels ->rhs_info .packed_size , n, k, nr, kr, QK4_0);
492
+
493
+ GGML_UNUSED (buft);
494
+ }
495
+
421
496
namespace ggml ::cpu::kleidiai {
422
497
class extra_buffer_type : ggml::cpu::extra_buffer_type {
423
498
bool supports_op (ggml_backend_dev_t , const struct ggml_tensor * op) override {
424
- if (op->op == GGML_OP_MUL_MAT &&
499
+ if (( op->op == GGML_OP_MUL_MAT || op-> op == GGML_OP_GET_ROWS) &&
425
500
op->src [0 ]->type == GGML_TYPE_Q4_0 &&
426
501
op->src [0 ]->buffer &&
427
502
(ggml_n_dims (op->src [0 ]) == 2 ) &&
428
503
op->src [0 ]->buffer ->buft == ggml_backend_cpu_kleidiai_buffer_type () && ctx.kernels ) {
504
+ if (op->op == GGML_OP_GET_ROWS && op->src [1 ]->ne [0 ] != 8 ) {
505
+ return false ;
506
+ }
429
507
if (op->src [1 ]->buffer && !ggml_backend_buft_is_host (op->src [1 ]->buffer ->buft )) {
430
508
return false ;
431
509
}
432
- if (op->src [1 ]->type == GGML_TYPE_F32 &&
510
+ if (( op->src [1 ]->type == GGML_TYPE_F32 || op-> src [ 1 ]-> type == GGML_TYPE_I32) &&
433
511
ggml_ne (op->src [1 ], 2 ) == 1 && ggml_ne (op->src [1 ], 3 ) == 1 ) {
434
512
return true ;
435
513
}
@@ -438,7 +516,7 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
438
516
}
439
517
440
518
ggml::cpu::tensor_traits * get_tensor_traits (const struct ggml_tensor * op) override {
441
- if (op->op == GGML_OP_MUL_MAT) {
519
+ if (op->op == GGML_OP_MUL_MAT || op-> op == GGML_OP_GET_ROWS ) {
442
520
if (op->src [0 ]->buffer && op->src [0 ]->buffer ->buft == ggml_backend_cpu_kleidiai_buffer_type ()) {
443
521
return (ggml::cpu::tensor_traits *) op->src [0 ]->extra ;
444
522
}
@@ -469,7 +547,7 @@ ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void) {
469
547
/* .alloc_buffer = */ ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer,
470
548
/* .get_alignment = */ ggml_backend_cpu_kleidiai_buffer_type_get_alignment,
471
549
/* .get_max_size = */ nullptr , // defaults to SIZE_MAX
472
- /* .get_alloc_size = */ nullptr , // defaults to ggml_nbytes
550
+ /* .get_alloc_size = */ ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size,
473
551
/* .is_host = */ nullptr ,
474
552
},
475
553
/* .device = */ ggml_backend_reg_dev_get (ggml_backend_cpu_reg (), 0 ),
0 commit comments