Skip to content

Commit 7d0b38f

Browse files
committed
fix rank0 coreml
1 parent c991de4 commit 7d0b38f

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

backends/apple/coreml/runtime/delegate/coreml_backend_delegate.mm

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,17 @@
8888
ET_LOG(Error, "%s: DataType=%d is not supported", ETCoreMLStrings.delegateIdentifier.UTF8String, (int)tensor.scalar_type());
8989
return std::nullopt;
9090
}
91-
91+
9292
std::vector<ssize_t> strides(tensor.strides().begin(), tensor.strides().end());
9393
std::vector<size_t> shape(tensor.sizes().begin(), tensor.sizes().end());
94+
95+
// If tensor is rank 0, wrap in rank 1
96+
// See https://github.com/apple/coremltools/blob/8.2/coremltools/converters/mil/frontend/torch/exir_utils.py#L73
97+
if (strides.size() == 0) {
98+
shape.push_back(1);
99+
strides.push_back(1);
100+
}
101+
94102
MultiArray::MemoryLayout layout(dataType.value(), std::move(shape), std::move(strides));
95103
switch (argType) {
96104
case ArgType::Input: {
@@ -195,7 +203,10 @@ ModelLoggingOptions get_logging_options(BackendExecutionContext& context) {
195203
size_t nInputs = nArgs.first;
196204
size_t nOutputs = nArgs.second;
197205
delegate_args.reserve(nInputs + nOutputs);
198-
206+
207+
// Container to hold wrapped scalar input args
208+
std::vector<executorch::extension::TensorPtr> wrapped_scalars;
209+
199210
// inputs
200211
for (size_t i = 0; i < nInputs; i++) {
201212
auto multi_array = get_multi_array(args[i], ArgType::Input);

0 commit comments

Comments
 (0)