Skip to content

Commit c63dd02

Browse files
authored
[WebNN EP] Use opSupportLimits to dynamically check data type support (microsoft#22025)
- Remove hard code data type checks and use WebNN's opSupportLimits instead - Add HasSupportedOutputsImpl for output data type validation - Get preferred layout info from opSupportLimits - Move Not op to logical_op_builder.cc because it should be there. This avoid the inconsistent input names in `unary_op_builder.cc`.
1 parent a89bddd commit c63dd02

32 files changed

+281
-635
lines changed

onnxruntime/core/providers/webnn/builders/helper.cc

Lines changed: 53 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,12 @@ bool GetShape(const NodeArg& node_arg, std::vector<int64_t>& shape, const loggin
4545
return true;
4646
}
4747

48-
bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer,
49-
const WebnnDeviceType device_type, const logging::Logger& logger) {
48+
bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer, const WebnnDeviceType device_type,
49+
const emscripten::val& wnn_limits, const logging::Logger& logger) {
5050
const auto& op_builders = GetOpBuilders();
5151
if (Contains(op_builders, node.OpType())) {
5252
const auto* op_builder = op_builders.at(node.OpType());
53-
return op_builder->IsOpSupported(graph_viewer.GetAllInitializedTensors(), node, device_type, logger);
53+
return op_builder->IsOpSupported(graph_viewer.GetAllInitializedTensors(), node, device_type, wnn_limits, logger);
5454
} else {
5555
return false;
5656
}
@@ -86,6 +86,7 @@ bool IsInputSupported(const NodeArg& input, const std::string& parent_name, cons
8686
std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_viewer,
8787
const emscripten::val& wnn_builder,
8888
const WebnnDeviceType device_type,
89+
const emscripten::val& wnn_limits,
8990
const logging::Logger& logger) {
9091
std::vector<std::vector<size_t>> supported_node_groups;
9192

@@ -105,7 +106,7 @@ std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_v
105106
// Firstly check if platform supports the WebNN op.
106107
if (CheckSingleOp(node->OpType(), wnn_builder, device_type)) {
107108
LOGS(logger, VERBOSE) << "Operator type: [" << node->OpType() << "] is supported by browser";
108-
supported = IsNodeSupported(*node, graph_viewer, device_type, logger);
109+
supported = IsNodeSupported(*node, graph_viewer, device_type, wnn_limits, logger);
109110
}
110111

111112
LOGS(logger, VERBOSE) << "Operator type: [" << node->OpType()
@@ -130,10 +131,54 @@ std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_v
130131
return supported_node_groups;
131132
}
132133

133-
bool IsSupportedDataType(const int32_t data_type,
134-
const std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType>& supported_data_types) {
135-
return std::find(supported_data_types.begin(), supported_data_types.end(), data_type) !=
136-
supported_data_types.end();
134+
bool AreInputDataTypesSame(const std::string& op_type,
135+
gsl::span<const int32_t> input_types,
136+
const logging::Logger& logger) {
137+
for (size_t i = 1; i < input_types.size(); i++) {
138+
if (input_types[0] != input_types[i]) {
139+
LOGS(logger, VERBOSE) << "[" << op_type
140+
<< "] Input data types should be the same, but ["
141+
<< input_types[0] << "] does not match "
142+
<< input_types[i] << "].";
143+
return false;
144+
}
145+
}
146+
return true;
147+
}
148+
149+
bool IsSupportedDataType(const int32_t onnx_data_type, const emscripten::val& webnn_supported_data_types) {
150+
auto it = onnx_to_webnn_data_type_map.find(static_cast<ONNX_NAMESPACE::TensorProto_DataType>(onnx_data_type));
151+
if (it == onnx_to_webnn_data_type_map.end())
152+
return false;
153+
154+
std::string webnn_data_type = it->second;
155+
156+
// Check if WebNN supports the data type.
157+
emscripten::val is_supported = webnn_supported_data_types.call<emscripten::val>("includes",
158+
emscripten::val(webnn_data_type));
159+
return is_supported.as<bool>();
160+
}
161+
162+
// Check if the input or output data type of ONNX node is supported by the WebNN operator.
163+
bool IsDataTypeSupportedByOp(const std::string& onnx_op_type,
164+
const int32_t onnx_data_type,
165+
const emscripten::val& wnn_limits,
166+
const std::string& webnn_input_output_name,
167+
const std::string& onnx_input_output_name,
168+
const logging::Logger& logger) {
169+
std::string webnn_op_type;
170+
if (!GetWebNNOpType(onnx_op_type, webnn_op_type))
171+
return false;
172+
173+
if (!IsSupportedDataType(onnx_data_type, wnn_limits[webnn_op_type][webnn_input_output_name]["dataTypes"])) {
174+
LOGS(logger, VERBOSE) << "[" << onnx_op_type
175+
<< "] " << onnx_input_output_name
176+
<< " type: [" << onnx_data_type
177+
<< "] is not supported for now";
178+
return false;
179+
}
180+
181+
return true;
137182
}
138183

139184
bool GetBidirectionalBroadcastShape(std::vector<int64_t>& shape_a,

onnxruntime/core/providers/webnn/builders/helper.h

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ bool IsInputSupported(const NodeArg& node_arg, const std::string& parent_name, c
148148
std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_viewer,
149149
const emscripten::val& wnn_builder,
150150
const WebnnDeviceType device_type,
151+
const emscripten::val& wnn_limits,
151152
const logging::Logger& logger);
152153
static const InlinedHashMap<std::string, std::string> op_map = {
153154
{"Abs", "abs"},
@@ -250,20 +251,38 @@ inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn
250251
return true;
251252
}
252253

253-
static const std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> webnn_supported_data_types = {
254-
ONNX_NAMESPACE::TensorProto_DataType_BOOL,
255-
ONNX_NAMESPACE::TensorProto_DataType_INT8,
256-
ONNX_NAMESPACE::TensorProto_DataType_UINT8,
257-
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
258-
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
259-
ONNX_NAMESPACE::TensorProto_DataType_INT32,
260-
ONNX_NAMESPACE::TensorProto_DataType_INT64,
261-
ONNX_NAMESPACE::TensorProto_DataType_UINT32,
262-
ONNX_NAMESPACE::TensorProto_DataType_UINT64,
254+
inline bool GetWebNNOpType(const std::string& op_type, std::string& webnn_op_type) {
255+
auto it = op_map.find(op_type);
256+
// Returns false if the op_type is not listed in the op_map.
257+
if (it == op_map.end()) {
258+
return false;
259+
}
260+
webnn_op_type = it->second;
261+
return true;
262+
}
263+
264+
static const InlinedHashMap<ONNX_NAMESPACE::TensorProto_DataType, std::string> onnx_to_webnn_data_type_map = {
265+
{ONNX_NAMESPACE::TensorProto_DataType_BOOL, "uint8"},
266+
{ONNX_NAMESPACE::TensorProto_DataType_INT8, "int8"},
267+
{ONNX_NAMESPACE::TensorProto_DataType_UINT8, "uint8"},
268+
{ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, "float16"},
269+
{ONNX_NAMESPACE::TensorProto_DataType_FLOAT, "float32"},
270+
{ONNX_NAMESPACE::TensorProto_DataType_INT32, "int32"},
271+
{ONNX_NAMESPACE::TensorProto_DataType_INT64, "int64"},
272+
{ONNX_NAMESPACE::TensorProto_DataType_UINT32, "uint32"},
273+
{ONNX_NAMESPACE::TensorProto_DataType_UINT64, "uint64"},
263274
};
264275

265-
bool IsSupportedDataType(const int32_t data_type,
266-
const std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType>& supported_data_types);
276+
bool AreInputDataTypesSame(const std::string& op_type,
277+
gsl::span<const int32_t> input_types,
278+
const logging::Logger& logger);
279+
bool IsSupportedDataType(const int32_t onnx_data_type, const emscripten::val& webnn_supported_data_types);
280+
bool IsDataTypeSupportedByOp(const std::string& onnx_op_type,
281+
const int32_t onnx_data_type,
282+
const emscripten::val& wnn_limits,
283+
const std::string& webnn_input_output_name,
284+
const std::string& onnx_input_output_name,
285+
const logging::Logger& logger);
267286

268287
bool GetBidirectionalBroadcastShape(std::vector<int64_t>& shape_a,
269288
std::vector<int64_t>& shape_b,

onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@ class ActivationOpBuilder : public BaseOpBuilder {
2121
// Operator support related.
2222
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
2323
WebnnDeviceType device_type, const logging::Logger& logger) const override;
24-
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
25-
const logging::Logger& logger) const override;
2624
};
2725

2826
// Add operator related.
@@ -94,44 +92,6 @@ bool ActivationOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initi
9492
return true;
9593
}
9694

97-
bool ActivationOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
98-
const logging::Logger& logger) const {
99-
const auto& input = *node.InputDefs()[0];
100-
const auto& op_type = node.OpType();
101-
int32_t input_type;
102-
if (!GetType(input, input_type, logger))
103-
return false;
104-
105-
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types;
106-
// WebNN relu op supports float32, float16, int32, int8 input data types.
107-
if (op_type == "Relu") {
108-
supported_data_types = {
109-
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
110-
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
111-
ONNX_NAMESPACE::TensorProto_DataType_INT32,
112-
ONNX_NAMESPACE::TensorProto_DataType_INT8,
113-
};
114-
// WebNN CPU backend does not support int32 data type for relu.
115-
if (device_type == WebnnDeviceType::CPU) {
116-
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT32);
117-
}
118-
} else { // Others only support float32 and float16.
119-
supported_data_types = {
120-
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
121-
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
122-
};
123-
}
124-
125-
if (!IsSupportedDataType(input_type, supported_data_types)) {
126-
LOGS(logger, VERBOSE) << "[" << op_type
127-
<< "] Input type: [" << input_type
128-
<< "] is not supported for now";
129-
return false;
130-
}
131-
132-
return true;
133-
}
134-
13595
void CreateActivationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
13696
if (op_registrations.op_builder_map.count(op_type) > 0)
13797
return;

onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@ class ArgMaxMinOpBuilder : public BaseOpBuilder {
2222
// Operator support related.
2323
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
2424
WebnnDeviceType device_type, const logging::Logger& logger) const override;
25-
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
26-
const logging::Logger& logger) const override;
2725
};
2826

2927
// Add operator related.
@@ -77,31 +75,6 @@ bool ArgMaxMinOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initia
7775
return true;
7876
}
7977

80-
bool ArgMaxMinOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
81-
const logging::Logger& logger) const {
82-
const auto& input = *node.InputDefs()[0];
83-
const auto& op_type = node.OpType();
84-
int32_t input_type;
85-
if (!GetType(input, input_type, logger))
86-
return false;
87-
88-
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types = webnn_supported_data_types;
89-
// WebNN CPU backend doesn't support int64, uint64 input data types for argMax and argMin.
90-
if (device_type == WebnnDeviceType::CPU) {
91-
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT64);
92-
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64);
93-
}
94-
95-
if (!IsSupportedDataType(input_type, supported_data_types)) {
96-
LOGS(logger, VERBOSE) << "[" << op_type
97-
<< "] Input type: [" << input_type
98-
<< "] is not supported for now";
99-
return false;
100-
}
101-
102-
return true;
103-
}
104-
10578
void CreateArgMaxMinOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
10679
if (op_registrations.op_builder_map.count(op_type) > 0)
10780
return;

onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ bool HasExternalInitializer(const InitializedTensorSet& initializers, const Node
3838
Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node& node,
3939
const logging::Logger& logger) const {
4040
ORT_RETURN_IF_NOT(
41-
IsOpSupported(model_builder.GetInitializerTensors(), node, model_builder.GetWebnnDeviceType(), logger),
42-
"Unsupported operator ",
43-
node.OpType());
41+
IsOpSupported(model_builder.GetInitializerTensors(), node, model_builder.GetWebnnDeviceType(),
42+
model_builder.GetOpSupportLimits(), logger),
43+
"Unsupported operator ", node.OpType());
4444
ORT_RETURN_IF_ERROR(AddToModelBuilderImpl(model_builder, node, logger));
4545
LOGS(logger, VERBOSE) << "Operator name: [" << node.Name()
4646
<< "] type: [" << node.OpType() << "] was added";
@@ -50,8 +50,12 @@ Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node&
5050
// Operator support related.
5151

5252
bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, const Node& node,
53-
const WebnnDeviceType device_type, const logging::Logger& logger) const {
54-
if (!HasSupportedInputs(node, device_type, logger))
53+
const WebnnDeviceType device_type, const emscripten::val& wnn_limits,
54+
const logging::Logger& logger) const {
55+
if (!HasSupportedInputs(node, wnn_limits, logger))
56+
return false;
57+
58+
if (!HasSupportedOutputsImpl(node, wnn_limits, logger))
5559
return false;
5660

5761
// We do not support external initializers for now.
@@ -64,7 +68,7 @@ bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, cons
6468
return IsOpSupportedImpl(initializers, node, device_type, logger);
6569
}
6670

