Skip to content

Commit 032697b

Browse files
authored
whisper: validate get_rows support for cpu extra buffer (#3323)
1 parent a16da91 commit 032697b

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

src/whisper.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1438,7 +1438,8 @@ static bool weight_buft_supported(const whisper_hparams & hparams, ggml_tensor *
14381438
op_supported = true;
14391439
} else {
14401440
switch (op) {
1441-
// The current extra_buffer_type implementations only support GGML_OP_MUL_MAT
1441+
// The current extra_buffer_type implementations only support GGML_OP_MUL_MAT and GGML_OP_GET_ROWS
1442+
case GGML_OP_GET_ROWS:
14421443
case GGML_OP_MUL_MAT: {
14431444
ggml_init_params params = {
14441445
/*.mem_size =*/ 2 * ggml_tensor_overhead(),
@@ -1454,9 +1455,15 @@ static bool weight_buft_supported(const whisper_hparams & hparams, ggml_tensor *
14541455

14551456
ggml_tensor * op_tensor = nullptr;
14561457

1457-
int64_t n_ctx = hparams.n_audio_ctx;
1458-
ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]);
1459-
op_tensor = ggml_mul_mat(ctx, w, b);
1458+
if (op == GGML_OP_MUL_MAT) {
1459+
int64_t n_ctx = hparams.n_audio_ctx;
1460+
ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]);
1461+
op_tensor = ggml_mul_mat(ctx, w, b);
1462+
} else if (op == GGML_OP_GET_ROWS) {
1463+
int64_t num_indices = 8;
1464+
ggml_tensor * indices = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, num_indices);
1465+
op_tensor = ggml_get_rows(ctx, w, indices);
1466+
}
14601467

14611468
// create a temporary dummy buffer for the weight so that supports_op can check the buffer type
14621469
GGML_ASSERT(w->buffer == nullptr);

0 commit comments

Comments
 (0)