Skip to content

Commit 9220426

Browse files
authored
kleidiai: add support for get_rows (ggml-org#14676)
* kleidiai: add support for get_rows * apply fixes based on code review * apply more fixes based on code review
1 parent 2ba1333 commit 9220426

File tree

4 files changed

+202
-24
lines changed

4 files changed

+202
-24
lines changed

ggml/src/ggml-cpu/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -494,9 +494,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
494494

495495
# Fetch KleidiAI sources:
496496
include(FetchContent)
497-
set(KLEIDIAI_COMMIT_TAG "v1.9.0")
497+
set(KLEIDIAI_COMMIT_TAG "v1.11.0")
498498
set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
499-
set(KLEIDIAI_ARCHIVE_MD5 "2a8e1bb55d201557553545536489a017")
499+
set(KLEIDIAI_ARCHIVE_MD5 "3fe9e5ab964c375c53839296eb71eaa2")
500500

501501
if (POLICY CMP0135)
502502
cmake_policy(SET CMP0135 NEW)

ggml/src/ggml-cpu/kleidiai/kernels.cpp

Lines changed: 109 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,94 @@
2222

2323
#include "kai_common.h"
2424

25+
#include "simd-mappings.h"
26+
2527
#include "kernels.h"
2628

2729
#define NELEMS(x) sizeof(x) / sizeof(*x)
30+
31+
static const size_t INT4_PER_BYTE = 2;
32+
static const size_t INT4_BITS = 4;
33+
static const int Q4_0_ZERO_POINT = 8;
34+
const size_t INT4_PER_UINT16 = 4;
35+
36+
static void dequantize_row_qsi4c32pscalef16(
37+
const void *packed_data,
38+
int32_t row_idx,
39+
int64_t nc,
40+
float *out,
41+
size_t nr_pack,
42+
size_t packed_row_stride,
43+
size_t kr,
44+
size_t bl,
45+
size_t num_bytes_multiplier
46+
) {
47+
size_t group_idx = row_idx / nr_pack;
48+
size_t row_in_group = row_idx % nr_pack;
49+
const uint8_t *packed_group = (const uint8_t *)packed_data + group_idx * packed_row_stride;
50+
size_t num_blocks = nc / bl;
51+
const uint8_t *block_ptr = packed_group;
52+
53+
for (size_t b = 0; b < num_blocks; ++b) {
54+
uint16_t scale_f16 = *((const uint16_t *)(block_ptr + row_in_group * num_bytes_multiplier));
55+
float scale = GGML_CPU_FP16_TO_FP32(scale_f16);
56+
57+
const uint8_t *segment_ptr = block_ptr + nr_pack * num_bytes_multiplier;
58+
size_t num_segments = bl / kr;
59+
size_t num_bytes_per_segment = kr / INT4_PER_BYTE;
60+
61+
for (size_t s = 0; s < num_segments; ++s) {
62+
const uint8_t *seg_base = segment_ptr + s * nr_pack * num_bytes_per_segment;
63+
const uint8_t *qbytes = seg_base + row_in_group * num_bytes_per_segment;
64+
for (size_t k = 0; k < num_bytes_per_segment; ++k) {
65+
uint8_t byte = qbytes[k] ^ 0x88;
66+
int x0 = (byte & 0x0F) - Q4_0_ZERO_POINT;
67+
int x1 = (byte >> INT4_BITS) - Q4_0_ZERO_POINT;
68+
out[b * bl + s * num_bytes_per_segment + k] = x0 * scale;
69+
out[b * bl + s * num_bytes_per_segment + k + bl/2] = x1 * scale;
70+
}
71+
}
72+
block_ptr += nr_pack * num_bytes_multiplier + num_segments * nr_pack * num_bytes_per_segment;
73+
}
74+
}
75+
76+
static void dequantize_row_qsi4c32ps1s0scalef16(
77+
const void *packed_data,
78+
int32_t row_idx,
79+
int64_t k,
80+
float *out,
81+
size_t nr,
82+
size_t packed_row_stride,
83+
size_t kr,
84+
size_t bl,
85+
size_t num_bytes_multiplier
86+
) {
87+
const size_t num_blocks = k / bl;
88+
const size_t bl4 = bl / INT4_PER_UINT16;
89+
90+
size_t group_idx = row_idx / nr;
91+
size_t row_in_group = row_idx % nr;
92+
93+
const uint8_t *packed_group = (const uint8_t *)packed_data + group_idx * packed_row_stride;
94+
const uint16_t *qdata = (const uint16_t *)packed_group;
95+
const uint16_t *scales = (const uint16_t *)(packed_group + packed_row_stride - (nr * num_blocks * num_bytes_multiplier));
96+
97+
for (size_t block_idx = 0; block_idx < num_blocks; ++block_idx) {
98+
uint16_t scale_f16 = scales[row_in_group + block_idx * nr];
99+
float scale = GGML_CPU_FP16_TO_FP32(scale_f16);
100+
101+
for (size_t bl4_idx = 0; bl4_idx < bl4; ++bl4_idx) {
102+
uint16_t q = qdata[(block_idx * bl4 + bl4_idx) * nr + row_in_group];
103+
104+
for (size_t qidx = 0; qidx < INT4_PER_UINT16; ++qidx) {
105+
int v = ((q >> (qidx * 4)) & 0xF) - Q4_0_ZERO_POINT;
106+
out[block_idx * bl + bl4_idx * INT4_BITS + qidx] = v * scale;
107+
}
108+
}
109+
}
110+
GGML_UNUSED(kr);
111+
}
112+
28113
static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
29114
#if defined(__ARM_FEATURE_SME)
30115
{
@@ -63,8 +148,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
63148
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon,
64149
},
65150
/* .rhs_info = */ {
66-
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
67-
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
151+
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
152+
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
153+
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
154+
/* .to_float = */ dequantize_row_qsi4c32ps1s0scalef16,
68155
},
69156
/* .required_cpu = */ CPU_FEATURE_SME,
70157
/* .lhs_type = */ GGML_TYPE_F32,
@@ -107,8 +194,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
107194
/* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme,
108195
},
109196
/* .rhs_info = */ {
110-
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
111-
/* .pack_func = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
197+
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
198+
/* .packed_stride = */ NULL,
199+
/* .pack_func = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
200+
/* .to_float = */ NULL,
112201
},
113202
/* .required_cpu = */ CPU_FEATURE_SME,
114203
/* .lhs_type = */ GGML_TYPE_F32,
@@ -154,8 +243,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
154243
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
155244
},
156245
/* .rhs_info = */ {
157-
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
158-
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
246+
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
247+
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
248+
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
249+
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
159250
},
160251
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
161252
/* .lhs_type = */ GGML_TYPE_F32,
@@ -200,8 +291,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
200291
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
201292
},
202293
/* .rhs_info = */ {
203-
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
204-
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
294+
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
295+
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
296+
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
297+
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
205298
},
206299
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
207300
/* .lhs_type = */ GGML_TYPE_F32,
@@ -247,8 +340,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
247340
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
248341
},
249342
/* .rhs_info = */ {
250-
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
251-
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
343+
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
344+
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
345+
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
346+
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
252347
},
253348
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
254349
/* .lhs_type = */ GGML_TYPE_F32,
@@ -293,8 +388,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
293388
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
294389
},
295390
/* .rhs_info = */ {
296-
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
297-
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
391+
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
392+
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
393+
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
394+
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
298395
},
299396
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
300397
/* .lhs_type = */ GGML_TYPE_F32,

