@@ -102,13 +102,16 @@ wasi_nn_ctx_destroy(WASINNContext *wasi_nn_ctx)
102
102
NN_DBG_PRINTF ("-> is_model_loaded: %d" , wasi_nn_ctx -> is_model_loaded );
103
103
NN_DBG_PRINTF ("-> current_encoding: %d" , wasi_nn_ctx -> backend );
104
104
105
+ bh_assert (!wasi_nn_ctx -> busy );
106
+
105
107
/* deinit() the backend */
106
108
if (wasi_nn_ctx -> is_backend_ctx_initialized ) {
107
109
wasi_nn_error res ;
108
110
call_wasi_nn_func (wasi_nn_ctx -> backend , deinit , res ,
109
111
wasi_nn_ctx -> backend_ctx );
110
112
}
111
113
114
+ os_mutex_destroy (& wasi_nn_ctx -> lock );
112
115
wasm_runtime_free (wasi_nn_ctx );
113
116
}
114
117
@@ -154,6 +157,11 @@ wasi_nn_initialize_context()
154
157
}
155
158
156
159
memset (wasi_nn_ctx , 0 , sizeof (WASINNContext ));
160
+ if (os_mutex_init (& wasi_nn_ctx -> lock )) {
161
+ NN_ERR_PRINTF ("Error when initializing a lock for WASI-NN context" );
162
+ wasm_runtime_free (wasi_nn_ctx );
163
+ return NULL ;
164
+ }
157
165
return wasi_nn_ctx ;
158
166
}
159
167
@@ -180,6 +188,35 @@ wasm_runtime_get_wasi_nn_ctx(wasm_module_inst_t instance)
180
188
return wasi_nn_ctx ;
181
189
}
182
190
191
+ static WASINNContext *
192
+ lock_ctx (wasm_module_inst_t instance )
193
+ {
194
+ WASINNContext * wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx (instance );
195
+ if (wasi_nn_ctx == NULL ) {
196
+ return NULL ;
197
+ }
198
+ os_mutex_lock (& wasi_nn_ctx -> lock );
199
+ if (wasi_nn_ctx -> busy ) {
200
+ os_mutex_unlock (& wasi_nn_ctx -> lock );
201
+ return NULL ;
202
+ }
203
+ wasi_nn_ctx -> busy = true;
204
+ os_mutex_unlock (& wasi_nn_ctx -> lock );
205
+ return wasi_nn_ctx ;
206
+ }
207
+
208
+ static void
209
+ unlock_ctx (WASINNContext * wasi_nn_ctx )
210
+ {
211
+ if (wasi_nn_ctx == NULL ) {
212
+ return ;
213
+ }
214
+ os_mutex_lock (& wasi_nn_ctx -> lock );
215
+ bh_assert (wasi_nn_ctx -> busy );
216
+ wasi_nn_ctx -> busy = false;
217
+ os_mutex_unlock (& wasi_nn_ctx -> lock );
218
+ }
219
+
183
220
void
184
221
wasi_nn_destroy ()
185
222
{
@@ -405,7 +442,7 @@ detect_and_load_backend(graph_encoding backend_hint,
405
442
406
443
static wasi_nn_error
407
444
ensure_backend (wasm_module_inst_t instance , graph_encoding encoding ,
408
- WASINNContext * * wasi_nn_ctx_ptr )
445
+ WASINNContext * wasi_nn_ctx )
409
446
{
410
447
wasi_nn_error res ;
411
448
@@ -416,7 +453,6 @@ ensure_backend(wasm_module_inst_t instance, graph_encoding encoding,
416
453
goto fail ;
417
454
}
418
455
419
- WASINNContext * wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx (instance );
420
456
if (wasi_nn_ctx -> is_backend_ctx_initialized ) {
421
457
if (wasi_nn_ctx -> backend != loaded_backend ) {
422
458
res = unsupported_operation ;
@@ -434,7 +470,6 @@ ensure_backend(wasm_module_inst_t instance, graph_encoding encoding,
434
470
435
471
wasi_nn_ctx -> is_backend_ctx_initialized = true;
436
472
}
437
- * wasi_nn_ctx_ptr = wasi_nn_ctx ;
438
473
return success ;
439
474
fail :
440
475
return res ;
@@ -462,17 +497,23 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
462
497
if (!instance )
463
498
return runtime_error ;
464
499
500
+ WASINNContext * wasi_nn_ctx = lock_ctx (instance );
501
+ if (wasi_nn_ctx == NULL ) {
502
+ res = busy ;
503
+ goto fail ;
504
+ }
505
+
465
506
graph_builder_array builder_native = { 0 };
466
507
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
467
508
if (success
468
509
!= (res = graph_builder_array_app_native (
469
510
instance , builder , builder_wasm_size , & builder_native )))
470
- return res ;
511
+ goto fail ;
471
512
#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
472
513
if (success
473
514
!= (res = graph_builder_array_app_native (instance , builder ,
474
515
& builder_native )))
475
- return res ;
516
+ goto fail ;
476
517
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
477
518
478
519
if (!wasm_runtime_validate_native_addr (instance , g ,
@@ -482,8 +523,7 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
482
523
goto fail ;
483
524
}
484
525
485
- WASINNContext * wasi_nn_ctx ;
486
- res = ensure_backend (instance , encoding , & wasi_nn_ctx );
526
+ res = ensure_backend (instance , encoding , wasi_nn_ctx );
487
527
if (res != success )
488
528
goto fail ;
489
529
@@ -498,6 +538,7 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
498
538
// XXX: Free intermediate structure pointers
499
539
if (builder_native .buf )
500
540
wasm_runtime_free (builder_native .buf );
541
+ unlock_ctx (wasi_nn_ctx );
501
542
502
543
return res ;
503
544
}
@@ -531,18 +572,26 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
531
572
532
573
NN_DBG_PRINTF ("[WASI NN] LOAD_BY_NAME %s..." , name );
533
574
534
- WASINNContext * wasi_nn_ctx ;
535
- res = ensure_backend (instance , autodetect , & wasi_nn_ctx );
575
+ WASINNContext * wasi_nn_ctx = lock_ctx (instance );
576
+ if (wasi_nn_ctx == NULL ) {
577
+ res = busy ;
578
+ goto fail ;
579
+ }
580
+
581
+ res = ensure_backend (instance , autodetect , wasi_nn_ctx );
536
582
if (res != success )
537
- return res ;
583
+ goto fail ;
538
584
539
585
call_wasi_nn_func (wasi_nn_ctx -> backend , load_by_name , res ,
540
586
wasi_nn_ctx -> backend_ctx , name , name_len , g );
541
587
if (res != success )
542
- return res ;
588
+ goto fail ;
543
589
544
590
wasi_nn_ctx -> is_model_loaded = true;
545
- return success ;
591
+ res = success ;
592
+ fail :
593
+ unlock_ctx (wasi_nn_ctx );
594
+ return res ;
546
595
}
547
596
548
597
wasi_nn_error
@@ -580,19 +629,28 @@ wasi_nn_load_by_name_with_config(wasm_exec_env_t exec_env, char *name,
580
629
581
630
NN_DBG_PRINTF ("[WASI NN] LOAD_BY_NAME_WITH_CONFIG %s %s..." , name , config );
582
631
583
- WASINNContext * wasi_nn_ctx ;
584
- res = ensure_backend (instance , autodetect , & wasi_nn_ctx );
632
+ WASINNContext * wasi_nn_ctx = lock_ctx (instance );
633
+ if (wasi_nn_ctx == NULL ) {
634
+ res = busy ;
635
+ goto fail ;
636
+ }
637
+
638
+ res = ensure_backend (instance , autodetect , wasi_nn_ctx );
585
639
if (res != success )
586
- return res ;
640
+ goto fail ;
641
+ ;
587
642
588
643
call_wasi_nn_func (wasi_nn_ctx -> backend , load_by_name_with_config , res ,
589
644
wasi_nn_ctx -> backend_ctx , name , name_len , config ,
590
645
config_len , g );
591
646
if (res != success )
592
- return res ;
647
+ goto fail ;
593
648
594
649
wasi_nn_ctx -> is_model_loaded = true;
595
- return success ;
650
+ res = success ;
651
+ fail :
652
+ unlock_ctx (wasi_nn_ctx );
653
+ return res ;
596
654
}
597
655
598
656
wasi_nn_error
@@ -606,20 +664,27 @@ wasi_nn_init_execution_context(wasm_exec_env_t exec_env, graph g,
606
664
return runtime_error ;
607
665
}
608
666
609
- WASINNContext * wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx (instance );
610
-
611
667
wasi_nn_error res ;
668
+ WASINNContext * wasi_nn_ctx = lock_ctx (instance );
669
+ if (wasi_nn_ctx == NULL ) {
670
+ res = busy ;
671
+ goto fail ;
672
+ }
673
+
612
674
if (success != (res = is_model_initialized (wasi_nn_ctx )))
613
- return res ;
675
+ goto fail ;
614
676
615
677
if (!wasm_runtime_validate_native_addr (
616
678
instance , ctx , (uint64 )sizeof (graph_execution_context ))) {
617
679
NN_ERR_PRINTF ("ctx is invalid" );
618
- return invalid_argument ;
680
+ res = invalid_argument ;
681
+ goto fail ;
619
682
}
620
683
621
684
call_wasi_nn_func (wasi_nn_ctx -> backend , init_execution_context , res ,
622
685
wasi_nn_ctx -> backend_ctx , g , ctx );
686
+ fail :
687
+ unlock_ctx (wasi_nn_ctx );
623
688
return res ;
624
689
}
625
690
@@ -634,25 +699,30 @@ wasi_nn_set_input(wasm_exec_env_t exec_env, graph_execution_context ctx,
634
699
return runtime_error ;
635
700
}
636
701
637
- WASINNContext * wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx (instance );
638
-
639
702
wasi_nn_error res ;
703
+ WASINNContext * wasi_nn_ctx = lock_ctx (instance );
704
+ if (wasi_nn_ctx == NULL ) {
705
+ res = busy ;
706
+ goto fail ;
707
+ }
708
+
640
709
if (success != (res = is_model_initialized (wasi_nn_ctx )))
641
- return res ;
710
+ goto fail ;
642
711
643
712
tensor input_tensor_native = { 0 };
644
713
if (success
645
714
!= (res = tensor_app_native (instance , input_tensor ,
646
715
& input_tensor_native )))
647
- return res ;
716
+ goto fail ;
648
717
649
718
call_wasi_nn_func (wasi_nn_ctx -> backend , set_input , res ,
650
719
wasi_nn_ctx -> backend_ctx , ctx , index ,
651
720
& input_tensor_native );
652
721
// XXX: Free intermediate structure pointers
653
722
if (input_tensor_native .dimensions )
654
723
wasm_runtime_free (input_tensor_native .dimensions );
655
-
724
+ fail :
725
+ unlock_ctx (wasi_nn_ctx );
656
726
return res ;
657
727
}
658
728
@@ -666,14 +736,20 @@ wasi_nn_compute(wasm_exec_env_t exec_env, graph_execution_context ctx)
666
736
return runtime_error ;
667
737
}
668
738
669
- WASINNContext * wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx (instance );
670
-
671
739
wasi_nn_error res ;
740
+ WASINNContext * wasi_nn_ctx = lock_ctx (instance );
741
+ if (wasi_nn_ctx == NULL ) {
742
+ res = busy ;
743
+ goto fail ;
744
+ }
745
+
672
746
if (success != (res = is_model_initialized (wasi_nn_ctx )))
673
- return res ;
747
+ goto fail ;
674
748
675
749
call_wasi_nn_func (wasi_nn_ctx -> backend , compute , res ,
676
750
wasi_nn_ctx -> backend_ctx , ctx );
751
+ fail :
752
+ unlock_ctx (wasi_nn_ctx );
677
753
return res ;
678
754
}
679
755
@@ -696,16 +772,21 @@ wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
696
772
return runtime_error ;
697
773
}
698
774
699
- WASINNContext * wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx (instance );
700
-
701
775
wasi_nn_error res ;
776
+ WASINNContext * wasi_nn_ctx = lock_ctx (instance );
777
+ if (wasi_nn_ctx == NULL ) {
778
+ res = busy ;
779
+ goto fail ;
780
+ }
781
+
702
782
if (success != (res = is_model_initialized (wasi_nn_ctx )))
703
- return res ;
783
+ goto fail ;
704
784
705
785
if (!wasm_runtime_validate_native_addr (instance , output_tensor_size ,
706
786
(uint64 )sizeof (uint32_t ))) {
707
787
NN_ERR_PRINTF ("output_tensor_size is invalid" );
708
- return invalid_argument ;
788
+ res = invalid_argument ;
789
+ goto fail ;
709
790
}
710
791
711
792
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
@@ -718,6 +799,8 @@ wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
718
799
wasi_nn_ctx -> backend_ctx , ctx , index , output_tensor ,
719
800
output_tensor_size );
720
801
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
802
+ fail :
803
+ unlock_ctx (wasi_nn_ctx );
721
804
return res ;
722
805
}
723
806
0 commit comments