Skip to content

CoreML Partitioner is not able to lower mv3 #10451

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
mergennachin opened this issue Apr 24, 2025 · 6 comments · Fixed by #10534
Closed

CoreML Partitioner is not able to lower mv3 #10451

mergennachin opened this issue Apr 24, 2025 · 6 comments · Fixed by #10534
Assignees
Labels
module: coreml Issues related to Apple's Core ML delegation and code under backends/apple/coreml/ module: user experience Issues related to reducing friction for users

Comments

@mergennachin
Copy link
Contributor

mergennachin commented Apr 24, 2025

🐛 Describe the bug

    model = models.mobilenet_v3_small(weights="DEFAULT").eval()
    sample_inputs = (torch.randn(1, 3, 224, 224),)


    et_program_coreml = to_edge_transform_and_lower(
        torch.export.export(model, sample_inputs),
        partitioner=[CoreMLPartitioner()],
    ).to_executorch()


   with open("mv3_coreml_all.pte", "wb") as file:
        et_program_coreml.write_to_file(file)

Even though it is able to generate a file, it is spewing so much error. And during runtime it is crashing.

https://gist.github.com/mergennachin/74ca8ef593bc6c962d8d1baacaede2ed

On the other hand,

python3 -m executorch.examples.apple.coreml.scripts.export --model_name=mv3

is fine because it has many layers of patches to make CoreML work

Versions

executorch==0.6.0

cc @kimishpatel @YifanShenSZ @cymbalrush @metascroy @byjlw

@mergennachin mergennachin added module: coreml Issues related to Apple's Core ML delegation and code under backends/apple/coreml/ module: user experience Issues related to reducing friction for users labels Apr 24, 2025
@github-project-automation github-project-automation bot moved this to To triage in ExecuTorch DevX Apr 24, 2025
@digantdesai
Copy link
Contributor

Do we have something close to this in CI? Like a quantizer variant perhaps?

@mergennachin
Copy link
Contributor Author

@metascroy
Copy link
Contributor

I took a closer look.

When dim order is enabled (now the default), this model has “executorch.exir.dialects.edge._ops.dim_order_ops._to_dim_order_copy.default” ops that return floats, and this op is not recognized by CoreML (https://github.com/apple/coremltools/blob/main/coremltools/converters/mil/frontend/torch/ops.py), so the partitioner skips them. But this results in a delegate call that passes in a bunch of floats as inputs.

coremltools wraps these floats in Rank1 tensors at compile time, but to ExecuTorch they are still floats. At runtime, ExecuTorch forwards these floats to the CoreML delegate, but the model complains it hasn't received the wrapped tensor inputs (which results in an error).

When dim order is disabled, we see “executorch.exir.dialects.edge._ops.aten._to_copy.default” instead in the graph and this is supported by CoreML and grabbed by the partitioner. So the delegate only has a tensor input and things work fine.

Disabled dim order for this model is a short term fix. Longer term, we should 1) add support for _to_dim_order_copy to coremltools, and 2) handle scalars in ET CoreML’s runtime in the same way they’re handled by coremltools at compile time (i.e., wrap them in rank 1 tensors). Either one of these fixes would solve the problem for this model, but we should probably do both.

cc @YifanShenSZ @cymbalrush

@metascroy
Copy link
Contributor

metascroy commented Apr 25, 2025

@digantdesai @Gasoonjia

@shoumikhin had to disable dim order https://github.com/pytorch-labs/executorch-examples/pull/23/files when exporting

In terms of why CI did not catch this when dim order was enabled by default, it does not look like we use partitioner in our test_models.sh script. We use the older to_backend API instead. To use partitioner, this arg needs to be set: https://github.com/pytorch/executorch/blob/main/examples/apple/coreml/scripts/export.py#L76

@digantdesai
Copy link
Contributor

  1. add support for _to_dim_order_copy to coremltools

Yep

@metascroy
Copy link
Contributor

  1. add support for _to_dim_order_copy to coremltools

Yep

But that doesn't support the dim order op, so the partitioner will still skip it.

metascroy added a commit that referenced this issue Apr 30, 2025
coremltools wraps rank 0 tensors as rank1 tensors in the AOT flow when
producing a PTE file. This PR:

* Applies the same change on the coreml runtime
* Changes CoreML model tests to use the partitioner

Fixes #10451
@github-project-automation github-project-automation bot moved this from To triage to Done in ExecuTorch DevX Apr 30, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: coreml Issues related to Apple's Core ML delegation and code under backends/apple/coreml/ module: user experience Issues related to reducing friction for users
Projects
Status: Done
Development

Successfully merging a pull request may close this issue.

5 participants