Skip to content

Commit 84f7332

Browse files
authored
allow scalar axes for Unsqueeze for WebGPU (microsoft#22054)
### Description Align with CPU behavior. https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/providers/cpu/tensor/unsqueeze.cc#L60-L62
1 parent 951b1b7 commit 84f7332

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

onnxruntime/core/providers/js/operators/unsqueeze.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,9 @@ class Unsqueeze final : public JsKernel, public UnsqueezeBase {
2626
if (num_inputs == 2) { // axes is an input
2727
const Tensor* axes_tensor = context->Input<Tensor>(1);
2828
ORT_ENFORCE(axes_tensor != nullptr, "Axes input is null");
29-
ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1,
30-
"An axes tensor must be a vector tensor.");
29+
ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 0 ||
30+
axes_tensor->Shape().NumDimensions() == 1,
31+
"An axes tensor must be a scalar or a vector tensor.");
3132
auto nDims = static_cast<size_t>(axes_tensor->Shape()[0]);
3233
const auto* data = axes_tensor->Data<int64_t>();
3334
axes.assign(data, data + nDims);

0 commit comments

Comments
 (0)