Skip to content

Commit f7bf5a1

Browse files
[QNN EP] Ensure QNN EP rejects nodes with I/O of dynamic shape (microsoft#22066)
### Description Updates QNN EP to properly reject nodes that have inputs or outputs with dynamic shapes. ### Motivation and Context Currently, QNN EP does not properly offload subgraphs with dynamic shapes to the CPU EP. This PR ensures that QNN EP rejects nodes that consume or generate I/O with dynamic shapes.
1 parent 55ab13e commit f7bf5a1

File tree

4 files changed

+68
-3
lines changed

4 files changed

+68
-3
lines changed

onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,8 +308,10 @@ bool QnnModelWrapper::GetOnnxShape(const NodeArg& node_arg, std::vector<uint32_t
308308
return true;
309309
}
310310

311-
// We already checked the shape has no dynamic dimension
312311
for (const auto& dim : shape_proto->dim()) {
312+
if (!dim.has_dim_value()) {
313+
return false; // Do not support dynamic shapes.
314+
}
313315
shape.push_back(SafeInt<uint32_t>(dim.dim_value()));
314316
}
315317

onnxruntime/test/providers/qnn/qnn_basic_test.cc

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -948,6 +948,63 @@ TEST_F(QnnHTPBackendTests, Float32ModelWithFP16PrecisionTest) {
948948
0.008f);
949949
}
950950

951+
// Test that QNN EP only handles nodes with static shapes and rejects nodes with dynamic shape I/O.
952+
TEST_F(QnnHTPBackendTests, EPRejectsDynamicShapesF32) {
953+
// Local function that builds a model in which the last two nodes use dynamic shapes.
954+
auto model_build_fn = [](ModelTestBuilder& builder) {
955+
NodeArg* input1 = builder.MakeInput<float>(std::vector<int64_t>{1, 2, 8, 8},
956+
GetFloatDataInRange(0.0f, 1.0f, 128));
957+
NodeArg* input2 = builder.MakeInput<int64_t>(std::vector<int64_t>{3}, std::vector<int64_t>{1, 2, 49});
958+
959+
// Add a Conv with known shapes. QNN EP should support it.
960+
NodeArg* weight = builder.MakeInitializer<float>(std::vector<int64_t>{2, 2, 2, 2},
961+
GetFloatDataInRange(-0.3f, 0.3f, 16));
962+
NodeArg* bias = builder.MakeInitializer<float>(std::vector<int64_t>{2}, {0.0f, 1.0f});
963+
964+
auto* conv_output = builder.MakeIntermediate();
965+
builder.AddNode("Conv", {input1, weight, bias}, {conv_output});
966+
967+
// Add a Reshape to a dynamic shape. QNN EP should reject this node.
968+
auto* reshape_output = builder.MakeIntermediate();
969+
builder.AddNode("Reshape", {conv_output, input2}, {reshape_output});
970+
971+
// Add a Softmax. QNN EP should reject this node because its input has a dynamic shape.
972+
NodeArg* output = builder.MakeOutput();
973+
builder.AddNode("Softmax", {reshape_output}, {output});
974+
};
975+
976+
// Local function that checks that the nodes with dynamic shape I/O were assigned to CPU EP.
977+
std::function<void(const Graph&)> ep_graph_checker = [](const Graph& graph) {
978+
for (const Node& node : graph.Nodes()) {
979+
const std::string& ep_name = node.GetExecutionProviderType();
980+
const std::string& op_type = node.OpType();
981+
if (op_type == "Reshape" || op_type == "Softmax") {
982+
EXPECT_EQ(ep_name, kCpuExecutionProvider);
983+
} else {
984+
EXPECT_EQ(ep_name, kQnnExecutionProvider);
985+
}
986+
}
987+
};
988+
989+
ProviderOptions provider_options;
990+
#if defined(_WIN32)
991+
provider_options["backend_path"] = "QnnHtp.dll";
992+
#else
993+
provider_options["backend_path"] = "libQnnHtp.so";
994+
#endif
995+
provider_options["enable_htp_fp16_precision"] = "1"; // QNN EP will use fp16 precision.
996+
// CPU EP will use fp32, so we can relax accuracy requirements.
997+
998+
RunQnnModelTest(model_build_fn,
999+
provider_options,
1000+
/*opset*/ 19,
1001+
ExpectedEPNodeAssignment::Some,
1002+
/*abs_err*/ 1e-4f,
1003+
logging::Severity::kERROR,
1004+
/*verify_output*/ true,
1005+
&ep_graph_checker);
1006+
}
1007+
9511008
#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__)
9521009
#endif // !defined(ORT_MINIMAL_BUILD)
9531010

onnxruntime/test/providers/qnn/qnn_test_utils.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,12 @@ void TryEnableQNNSaver(ProviderOptions& qnn_options) {
9898

9999
void RunQnnModelTest(const GetTestModelFn& build_test_case, ProviderOptions provider_options,
100100
int opset_version, ExpectedEPNodeAssignment expected_ep_assignment,
101-
float fp32_abs_err, logging::Severity log_severity, bool verify_outputs) {
101+
float fp32_abs_err, logging::Severity log_severity, bool verify_outputs,
102+
std::function<void(const Graph&)>* ep_graph_checker) {
102103
EPVerificationParams verification_params;
103104
verification_params.ep_node_assignment = expected_ep_assignment;
104105
verification_params.fp32_abs_err = fp32_abs_err;
106+
verification_params.graph_verifier = ep_graph_checker;
105107
// Add kMSDomain to cover contrib op like Gelu
106108
const std::unordered_map<std::string, int> domain_to_version = {{"", opset_version}, {kMSDomain, 1}};
107109

onnxruntime/test/providers/qnn/qnn_test_utils.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1033,12 +1033,16 @@ inline GetTestQDQModelFn<QuantType> BuildQDQOpTestCase(
10331033
* \param expected_ep_assignment How many nodes are expected to be assigned to QNN (All, Some, or None).
10341034
* \param fp32_abs_err The acceptable error between CPU EP and QNN EP.
10351035
* \param log_severity The logger's minimum severity level.
1036+
* \param verify_outputs True to verify that the outputs match (within tolerance).
1037+
* \param ep_graph_checker Function called on the Graph generated for the EP's session. Used to check node
1038+
* EP assignment.
10361039
*/
10371040
void RunQnnModelTest(const GetTestModelFn& build_test_case, ProviderOptions provider_options,
10381041
int opset_version, ExpectedEPNodeAssignment expected_ep_assignment,
10391042
float fp32_abs_err = 1e-5f,
10401043
logging::Severity log_severity = logging::Severity::kERROR,
1041-
bool verify_outputs = true);
1044+
bool verify_outputs = true,
1045+
std::function<void(const Graph&)>* ep_graph_checker = nullptr);
10421046

10431047
enum class BackendSupport {
10441048
SUPPORT_UNKNOWN,

0 commit comments

Comments
 (0)