From e0c822ba355308c5faf9687bb09b238ebb7bd540 Mon Sep 17 00:00:00 2001 From: Alex Plex Date: Tue, 28 Jan 2025 11:21:09 +0500 Subject: [PATCH] fix: correct convertation arg to tensor array --- tfjs-core/src/tensor_util_env.ts | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/tfjs-core/src/tensor_util_env.ts b/tfjs-core/src/tensor_util_env.ts index 60fd1272b1e..15377f2a8bf 100644 --- a/tfjs-core/src/tensor_util_env.ts +++ b/tfjs-core/src/tensor_util_env.ts @@ -133,12 +133,18 @@ export function convertToTensor( export function convertToTensorArray( arg: Array, argName: string, functionName: string, parseAsDtype: DataType|'numeric'|'string_or_numeric' = 'numeric'): T[] { + let tensors = arg as T[]; + if (!Array.isArray(arg)) { - throw new Error( - `Argument ${argName} passed to ${functionName} must be a ` + - '`Tensor[]` or `TensorLike[]`'); + if ((arg as T) instanceof getGlobalTensorClass()) { + tensors = [arg]; + } else { + throw new Error( + `Argument ${argName} passed to ${functionName} must be a ` + + '`Tensor[]` or `TensorLike[]`'); + } } - const tensors = arg as T[]; + return tensors.map( (t, i) => convertToTensor(t, `${argName}[${i}]`, functionName, parseAsDtype));