Skip to content

Commit 2372a47

Browse files
authored
wasi-nn: make the host use the wasi_ephemeral_nn version of tensor_data (#4411)
the motivations: * make the actual input size available to the backends. (currently the backends have to make a guess from shape/type.) * make the host logic look a bit similar to wasi_ephemeral_nn. this is a backend api/abi change.
1 parent 23799a2 commit 2372a47

File tree

9 files changed

+43
-37
lines changed

9 files changed

+43
-37
lines changed

core/iwasm/libraries/wasi-nn/include/wasi_nn_types.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ typedef enum {
9999
// 4-byte f32 elements would have a data array of length 16). Naturally, this
100100
// representation requires some knowledge of how to lay out data in
101101
// memory--e.g., using row-major ordering--and could perhaps be improved.
102-
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 && defined(__wasm__)
102+
#if !defined(__wasm__) || WASM_ENABLE_WASI_EPHEMERAL_NN != 0
103103
typedef struct {
104104
uint8_t *buf;
105105
uint32_t size;

core/iwasm/libraries/wasi-nn/src/utils/wasi_nn_app_native.c

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ graph_builder_array_app_native(wasm_module_inst_t instance,
9999

100100
static wasi_nn_error
101101
tensor_data_app_native(wasm_module_inst_t instance, uint32_t total_elements,
102-
tensor_wasm *input_tensor_wasm, tensor_data *data)
102+
tensor_wasm *input_tensor_wasm, void **data,
103+
uint32_t *size)
103104
{
104105
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
105106
#define data_size input_tensor_wasm->data_size
@@ -113,8 +114,9 @@ tensor_data_app_native(wasm_module_inst_t instance, uint32_t total_elements,
113114
NN_ERR_PRINTF("input_tensor_wasm->data_offset is invalid");
114115
return invalid_argument;
115116
}
116-
*data = (tensor_data)wasm_runtime_addr_app_to_native(
117+
*data = wasm_runtime_addr_app_to_native(
117118
instance, (uint64)input_tensor_wasm->data_offset);
119+
*size = data_size;
118120
return success;
119121
#undef data_size
120122
}
@@ -188,16 +190,19 @@ tensor_app_native(wasm_module_inst_t instance, tensor_wasm *input_tensor_wasm,
188190
NN_DBG_PRINTF("Tensor type: %d", input_tensor_wasm->type);
189191
NN_DBG_PRINTF("Total number of elements: %d", total_elements);
190192

191-
tensor_data data = NULL;
193+
void *data = NULL;
194+
uint32_t datasize;
192195
if (success
193-
!= (res = tensor_data_app_native(instance, total_elements,
194-
input_tensor_wasm, &data))) {
196+
!= (res =
197+
tensor_data_app_native(instance, total_elements,
198+
input_tensor_wasm, &data, &datasize))) {
195199
wasm_runtime_free(dimensions);
196200
return res;
197201
}
198202

199203
input_tensor->type = input_tensor_wasm->type;
200204
input_tensor->dimensions = dimensions;
201-
input_tensor->data = data;
205+
input_tensor->data.buf = data;
206+
input_tensor->data.size = datasize;
202207
return success;
203208
}

core/iwasm/libraries/wasi-nn/src/wasi_nn.c

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -720,12 +720,12 @@ wasi_nn_compute(wasm_exec_env_t exec_env, graph_execution_context ctx)
720720
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
721721
wasi_nn_error
722722
wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
723-
uint32_t index, tensor_data output_tensor,
723+
uint32_t index, void *output_tensor,
724724
uint32_t output_tensor_len, uint32_t *output_tensor_size)
725725
#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
726726
wasi_nn_error
727727
wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
728-
uint32_t index, tensor_data output_tensor,
728+
uint32_t index, void *output_tensor,
729729
uint32_t *output_tensor_size)
730730
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
731731
{
@@ -753,16 +753,17 @@ wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
753753
goto fail;
754754
}
755755

756+
tensor_data tensor = {
757+
.buf = output_tensor,
756758
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
759+
.size = output_tensor_len,
760+
#else
761+
.size = *output_tensor_size,
762+
#endif
763+
};
757764
call_wasi_nn_func(wasi_nn_ctx->backend, get_output, res,
758-
wasi_nn_ctx->backend_ctx, ctx, index, output_tensor,
759-
&output_tensor_len);
760-
*output_tensor_size = output_tensor_len;
761-
#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
762-
call_wasi_nn_func(wasi_nn_ctx->backend, get_output, res,
763-
wasi_nn_ctx->backend_ctx, ctx, index, output_tensor,
765+
wasi_nn_ctx->backend_ctx, ctx, index, &tensor,
764766
output_tensor_size);
765-
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
766767
fail:
767768
unlock_ctx(wasi_nn_ctx);
768769
return res;

core/iwasm/libraries/wasi-nn/src/wasi_nn_llamacpp.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ set_input(void *ctx, graph_execution_context exec_ctx, uint32_t index,
385385
{
386386
struct LlamaContext *backend_ctx = (struct LlamaContext *)ctx;
387387
// tensor->data is the prompt string. ends with \0
388-
char *prompt_text = (char *)wasi_nn_tensor->data;
388+
char *prompt_text = (char *)wasi_nn_tensor->data.buf;
389389

390390
#ifndef NDEBUG
391391
NN_DBG_PRINTF("--------------------------------------------------");
@@ -552,7 +552,7 @@ compute(void *ctx, graph_execution_context exec_ctx)
552552

553553
__attribute__((visibility("default"))) wasi_nn_error
554554
get_output(void *ctx, graph_execution_context exec_ctx, uint32_t index,
555-
tensor_data output_tensor, uint32_t *output_tensor_size)
555+
tensor_data *output_tensor, uint32_t *output_tensor_size)
556556
{
557557
struct LlamaContext *backend_ctx = (struct LlamaContext *)ctx;
558558

@@ -571,7 +571,7 @@ get_output(void *ctx, graph_execution_context exec_ctx, uint32_t index,
571571
printf("%s\n", output_metadata);
572572
}
573573

574-
memcpy(output_tensor, output_metadata, strlen(output_metadata));
574+
memcpy(output_tensor->buf, output_metadata, strlen(output_metadata));
575575
*output_tensor_size = strlen(output_metadata);
576576
return success;
577577
}
@@ -591,7 +591,7 @@ get_output(void *ctx, graph_execution_context exec_ctx, uint32_t index,
591591
printf("%s", buf);
592592
}
593593

594-
memcpy(output_tensor + end_pos, buf, strlen(buf));
594+
memcpy(output_tensor->buf + end_pos, buf, strlen(buf));
595595
end_pos += strlen(buf);
596596
}
597597

core/iwasm/libraries/wasi-nn/src/wasi_nn_openvino.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ set_input(void *ctx, graph_execution_context exec_ctx, uint32_t index,
402402
shape_info);
403403

404404
CHECK_OV_STATUS(ov_tensor_create_from_host_ptr(input_type, input_shape,
405-
wasi_nn_tensor->data,
405+
wasi_nn_tensor->data.buf,
406406
&input_tensor),
407407
ret);
408408
}
@@ -441,7 +441,7 @@ compute(void *ctx, graph_execution_context exec_ctx)
441441

442442
__attribute__((visibility("default"))) wasi_nn_error
443443
get_output(void *ctx, graph_execution_context exec_ctx, uint32_t index,
444-
tensor_data output_tensor, uint32_t *output_tensor_size)
444+
tensor_data *output_tensor, uint32_t *output_tensor_size)
445445
{
446446
OpenVINOContext *ov_ctx = (OpenVINOContext *)ctx;
447447
struct OpenVINOExecutionContext *exec;
@@ -460,14 +460,14 @@ get_output(void *ctx, graph_execution_context exec_ctx, uint32_t index,
460460

461461
CHECK_OV_STATUS(ov_tensor_get_byte_size(ov_tensor, &byte_size), ret);
462462

463-
if (byte_size > *output_tensor_size) {
463+
if (byte_size > output_tensor->size) {
464464
ret = too_large;
465465
goto fail;
466466
}
467467

468468
CHECK_OV_STATUS(ov_tensor_data(ov_tensor, &data), ret);
469469

470-
memcpy(output_tensor, data, byte_size);
470+
memcpy(output_tensor->buf, data, byte_size);
471471

472472
*output_tensor_size = (uint32_t)byte_size;
473473

core/iwasm/libraries/wasi-nn/src/wasi_nn_openvino.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@ compute(void *ctx, graph_execution_context exec_ctx);
2424

2525
__attribute__((visibility("default"))) wasi_nn_error
2626
get_output(void *ctx, graph_execution_context exec_ctx, uint32_t index,
27-
tensor_data output_tensor, uint32_t *output_tensor_size);
27+
tensor_data *output_tensor, uint32_t *output_tensor_size);
2828

2929
__attribute__((visibility("default"))) wasi_nn_error
3030
init_backend(void **ctx);
3131

3232
__attribute__((visibility("default"))) wasi_nn_error
3333
deinit_backend(void *ctx);
3434

35-
#endif /* WASI_NN_OPENVINO_HPP */
35+
#endif /* WASI_NN_OPENVINO_HPP */

core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ typedef wasi_nn_error (*SET_INPUT)(void *, graph_execution_context, uint32_t,
3232
tensor *);
3333
typedef wasi_nn_error (*COMPUTE)(void *, graph_execution_context);
3434
typedef wasi_nn_error (*GET_OUTPUT)(void *, graph_execution_context, uint32_t,
35-
tensor_data, uint32_t *);
35+
tensor_data *, uint32_t *);
3636
/* wasi-nn general APIs */
3737
typedef wasi_nn_error (*BACKEND_INITIALIZE)(void **);
3838
typedef wasi_nn_error (*BACKEND_DEINITIALIZE)(void *);

core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ set_input(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
324324
index);
325325

326326
int size = model_tensor_size * sizeof(float);
327-
bh_memcpy_s(it, size, input_tensor->data, size);
327+
bh_memcpy_s(it, size, input_tensor->data.buf, size);
328328
}
329329
else { // TODO: Assuming uint8 quantized networks.
330330
TfLiteAffineQuantization *quant_info =
@@ -342,7 +342,7 @@ set_input(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
342342
NN_DBG_PRINTF("input tensor: (scale, offset) = (%f, %f)", scale,
343343
zero_point);
344344

345-
float *input_tensor_f = (float *)input_tensor->data;
345+
float *input_tensor_f = (float *)input_tensor->data.buf;
346346
for (uint32_t i = 0; i < model_tensor_size; ++i) {
347347
it[i] = (uint8_t)(input_tensor_f[i] / scale + zero_point);
348348
}
@@ -366,7 +366,7 @@ compute(void *tflite_ctx, graph_execution_context ctx)
366366

367367
__attribute__((visibility("default"))) wasi_nn_error
368368
get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
369-
tensor_data output_tensor, uint32_t *output_tensor_size)
369+
tensor_data *output_tensor, uint32_t *output_tensor_size)
370370
{
371371
TFLiteContext *tfl_ctx = (TFLiteContext *)tflite_ctx;
372372

@@ -392,7 +392,7 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
392392
if (tensor->quantization.type == kTfLiteNoQuantization) {
393393
NN_DBG_PRINTF("No quantization information");
394394
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
395-
if (*output_tensor_size < tensor->bytes) {
395+
if (output_tensor->size < tensor->bytes) {
396396
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
397397
return too_large;
398398
}
@@ -401,12 +401,12 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
401401
* for now, maintain the bug-to-bug compatibility with the old abi,
402402
* where the size here is the number of fp32, not bytes.
403403
*/
404-
if (*output_tensor_size < tensor->bytes / sizeof(float)) {
404+
if (output_tensor->size < tensor->bytes / sizeof(float)) {
405405
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
406406
return too_large;
407407
}
408408
#endif
409-
bh_memcpy_s(output_tensor, *output_tensor_size, tensor->data.data,
409+
bh_memcpy_s(output_tensor->buf, output_tensor->size, tensor->data.data,
410410
tensor->bytes);
411411
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
412412
*output_tensor_size = tensor->bytes;
@@ -431,7 +431,7 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
431431
model_tensor_size *= (uint32_t)tensor->dims->data[i];
432432

433433
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
434-
if (*output_tensor_size / sizeof(float) < model_tensor_size) {
434+
if (output_tensor->size / sizeof(float) < model_tensor_size) {
435435
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
436436
return too_large;
437437
}
@@ -440,7 +440,7 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
440440
* for now, maintain the bug-to-bug compatibility with the old abi,
441441
* where the size here is the number of fp32, not bytes.
442442
*/
443-
if (*output_tensor_size < model_tensor_size) {
443+
if (output_tensor->size < model_tensor_size) {
444444
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
445445
return too_large;
446446
}
@@ -454,7 +454,7 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
454454
NN_DBG_PRINTF("output tensor: (scale, offset) = (%f, %f)", scale,
455455
zero_point);
456456

457-
float *output_tensor_f = (float *)output_tensor;
457+
float *output_tensor_f = (float *)output_tensor->buf;
458458
for (uint32_t i = 0; i < model_tensor_size; ++i) {
459459
output_tensor_f[i] = (ot[i] - zero_point) * scale;
460460
}

core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ compute(void *tflite_ctx, graph_execution_context ctx);
3232

3333
__attribute__((visibility("default"))) wasi_nn_error
3434
get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
35-
tensor_data output_tensor, uint32_t *output_tensor_size);
35+
tensor_data *output_tensor, uint32_t *output_tensor_size);
3636

3737
__attribute__((visibility("default"))) wasi_nn_error
3838
init_backend(void **tflite_ctx);

0 commit comments

Comments
 (0)