From 3f9e82a411cb709414fdd5ee89befedfa256f5ee Mon Sep 17 00:00:00 2001 From: Zingo Andersen Date: Tue, 13 May 2025 09:58:46 +0200 Subject: [PATCH] Arm backend: Support ScalarType::Bool in EthosUBackend Signed-off-by: Zingo Andersen Change-Id: I6b90132728c83dce84f7c98941ab0d9540b97794 --- backends/arm/_passes/__init__.py | 1 + backends/arm/_passes/arm_pass_manager.py | 4 + .../arm/_passes/cast_bool_to_int8_pass.py | 58 ++++++++++ backends/arm/operators/ops_binary.py | 49 ++++++++ backends/arm/runtime/EthosUBackend.cpp | 44 +++---- backends/arm/test/ops/test_any.py | 25 +++- backends/arm/test/ops/test_bitwise.py | 93 +++++++++++++-- backends/arm/test/ops/test_logical.py | 109 +++++++++++++++--- .../arm/test/tester/analyze_output_utils.py | 39 ++++--- backends/arm/test/tester/arm_tester.py | 5 +- .../executor_runner/arm_executor_runner.cpp | 64 ++++++---- 11 files changed, 402 insertions(+), 89 deletions(-) create mode 100644 backends/arm/_passes/cast_bool_to_int8_pass.py diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 9d1e7f2e01f..237b3a06dc3 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -9,6 +9,7 @@ from .annotate_decomposed_matmul import AnnotateDecomposedMatmulPass # noqa from .arm_pass import ArmPass # noqa from .broadcast_args_pass import BroadcastArgsPass # noqa +from .cast_bool_to_int8_pass import CastBoolToInt8Pass # noqa from .cast_int64_pass import CastInt64BuffersToInt32Pass # noqa from .cast_to_int32_pass import CastToInt32Pass # noqa from .conv1d_unsqueeze_pass import Conv1dUnsqueezePass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index fee4fda9789..e92e8da4fc7 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -10,6 +10,7 @@ AnnotateChannelsLastDimOrder, AnnotateDecomposedMatmulPass, BroadcastArgsPass, + CastBoolToInt8Pass, CastInt64BuffersToInt32Pass, CastToInt32Pass, ComputeConstantOpsAOT, @@ -107,6 +108,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul if self.tosa_spec.is_U55_subset: self.add_pass(CastToInt32Pass()) + self.add_pass(CastBoolToInt8Pass()) self.add_pass(ReplaceScalarWithTensorArgPassTOSABI()) self.add_pass(AnnotateDecomposedMatmulPass()) self.add_pass(QuantizeOperatorArguments()) @@ -146,6 +148,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(DecomposeRoundPass()) self.add_pass(DecomposeSqrtPass()) self.add_pass(ConvertIntPowToMuls()) + self.add_pass(CastBoolToInt8Pass()) self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI()) self.add_pass(DecomposeEmbeddingPass()) self.add_pass(FuseQuantizedActivationPass()) @@ -227,6 +230,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(DecomposeEmbeddingPass()) self.add_pass(DecomposeScaledDotProductAttention()) self.add_pass(DecomposeRoundPass()) + self.add_pass(CastBoolToInt8Pass()) self.add_pass(ReplaceScalarWithTensorArgPassTOSABI()) self.add_pass(ScalarsToAttributePass()) self.add_pass(DecomposeGroupNormPass()) diff --git a/backends/arm/_passes/cast_bool_to_int8_pass.py b/backends/arm/_passes/cast_bool_to_int8_pass.py new file mode 100644 index 00000000000..1352671b01e --- /dev/null +++ b/backends/arm/_passes/cast_bool_to_int8_pass.py @@ -0,0 +1,58 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# The TOSA BITWISE_AND, BITWISE_OR, and BITWISE_XOR don't handle bool as input +# If input/output is bool lest add a cast/conversion pass before/after to/from int8. + +import torch + +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + + +class CastBoolToInt8Pass(ExportPass): + """Casts the input to int8 if it is not already and casts back the output to the original input dtype.""" + + targeted_ops = { + exir_ops.edge.aten.bitwise_and.Tensor, + exir_ops.edge.aten.bitwise_or.Tensor, + exir_ops.edge.aten.bitwise_xor.Tensor, + } + + def call_operator(self, op, args, kwargs, meta): + if op not in self.targeted_ops: + return super().call_operator(op, args, kwargs, meta) + + new_args: list = [] + did_cast = False + for arg in args: + if arg.data.dtype == torch.bool: + new_args.append( + super().call_operator( + exir_ops.edge.dim_order_ops._to_dim_order_copy.default, + (arg,), + {"dtype": torch.int8}, + meta, + ) + ) + did_cast = True + else: + new_args.append(arg) + + output = super().call_operator( + op, + tuple(new_args), + {}, + meta, + ) + + if did_cast: + output = super().call_operator( + exir_ops.edge.dim_order_ops._to_dim_order_copy.default, + (output,), + {"dtype": args[0].data.dtype}, + meta, + ) + return output diff --git a/backends/arm/operators/ops_binary.py b/backends/arm/operators/ops_binary.py index 81a2946c8fb..9c0c15364fc 100644 --- a/backends/arm/operators/ops_binary.py +++ b/backends/arm/operators/ops_binary.py @@ -17,6 +17,7 @@ from executorch.backends.arm.operators.operator_validation_utils import ( validate_num_inputs, validate_same_dtype, + validate_valid_dtype, ) from executorch.backends.arm.tosa_mapping import TosaArg @@ -40,6 +41,30 @@ def define_node( validate_num_inputs(self.target, inputs, 2) validate_same_dtype(self.target, [*inputs, output], ts) + if self.target in [ + "aten.bitwise_and.Tensor", + "aten.bitwise_xor.Tensor", + "aten.bitwise_or.Tensor", + "aten.bitwise_left_shift.Tensor", + ]: + validate_valid_dtype( + self.target, + [*inputs, output], + [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32], + output.tosa_spec, + ) + if self.target in [ + "aten.logical_and.default", + "aten.logical_xor.defaul", + "aten.logical_or.default", + ]: + validate_valid_dtype( + self.target, + [*inputs, output], + [ts.DType.BOOL], + output.tosa_spec, + ) + tosa_graph.addOperator( tosa_op, [inputs[0].name, inputs[1].name], [output.name] ) @@ -66,6 +91,30 @@ def define_node( validate_num_inputs(self.target, inputs, 2) validate_same_dtype(self.target, [*inputs, output], ts) + if self.target in [ + "aten.bitwise_and.Tensor", + "aten.bitwise_xor.Tensor", + "aten.bitwise_or.Tensor", + "aten.bitwise_left_shift.Tensor", + ]: + validate_valid_dtype( + self.target, + [*inputs, output], + [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32], + output.tosa_spec, + ) + if self.target in [ + "aten.logical_and.default", + "aten.logical_xor.defaul", + "aten.logical_or.default", + ]: + validate_valid_dtype( + self.target, + [*inputs, output], + [ts.DType.BOOL], + output.tosa_spec, + ) + tosa_graph.addOperator( tosa_op, [inputs[0].name, inputs[1].name], [output.name] ) diff --git a/backends/arm/runtime/EthosUBackend.cpp b/backends/arm/runtime/EthosUBackend.cpp index b5575e21b61..d29c32b02f3 100644 --- a/backends/arm/runtime/EthosUBackend.cpp +++ b/backends/arm/runtime/EthosUBackend.cpp @@ -234,12 +234,17 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface { supported |= (tensor_in.scalar_type() == ScalarType::Short and handles.inputs->io[i].elem_size == 2); + // bool (IOQDQ pass prepared networks) + supported |= + (tensor_in.scalar_type() == ScalarType::Bool and + handles.inputs->io[i].elem_size == 1); if (!supported) { ET_LOG( Error, - "Input %d expected Integer (4 byte) or Char (1 byte) integer inputs, got ScalarType id %s", + "Input %d expected Integer (4 byte), Char (1 byte) or Bool (1 byte) integer inputs, got ScalarType id %s size %d", i, - executorch::runtime::toString(tensor_in.scalar_type())); + executorch::runtime::toString(tensor_in.scalar_type()), + handles.inputs->io[i].elem_size); return Error::InvalidProgram; } supported = executorch::runtime::is_contiguous_dim_order( @@ -257,15 +262,17 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface { bool permuted_input_shape; ET_CHECK_OK_OR_RETURN_ERROR(check_requires_permute( i, tensor_in, &handles.inputs->io[i], &permuted_input_shape)); - bool both_char = tensor_in.scalar_type() == ScalarType::Char and - handles.inputs->io[i].elem_size == 1; - bool both_int = tensor_in.scalar_type() == ScalarType::Int and + bool both_int = tensor_in.scalar_type() == ScalarType::Int && handles.inputs->io[i].elem_size == 4; - bool both_short = tensor_in.scalar_type() == ScalarType::Short and + bool both_char = tensor_in.scalar_type() == ScalarType::Char && + handles.inputs->io[i].elem_size == 1; + bool both_short = tensor_in.scalar_type() == ScalarType::Short && handles.inputs->io[i].elem_size == 2; + bool both_bool = tensor_in.scalar_type() == ScalarType::Bool && + (handles.inputs->io[i].elem_size == 1); // Select a compatible copy routine - if (both_char && permuted_input_shape) { + if ((both_char || both_bool) && permuted_input_shape) { EXECUTORCH_PROF_SCOPE( event_tracer, "+EthosUBackend::execute()handles.input.permute_CHW_to_HWC()"); @@ -276,7 +283,7 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface { tensor_in.size(1), tensor_in.size(2), tensor_in.size(3)); - } else if (both_char || both_int || both_short) { + } else if (both_char || both_int || both_short || both_bool) { EXECUTORCH_PROF_SCOPE( event_tracer, "+EthosUBackend::execute()handles.input.memcpy()"); // Sizes match and elt size matches so memcpy @@ -363,7 +370,9 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface { bool permuted_output_shape; ET_CHECK_OK_OR_RETURN_ERROR(check_requires_permute( i, tensor_out, &handles.outputs->io[i], &permuted_output_shape)); - if (tensor_out.scalar_type() == ScalarType::Char && + + if ((tensor_out.scalar_type() == ScalarType::Char || + tensor_out.scalar_type() == ScalarType::Bool) && permuted_output_shape) { EXECUTORCH_PROF_SCOPE( event_tracer, @@ -379,17 +388,12 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface { tensor_out.size(3)); } else { EXECUTORCH_PROF_SCOPE( - event_tracer, "+EthosUBackend::execute()handles.output.move()"); - for (int j = 0; j < tensor_out.numel(); j++) { - if (tensor_out.scalar_type() == ScalarType::Char) { - const char* output_address = static_cast(output_addr); - tensor_out.mutable_data_ptr()[j] = output_address[j]; - } else { - const int* output_address = - reinterpret_cast(output_addr); - tensor_out.mutable_data_ptr()[j] = output_address[j]; - } - } + event_tracer, "+EthosUBackend::execute()handles.output.memcpy()"); + + memcpy( + tensor_out.mutable_data_ptr(), + static_cast(output_addr), + tensor_out.nbytes()); } } if (tensor_dim != io_dim) { diff --git a/backends/arm/test/ops/test_any.py b/backends/arm/test/ops/test_any.py index 6ddef1ad0b5..338c5f05cc6 100644 --- a/backends/arm/test/ops/test_any.py +++ b/backends/arm/test/ops/test_any.py @@ -6,7 +6,6 @@ from typing import List, Tuple -import pytest import torch from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import ( @@ -125,14 +124,30 @@ def forward(self, x: torch.Tensor): @common.parametrize("test_data", test_data) def test_any_tosa_MI(test_data: input_t1): op, test_input = test_data() - pipeline = TosaPipelineMI[input_t1](op, test_input(), op.aten_op, op.exir_op) + pipeline = TosaPipelineMI[input_t1]( + op, + test_input(), + op.aten_op, + op.exir_op, + atol=0, + rtol=0, + qtol=0, + ) pipeline.run() @common.parametrize("test_data", test_data) def test_any_tosa_BI(test_data: input_t1): op, test_input = test_data() - pipeline = TosaPipelineBI[input_t1](op, test_input(), op.aten_op, op.exir_op) + pipeline = TosaPipelineBI[input_t1]( + op, + test_input(), + op.aten_op, + op.exir_op, + atol=0, + rtol=0, + qtol=0, + ) pipeline.pop_stage("quantize") pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -153,7 +168,6 @@ def test_any_u55_BI(test_data: input_t1): @common.parametrize("test_data", test_data) -@pytest.mark.xfail(reason="MLETORCH-706: Support ScalarType::Bool in EthosUBackend.") @common.XfailIfNoCorstone320 def test_any_u85_BI(test_data: input_t1): op, test_input = test_data() @@ -163,6 +177,9 @@ def test_any_u85_BI(test_data: input_t1): op.aten_op, op.exir_op, run_on_fvp=True, + atol=0, + rtol=0, + qtol=0, ) pipeline.pop_stage("quantize") pipeline.pop_stage("check.quant_nodes") diff --git a/backends/arm/test/ops/test_bitwise.py b/backends/arm/test/ops/test_bitwise.py index 8be8ba35b4e..032639b8607 100644 --- a/backends/arm/test/ops/test_bitwise.py +++ b/backends/arm/test/ops/test_bitwise.py @@ -6,7 +6,6 @@ from typing import Tuple -import pytest import torch from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import ( @@ -30,6 +29,22 @@ class BitwiseBinary(torch.nn.Module): torch.ones(10, 10, 10, dtype=torch.int8), torch.ones(10, 10, 10, dtype=torch.int8), ), + "pattern_int8": lambda: ( + 0xAA * torch.ones(1, 2, 2, 2, dtype=torch.int8), + 0xCC * torch.ones(1, 2, 2, 2, dtype=torch.int8), + ), + "pattern_int16": lambda: ( + 0xAAAA * torch.ones(1, 2, 2, 2, dtype=torch.int16), + 0xCCCC * torch.ones(1, 2, 2, 2, dtype=torch.int16), + ), + "pattern_int32": lambda: ( + 0xAAAAAAAA * torch.ones(1, 2, 2, 2, dtype=torch.int32), + 0xCCCCCCCC * torch.ones(1, 2, 2, 2, dtype=torch.int32), + ), + "pattern_bool": lambda: ( + torch.tensor([True, False, True], dtype=torch.bool), + torch.tensor([True, True, False], dtype=torch.bool), + ), "rand_rank2": lambda: ( torch.randint(-128, 127, (10, 10), dtype=torch.int8), torch.randint(-128, 127, (10, 10), dtype=torch.int8), @@ -68,7 +83,13 @@ def forward(self, tensor1: torch.Tensor, tensor2: torch.Tensor): @common.parametrize("test_data", And().test_data) def test_bitwise_and_tensor_tosa_MI(test_data: input_t2): pipeline = TosaPipelineMI[input_t2]( - And(), test_data(), And().aten_op, And().exir_op + And(), + test_data(), + And().aten_op, + And().exir_op, + atol=0, + rtol=0, + qtol=0, ) pipeline.run() @@ -76,7 +97,13 @@ def test_bitwise_and_tensor_tosa_MI(test_data: input_t2): @common.parametrize("test_data", And().test_data) def test_bitwise_and_tensor_tosa_BI(test_data: input_t2): pipeline = TosaPipelineBI[input_t2]( - And(), test_data(), And().aten_op, And().exir_op + And(), + test_data(), + And().aten_op, + And().exir_op, + atol=0, + rtol=0, + qtol=0, ) pipeline.pop_stage("quantize") pipeline.pop_stage("check.quant_nodes") @@ -97,11 +124,17 @@ def test_bitwise_and_tensor_u55_BI(test_data: input_t2): @common.parametrize("test_data", And().test_data) -@pytest.mark.xfail(reason="MLETORCH-706: Support ScalarType::Bool in EthosUBackend.") @common.XfailIfNoCorstone320 def test_bitwise_and_tensor_u85_BI(test_data: input_t2): pipeline = EthosU85PipelineBI[input_t2]( - And(), test_data(), And().aten_op, And().exir_op, run_on_fvp=True + And(), + test_data(), + And().aten_op, + And().exir_op, + run_on_fvp=True, + atol=0, + rtol=0, + qtol=0, ) pipeline.pop_stage("quantize") pipeline.pop_stage("check.quant_nodes") @@ -111,7 +144,13 @@ def test_bitwise_and_tensor_u85_BI(test_data: input_t2): @common.parametrize("test_data", Xor().test_data) def test_bitwise_xor_tensor_tosa_MI(test_data: input_t2): pipeline = TosaPipelineMI[input_t2]( - Xor(), test_data(), Xor().aten_op, Xor().exir_op + Xor(), + test_data(), + Xor().aten_op, + Xor().exir_op, + atol=0, + rtol=0, + qtol=0, ) pipeline.run() @@ -119,7 +158,13 @@ def test_bitwise_xor_tensor_tosa_MI(test_data: input_t2): @common.parametrize("test_data", Xor().test_data) def test_bitwise_xor_tensor_tosa_BI(test_data: input_t2): pipeline = TosaPipelineBI[input_t2]( - Xor(), test_data(), Xor().aten_op, Xor().exir_op + Xor(), + test_data(), + Xor().aten_op, + Xor().exir_op, + atol=0, + rtol=0, + qtol=0, ) pipeline.pop_stage("quantize") pipeline.pop_stage("check.quant_nodes") @@ -140,11 +185,17 @@ def test_bitwise_xor_tensor_u55_BI(test_data: input_t2): @common.parametrize("test_data", Xor().test_data) -@pytest.mark.xfail(reason="MLETORCH-706: Support ScalarType::Bool in EthosUBackend.") @common.XfailIfNoCorstone320 def test_bitwise_xor_tensor_u85_BI(test_data: input_t2): pipeline = EthosU85PipelineBI[input_t2]( - Xor(), test_data(), Xor().aten_op, Xor().exir_op, run_on_fvp=True + Xor(), + test_data(), + Xor().aten_op, + Xor().exir_op, + run_on_fvp=True, + atol=0, + rtol=0, + qtol=0, ) pipeline.pop_stage("quantize") pipeline.pop_stage("check.quant_nodes") @@ -153,13 +204,29 @@ def test_bitwise_xor_tensor_u85_BI(test_data: input_t2): @common.parametrize("test_data", Or().test_data) def test_bitwise_or_tensor_tosa_MI(test_data: input_t2): - pipeline = TosaPipelineMI[input_t2](Or(), test_data(), Or().aten_op, Or().exir_op) + pipeline = TosaPipelineMI[input_t2]( + Or(), + test_data(), + Or().aten_op, + Or().exir_op, + atol=0, + rtol=0, + qtol=0, + ) pipeline.run() @common.parametrize("test_data", Or().test_data) def test_bitwise_or_tensor_tosa_BI(test_data: input_t2): - pipeline = TosaPipelineBI[input_t2](Or(), test_data(), Or().aten_op, Or().exir_op) + pipeline = TosaPipelineBI[input_t2]( + Or(), + test_data(), + Or().aten_op, + Or().exir_op, + atol=0, + rtol=0, + qtol=0, + ) pipeline.pop_stage("quantize") pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -179,7 +246,6 @@ def test_bitwise_or_tensor_u55_BI(test_data: input_t2): @common.parametrize("test_data", Or().test_data) -@pytest.mark.xfail(reason="MLETORCH-706: Support ScalarType::Bool in EthosUBackend.") @common.XfailIfNoCorstone320 def test_bitwise_or_tensor_u85_BI(test_data: input_t2): pipeline = EthosU85PipelineBI[input_t2]( @@ -188,6 +254,9 @@ def test_bitwise_or_tensor_u85_BI(test_data: input_t2): Or().aten_op, Or().exir_op, run_on_fvp=True, + atol=0, + rtol=0, + qtol=0, ) pipeline.pop_stage("quantize") pipeline.pop_stage("check.quant_nodes") diff --git a/backends/arm/test/ops/test_logical.py b/backends/arm/test/ops/test_logical.py index 139653eea97..1a056e31b3c 100644 --- a/backends/arm/test/ops/test_logical.py +++ b/backends/arm/test/ops/test_logical.py @@ -6,7 +6,6 @@ from typing import Tuple -import pytest import torch from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import ( @@ -84,7 +83,13 @@ def forward(self, tensor: torch.Tensor): @common.parametrize("test_data", And().test_data) def test_logical_and_tosa_MI(test_data: input_t2): pipeline = TosaPipelineMI[input_t2]( - And(), test_data(), And().aten_op, And().exir_op + And(), + test_data(), + And().aten_op, + And().exir_op, + atol=0, + rtol=0, + qtol=0, ) pipeline.run() @@ -92,7 +97,13 @@ def test_logical_and_tosa_MI(test_data: input_t2): @common.parametrize("test_data", And().test_data) def test_logical_and_tosa_BI(test_data: input_t2): pipeline = TosaPipelineBI[input_t2]( - And(), test_data(), And().aten_op, And().exir_op + And(), + test_data(), + And().aten_op, + And().exir_op, + atol=0, + rtol=0, + qtol=0, ) pipeline.pop_stage("quantize") pipeline.pop_stage("check.quant_nodes") @@ -113,11 +124,17 @@ def test_logical_and_u55_BI_not_delegated(test_data: input_t2): @common.parametrize("test_data", And().test_data) -@pytest.mark.xfail(reason="MLETORCH-706: Support ScalarType::Bool in EthosUBackend.") @common.XfailIfNoCorstone320 def test_logical_and_u85_BI(test_data: input_t2): pipeline = EthosU85PipelineBI[input_t2]( - And(), test_data(), And().aten_op, And().exir_op, run_on_fvp=True + And(), + test_data(), + And().aten_op, + And().exir_op, + run_on_fvp=True, + atol=0, + rtol=0, + qtol=0, ) pipeline.pop_stage("quantize") pipeline.pop_stage("check.quant_nodes") @@ -127,7 +144,13 @@ def test_logical_and_u85_BI(test_data: input_t2): @common.parametrize("test_data", Xor().test_data) def test_logical_xor_tosa_MI(test_data: input_t2): pipeline = TosaPipelineMI[input_t2]( - Xor(), test_data(), Xor().aten_op, Xor().exir_op + Xor(), + test_data(), + Xor().aten_op, + Xor().exir_op, + atol=0, + rtol=0, + qtol=0, ) pipeline.run() @@ -135,7 +158,13 @@ def test_logical_xor_tosa_MI(test_data: input_t2): @common.parametrize("test_data", Xor().test_data) def test_logical_xor_tosa_BI(test_data: input_t2): pipeline = TosaPipelineBI[input_t2]( - Xor(), test_data(), Xor().aten_op, Xor().exir_op + Xor(), + test_data(), + Xor().aten_op, + Xor().exir_op, + atol=0, + rtol=0, + qtol=0, ) pipeline.pop_stage("quantize") pipeline.pop_stage("check.quant_nodes") @@ -156,11 +185,17 @@ def test_logical_xor_u55_BI_not_delegated(test_data: input_t2): @common.parametrize("test_data", Xor().test_data) -@pytest.mark.xfail(reason="MLETORCH-706: Support ScalarType::Bool in EthosUBackend.") @common.XfailIfNoCorstone320 def test_logical_xor_u85_BI(test_data: input_t2): pipeline = EthosU85PipelineBI[input_t2]( - Xor(), test_data(), Xor().aten_op, Xor().exir_op, run_on_fvp=True + Xor(), + test_data(), + Xor().aten_op, + Xor().exir_op, + run_on_fvp=True, + atol=0, + rtol=0, + qtol=0, ) pipeline.pop_stage("quantize") pipeline.pop_stage("check.quant_nodes") @@ -169,13 +204,29 @@ def test_logical_xor_u85_BI(test_data: input_t2): @common.parametrize("test_data", Or().test_data) def test_logical_or_tosa_MI(test_data: input_t2): - pipeline = TosaPipelineMI[input_t2](Or(), test_data(), Or().aten_op, Or().exir_op) + pipeline = TosaPipelineMI[input_t2]( + Or(), + test_data(), + Or().aten_op, + Or().exir_op, + atol=0, + rtol=0, + qtol=0, + ) pipeline.run() @common.parametrize("test_data", Or().test_data) def test_logical_or_tosa_BI(test_data: input_t2): - pipeline = TosaPipelineBI[input_t2](Or(), test_data(), Or().aten_op, Or().exir_op) + pipeline = TosaPipelineBI[input_t2]( + Or(), + test_data(), + Or().aten_op, + Or().exir_op, + atol=0, + rtol=0, + qtol=0, + ) pipeline.pop_stage("quantize") pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -195,11 +246,17 @@ def test_logical_or_u55_BI_not_delegated(test_data: input_t2): @common.parametrize("test_data", Or().test_data) -@pytest.mark.xfail(reason="MLETORCH-706: Support ScalarType::Bool in EthosUBackend.") @common.XfailIfNoCorstone320 def test_logical_or_u85_BI(test_data: input_t2): pipeline = EthosU85PipelineBI[input_t2]( - Or(), test_data(), Or().aten_op, Or().exir_op, run_on_fvp=True + Or(), + test_data(), + Or().aten_op, + Or().exir_op, + run_on_fvp=True, + atol=0, + rtol=0, + qtol=0, ) pipeline.pop_stage("quantize") pipeline.pop_stage("check.quant_nodes") @@ -209,7 +266,13 @@ def test_logical_or_u85_BI(test_data: input_t2): @common.parametrize("test_data", Not().test_data) def test_logical_not_tosa_MI(test_data: input_t2): pipeline = TosaPipelineMI[input_t2]( - Not(), test_data(), Not().aten_op, Not().exir_op + Not(), + test_data(), + Not().aten_op, + Not().exir_op, + atol=0, + rtol=0, + qtol=0, ) pipeline.run() @@ -217,7 +280,13 @@ def test_logical_not_tosa_MI(test_data: input_t2): @common.parametrize("test_data", Not().test_data) def test_logical_not_tosa_BI(test_data: input_t2): pipeline = TosaPipelineBI[input_t2]( - Not(), test_data(), Not().aten_op, Not().exir_op + Not(), + test_data(), + Not().aten_op, + Not().exir_op, + atol=0, + rtol=0, + qtol=0, ) pipeline.pop_stage("quantize") pipeline.pop_stage("check.quant_nodes") @@ -238,11 +307,17 @@ def test_logical_not_u55_BI_not_delegated(test_data: input_t2): @common.parametrize("test_data", Not().test_data) -@pytest.mark.xfail(reason="MLETORCH-706: Support ScalarType::Bool in EthosUBackend.") @common.XfailIfNoCorstone320 def test_logical_not_u85_BI(test_data: input_t2): pipeline = EthosU85PipelineBI[input_t2]( - Not(), test_data(), Not().aten_op, Not().exir_op, run_on_fvp=True + Not(), + test_data(), + Not().aten_op, + Not().exir_op, + run_on_fvp=True, + atol=0, + rtol=0, + qtol=0, ) pipeline.pop_stage("quantize") pipeline.pop_stage("check.quant_nodes") diff --git a/backends/arm/test/tester/analyze_output_utils.py b/backends/arm/test/tester/analyze_output_utils.py index 140a9bcc418..96060b7b563 100644 --- a/backends/arm/test/tester/analyze_output_utils.py +++ b/backends/arm/test/tester/analyze_output_utils.py @@ -22,22 +22,36 @@ def _print_channels(result, reference, channels_close, C, H, W, rtol, atol): output_str = "" + booldata = False + if reference.dtype == torch.bool or result.dtype == torch.bool: + booldata = True + for c in range(C): if channels_close[c]: continue - - max_diff = torch.max(torch.abs(reference - result)) - exp = f"{max_diff:2e}"[-3:] - output_str += f"channel {c} (e{exp})\n" + if not booldata: + max_diff = torch.max(torch.abs(reference - result)) + exp = f"{max_diff:2e}"[-3:] + output_str += f"channel {c} (e{exp})\n" + else: + max_diff = torch.max(reference ^ result) + output_str += f"channel {c} (bool)\n" for y in range(H): res = "[" for x in range(W): if torch.allclose(reference[c, y, x], result[c, y, x], rtol, atol): - res += " . " + if not booldata: + res += " . " + else: + res += " . " else: - diff = (reference[c, y, x] - result[c, y, x]) / 10 ** (int(exp)) - res += f"{diff: .2f} " + if not booldata: + diff = (reference[c, y, x] - result[c, y, x]) / 10 ** (int(exp)) + res += f"{diff: .2f} " + else: + diff = reference[c, y, x] ^ result[c, y, x] + res += " X " # Break early for large widths if x == 16: @@ -157,12 +171,6 @@ def print_error_diffs( result_batch = result[n, :, :, :] reference_batch = reference[n, :, :, :] - if reference_batch.dtype == torch.bool or result_batch.dtype == torch.bool: - mismatches = (reference_batch != result_batch).sum().item() - total = reference_batch.numel() - output_str += f"(BOOLEAN tensor) {mismatches} / {total} elements differ ({mismatches / total:.2%})\n" - continue - is_close = torch.allclose(result_batch, reference_batch, rtol, atol) if is_close: output_str += ".\n" @@ -189,6 +197,11 @@ def print_error_diffs( output_str += _print_elements( result[n, :, :, :], reference[n, :, :, :], C, H, W, rtol, atol ) + if reference_batch.dtype == torch.bool or result_batch.dtype == torch.bool: + mismatches = (reference_batch != result_batch).sum().item() + total = reference_batch.numel() + output_str += f"(BOOLEAN tensor) {mismatches} / {total} elements differ ({mismatches / total:.2%})\n" + # Only compute numeric error metrics if tensor is not boolean if reference.dtype != torch.bool and result.dtype != torch.bool: reference_range = torch.max(reference) - torch.min(reference) diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index 56a89d9b589..04034521f9b 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -496,7 +496,6 @@ def run_method_and_compare_outputs( reference_outputs, _ = pytree.tree_flatten( reference_stage.run_artifact(reference_input) ) - if run_eager_mode: # Run exported module directly test_outputs, _ = pytree.tree_flatten( @@ -510,6 +509,10 @@ def run_method_and_compare_outputs( test_stage.run_artifact(reference_input) ) + logger.info(f"\n Input: {reference_input}") + logger.info(f"\n Ref output: {reference_outputs}") + logger.info(f"\nTest output: {test_outputs}") + for reference_output, test_output, quantization_scale in zip( reference_outputs, test_outputs, quantization_scales ): diff --git a/examples/arm/executor_runner/arm_executor_runner.cpp b/examples/arm/executor_runner/arm_executor_runner.cpp index 9046df25a47..5944a1f081c 100644 --- a/examples/arm/executor_runner/arm_executor_runner.cpp +++ b/examples/arm/executor_runner/arm_executor_runner.cpp @@ -649,29 +649,41 @@ int main(int argc, const char* argv[]) { ET_CHECK(status == Error::Ok); for (int i = 0; i < inputs.size(); ++i) { - Tensor t = inputs[i].toTensor(); - // The output might be collected and parsed so printf() is used instead - // of ET_LOG() here - for (int j = 0; j < inputs[i].toTensor().numel(); ++j) { - if (t.scalar_type() == ScalarType::Int) { - printf( - "Input[%d][%d]: (int) %d\n", - i, - j, - inputs[i].toTensor().const_data_ptr()[j]); - } else if (t.scalar_type() == ScalarType::Float) { - printf( - "Input[%d][%d]: (float) %f\n", - i, - j, - inputs[i].toTensor().const_data_ptr()[j]); - } else if (t.scalar_type() == ScalarType::Char) { - printf( - "Input[%d][%d]: (char) %d\n", - i, - j, - inputs[i].toTensor().const_data_ptr()[j]); + if (inputs[i].isTensor()) { + Tensor t = inputs[i].toTensor(); + // The output might be collected and parsed so printf() is used instead + // of ET_LOG() here + for (int j = 0; j < inputs[i].toTensor().numel(); ++j) { + if (t.scalar_type() == ScalarType::Int) { + printf( + "Input[%d][%d]: (int) %d\n", + i, + j, + inputs[i].toTensor().const_data_ptr()[j]); + } else if (t.scalar_type() == ScalarType::Float) { + printf( + "Input[%d][%d]: (float) %f\n", + i, + j, + inputs[i].toTensor().const_data_ptr()[j]); + } else if (t.scalar_type() == ScalarType::Char) { + printf( + "Input[%d][%d]: (char) %d\n", + i, + j, + inputs[i].toTensor().const_data_ptr()[j]); + } else if (t.scalar_type() == ScalarType::Bool) { + printf( + "Input[%d][%d]: (bool) %s (0x%x)\n", + i, + j, + inputs[i].toTensor().const_data_ptr()[j] ? "true" + : "false", + inputs[i].toTensor().const_data_ptr()[j]); + } } + } else { + printf("Input[%d]: Not Tensor\n", i); } } } @@ -766,6 +778,14 @@ int main(int argc, const char* argv[]) { i, j, outputs[i].toTensor().const_data_ptr()[j]); + } else if (t.scalar_type() == ScalarType::Bool) { + printf( + "Output[%d][%d]: (bool) %s (0x%x)\n", + i, + j, + outputs[i].toTensor().const_data_ptr()[j] ? "true " + : "false", + outputs[i].toTensor().const_data_ptr()[j]); } } #endif