Skip to content

Commit ea408ab

Browse files
authored
wasi-nn: add minimum serialization on WASINNContext (#4387)
currently this is not necessary because context (WASINNContext) is local to instance. (wasm_module_instance_t) i plan to make a context shared among instances in a cluster when fixing #4313. this is a preparation for that direction. an obvious alternative is to tweak the module instance context APIs to allow declaring some kind of contexts instance-local. but i feel, in this particular case, it's more natural to make "wasi-nn handles" shared among threads within a "process". note that, spec-wise, how wasi-nn behaves wrt threads is not defined at all because wasi officially doesn't have threads yet. i suppose, at this point, that how wasi-nn interacts with wasi-threads is something we need to define by ourselves, especially when we are using an outdated wasi-nn version. with this change, if a thread attempts to access a context while another thread is using it, we simply make the operation fail with the "busy" error. this is intended for the mimimum serialization to avoid problems like crashes/leaks/etc. this is not intended to allow parallelism or such. no functional changes are intended at this point yet. cf. #4313 #2430
1 parent 71c07f3 commit ea408ab

File tree

2 files changed

+120
-33
lines changed

2 files changed

+120
-33
lines changed

core/iwasm/libraries/wasi-nn/src/wasi_nn.c

Lines changed: 116 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -102,13 +102,16 @@ wasi_nn_ctx_destroy(WASINNContext *wasi_nn_ctx)
102102
NN_DBG_PRINTF("-> is_model_loaded: %d", wasi_nn_ctx->is_model_loaded);
103103
NN_DBG_PRINTF("-> current_encoding: %d", wasi_nn_ctx->backend);
104104

105+
bh_assert(!wasi_nn_ctx->busy);
106+
105107
/* deinit() the backend */
106108
if (wasi_nn_ctx->is_backend_ctx_initialized) {
107109
wasi_nn_error res;
108110
call_wasi_nn_func(wasi_nn_ctx->backend, deinit, res,
109111
wasi_nn_ctx->backend_ctx);
110112
}
111113

114+
os_mutex_destroy(&wasi_nn_ctx->lock);
112115
wasm_runtime_free(wasi_nn_ctx);
113116
}
114117

@@ -154,6 +157,11 @@ wasi_nn_initialize_context()
154157
}
155158

156159
memset(wasi_nn_ctx, 0, sizeof(WASINNContext));
160+
if (os_mutex_init(&wasi_nn_ctx->lock)) {
161+
NN_ERR_PRINTF("Error when initializing a lock for WASI-NN context");
162+
wasm_runtime_free(wasi_nn_ctx);
163+
return NULL;
164+
}
157165
return wasi_nn_ctx;
158166
}
159167

@@ -180,6 +188,35 @@ wasm_runtime_get_wasi_nn_ctx(wasm_module_inst_t instance)
180188
return wasi_nn_ctx;
181189
}
182190

