Skip to content

Arm backend: Support ScalarType::Bool in EthosUBackend #11850

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .annotate_decomposed_matmul import AnnotateDecomposedMatmulPass # noqa
from .arm_pass import ArmPass # noqa
from .broadcast_args_pass import BroadcastArgsPass # noqa
from .cast_bool_to_int8_pass import CastBoolToInt8Pass # noqa
from .cast_int64_pass import CastInt64BuffersToInt32Pass # noqa
from .cast_to_int32_pass import CastToInt32Pass # noqa
from .conv1d_unsqueeze_pass import Conv1dUnsqueezePass # noqa
Expand Down
4 changes: 4 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
AnnotateChannelsLastDimOrder,
AnnotateDecomposedMatmulPass,
BroadcastArgsPass,
CastBoolToInt8Pass,
CastInt64BuffersToInt32Pass,
CastToInt32Pass,
ComputeConstantOpsAOT,
Expand Down Expand Up @@ -107,6 +108,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
if self.tosa_spec.is_U55_subset:
self.add_pass(CastToInt32Pass())

self.add_pass(CastBoolToInt8Pass())
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
self.add_pass(AnnotateDecomposedMatmulPass())
self.add_pass(QuantizeOperatorArguments())
Expand Down Expand Up @@ -146,6 +148,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
self.add_pass(DecomposeRoundPass())
self.add_pass(DecomposeSqrtPass())
self.add_pass(ConvertIntPowToMuls())
self.add_pass(CastBoolToInt8Pass())
self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI())
self.add_pass(DecomposeEmbeddingPass())
self.add_pass(FuseQuantizedActivationPass())
Expand Down Expand Up @@ -227,6 +230,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
self.add_pass(DecomposeEmbeddingPass())
self.add_pass(DecomposeScaledDotProductAttention())
self.add_pass(DecomposeRoundPass())
self.add_pass(CastBoolToInt8Pass())
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
self.add_pass(ScalarsToAttributePass())
self.add_pass(DecomposeGroupNormPass())
Expand Down
58 changes: 58 additions & 0 deletions backends/arm/_passes/cast_bool_to_int8_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# The TOSA BITWISE_AND, BITWISE_OR, and BITWISE_XOR don't handle bool as input
# If input/output is bool lest add a cast/conversion pass before/after to/from int8.

import torch

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass


class CastBoolToInt8Pass(ExportPass):
"""Casts the input to int8 if it is not already and casts back the output to the original input dtype."""

targeted_ops = {
exir_ops.edge.aten.bitwise_and.Tensor,
exir_ops.edge.aten.bitwise_or.Tensor,
exir_ops.edge.aten.bitwise_xor.Tensor,
}

def call_operator(self, op, args, kwargs, meta):
if op not in self.targeted_ops:
return super().call_operator(op, args, kwargs, meta)

new_args: list = []
did_cast = False
for arg in args:
if arg.data.dtype == torch.bool:
new_args.append(
super().call_operator(
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
(arg,),
{"dtype": torch.int8},
meta,
)
)
did_cast = True
else:
new_args.append(arg)

output = super().call_operator(
op,
tuple(new_args),
{},
meta,
)

if did_cast:
output = super().call_operator(
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
(output,),
{"dtype": args[0].data.dtype},
meta,
)
return output
49 changes: 49 additions & 0 deletions backends/arm/operators/ops_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from executorch.backends.arm.operators.operator_validation_utils import (
validate_num_inputs,
validate_same_dtype,
validate_valid_dtype,
)
from executorch.backends.arm.tosa_mapping import TosaArg

Expand All @@ -40,6 +41,30 @@ def define_node(
validate_num_inputs(self.target, inputs, 2)
validate_same_dtype(self.target, [*inputs, output], ts)

if self.target in [
"aten.bitwise_and.Tensor",
"aten.bitwise_xor.Tensor",
"aten.bitwise_or.Tensor",
"aten.bitwise_left_shift.Tensor",
]:
validate_valid_dtype(
self.target,
[*inputs, output],
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32],
output.tosa_spec,
)
if self.target in [
"aten.logical_and.default",
"aten.logical_xor.defaul",
"aten.logical_or.default",
]:
validate_valid_dtype(
self.target,
[*inputs, output],
[ts.DType.BOOL],
output.tosa_spec,
)

tosa_graph.addOperator(
tosa_op, [inputs[0].name, inputs[1].name], [output.name]
)
Expand All @@ -66,6 +91,30 @@ def define_node(
validate_num_inputs(self.target, inputs, 2)
validate_same_dtype(self.target, [*inputs, output], ts)

if self.target in [
"aten.bitwise_and.Tensor",
"aten.bitwise_xor.Tensor",
"aten.bitwise_or.Tensor",
"aten.bitwise_left_shift.Tensor",
]:
validate_valid_dtype(
self.target,
[*inputs, output],
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32],
output.tosa_spec,
)
if self.target in [
"aten.logical_and.default",
"aten.logical_xor.defaul",
"aten.logical_or.default",
]:
validate_valid_dtype(
self.target,
[*inputs, output],
[ts.DType.BOOL],
output.tosa_spec,
)

