Skip to content

Commit e0c822b

Browse files
committed
fix: correct convertation arg to tensor array
1 parent 2644bd0 commit e0c822b

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

tfjs-core/src/tensor_util_env.ts

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,12 +133,18 @@ export function convertToTensor<T extends Tensor>(
133133
export function convertToTensorArray<T extends Tensor>(
134134
arg: Array<T|TensorLike>, argName: string, functionName: string,
135135
parseAsDtype: DataType|'numeric'|'string_or_numeric' = 'numeric'): T[] {
136+
let tensors = arg as T[];
137+
136138
if (!Array.isArray(arg)) {
137-
throw new Error(
138-
`Argument ${argName} passed to ${functionName} must be a ` +
139-
'`Tensor[]` or `TensorLike[]`');
139+
if ((arg as T) instanceof getGlobalTensorClass()) {
140+
tensors = [arg];
141+
} else {
142+
throw new Error(
143+
`Argument ${argName} passed to ${functionName} must be a ` +
144+
'`Tensor[]` or `TensorLike[]`');
145+
}
140146
}
141-
const tensors = arg as T[];
147+
142148
return tensors.map(
143149
(t, i) =>
144150
convertToTensor(t, `${argName}[${i}]`, functionName, parseAsDtype));

0 commit comments

Comments
 (0)