ggml/src/ggml-cpu/kleidiai/kernels.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,15 @@ struct rhs_packing_info {
7171
std::function<size_t(size_t n, size_t k, size_t nr, size_t kr, size_t bl)>,
7272
std::function<size_t(size_t n, size_t k)>
7373
> packed_size;
74+
size_t (*packed_stride)(size_t k, size_t nr, size_t kr, size_t bl);
7475
std::variant<
7576
std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs,
7677
const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params)>,
7778
std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs,
7879
const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params)>
7980
> pack_func;
81+
void (*to_float)(const void *packed_data, int32_t row_idx, int64_t nc, float *out, size_t nr_pack, size_t packed_row_stride,
82+
size_t kr, size_t bl, size_t num_bytes_multiplier);
8083
};
8184

8285
struct ggml_kleidiai_kernels {

ggml/src/ggml-cpu/kleidiai/kleidiai.cpp

Lines changed: 88 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,17 @@ struct ggml_kleidiai_context {
4040
ggml_kleidiai_kernels * kernels;
4141
} static ctx = { CPU_FEATURE_NONE, NULL };
4242

43+
static const char* cpu_feature_to_string(cpu_feature f) {
44+
switch (f) {
45+
case CPU_FEATURE_NONE: return "NONE";
46+
case CPU_FEATURE_DOTPROD: return "DOTPROD";
47+
case CPU_FEATURE_I8MM: return "I8MM";
48+
case CPU_FEATURE_SVE: return "SVE";
49+
case CPU_FEATURE_SME: return "SME";
50+
default: return "UNKNOWN";
51+
}
52+
}
53+
4354
static void init_kleidiai_context(void) {
4455

4556
ggml_critical_section_start();
@@ -62,6 +73,11 @@ static void init_kleidiai_context(void) {
6273
ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
6374
}
6475
ctx.kernels = ggml_kleidiai_select_kernels_q4_0(ctx.features);
76+
#ifndef NDEBUG
77+
if (ctx.kernels) {
78+
GGML_LOG_DEBUG("kleidiai: using kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels->required_cpu));
79+
}
80+
#endif
6581
}
6682
ggml_critical_section_end();
6783
}
@@ -102,6 +118,9 @@ static void transpose_f32kxn_f16nxk(size_t n, size_t k, float * dst, const uint1
102118

103119
class tensor_traits : public ggml::cpu::tensor_traits {
104120
bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
121+
if (op->op != GGML_OP_MUL_MAT) {
122+
return false;
123+
}
105124
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
106125
GGML_ASSERT(kernels);
107126
kernel_info * kernel = op->src[1]->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
@@ -135,6 +154,10 @@ class tensor_traits : public ggml::cpu::tensor_traits {
135154
} else if (dst->src[0]->type == GGML_TYPE_F16) {
136155
return compute_forward_kv_cache(params, dst);
137156
}
157+
} else if (dst->op == GGML_OP_GET_ROWS) {
158+
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
159+
return compute_forward_get_rows(params, dst);
160+
}
138161
}
139162
return false;
140163
}
@@ -270,6 +293,8 @@ class tensor_traits : public ggml::cpu::tensor_traits {
270293
}
271294