tosa_graph.addOperator(
tosa_op, [inputs[0].name, inputs[1].name], [output.name]
)
Expand Down
44 changes: 24 additions & 20 deletions backends/arm/runtime/EthosUBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,12 +234,17 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface {
supported |=
(tensor_in.scalar_type() == ScalarType::Short and
handles.inputs->io[i].elem_size == 2);
// bool (IOQDQ pass prepared networks)
supported |=
(tensor_in.scalar_type() == ScalarType::Bool and
handles.inputs->io[i].elem_size == 1);
if (!supported) {
ET_LOG(
Error,
"Input %d expected Integer (4 byte) or Char (1 byte) integer inputs, got ScalarType id %s",
"Input %d expected Integer (4 byte), Char (1 byte) or Bool (1 byte) integer inputs, got ScalarType id %s size %d",
i,
executorch::runtime::toString(tensor_in.scalar_type()));
executorch::runtime::toString(tensor_in.scalar_type()),
handles.inputs->io[i].elem_size);
return Error::InvalidProgram;
}
supported = executorch::runtime::is_contiguous_dim_order(
Expand All @@ -257,15 +262,17 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface {
bool permuted_input_shape;
ET_CHECK_OK_OR_RETURN_ERROR(check_requires_permute(
i, tensor_in, &handles.inputs->io[i], &permuted_input_shape));
bool both_char = tensor_in.scalar_type() == ScalarType::Char and
handles.inputs->io[i].elem_size == 1;
bool both_int = tensor_in.scalar_type() == ScalarType::Int and
bool both_int = tensor_in.scalar_type() == ScalarType::Int &&
handles.inputs->io[i].elem_size == 4;
bool both_short = tensor_in.scalar_type() == ScalarType::Short and
bool both_char = tensor_in.scalar_type() == ScalarType::Char &&
handles.inputs->io[i].elem_size == 1;
bool both_short = tensor_in.scalar_type() == ScalarType::Short &&
handles.inputs->io[i].elem_size == 2;
bool both_bool = tensor_in.scalar_type() == ScalarType::Bool &&
(handles.inputs->io[i].elem_size == 1);

// Select a compatible copy routine
if (both_char && permuted_input_shape) {
if ((both_char || both_bool) && permuted_input_shape) {
EXECUTORCH_PROF_SCOPE(
event_tracer,
"+EthosUBackend::execute()handles.input.permute_CHW_to_HWC()");
Expand All @@ -276,7 +283,7 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface {
tensor_in.size(1),
tensor_in.size(2),
tensor_in.size(3));
} else if (both_char || both_int || both_short) {
} else if (both_char || both_int || both_short || both_bool) {
EXECUTORCH_PROF_SCOPE(
event_tracer, "+EthosUBackend::execute()handles.input.memcpy()");
// Sizes match and elt size matches so memcpy
Expand Down Expand Up @@ -363,7 +370,9 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface {
bool permuted_output_shape;
ET_CHECK_OK_OR_RETURN_ERROR(check_requires_permute(
i, tensor_out, &handles.outputs->io[i], &permuted_output_shape));
if (tensor_out.scalar_type() == ScalarType::Char &&

if ((tensor_out.scalar_type() == ScalarType::Char ||
tensor_out.scalar_type() == ScalarType::Bool) &&
permuted_output_shape) {
EXECUTORCH_PROF_SCOPE(
event_tracer,
Expand All @@ -379,17 +388,12 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface {
tensor_out.size(3));
} else {
EXECUTORCH_PROF_SCOPE(
event_tracer, "+EthosUBackend::execute()handles.output.move()");
for (int j = 0; j < tensor_out.numel(); j++) {
if (tensor_out.scalar_type() == ScalarType::Char) {
const char* output_address = static_cast<const char*>(output_addr);
tensor_out.mutable_data_ptr<char>()[j] = output_address[j];
} else {
const int* output_address =
reinterpret_cast<const int*>(output_addr);
tensor_out.mutable_data_ptr<int>()[j] = output_address[j];
}
}
event_tracer, "+EthosUBackend::execute()handles.output.memcpy()");

memcpy(
tensor_out.mutable_data_ptr<char>(),
static_cast<const char*>(output_addr),
tensor_out.nbytes());
}
}
if (tensor_dim != io_dim) {
Expand Down
25 changes: 21 additions & 4 deletions backends/arm/test/ops/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from typing import List, Tuple

import pytest
import torch
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.test_pipeline import (
Expand Down Expand Up @@ -125,14 +124,30 @@ def forward(self, x: torch.Tensor):
@common.parametrize("test_data", test_data)
def test_any_tosa_MI(test_data: input_t1):
op, test_input = test_data()
pipeline = TosaPipelineMI[input_t1](op, test_input(), op.aten_op, op.exir_op)
pipeline = TosaPipelineMI[input_t1](
op,
test_input(),
op.aten_op,
op.exir_op,
atol=0,
rtol=0,
qtol=0,
)
pipeline.run()


@common.parametrize("test_data", test_data)
def test_any_tosa_BI(test_data: input_t1):
op, test_input = test_data()
pipeline = TosaPipelineBI[input_t1](op, test_input(), op.aten_op, op.exir_op)
pipeline = TosaPipelineBI[input_t1](
op,
test_input(),
op.aten_op,
op.exir_op,
atol=0,
rtol=0,
qtol=0,
)
pipeline.pop_stage("quantize")
pipeline.pop_stage("check.quant_nodes")
pipeline.run()
Expand All @@ -153,7 +168,6 @@ def test_any_u55_BI(test_data: input_t1):


@common.parametrize("test_data", test_data)
@pytest.mark.xfail(reason="MLETORCH-706: Support ScalarType::Bool in EthosUBackend.")
@common.XfailIfNoCorstone320
def test_any_u85_BI(test_data: input_t1):
op, test_input = test_data()
Expand All @@ -163,6 +177,9 @@ def test_any_u85_BI(test_data: input_t1):
op.aten_op,
op.exir_op,
run_on_fvp=True,
atol=0,
rtol=0,
qtol=0,
)
pipeline.pop_stage("quantize")
pipeline.pop_stage("check.quant_nodes")
Expand Down
Loading
Loading