@@ -1438,7 +1438,8 @@ static bool weight_buft_supported(const whisper_hparams & hparams, ggml_tensor *
1438
1438
op_supported = true ;
1439
1439
} else {
1440
1440
switch (op) {
1441
- // The current extra_buffer_type implementations only support GGML_OP_MUL_MAT
1441
+ // The current extra_buffer_type implementations only support GGML_OP_MUL_MAT and GGML_OP_GET_ROWS
1442
+ case GGML_OP_GET_ROWS:
1442
1443
case GGML_OP_MUL_MAT: {
1443
1444
ggml_init_params params = {
1444
1445
/* .mem_size =*/ 2 * ggml_tensor_overhead (),
@@ -1454,9 +1455,15 @@ static bool weight_buft_supported(const whisper_hparams & hparams, ggml_tensor *
1454
1455
1455
1456
ggml_tensor * op_tensor = nullptr ;
1456
1457
1457
- int64_t n_ctx = hparams.n_audio_ctx ;
1458
- ggml_tensor * b = ggml_new_tensor_4d (ctx, GGML_TYPE_F32, w->ne [0 ], n_ctx, w->ne [2 ], w->ne [3 ]);
1459
- op_tensor = ggml_mul_mat (ctx, w, b);
1458
+ if (op == GGML_OP_MUL_MAT) {
1459
+ int64_t n_ctx = hparams.n_audio_ctx ;
1460
+ ggml_tensor * b = ggml_new_tensor_4d (ctx, GGML_TYPE_F32, w->ne [0 ], n_ctx, w->ne [2 ], w->ne [3 ]);
1461
+ op_tensor = ggml_mul_mat (ctx, w, b);
1462
+ } else if (op == GGML_OP_GET_ROWS) {
1463
+ int64_t num_indices = 8 ;
1464
+ ggml_tensor * indices = ggml_new_tensor_1d (ctx, GGML_TYPE_I32, num_indices);
1465
+ op_tensor = ggml_get_rows (ctx, w, indices);
1466
+ }
1460
1467
1461
1468
// create a temporary dummy buffer for the weight so that supports_op can check the buffer type
1462
1469
GGML_ASSERT (w->buffer == nullptr );
0 commit comments