191+
static WASINNContext *
192+
lock_ctx(wasm_module_inst_t instance)
193+
{
194+
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
195+
if (wasi_nn_ctx == NULL) {
196+
return NULL;
197+
}
198+
os_mutex_lock(&wasi_nn_ctx->lock);
199+
if (wasi_nn_ctx->busy) {
200+
os_mutex_unlock(&wasi_nn_ctx->lock);
201+
return NULL;
202+
}
203+
wasi_nn_ctx->busy = true;
204+
os_mutex_unlock(&wasi_nn_ctx->lock);
205+
return wasi_nn_ctx;
206+
}
207+
208+
static void
209+
unlock_ctx(WASINNContext *wasi_nn_ctx)
210+
{
211+
if (wasi_nn_ctx == NULL) {
212+
return;
213+
}
214+
os_mutex_lock(&wasi_nn_ctx->lock);
215+
bh_assert(wasi_nn_ctx->busy);
216+
wasi_nn_ctx->busy = false;
217+
os_mutex_unlock(&wasi_nn_ctx->lock);
218+
}
219+
183220
void
184221
wasi_nn_destroy()
185222
{
@@ -405,7 +442,7 @@ detect_and_load_backend(graph_encoding backend_hint,
405442

406443
static wasi_nn_error
407444
ensure_backend(wasm_module_inst_t instance, graph_encoding encoding,
408-
WASINNContext **wasi_nn_ctx_ptr)
445+
WASINNContext *wasi_nn_ctx)
409446
{
410447
wasi_nn_error res;
411448

@@ -416,7 +453,6 @@ ensure_backend(wasm_module_inst_t instance, graph_encoding encoding,
416453
goto fail;
417454
}
418455

419-
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
420456
if (wasi_nn_ctx->is_backend_ctx_initialized) {
421457
if (wasi_nn_ctx->backend != loaded_backend) {
422458
res = unsupported_operation;
@@ -434,7 +470,6 @@ ensure_backend(wasm_module_inst_t instance, graph_encoding encoding,
434470

435471
wasi_nn_ctx->is_backend_ctx_initialized = true;
436472
}
437-
*wasi_nn_ctx_ptr = wasi_nn_ctx;
438473
return success;
439474
fail:
440475
return res;
@@ -462,17 +497,23 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
462497
if (!instance)
463498
return runtime_error;
464499

500+
WASINNContext *wasi_nn_ctx = lock_ctx(instance);
501+
if (wasi_nn_ctx == NULL) {
502+
res = busy;
503+
goto fail;
504+
}
505+
465506
graph_builder_array builder_native = { 0 };
466507
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
467508
if (success
468509
!= (res = graph_builder_array_app_native(
469510
instance, builder, builder_wasm_size, &builder_native)))
470-
return res;
511+
goto fail;
471512
#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
472513
if (success
473514
!= (res = graph_builder_array_app_native(instance, builder,
474515
&builder_native)))
475-
return res;
516+
goto fail;
476517
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
477518

478519
if (!wasm_runtime_validate_native_addr(instance, g,
@@ -482,8 +523,7 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
482523
goto fail;
483524
}
484525

485-
WASINNContext *wasi_nn_ctx;
486-
res = ensure_backend(instance, encoding, &wasi_nn_ctx);
526+
res = ensure_backend(instance, encoding, wasi_nn_ctx);
487527
if (res != success)
488528
goto fail;
489529

@@ -498,6 +538,7 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
498538
// XXX: Free intermediate structure pointers
499539
if (builder_native.buf)
500540
wasm_runtime_free(builder_native.buf);
541+
unlock_ctx(wasi_nn_ctx);
501542

502543
return res;
503544
}
@@ -531,18 +572,26 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
531572

532573
NN_DBG_PRINTF("[WASI NN] LOAD_BY_NAME %s...", name);
533574

534-
WASINNContext *wasi_nn_ctx;
535-
res = ensure_backend(instance, autodetect, &wasi_nn_ctx);
575+
WASINNContext *wasi_nn_ctx = lock_ctx(instance);
576+
if (wasi_nn_ctx == NULL) {
577+
res = busy;
578+
goto fail;
579+
}
580+
581+
res = ensure_backend(instance, autodetect, wasi_nn_ctx);
536582
if (res != success)
537-
return res;
583+
goto fail;
538584

539585
call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name, res,
540586
wasi_nn_ctx->backend_ctx, name, name_len, g);
541587
if (res != success)
542-
return res;
588+
goto fail;
543589

544590
wasi_nn_ctx->is_model_loaded = true;
545-
return success;
591+
res = success;
592+
fail:
593+
unlock_ctx(wasi_nn_ctx);
594+
return res;
546595
}
547596

548597
wasi_nn_error
@@ -580,19 +629,28 @@ wasi_nn_load_by_name_with_config(wasm_exec_env_t exec_env, char *name,
580629

581630
NN_DBG_PRINTF("[WASI NN] LOAD_BY_NAME_WITH_CONFIG %s %s...", name, config);
582631

583-
WASINNContext *wasi_nn_ctx;
584-
res = ensure_backend(instance, autodetect, &wasi_nn_ctx);
632+
WASINNContext *wasi_nn_ctx = lock_ctx(instance);
633+
if (wasi_nn_ctx == NULL) {
634+
res = busy;
635+
goto fail;
636+
}
637+
638+
res = ensure_backend(instance, autodetect, wasi_nn_ctx);
585639
if (res != success)
586-
return res;
640+
goto fail;
641+
;
587642

588643
call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name_with_config, res,
589644
wasi_nn_ctx->backend_ctx, name, name_len, config,
590645
config_len, g);
591646
if (res != success)
592-
return res;
647+
goto fail;
593648

594649
wasi_nn_ctx->is_model_loaded = true;
595-
return success;
650+
res = success;
651+
fail:
652+
unlock_ctx(wasi_nn_ctx);
653+
return res;
596654
}
597655

598656
wasi_nn_error
@@ -606,20 +664,27 @@ wasi_nn_init_execution_context(wasm_exec_env_t exec_env, graph g,
606664
return runtime_error;
607665
}
608666

609-
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
610-
611667
wasi_nn_error res;
668+
WASINNContext *wasi_nn_ctx = lock_ctx(instance);
669+
if (wasi_nn_ctx == NULL) {
670+
res = busy;
671+
goto fail;
672+
}
673+
612674
if (success != (res = is_model_initialized(wasi_nn_ctx)))
613-
return res;
675+
goto fail;
614676

615677
if (!wasm_runtime_validate_native_addr(
616678
instance, ctx, (uint64)sizeof(graph_execution_context))) {
617679
NN_ERR_PRINTF("ctx is invalid");
618-
return invalid_argument;
680+
res = invalid_argument;
681+
goto fail;
619682
}
620683

621684
call_wasi_nn_func(wasi_nn_ctx->backend, init_execution_context, res,
622685
wasi_nn_ctx->backend_ctx, g, ctx);
686+
fail:
687+
unlock_ctx(wasi_nn_ctx);
623688
return res;
624689
}
625690