272295
bool compute_forward_q4_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
296+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
297+
273298
const ggml_tensor * src0 = dst->src[0];
274299
const ggml_tensor * src1 = dst->src[1];
275300

@@ -342,26 +367,62 @@ class tensor_traits : public ggml::cpu::tensor_traits {
342367
return true;
343368
}
344369

370+
bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) {
371+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
372+
GGML_ASSERT(ctx.kernels);
373+
374+
const ggml_tensor * src0 = dst->src[0];
375+
const ggml_tensor * src1 = dst->src[1];
376+
377+
GGML_TENSOR_BINARY_OP_LOCALS
378+
379+
rhs_packing_info * rhs_info = &ctx.kernels->rhs_info;
380+
kernel_info * kernel = &ctx.kernels->gemm;
381+
382+
const int64_t nc = ne00;
383+
const int64_t nr = ggml_nelements(src1);
384+
385+
const size_t block_rows = kernel->get_nr();
386+
const size_t kr = kernel->get_kr();
387+
388+
const size_t num_bytes_multiplier = sizeof(uint16_t);
389+
const size_t packed_stride = rhs_info->packed_stride(nc, block_rows, kr, QK4_0);
390+
391+
const int ith = params->ith;
392+
const int nth = params->nth;
393+
394+
const int dr = (nr + nth - 1) / nth;
395+
const int ir0 = dr * ith;
396+
const int ir1 = MIN(ir0 + dr, nr);
397+
398+
for (int64_t i = ir0; i < ir1; ++i) {
399+
GGML_ASSERT(src1->type == GGML_TYPE_I32);
400+
int64_t row_idx = ((const int32_t *)src1->data)[i];
401+
GGML_ASSERT(row_idx >= 0 && row_idx < src0->ne[1]);
402+
403+
float *out = (float *)((char *)dst->data + i * nb1);
404+
rhs_info->to_float(src0->data, row_idx, nc, out, block_rows, packed_stride, kr, QK4_0, num_bytes_multiplier);
405+
}
406+
407+
return true;
408+
}
409+
345410
public:
346411
int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) {
412+
GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0);
347413
GGML_ASSERT(ctx.kernels);
348414
const size_t n = tensor->ne[1];
349415
const size_t k = tensor->ne[0];
350416
size_t nr = ctx.kernels->gemm.get_nr();
351417
size_t kr = ctx.kernels->gemm.get_kr();
352418
size_t sr = ctx.kernels->gemm.get_sr();
353419

