diff --git a/backends/xnnpack/runtime/XNNCompiler.cpp b/backends/xnnpack/runtime/XNNCompiler.cpp index 312cbc17b9..dcee8258ef 100644 --- a/backends/xnnpack/runtime/XNNCompiler.cpp +++ b/backends/xnnpack/runtime/XNNCompiler.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -615,10 +616,14 @@ Error defineAddNode( std::pair min_max = getOutputMinMax(node); auto graph_node = node->xnode_union_as_XNNAdd(); - xnn_status status = xnn_define_add2( + + struct xnn_binary_params params = { + .output_min = min_max.first, .output_max = min_max.second}; + + xnn_status status = xnn_define_binary( subgraph_ptr, - min_max.first, - min_max.second, + xnn_binary_add, + ¶ms, remapped_ids.at(graph_node->input1_id()), remapped_ids.at(graph_node->input2_id()), remapped_ids.at(graph_node->output_id()), @@ -644,8 +649,11 @@ Error defineMinimumNode( MAYBE_UNUSED(graph); auto graph_node = node->xnode_union_as_XNNMinimum(); - xnn_status status = xnn_define_minimum2( + + xnn_status status = xnn_define_binary( subgraph_ptr, + xnn_binary_minimum, + nullptr, remapped_ids.at(graph_node->input1_id()), remapped_ids.at(graph_node->input2_id()), remapped_ids.at(graph_node->output_id()), @@ -673,10 +681,14 @@ Error defineSubtractNode( auto graph_node = node->xnode_union_as_XNNSubtract(); std::pair min_max = getOutputMinMax(node); - xnn_status status = xnn_define_subtract( + + struct xnn_binary_params params = { + .output_min = min_max.first, .output_max = min_max.second}; + + xnn_status status = xnn_define_binary( subgraph_ptr, - min_max.first, - min_max.second, + xnn_binary_subtract, + ¶ms, remapped_ids.at(graph_node->input1_id()), remapped_ids.at(graph_node->input2_id()), remapped_ids.at(graph_node->output_id()), @@ -704,10 +716,14 @@ Error defineMultiplyNode( auto graph_node = node->xnode_union_as_XNNMultiply(); std::pair min_max = getOutputMinMax(node); - xnn_status status = xnn_define_multiply2( + + struct xnn_binary_params params = { + .output_min = min_max.first, .output_max = min_max.second}; + + xnn_status status = xnn_define_binary( subgraph_ptr, - min_max.first, - min_max.second, + xnn_binary_multiply, + ¶ms, remapped_ids.at(graph_node->input1_id()), remapped_ids.at(graph_node->input2_id()), remapped_ids.at(graph_node->output_id()), @@ -857,10 +873,14 @@ Error defineClampNode( std::pair min_max = getOutputMinMax(node); auto graph_node = node->xnode_union_as_XNNClamp(); - xnn_status status = xnn_define_clamp( + + union xnn_unary_params params = { + .clamp = {.min = min_max.first, .max = min_max.second}}; + + xnn_status status = xnn_define_unary( subgraph_ptr, - min_max.first, - min_max.second, + xnn_unary_clamp, + ¶ms, remapped_ids.at(graph_node->input_id()), remapped_ids.at(graph_node->output_id()), graph_node->flags()); @@ -916,8 +936,10 @@ Error defineSigmoidNode( MAYBE_UNUSED(graph); auto graph_node = node->xnode_union_as_XNNSigmoid(); - xnn_status status = xnn_define_sigmoid( + xnn_status status = xnn_define_unary( subgraph_ptr, + xnn_unary_sigmoid, + nullptr, remapped_ids.at(graph_node->input_id()), remapped_ids.at(graph_node->output_id()), graph_node->flags()); @@ -944,8 +966,10 @@ Error defineFloorNode( MAYBE_UNUSED(graph); auto graph_node = node->xnode_union_as_XNNFloor(); - xnn_status status = xnn_define_floor( + xnn_status status = xnn_define_unary( subgraph_ptr, + xnn_unary_floor, + nullptr, remapped_ids.at(graph_node->input_id()), remapped_ids.at(graph_node->output_id()), graph_node->flags()); @@ -1167,10 +1191,14 @@ Error defineDivNode( auto graph_node = node->xnode_union_as_XNNDiv(); std::pair min_max = getOutputMinMax(node); - xnn_status status = xnn_define_divide( + + struct xnn_binary_params params = { + .output_min = min_max.first, .output_max = min_max.second}; + + xnn_status status = xnn_define_binary( subgraph_ptr, - min_max.first, - min_max.second, + xnn_binary_divide, + ¶ms, remapped_ids.at(graph_node->input1_id()), remapped_ids.at(graph_node->input2_id()), remapped_ids.at(graph_node->output_id()), @@ -1415,8 +1443,10 @@ Error defineSquareRootNode( auto graph_node = node->xnode_union_as_XNNSquareRoot(); - xnn_status status = xnn_define_square_root( + xnn_status status = xnn_define_unary( subgraph_ptr, + xnn_unary_square_root, + nullptr, remapped_ids.at(graph_node->input_id()), remapped_ids.at(graph_node->output_id()), graph_node->flags()); @@ -1445,8 +1475,10 @@ Error defineReciprocalSquareRootNode( auto graph_node = node->xnode_union_as_XNNReciprocalSquareRoot(); - xnn_status status = xnn_define_reciprocal_square_root( + xnn_status status = xnn_define_unary( subgraph_ptr, + xnn_unary_reciprocal_square_root, + nullptr, remapped_ids.at(graph_node->input_id()), remapped_ids.at(graph_node->output_id()), graph_node->flags()); @@ -1475,8 +1507,10 @@ Error defineLogNode( auto graph_node = node->xnode_union_as_XNNLog(); - xnn_status status = xnn_define_log( + xnn_status status = xnn_define_unary( subgraph_ptr, + xnn_unary_log, + nullptr, remapped_ids.at(graph_node->input_id()), remapped_ids.at(graph_node->output_id()), graph_node->flags()); @@ -1505,8 +1539,10 @@ Error defineGeluNode( auto graph_node = node->xnode_union_as_XNNGelu(); - xnn_status status = xnn_define_gelu( + xnn_status status = xnn_define_unary( subgraph_ptr, + xnn_unary_gelu, + nullptr, remapped_ids.at(graph_node->input_id()), remapped_ids.at(graph_node->output_id()), graph_node->flags()); @@ -1535,8 +1571,10 @@ Error defineCeilingNode( auto graph_node = node->xnode_union_as_XNNCeiling(); - xnn_status status = xnn_define_ceiling( + xnn_status status = xnn_define_unary( subgraph_ptr, + xnn_unary_ceiling, + nullptr, remapped_ids.at(graph_node->input_id()), remapped_ids.at(graph_node->output_id()), graph_node->flags()); @@ -1565,8 +1603,10 @@ Error defineHardswishNode( auto graph_node = node->xnode_union_as_XNNHardswish(); - xnn_status status = xnn_define_hardswish( + xnn_status status = xnn_define_unary( subgraph_ptr, + xnn_unary_hardswish, + nullptr, remapped_ids.at(graph_node->input_id()), remapped_ids.at(graph_node->output_id()), graph_node->flags()); @@ -1595,9 +1635,13 @@ Error defineLeakyReLUNode( auto graph_node = node->xnode_union_as_XNNLeakyReLU(); - xnn_status status = xnn_define_leaky_relu( + union xnn_unary_params params = { + .leaky_relu = {.negative_slope = graph_node->negative_slope()}}; + + xnn_status status = xnn_define_unary( subgraph_ptr, - graph_node->negative_slope(), + xnn_unary_leaky_relu, + ¶ms, remapped_ids.at(graph_node->input_id()), remapped_ids.at(graph_node->output_id()), graph_node->flags()); @@ -1626,8 +1670,10 @@ Error defineMaximumNode( auto graph_node = node->xnode_union_as_XNNMaximum(); - xnn_status status = xnn_define_maximum2( + xnn_status status = xnn_define_binary( subgraph_ptr, + xnn_binary_maximum, + nullptr, remapped_ids.at(graph_node->input1_id()), remapped_ids.at(graph_node->input2_id()), remapped_ids.at(graph_node->output_id()), @@ -1656,8 +1702,10 @@ Error defineNegateNode( auto graph_node = node->xnode_union_as_XNNNegate(); - xnn_status status = xnn_define_negate( + xnn_status status = xnn_define_unary( subgraph_ptr, + xnn_unary_negate, + nullptr, remapped_ids.at(graph_node->input_id()), remapped_ids.at(graph_node->output_id()), graph_node->flags()); @@ -1685,8 +1733,10 @@ Error defineSquareNode( auto graph_node = node->xnode_union_as_XNNSquare(); - xnn_status status = xnn_define_square( + xnn_status status = xnn_define_unary( subgraph_ptr, + xnn_unary_square, + nullptr, remapped_ids.at(graph_node->input_id()), remapped_ids.at(graph_node->output_id()), graph_node->flags()); @@ -1714,9 +1764,12 @@ Error defineELUNode( auto graph_node = node->xnode_union_as_XNNELU(); - xnn_status status = xnn_define_elu( + union xnn_unary_params params = {.elu = {.alpha = graph_node->alpha()}}; + + xnn_status status = xnn_define_unary( subgraph_ptr, - graph_node->alpha(), + xnn_unary_elu, + ¶ms, remapped_ids.at(graph_node->input_id()), remapped_ids.at(graph_node->output_id()), graph_node->flags()); @@ -1744,8 +1797,10 @@ Error defineAbsNode( auto graph_node = node->xnode_union_as_XNNAbs(); - xnn_status status = xnn_define_abs( + xnn_status status = xnn_define_unary( subgraph_ptr, + xnn_unary_abs, + nullptr, remapped_ids.at(graph_node->input_id()), remapped_ids.at(graph_node->output_id()), graph_node->flags());