@@ -634,25 +699,30 @@ wasi_nn_set_input(wasm_exec_env_t exec_env, graph_execution_context ctx,
634699
return runtime_error;
635700
}
636701

637-
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
638-
639702
wasi_nn_error res;
703+
WASINNContext *wasi_nn_ctx = lock_ctx(instance);
704+
if (wasi_nn_ctx == NULL) {
705+
res = busy;
706+
goto fail;
707+
}
708+
640709
if (success != (res = is_model_initialized(wasi_nn_ctx)))
641-
return res;
710+
goto fail;
642711

643712
tensor input_tensor_native = { 0 };
644713
if (success
645714
!= (res = tensor_app_native(instance, input_tensor,
646715
&input_tensor_native)))
647-
return res;
716+
goto fail;
648717

649718
call_wasi_nn_func(wasi_nn_ctx->backend, set_input, res,
650719
wasi_nn_ctx->backend_ctx, ctx, index,
651720
&input_tensor_native);
652721
// XXX: Free intermediate structure pointers
653722
if (input_tensor_native.dimensions)
654723
wasm_runtime_free(input_tensor_native.dimensions);
655-
724+
fail:
725+
unlock_ctx(wasi_nn_ctx);
656726
return res;
657727
}
658728

@@ -666,14 +736,20 @@ wasi_nn_compute(wasm_exec_env_t exec_env, graph_execution_context ctx)
666736
return runtime_error;
667737
}
668738

669-
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
670-
671739
wasi_nn_error res;
740+
WASINNContext *wasi_nn_ctx = lock_ctx(instance);
741+
if (wasi_nn_ctx == NULL) {
742+
res = busy;
743+
goto fail;
744+
}
745+
672746
if (success != (res = is_model_initialized(wasi_nn_ctx)))
673-
return res;
747+
goto fail;
674748

675749
call_wasi_nn_func(wasi_nn_ctx->backend, compute, res,
676750
wasi_nn_ctx->backend_ctx, ctx);
751+
fail:
752+
unlock_ctx(wasi_nn_ctx);
677753
return res;
678754
}
679755

@@ -696,16 +772,21 @@ wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
696772
return runtime_error;
697773
}
698774

699-
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
700-
701775
wasi_nn_error res;
776+
WASINNContext *wasi_nn_ctx = lock_ctx(instance);
777+
if (wasi_nn_ctx == NULL) {
778+
res = busy;
779+
goto fail;
780+
}
781+
702782
if (success != (res = is_model_initialized(wasi_nn_ctx)))
703-
return res;
783+
goto fail;
704784

705785
if (!wasm_runtime_validate_native_addr(instance, output_tensor_size,
706786
(uint64)sizeof(uint32_t))) {
707787
NN_ERR_PRINTF("output_tensor_size is invalid");
708-
return invalid_argument;
788+
res = invalid_argument;
789+
goto fail;
709790
}
710791

711792
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
@@ -718,6 +799,8 @@ wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
718799
wasi_nn_ctx->backend_ctx, ctx, index, output_tensor,
719800
output_tensor_size);
720801
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
802+
fail:
803+
unlock_ctx(wasi_nn_ctx);
721804
return res;
722805
}
723806

core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@
99
#include "wasi_nn_types.h"
1010
#include "wasm_export.h"
1111

12+
#include "bh_platform.h"
13+
1214
typedef struct {
15+
korp_mutex lock;
16+
bool busy;
1317
bool is_backend_ctx_initialized;
1418
bool is_model_loaded;
1519
graph_encoding backend;

0 commit comments

Comments
 (0)