@@ -305,6 +305,11 @@ __load_by_name_with_configuration(void *ctx, const char *filename, graph *g)
305
305
{
306
306
struct LlamaContext * backend_ctx = (struct LlamaContext * )ctx ;
307
307
308
+ if (backend_ctx -> model != NULL ) {
309
+ // we only implement a single graph
310
+ return unsupported_operation ;
311
+ }
312
+
308
313
// make sure backend_ctx->config is initialized
309
314
310
315
struct llama_model_params model_params =
@@ -323,6 +328,7 @@ __load_by_name_with_configuration(void *ctx, const char *filename, graph *g)
323
328
#endif
324
329
325
330
backend_ctx -> model = model ;
331
+ * g = 0 ;
326
332
327
333
return success ;
328
334
}
@@ -363,6 +369,16 @@ init_execution_context(void *ctx, graph g, graph_execution_context *exec_ctx)
363
369
{
364
370
struct LlamaContext * backend_ctx = (struct LlamaContext * )ctx ;
365
371
372
+ if (g != 0 || backend_ctx -> model == NULL ) {
373
+ // we only implement a single graph
374
+ return runtime_error ;
375
+ }
376
+
377
+ if (backend_ctx -> ctx != NULL ) {
378
+ // we only implement a single context
379
+ return unsupported_operation ;
380
+ }
381
+
366
382
struct llama_context_params ctx_params =
367
383
llama_context_params_from_wasi_nn_llama_config (& backend_ctx -> config );
368
384
struct llama_context * llama_ctx =
@@ -373,6 +389,7 @@ init_execution_context(void *ctx, graph g, graph_execution_context *exec_ctx)
373
389
}
374
390
375
391
backend_ctx -> ctx = llama_ctx ;
392
+ * exec_ctx = 0 ;
376
393
377
394
NN_INFO_PRINTF ("n_predict = %d, n_ctx = %d" , backend_ctx -> config .n_predict ,
378
395
llama_n_ctx (backend_ctx -> ctx ));
@@ -384,6 +401,12 @@ set_input(void *ctx, graph_execution_context exec_ctx, uint32_t index,
384
401
tensor * wasi_nn_tensor )
385
402
{
386
403
struct LlamaContext * backend_ctx = (struct LlamaContext * )ctx ;
404
+
405
+ if (exec_ctx != 0 || backend_ctx -> ctx == NULL ) {
406
+ // we only implement a single context
407
+ return runtime_error ;
408
+ }
409
+
387
410
// tensor->data is the prompt string.
388
411
char * prompt_text = (char * )wasi_nn_tensor -> data .buf ;
389
412
uint32_t prompt_text_len = wasi_nn_tensor -> data .size ;
@@ -433,6 +456,11 @@ compute(void *ctx, graph_execution_context exec_ctx)
433
456
struct LlamaContext * backend_ctx = (struct LlamaContext * )ctx ;
434
457
wasi_nn_error ret = runtime_error ;
435
458
459
+ if (exec_ctx != 0 || backend_ctx -> ctx == NULL ) {
460
+ // we only implement a single context
461
+ return runtime_error ;
462
+ }
463
+
436
464
// reset the generation buffer
437
465
if (backend_ctx -> generation == NULL ) {
438
466
backend_ctx -> generation =
@@ -554,6 +582,11 @@ get_output(void *ctx, graph_execution_context exec_ctx, uint32_t index,
554
582
{
555
583
struct LlamaContext * backend_ctx = (struct LlamaContext * )ctx ;
556
584
585
+ if (exec_ctx != 0 || backend_ctx -> ctx == NULL ) {
586
+ // we only implement a single context
587
+ return runtime_error ;
588
+ }
589
+
557
590
// Compatibility with WasmEdge
558
591
if (index > 1 ) {
559
592
NN_ERR_PRINTF ("Invalid output index %d" , index );
0 commit comments