Skip to content

Commit da6019f

Browse files
authored
wasi_nn_llamacpp.c: reject invalid graph and execution context (#4422)
* return valid graph and execution context instead of using stack garbage. (always 0 for now because we don't implement multiple graph/context for this backend.) * validate user-given graph and execution context values. reject invalid ones.
1 parent ebf1404 commit da6019f

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed

core/iwasm/libraries/wasi-nn/src/wasi_nn_llamacpp.c

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,11 @@ __load_by_name_with_configuration(void *ctx, const char *filename, graph *g)
305305
{
306306
struct LlamaContext *backend_ctx = (struct LlamaContext *)ctx;
307307

308+
if (backend_ctx->model != NULL) {
309+
// we only implement a single graph
310+
return unsupported_operation;
311+
}
312+
308313
// make sure backend_ctx->config is initialized
309314

310315
struct llama_model_params model_params =
@@ -323,6 +328,7 @@ __load_by_name_with_configuration(void *ctx, const char *filename, graph *g)
323328
#endif
324329

325330
backend_ctx->model = model;
331+
*g = 0;
326332

327333
return success;
328334
}
@@ -363,6 +369,16 @@ init_execution_context(void *ctx, graph g, graph_execution_context *exec_ctx)
363369
{
364370
struct LlamaContext *backend_ctx = (struct LlamaContext *)ctx;
365371

372+
if (g != 0 || backend_ctx->model == NULL) {
373+
// we only implement a single graph
374+
return runtime_error;
375+
}
376+
377+
if (backend_ctx->ctx != NULL) {
378+
// we only implement a single context
379+
return unsupported_operation;
380+
}
381+
366382
struct llama_context_params ctx_params =
367383
llama_context_params_from_wasi_nn_llama_config(&backend_ctx->config);
368384
struct llama_context *llama_ctx =
@@ -373,6 +389,7 @@ init_execution_context(void *ctx, graph g, graph_execution_context *exec_ctx)
373389
}
374390

375391
backend_ctx->ctx = llama_ctx;
392+
*exec_ctx = 0;
376393

377394
NN_INFO_PRINTF("n_predict = %d, n_ctx = %d", backend_ctx->config.n_predict,
378395
llama_n_ctx(backend_ctx->ctx));
@@ -384,6 +401,12 @@ set_input(void *ctx, graph_execution_context exec_ctx, uint32_t index,
384401
tensor *wasi_nn_tensor)
385402
{
386403
struct LlamaContext *backend_ctx = (struct LlamaContext *)ctx;
404+
405+
if (exec_ctx != 0 || backend_ctx->ctx == NULL) {
406+
// we only implement a single context
407+
return runtime_error;
408+
}
409+
387410
// tensor->data is the prompt string.
388411
char *prompt_text = (char *)wasi_nn_tensor->data.buf;
389412
uint32_t prompt_text_len = wasi_nn_tensor->data.size;
@@ -433,6 +456,11 @@ compute(void *ctx, graph_execution_context exec_ctx)
433456
struct LlamaContext *backend_ctx = (struct LlamaContext *)ctx;
434457
wasi_nn_error ret = runtime_error;
435458

459+
if (exec_ctx != 0 || backend_ctx->ctx == NULL) {
460+
// we only implement a single context
461+
return runtime_error;
462+
}
463+
436464
// reset the generation buffer
437465
if (backend_ctx->generation == NULL) {
438466
backend_ctx->generation =
@@ -554,6 +582,11 @@ get_output(void *ctx, graph_execution_context exec_ctx, uint32_t index,
554582
{
555583
struct LlamaContext *backend_ctx = (struct LlamaContext *)ctx;
556584

585+
if (exec_ctx != 0 || backend_ctx->ctx == NULL) {
586+
// we only implement a single context
587+
return runtime_error;
588+
}
589+
557590
// Compatibility with WasmEdge
558591
if (index > 1) {
559592
NN_ERR_PRINTF("Invalid output index %d", index);

0 commit comments

Comments
 (0)