67-
bool BaseOpBuilder::HasSupportedInputs(const Node& node, const WebnnDeviceType device_type,
71+
bool BaseOpBuilder::HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits,
6872
const logging::Logger& logger) const {
6973
const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]");
7074
for (const auto* input : node.InputDefs()) {
@@ -73,39 +77,33 @@ bool BaseOpBuilder::HasSupportedInputs(const Node& node, const WebnnDeviceType d
7377
}
7478
}
7579

76-
// WebNN CPU backend (TFLite) will enable float16 input data type soon,
77-
// temporarily fallback float16 input data type for WebNN CPU.
78-
if (device_type == WebnnDeviceType::CPU) {
79-
const auto& input = *node.InputDefs()[0];
80-
81-
int32_t input_type;
82-
if (!GetType(input, input_type, logger))
83-
return false;
84-
if (input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)
85-
return false;
86-
}
87-
88-
return HasSupportedInputsImpl(node, device_type, logger);
80+
return HasSupportedInputsImpl(node, wnn_limits, logger);
8981
}
9082

9183
bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node,
92-
const WebnnDeviceType /* device_type */,
84+
const emscripten::val& wnn_limits,
9385
const logging::Logger& logger) const {
9486
// We only check the type of input 0 by default, specific op builder can override this.
9587
const auto& input = *node.InputDefs()[0];
96-
88+
const auto& op_type = node.OpType();
9789
int32_t input_type;
9890
if (!GetType(input, input_type, logger))
9991
return false;
10092

101-
if (!IsSupportedDataType(input_type, webnn_supported_data_types)) {
102-
LOGS(logger, VERBOSE) << "[" << node.OpType()
103-
<< "] Input type: [" << input_type
104-
<< "] is not supported for now";
93+
return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "Input", logger);
94+
}
95+
96+
bool BaseOpBuilder::HasSupportedOutputsImpl(const Node& node,
97+
const emscripten::val& wnn_limits,
98+
const logging::Logger& logger) const {
99+
// We only check the type of output 0 by default, specific op builder can override this.
100+
const auto& output = *node.OutputDefs()[0];
101+
const auto& op_type = node.OpType();
102+
int32_t output_type;
103+
if (!GetType(output, output_type, logger))
105104
return false;
106-
}
107105

108-
return true;
106+
return IsDataTypeSupportedByOp(op_type, output_type, wnn_limits, "output", "Output", logger);
109107
}
110108

111109
bool BaseOpBuilder::HasSupportedOpSet(const Node& node,

onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,19 @@ class BaseOpBuilder : public IOpBuilder {
2828
// Operator support related.
2929
public:
3030
bool IsOpSupported(const InitializedTensorSet& initializers, const Node& node,
31-
const WebnnDeviceType device_type, const logging::Logger& logger) const override;
31+
const WebnnDeviceType device_type, const emscripten::val& wnn_limits,
32+
const logging::Logger& logger) const override;
3233

3334
protected:
3435
virtual bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& /* node */,
3536
const WebnnDeviceType /* device_type */, const logging::Logger& /* logger */) const {
3637
return true;
3738
}
3839

39-
virtual bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
40+
virtual bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
4041
const logging::Logger& logger) const;
42+
virtual bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits,
43+
const logging::Logger& logger) const;
4144

4245
// ONNX Runtime only *guarantees* support for models stamped
4346
// with opset version 7 or above for opset domain 'ai.onnx'.
@@ -50,7 +53,7 @@ class BaseOpBuilder : public IOpBuilder {
5053

5154
private:
5255
bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const;
53-
bool HasSupportedInputs(const Node& node, const WebnnDeviceType device_type, const logging::Logger& logger) const;
56+
bool HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const;
5457
};
5558

5659
} // namespace webnn

0 commit comments

Comments
 (0)