diff --git a/tfjs-core/src/tensor_util_env.ts b/tfjs-core/src/tensor_util_env.ts index 60fd1272b1..15377f2a8b 100644 --- a/tfjs-core/src/tensor_util_env.ts +++ b/tfjs-core/src/tensor_util_env.ts @@ -133,12 +133,18 @@ export function convertToTensor( export function convertToTensorArray( arg: Array, argName: string, functionName: string, parseAsDtype: DataType|'numeric'|'string_or_numeric' = 'numeric'): T[] { + let tensors = arg as T[]; + if (!Array.isArray(arg)) { - throw new Error( - `Argument ${argName} passed to ${functionName} must be a ` + - '`Tensor[]` or `TensorLike[]`'); + if ((arg as T) instanceof getGlobalTensorClass()) { + tensors = [arg]; + } else { + throw new Error( + `Argument ${argName} passed to ${functionName} must be a ` + + '`Tensor[]` or `TensorLike[]`'); + } } - const tensors = arg as T[]; + return tensors.map( (t, i) => convertToTensor(t, `${argName}[${i}]`, functionName, parseAsDtype));