diff --git a/backends/arm/README.md b/backends/arm/README.md index 94713ec3b3c..6bf46d3f3ae 100644 --- a/backends/arm/README.md +++ b/backends/arm/README.md @@ -104,6 +104,14 @@ The you can run the tests with pytest -c /dev/null -v -n auto backends/arm/test ``` +### Model test dependencies +Some model tests in Arm backend require third-party libraries or packages. To run these tests, you need to install the required dependencies by running the script `examples/arm/setup.sh` with the flag `--setup-test-dependency`. + +Please note that installing model test dependencies is a standalone process. When using the `--setup-test-dependency` flag, the script will install only the necessary dependencies for model tests, skipping all other setup procedures. + +List of models with specific dependencies: +- Stable Diffusion: [diffusers](https://github.com/huggingface/diffusers/tree/main) + ## Passes With the default passes in the Arm Ethos-U backend, assuming the model lowers fully to the @@ -189,7 +197,14 @@ Configuration of the EthosUBackend export flow is controlled by CompileSpec info As this is in active development see the EthosUBackend for accurate information on [compilation flags](https://github.com/pytorch/executorch/blob/29f6dc9353e90951ed3fae3c57ae416de0520067/backends/arm/arm_backend.py#L319-L324) ## Model specific and optional passes -The current TOSA version does not support int64. For LLMs for example LLama, often aten.emedding is the first operator and it requires int64 indicies. -In order to lower this to TOSA and int64->int32 cast need to be injected. This pass need to run very early in the lowering process and can be passed in to the to_edge_transform_and_lower() function call as an optional parameter. See example in: backends/arm/test/models/test_llama.py. -By doing this aten.embedding will be decomposed into to aten.index_select which can handle int32 indices. -Note that this additional step is only needed for pure float models. With quantization this is automatically handled during annotation before the export stage. +The current TOSA version does not support int64. However, int64 is commonly used in many models. In order to lower the operators with int64 inputs and/or outputs to TOSA, a few passes have been developed to handle the int64-related issues. The main idea behind these passes is to replace the uses of int64 with int32 where feasible. +- For floating-point models, these passes need to run very early in the lowering process and can be passed in to the to_edge_transform_and_lower() function call as an optional parameter. +- For quantized models, these transformations will be automatically handled during annotation before the export stage. + +List of model specific and optional passes: +- InsertCastForOpsWithInt64InputPass + - Functionality: + - For LLMs such as LLama, some opeartors like aten.embedding have int64 input. In order to lower these operators to TOSA, this pass will insert a casting node that converts the input from int64 to int32. + - Example usage: backends/arm/test/models/test_llama.py + - Supported Ops: + - aten.embedding.default, aten.slice_copy.Tensor diff --git a/backends/arm/_passes/insert_int64_input_cast_pass.py b/backends/arm/_passes/insert_int64_input_cast_pass.py index c1681320a54..5aa5d9807d6 100644 --- a/backends/arm/_passes/insert_int64_input_cast_pass.py +++ b/backends/arm/_passes/insert_int64_input_cast_pass.py @@ -20,8 +20,14 @@ class InsertCastForOpsWithInt64InputPass(ExportPass): - aten_ops = (torch.ops.aten.embedding.default,) - edge_ops = (exir_ops.edge.aten.embedding.default,) + aten_ops = ( + torch.ops.aten.embedding.default, + torch.ops.aten.slice_copy.Tensor, + ) + edge_ops = ( + exir_ops.edge.aten.embedding.default, + exir_ops.edge.aten.slice_copy.Tensor, + ) def get_decomposition(self, op): if op in self.edge_ops: @@ -60,35 +66,59 @@ def call(self, graph_module): continue args = node.args - weights = args[0] - indices = args[1] - valid_for_insert = False if node.target in ( exir_ops.edge.aten.embedding.default, torch.ops.aten.embedding.default, ): + weights = args[0] + indices = args[1] valid_for_insert = self._check_aten_embedding_within_int32( weights, indices, node ) - if valid_for_insert: + if valid_for_insert: + to_copy_op = self.get_decomposition(node.target) + with graph.inserting_before(node): + cast_before = create_node( + graph, + to_copy_op, + args=(indices,), + kwargs={ + "dtype": torch.int32, + "memory_format": torch.preserve_format, + }, + ) + node.replace_input_with(indices, cast_before) + + modified_graph = True + + elif node.target in ( + exir_ops.edge.aten.slice_copy.Tensor, + torch.ops.aten.slice_copy.Tensor, + ): + # MLETORCH-829: Add range check for slice_copy + input_tensor = args[0] + fake_tensor = input_tensor.meta["val"] + if fake_tensor.dtype != torch.int64: + continue + to_copy_op = self.get_decomposition(node.target) with graph.inserting_before(node): cast_before = create_node( graph, to_copy_op, - args=(indices,), + args=(input_tensor,), kwargs={ "dtype": torch.int32, "memory_format": torch.preserve_format, }, ) - node.replace_input_with(indices, cast_before) + node.replace_input_with(input_tensor, cast_before) modified_graph = True if modified_graph: graph_module.recompile() graph_module = super().call(graph_module).graph_module - return PassResult(graph_module, True) + return PassResult(graph_module, modified_graph) diff --git a/backends/arm/scripts/install_models_for_test.sh b/backends/arm/scripts/install_models_for_test.sh new file mode 100644 index 00000000000..9c8b034909e --- /dev/null +++ b/backends/arm/scripts/install_models_for_test.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env bash +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -e + +# Install diffusers for Stable Diffusion model test +pip install "diffusers[torch]==0.33.1" diff --git a/backends/arm/test/models/stable_diffusion/stable_diffusion_module_test_configs.py b/backends/arm/test/models/stable_diffusion/stable_diffusion_module_test_configs.py new file mode 100644 index 00000000000..86e945311c7 --- /dev/null +++ b/backends/arm/test/models/stable_diffusion/stable_diffusion_module_test_configs.py @@ -0,0 +1,114 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Adapted from Hugging Face's diffusers library: +# https://github.com/huggingface/diffusers/blob/v0.33.1/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py +# +# Licensed under the Apache License, Version 2.0 +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from transformers import CLIPTextConfig, T5Config + + +""" +This file defines test configs used to initialize Stable Diffusion module tests. +Module tests in the same directory will import these configs. + +To stay aligned with the Stable Diffusion implementation in the HuggingFace Diffusers library, +the configs here are either directly copied from corresponding test files or exported from +pre-trained models used in the Diffusers library. + +Licenses: +The test parameters are from Hugging Face's diffusers library and under the Apache 2.0 License, +while the remainder of the code is under the BSD-style license found in the LICENSE file in the +root directory of this source tree. +""" + + +# Source: https://github.com/huggingface/diffusers/blob/v0.33.1/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py#L56 +CLIP_text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + hidden_act="gelu", + projection_dim=32, +) + + +# Source: https://github.com/huggingface/diffusers/blob/v0.33.1/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py#L76 +# Exported from: T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5").config +T5_encoder_config = T5Config( + bos_token_id=0, + classifier_dropout=0.0, + d_ff=37, + d_kv=8, + d_model=32, + decoder_start_token_id=0, + dense_act_fn="relu", + dropout_rate=0.1, + eos_token_id=1, + feed_forward_proj="relu", + gradient_checkpointing=False, + initializer_factor=0.002, + is_encoder_decoder=True, + is_gated_act=False, + layer_norm_epsilon=1e-06, + model_type="t5", + num_decoder_layers=5, + num_heads=4, + num_layers=5, + pad_token_id=0, + relative_attention_max_distance=128, + relative_attention_num_buckets=8, + transformers_version="4.47.1", + vocab_size=1000, +) + + +# Source: https://github.com/huggingface/diffusers/blob/v0.33.1/tests/models/transformers/test_models_transformer_sd3.py#L142 +SD3Transformer2DModel_init_dict = { + "sample_size": 32, + "patch_size": 1, + "in_channels": 4, + "num_layers": 4, + "attention_head_dim": 8, + "num_attention_heads": 4, + "caption_projection_dim": 32, + "joint_attention_dim": 32, + "pooled_projection_dim": 64, + "out_channels": 4, + "pos_embed_max_size": 96, + "dual_attention_layers": (0,), + "qk_norm": "rms_norm", +} + + +# Source: https://github.com/huggingface/diffusers/blob/v0.33.1/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py#L83 +AutoencoderKL_config = { + "sample_size": 32, + "in_channels": 3, + "out_channels": 3, + "block_out_channels": (4,), + "layers_per_block": 1, + "latent_channels": 4, + "norm_num_groups": 1, + "use_quant_conv": False, + "use_post_quant_conv": False, + "shift_factor": 0.0609, + "scaling_factor": 1.5035, +} diff --git a/backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py b/backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py new file mode 100644 index 00000000000..72e23d506c5 --- /dev/null +++ b/backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py @@ -0,0 +1,103 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import unittest + +import torch +from executorch.backends.arm._passes import InsertCastForOpsWithInt64InputPass + +from executorch.backends.arm.test import common +from executorch.backends.arm.test.models.stable_diffusion.stable_diffusion_module_test_configs import ( + CLIP_text_encoder_config, +) +from executorch.backends.arm.test.tester.arm_tester import ArmTester +from transformers import CLIPTextModelWithProjection + + +class TestCLIPTextModelWithProjection(unittest.TestCase): + """ + Test class of CLIPTextModelWithProjection. + CLIPTextModelWithProjection is one of the text_encoder used by Stable Diffusion 3.5 Medium + """ + + # Adjust nbr below as we increase op support. Note: most of the delegates + # calls are directly consecutive to each other in the .pte. The reason + # for that is some assert ops are removed by passes in the + # .to_executorch step, i.e. after Arm partitioner. + ops_after_partitioner = { + "executorch_exir_dialects_edge__ops_aten__to_copy_default": 3, + "executorch_exir_dialects_edge__ops_aten_argmax_default": 1, + "executorch_exir_dialects_edge__ops_aten_index_Tensor": 1, + "executorch_exir_dialects_edge__ops_aten_lt_Tensor": 1, + "executorch_exir_dialects_edge__ops_aten_view_copy_default": 2, + "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1, + "torch.ops.higher_order.executorch_call_delegate": 3, + } + + def _prepare_inputs( + self, + batch_size=12, + seq_length=7, + vocab_size=1000, + ): + input_ids = torch.randint( + low=0, + high=vocab_size, + size=(batch_size, seq_length), + dtype=torch.long, + ) + return (input_ids,) + + def prepare_model_and_inputs(self): + clip_text_encoder_config = CLIP_text_encoder_config + + text_encoder_model = CLIPTextModelWithProjection(clip_text_encoder_config) + text_encoder_model.eval() + text_encoder_model_inputs = self._prepare_inputs() + + return text_encoder_model, text_encoder_model_inputs + + def test_CLIPTextModelWithProjection_tosa_MI(self): + text_encoder_model, text_encoder_model_inputs = self.prepare_model_and_inputs() + with torch.no_grad(): + ( + ArmTester( + text_encoder_model, + example_inputs=text_encoder_model_inputs, + compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+FP"), + transform_passes=[InsertCastForOpsWithInt64InputPass()], + ) + .export() + .to_edge_transform_and_lower() + .dump_operator_distribution() + .check_count(self.ops_after_partitioner) + .to_executorch() + .run_method_and_compare_outputs( + inputs=text_encoder_model_inputs, + ) + ) + + # MLETORCH-867, MLETORCH-1059 + # Failures: "Fatal Python error: Aborted, Dependency cycles, KeyError in CastInt64BuffersToInt32Pass") + @unittest.expectedFailure + def test_CLIPTextModelWithProjection_tosa_BI(self): + text_encoder_model, text_encoder_model_inputs = self.prepare_model_and_inputs() + with torch.no_grad(): + ( + ArmTester( + text_encoder_model, + example_inputs=text_encoder_model_inputs, + compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+INT"), + ) + .quantize() + .export() + .to_edge_transform_and_lower() + .dump_operator_distribution() + .to_executorch() + .run_method_and_compare_outputs( + inputs=text_encoder_model_inputs, + ) + ) diff --git a/backends/arm/test/models/stable_diffusion/test_SD3Transformer2DModel.py b/backends/arm/test/models/stable_diffusion/test_SD3Transformer2DModel.py new file mode 100644 index 00000000000..fc8ab9b484b --- /dev/null +++ b/backends/arm/test/models/stable_diffusion/test_SD3Transformer2DModel.py @@ -0,0 +1,136 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import unittest + +import torch +from diffusers.models.transformers import SD3Transformer2DModel + +from executorch.backends.arm.test import common +from executorch.backends.arm.test.models.stable_diffusion.stable_diffusion_module_test_configs import ( + SD3Transformer2DModel_init_dict, +) +from executorch.backends.arm.test.tester.arm_tester import ArmTester + + +class TestSD3Transformer2DModel(unittest.TestCase): + """ + Test class of AutoenSD3Transformer2DModelcoderKL. + SD3Transformer2DModel is the transformer model used by Stable Diffusion 3.5 Medium + """ + + # Adjust nbr below as we increase op support. Note: most of the delegates + # calls are directly consecutive to each other in the .pte. The reason + # for that is some assert ops are removed by passes in the + # .to_executorch step, i.e. after Arm partitioner. + ops_after_partitioner = { + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1, + "executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default": 1, + "executorch_exir_dialects_edge__ops_aten_view_copy_default": 2, + "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1, + "torch.ops.higher_order.executorch_call_delegate": 1, + } + + def _prepare_inputs( + self, + batch_size=2, + num_channels=4, + height=32, + width=32, + embedding_dim=32, + sequence_length=154, + max_timestep=1000, + ): + hidden_states = torch.randn( + ( + batch_size, + num_channels, + height, + width, + ) + ) + encoder_hidden_states = torch.randn( + ( + batch_size, + sequence_length, + embedding_dim, + ) + ) + pooled_prompt_embeds = torch.randn( + ( + batch_size, + embedding_dim * 2, + ) + ) + timestep = torch.randint(low=0, high=max_timestep, size=(batch_size,)) + + input_dict = { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "pooled_projections": pooled_prompt_embeds, + "timestep": timestep, + } + + return tuple(input_dict.values()) + + def prepare_model_and_inputs(self): + + class SD3Transformer2DModelWrapper(SD3Transformer2DModel): + def forward(self, *args, **kwargs): + return super().forward(*args, **kwargs).sample + + init_dict = SD3Transformer2DModel_init_dict + + sd35_transformer2D_model = SD3Transformer2DModelWrapper(**init_dict) + sd35_transformer2D_model_inputs = self._prepare_inputs() + + return sd35_transformer2D_model, sd35_transformer2D_model_inputs + + def test_SD3Transformer2DModel_tosa_MI(self): + sd35_transformer2D_model, sd35_transformer2D_model_inputs = ( + self.prepare_model_and_inputs() + ) + with torch.no_grad(): + ( + ArmTester( + sd35_transformer2D_model, + example_inputs=sd35_transformer2D_model_inputs, + compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+FP"), + ) + .export() + .to_edge_transform_and_lower() + .check_count(self.ops_after_partitioner) + .to_executorch() + .run_method_and_compare_outputs( + inputs=sd35_transformer2D_model_inputs, + rtol=1.0, # TODO: MLETORCH-875: Reduce tolerance of SD3Transformer2DModel with MI and BI + atol=4.0, + ) + ) + + def test_SD3Transformer2DModel_tosa_BI(self): + sd35_transformer2D_model, sd35_transformer2D_model_inputs = ( + self.prepare_model_and_inputs() + ) + with torch.no_grad(): + ( + ArmTester( + sd35_transformer2D_model, + example_inputs=sd35_transformer2D_model_inputs, + compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+INT"), + ) + .quantize() + .export() + .to_edge_transform_and_lower() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs( + inputs=sd35_transformer2D_model_inputs, + qtol=1.0, # TODO: MLETORCH-875: Reduce tolerance of SD3Transformer2DModel with MI and BI + rtol=1.0, + atol=4.0, + ) + ) diff --git a/backends/arm/test/models/stable_diffusion/test_T5EncoderModel.py b/backends/arm/test/models/stable_diffusion/test_T5EncoderModel.py new file mode 100644 index 00000000000..565db22492c --- /dev/null +++ b/backends/arm/test/models/stable_diffusion/test_T5EncoderModel.py @@ -0,0 +1,106 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import unittest + +import torch +from executorch.backends.arm._passes import InsertCastForOpsWithInt64InputPass + +from executorch.backends.arm.test import common +from executorch.backends.arm.test.models.stable_diffusion.stable_diffusion_module_test_configs import ( + T5_encoder_config, +) +from executorch.backends.arm.test.tester.arm_tester import ArmTester +from transformers import T5EncoderModel + + +class TestT5EncoderModel(unittest.TestCase): + """ + Test class of T5EncoderModel. + T5EncoderModel is one of the text_encoder used by Stable Diffusion 3.5 Medium + """ + + # Adjust nbr below as we increase op support. Note: most of the delegates + # calls are directly consecutive to each other in the .pte. The reason + # for that is some assert ops are removed by passes in the + # .to_executorch step, i.e. after Arm partitioner. + ops_after_partitioner = { + "executorch_exir_dialects_edge__ops_aten__to_copy_default": 2, + "executorch_exir_dialects_edge__ops_aten_abs_default": 1, + "executorch_exir_dialects_edge__ops_aten_add_Tensor": 3, + "executorch_exir_dialects_edge__ops_aten_full_like_default": 1, + "executorch_exir_dialects_edge__ops_aten_gt_Scalar": 1, + "executorch_exir_dialects_edge__ops_aten_lt_Scalar": 1, + "executorch_exir_dialects_edge__ops_aten_minimum_default": 1, + "executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1, + "executorch_exir_dialects_edge__ops_aten_sub_Tensor": 1, + "executorch_exir_dialects_edge__ops_aten_view_copy_default": 1, + "executorch_exir_dialects_edge__ops_aten_where_self": 1, + "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 3, + "torch.ops.higher_order.executorch_call_delegate": 3, + } + + def _prepare_inputs( + self, + batch_size=12, + seq_length=7, + vocab_size=1000, + ): + input_ids = torch.randint( + low=0, + high=vocab_size, + size=(batch_size, seq_length), + dtype=torch.long, + ) + return (input_ids,) + + def prepare_model_and_inputs(self): + t5_encoder_config = T5_encoder_config + + t5_encoder_model = T5EncoderModel(t5_encoder_config) + t5_encoder_model.eval() + t5_encoder_model_inputs = self._prepare_inputs() + + return t5_encoder_model, t5_encoder_model_inputs + + def test_T5EncoderModel_tosa_MI(self): + t5_encoder_model, t5_encoder_model_inputs = self.prepare_model_and_inputs() + with torch.no_grad(): + ( + ArmTester( + t5_encoder_model, + example_inputs=t5_encoder_model_inputs, + compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+FP"), + transform_passes=[InsertCastForOpsWithInt64InputPass()], + ) + .export() + .to_edge_transform_and_lower() + .dump_operator_distribution() + .check_count(self.ops_after_partitioner) + .to_executorch() + .run_method_and_compare_outputs( + inputs=t5_encoder_model_inputs, + ) + ) + + def test_T5EncoderModel_tosa_BI(self): + t5_encoder_model, t5_encoder_model_inputs = self.prepare_model_and_inputs() + with torch.no_grad(): + ( + ArmTester( + t5_encoder_model, + example_inputs=t5_encoder_model_inputs, + compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+INT"), + ) + .quantize() + .export() + .to_edge_transform_and_lower() + .dump_operator_distribution() + .to_executorch() + .run_method_and_compare_outputs( + inputs=t5_encoder_model_inputs, + ) + ) diff --git a/backends/arm/test/models/stable_diffusion/test_vae_AutoencoderKL.py b/backends/arm/test/models/stable_diffusion/test_vae_AutoencoderKL.py new file mode 100644 index 00000000000..d2c48e2adba --- /dev/null +++ b/backends/arm/test/models/stable_diffusion/test_vae_AutoencoderKL.py @@ -0,0 +1,80 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import unittest + +import torch +from diffusers.models.autoencoders import AutoencoderKL +from diffusers.utils.testing_utils import floats_tensor + +from executorch.backends.arm.test import common +from executorch.backends.arm.test.models.stable_diffusion.stable_diffusion_module_test_configs import ( + AutoencoderKL_config, +) +from executorch.backends.arm.test.tester.arm_tester import ArmTester + + +class TestAutoencoderKL(unittest.TestCase): + """ + Test class of AutoencoderKL. + AutoencoderKL is the encoder/decoder used by Stable Diffusion 3.5 Medium + """ + + def _prepare_inputs(self, batch_size=4, num_channels=3, sizes=(32, 32)): + image = floats_tensor((batch_size, num_channels) + sizes) + return (image,) + + def prepare_model_and_inputs(self): + + class AutoencoderWrapper(AutoencoderKL): + def forward(self, *args, **kwargs): + return super().forward(*args, **kwargs).sample + + vae_config = AutoencoderKL_config + + auto_encoder_model = AutoencoderWrapper(**vae_config) + + auto_encoder_model_inputs = self._prepare_inputs() + + return auto_encoder_model, auto_encoder_model_inputs + + def test_AutoencoderKL_tosa_MI(self): + auto_encoder_model, auto_encoder_model_inputs = self.prepare_model_and_inputs() + with torch.no_grad(): + ( + ArmTester( + auto_encoder_model, + example_inputs=auto_encoder_model_inputs, + compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+FP"), + ) + .export() + .to_edge_transform_and_lower() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs( + inputs=auto_encoder_model_inputs, + ) + ) + + def test_AutoencoderKL_tosa_BI(self): + auto_encoder_model, auto_encoder_model_inputs = self.prepare_model_and_inputs() + with torch.no_grad(): + ( + ArmTester( + auto_encoder_model, + example_inputs=auto_encoder_model_inputs, + compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+INT"), + ) + .quantize() + .export() + .to_edge_transform_and_lower() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs( + inputs=auto_encoder_model_inputs, + atol=1.0, # TODO: MLETORCH-990 Reduce tolerance of vae(AutoencoderKL) with BI + ) + ) diff --git a/backends/arm/test/test_arm_baremetal.sh b/backends/arm/test/test_arm_baremetal.sh index 89deba5e65b..678f1a9c4d2 100755 --- a/backends/arm/test/test_arm_baremetal.sh +++ b/backends/arm/test/test_arm_baremetal.sh @@ -97,6 +97,9 @@ test_pytest_models() { # Test ops and other things # Prepare for pytest backends/arm/scripts/build_executorch.sh + # Install model dependencies for pytest + source backends/arm/scripts/install_models_for_test.sh + # Run arm baremetal pytest tests without FVP pytest --verbose --color=yes --durations=0 backends/arm/test/models echo "${TEST_SUITE_NAME}: PASS" @@ -136,6 +139,9 @@ test_pytest_models_ethosu_fvp() { # Same as test_pytest but also sometime verify # arm_test/arm_semihosting_executor_runner_corstone-320 backends/arm/test/setup_testing.sh + # Install model dependencies for pytest + source backends/arm/scripts/install_models_for_test.sh + # Run arm baremetal pytest tests with FVP pytest --verbose --color=yes --durations=0 backends/arm/test/models echo "${TEST_SUITE_NAME}: PASS" diff --git a/examples/arm/setup.sh b/examples/arm/setup.sh index 80a7f5ad721..ee82e43a75d 100755 --- a/examples/arm/setup.sh +++ b/examples/arm/setup.sh @@ -112,6 +112,11 @@ function check_options() { skip_vela_setup=1 shift ;; + --setup-test-dependency) + echo "Installing test dependency..." + source $et_dir/backends/arm/scripts/install_models_for_test.sh + exit 0 + ;; --help) print_usage "$@" exit 0