354-
#ifndef NDEBUG
355-
const size_t repacked_size = variant_call<size_t>(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0);
356-
GGML_ASSERT(repacked_size <= data_size && "repacked size larger than the packed size!");
357-
#endif
358420
struct kai_rhs_pack_qs4cxs1s0_param params;
359421
params.lhs_zero_point = 1;
360422
params.rhs_zero_point = 8;
361423
variant_call<void>(ctx.kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, QK4_0, (const uint8_t*)data, nullptr, tensor->data, 0, &params);
362424

363425
return 0;
364-
365426
GGML_UNUSED(data_size);
366427
}
367428
};
@@ -375,8 +436,8 @@ static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struc
375436
static enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
376437
tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor);
377438

378-
GGML_UNUSED(buffer);
379439
return GGML_STATUS_SUCCESS;
440+
GGML_UNUSED(buffer);
380441
}
381442

382443
static void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
@@ -418,18 +479,35 @@ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_b
418479
GGML_UNUSED(buft);
419480
}
420481

482+
static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
483+
GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0);
484+
GGML_ASSERT(ctx.kernels);
485+
486+
const size_t n = tensor->ne[1];
487+
const size_t k = tensor->ne[0];
488+
const size_t nr = ctx.kernels->gemm.get_nr();
489+
const size_t kr = ctx.kernels->gemm.get_kr();
490+
491+
return variant_call<size_t>(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0);
492+
493+
GGML_UNUSED(buft);
494+
}
495+
421496
namespace ggml::cpu::kleidiai {
422497
class extra_buffer_type : ggml::cpu::extra_buffer_type {
423498
bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
424-
if (op->op == GGML_OP_MUL_MAT &&
499+
if ((op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) &&
425500
op->src[0]->type == GGML_TYPE_Q4_0 &&
426501
op->src[0]->buffer &&
427502
(ggml_n_dims(op->src[0]) == 2) &&
428503
op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) {
504+
if (op->op == GGML_OP_GET_ROWS && op->src[1]->ne[0] != 8) {
505+
return false;
506+
}
429507
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
430508
return false;
431509
}
432-
if (op->src[1]->type == GGML_TYPE_F32 &&
510+
if ((op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_I32) &&
433511
ggml_ne(op->src[1], 2) == 1 && ggml_ne(op->src[1], 3) == 1) {
434512
return true;
435513
}
@@ -438,7 +516,7 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
438516
}
439517

440518
ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
441-
if (op->op == GGML_OP_MUL_MAT) {
519+
if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) {
442520
if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
443521
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
444522
}
@@ -469,7 +547,7 @@ ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void) {
469547
/* .alloc_buffer = */ ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer,
470548
/* .get_alignment = */ ggml_backend_cpu_kleidiai_buffer_type_get_alignment,
471549
/* .get_max_size = */ nullptr, // defaults to SIZE_MAX
472-
/* .get_alloc_size = */ nullptr, // defaults to ggml_nbytes
550+
/* .get_alloc_size = */ ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size,
473551
/* .is_host = */ nullptr,
474552
},
475553
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),

0 commit comments

Comments
 (0)