Skip to content

Arm backend: Add initial module tests for Stable Diffusion 3.5 Medium #12242

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions backends/arm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,14 @@ The you can run the tests with
pytest -c /dev/null -v -n auto backends/arm/test
```

### Model test dependencies
Some model tests in Arm backend require third-party libraries or packages. To run these tests, you need to install the required dependencies by running the script `examples/arm/setup.sh` with the flag `--setup-test-dependency`.

Please note that installing model test dependencies is a standalone process. When using the `--setup-test-dependency` flag, the script will install only the necessary dependencies for model tests, skipping all other setup procedures.

List of models with specific dependencies:
- Stable Diffusion: [diffusers](https://github.com/huggingface/diffusers/tree/main)

## Passes

With the default passes in the Arm Ethos-U backend, assuming the model lowers fully to the
Expand Down Expand Up @@ -189,7 +197,14 @@ Configuration of the EthosUBackend export flow is controlled by CompileSpec info
As this is in active development see the EthosUBackend for accurate information on [compilation flags](https://github.com/pytorch/executorch/blob/29f6dc9353e90951ed3fae3c57ae416de0520067/backends/arm/arm_backend.py#L319-L324)

## Model specific and optional passes
The current TOSA version does not support int64. For LLMs for example LLama, often aten.emedding is the first operator and it requires int64 indicies.
In order to lower this to TOSA and int64->int32 cast need to be injected. This pass need to run very early in the lowering process and can be passed in to the to_edge_transform_and_lower() function call as an optional parameter. See example in: backends/arm/test/models/test_llama.py.
By doing this aten.embedding will be decomposed into to aten.index_select which can handle int32 indices.
Note that this additional step is only needed for pure float models. With quantization this is automatically handled during annotation before the export stage.
The current TOSA version does not support int64. However, int64 is commonly used in many models. In order to lower the operators with int64 inputs and/or outputs to TOSA, a few passes have been developed to handle the int64-related issues. The main idea behind these passes is to replace the uses of int64 with int32 where feasible.
- For floating-point models, these passes need to run very early in the lowering process and can be passed in to the to_edge_transform_and_lower() function call as an optional parameter.
- For quantized models, these transformations will be automatically handled during annotation before the export stage.

List of model specific and optional passes:
- InsertCastForOpsWithInt64InputPass
- Functionality:
- For LLMs such as LLama, some opeartors like aten.embedding have int64 input. In order to lower these operators to TOSA, this pass will insert a casting node that converts the input from int64 to int32.
- Example usage: backends/arm/test/models/test_llama.py
- Supported Ops:
- aten.embedding.default, aten.slice_copy.Tensor
48 changes: 39 additions & 9 deletions backends/arm/_passes/insert_int64_input_cast_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,14 @@

class InsertCastForOpsWithInt64InputPass(ExportPass):

aten_ops = (torch.ops.aten.embedding.default,)
edge_ops = (exir_ops.edge.aten.embedding.default,)
aten_ops = (
torch.ops.aten.embedding.default,
torch.ops.aten.slice_copy.Tensor,
)
edge_ops = (
exir_ops.edge.aten.embedding.default,
exir_ops.edge.aten.slice_copy.Tensor,
)

def get_decomposition(self, op):
if op in self.edge_ops:
Expand Down Expand Up @@ -60,35 +66,59 @@ def call(self, graph_module):
continue

args = node.args
weights = args[0]
indices = args[1]

valid_for_insert = False
if node.target in (
exir_ops.edge.aten.embedding.default,
torch.ops.aten.embedding.default,
):
weights = args[0]
indices = args[1]
valid_for_insert = self._check_aten_embedding_within_int32(
weights, indices, node
)

if valid_for_insert:
if valid_for_insert:
to_copy_op = self.get_decomposition(node.target)
with graph.inserting_before(node):
cast_before = create_node(
graph,
to_copy_op,
args=(indices,),
kwargs={
"dtype": torch.int32,
"memory_format": torch.preserve_format,
},
)
node.replace_input_with(indices, cast_before)

modified_graph = True

elif node.target in (
exir_ops.edge.aten.slice_copy.Tensor,
torch.ops.aten.slice_copy.Tensor,
):
# MLETORCH-829: Add range check for slice_copy
input_tensor = args[0]
fake_tensor = input_tensor.meta["val"]
if fake_tensor.dtype != torch.int64:
continue

to_copy_op = self.get_decomposition(node.target)
with graph.inserting_before(node):
cast_before = create_node(
graph,
to_copy_op,
args=(indices,),
args=(input_tensor,),
kwargs={
"dtype": torch.int32,
"memory_format": torch.preserve_format,
},
)
node.replace_input_with(indices, cast_before)
node.replace_input_with(input_tensor, cast_before)

modified_graph = True

if modified_graph:
graph_module.recompile()
graph_module = super().call(graph_module).graph_module
return PassResult(graph_module, True)
return PassResult(graph_module, modified_graph)
10 changes: 10 additions & 0 deletions backends/arm/scripts/install_models_for_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/usr/bin/env bash
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

set -e

# Install diffusers for Stable Diffusion model test
pip install "diffusers[torch]==0.33.1"
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# Adapted from Hugging Face's diffusers library:
# https://github.com/huggingface/diffusers/blob/v0.33.1/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py
#
# Licensed under the Apache License, Version 2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from transformers import CLIPTextConfig, T5Config


"""
This file defines test configs used to initialize Stable Diffusion module tests.
Module tests in the same directory will import these configs.

To stay aligned with the Stable Diffusion implementation in the HuggingFace Diffusers library,
the configs here are either directly copied from corresponding test files or exported from
pre-trained models used in the Diffusers library.

Licenses:
The test parameters are from Hugging Face's diffusers library and under the Apache 2.0 License,
while the remainder of the code is under the BSD-style license found in the LICENSE file in the
root directory of this source tree.
"""


# Source: https://github.com/huggingface/diffusers/blob/v0.33.1/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py#L56
CLIP_text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
hidden_act="gelu",
projection_dim=32,
)


# Source: https://github.com/huggingface/diffusers/blob/v0.33.1/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py#L76
# Exported from: T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5").config
T5_encoder_config = T5Config(
bos_token_id=0,
classifier_dropout=0.0,
d_ff=37,
d_kv=8,
d_model=32,
decoder_start_token_id=0,
dense_act_fn="relu",
dropout_rate=0.1,
eos_token_id=1,
feed_forward_proj="relu",
gradient_checkpointing=False,
initializer_factor=0.002,
is_encoder_decoder=True,
is_gated_act=False,
layer_norm_epsilon=1e-06,
model_type="t5",
num_decoder_layers=5,
num_heads=4,
num_layers=5,
pad_token_id=0,
relative_attention_max_distance=128,
relative_attention_num_buckets=8,
transformers_version="4.47.1",
vocab_size=1000,
)


# Source: https://github.com/huggingface/diffusers/blob/v0.33.1/tests/models/transformers/test_models_transformer_sd3.py#L142
SD3Transformer2DModel_init_dict = {
"sample_size": 32,
"patch_size": 1,
"in_channels": 4,
"num_layers": 4,
"attention_head_dim": 8,
"num_attention_heads": 4,
"caption_projection_dim": 32,
"joint_attention_dim": 32,
"pooled_projection_dim": 64,
"out_channels": 4,
"pos_embed_max_size": 96,
"dual_attention_layers": (0,),
"qk_norm": "rms_norm",
}


# Source: https://github.com/huggingface/diffusers/blob/v0.33.1/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py#L83
AutoencoderKL_config = {
"sample_size": 32,
"in_channels": 3,
"out_channels": 3,
"block_out_channels": (4,),
"layers_per_block": 1,
"latent_channels": 4,
"norm_num_groups": 1,
"use_quant_conv": False,
"use_post_quant_conv": False,
"shift_factor": 0.0609,
"scaling_factor": 1.5035,
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import unittest

import torch
from executorch.backends.arm._passes import InsertCastForOpsWithInt64InputPass

from executorch.backends.arm.test import common
from executorch.backends.arm.test.models.stable_diffusion.stable_diffusion_module_test_configs import (
CLIP_text_encoder_config,
)
from executorch.backends.arm.test.tester.arm_tester import ArmTester
from transformers import CLIPTextModelWithProjection


class TestCLIPTextModelWithProjection(unittest.TestCase):
"""
Test class of CLIPTextModelWithProjection.
CLIPTextModelWithProjection is one of the text_encoder used by Stable Diffusion 3.5 Medium
"""

# Adjust nbr below as we increase op support. Note: most of the delegates
# calls are directly consecutive to each other in the .pte. The reason
# for that is some assert ops are removed by passes in the
# .to_executorch step, i.e. after Arm partitioner.
ops_after_partitioner = {
"executorch_exir_dialects_edge__ops_aten__to_copy_default": 3,
"executorch_exir_dialects_edge__ops_aten_argmax_default": 1,
"executorch_exir_dialects_edge__ops_aten_index_Tensor": 1,
"executorch_exir_dialects_edge__ops_aten_lt_Tensor": 1,
"executorch_exir_dialects_edge__ops_aten_view_copy_default": 2,
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1,
"torch.ops.higher_order.executorch_call_delegate": 3,
}

def _prepare_inputs(
self,
batch_size=12,
seq_length=7,
vocab_size=1000,
):
input_ids = torch.randint(
low=0,
high=vocab_size,
size=(batch_size, seq_length),
dtype=torch.long,
)
return (input_ids,)

def prepare_model_and_inputs(self):
clip_text_encoder_config = CLIP_text_encoder_config

text_encoder_model = CLIPTextModelWithProjection(clip_text_encoder_config)
text_encoder_model.eval()
text_encoder_model_inputs = self._prepare_inputs()

return text_encoder_model, text_encoder_model_inputs

def test_CLIPTextModelWithProjection_tosa_MI(self):
text_encoder_model, text_encoder_model_inputs = self.prepare_model_and_inputs()
with torch.no_grad():
(
ArmTester(
text_encoder_model,
example_inputs=text_encoder_model_inputs,
compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+FP"),
transform_passes=[InsertCastForOpsWithInt64InputPass()],
)
.export()
.to_edge_transform_and_lower()
.dump_operator_distribution()
.check_count(self.ops_after_partitioner)
.to_executorch()
.run_method_and_compare_outputs(
inputs=text_encoder_model_inputs,
)
)

# MLETORCH-867, MLETORCH-1059
# Failures: "Fatal Python error: Aborted, Dependency cycles, KeyError in CastInt64BuffersToInt32Pass")
@unittest.expectedFailure
def test_CLIPTextModelWithProjection_tosa_BI(self):
text_encoder_model, text_encoder_model_inputs = self.prepare_model_and_inputs()
with torch.no_grad():
(
ArmTester(
text_encoder_model,
example_inputs=text_encoder_model_inputs,
compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+INT"),
)
.quantize()
.export()
.to_edge_transform_and_lower()
.dump_operator_distribution()
.to_executorch()
.run_method_and_compare_outputs(
inputs=text_encoder_model_inputs,
)
)
Loading
Loading