Skip to content

Commit a23786c

Browse files
authored
Backend Shape and Datatype Access API (#247)
Adding output shape and datatype accessors to backend API
1 parent 1dcf5bb commit a23786c

File tree

4 files changed

+432
-1
lines changed

4 files changed

+432
-1
lines changed

include/triton/core/tritonbackend.h

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ struct TRITONBACKEND_Batcher;
9494
/// }
9595
///
9696
#define TRITONBACKEND_API_VERSION_MAJOR 1
97-
#define TRITONBACKEND_API_VERSION_MINOR 15
97+
#define TRITONBACKEND_API_VERSION_MINOR 16
9898

9999
/// Get the TRITONBACKEND API version supported by Triton. This value
100100
/// can be compared against the TRITONBACKEND_API_VERSION_MAJOR and
@@ -1569,6 +1569,44 @@ TRITONBACKEND_ISPEC TRITONSERVER_Error* TRITONBACKEND_ModelBatchInitialize(
15691569
TRITONBACKEND_ISPEC TRITONSERVER_Error* TRITONBACKEND_ModelBatchFinalize(
15701570
void* userp);
15711571

1572+
/// Get all information about an output tensor by its name. The caller does
1573+
/// not own any of the referenced return values and must not modify or delete
1574+
/// them. The lifetime of all returned values extends until 'response' is
1575+
/// deleted.
1576+
///
1577+
/// \param response The response object.
1578+
/// \param name The name of the output.
1579+
/// \param datatype Returns the type of the output.
1580+
/// \param shape Returns the shape of the output.
1581+
/// \param dim_count Returns the number of dimensions of the returned
1582+
/// shape.
1583+
/// \return a TRITONSERVER_Error indicating success or failure.
1584+
TRITONBACKEND_ISPEC TRITONSERVER_Error*
1585+
TRITONBACKEND_InferenceResponseOutputByName(
1586+
TRITONBACKEND_Response* response, const char* name,
1587+
TRITONSERVER_DataType* datatype, const int64_t** shape,
1588+
uint64_t* dim_count);
1589+
1590+
/// Get all information about an output tensor by its index. The caller does
1591+
/// not own any of the referenced return values and must not modify or delete
1592+
/// them. The lifetime of all returned values extends until 'response' is
1593+
/// deleted.
1594+
///
1595+
/// \param response The response object.
1596+
/// \param index The index of the output tensor, must be 0 <= index <
1597+
/// count, where 'count' is the value returned by
1598+
/// TRITONSERVER_InferenceResponseOutputCount.
1599+
/// \param name Returns the name of the output.
1600+
/// \param datatype Returns the type of the output.
1601+
/// \param shape Returns the shape of the output.
1602+
/// \param dim_count Returns the number of dimensions of the returned
1603+
/// shape.
1604+
/// \return a TRITONSERVER_Error indicating success or failure.
1605+
TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_InferenceResponseOutput(
1606+
TRITONBACKEND_Response* response, const uint32_t index, const char** name,
1607+
TRITONSERVER_DataType* datatype, const int64_t** shape,
1608+
uint64_t* dim_count);
1609+
15721610
#ifdef __cplusplus
15731611
}
15741612
#endif

src/backend_model.cc

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1750,6 +1750,60 @@ TRITONBACKEND_BackendAttributeSetParallelModelInstanceLoading(
17501750
return nullptr;
17511751
}
17521752

1753+
TRITONAPI_DECLSPEC TRITONSERVER_Error*
1754+
TRITONBACKEND_InferenceResponseOutputByName(
1755+
TRITONBACKEND_Response* response, const char* name,
1756+
TRITONSERVER_DataType* datatype, const int64_t** shape, uint64_t* dim_count)
1757+
{
1758+
InferenceResponse* tr = reinterpret_cast<InferenceResponse*>(response);
1759+
1760+
const auto& outputs = tr->Outputs();
1761+
uint32_t output_count = outputs.size();
1762+
std::string output_name = std::string(name);
1763+
1764+
for (uint32_t idx = 0; idx < output_count; ++idx) {
1765+
if (outputs[idx].Name() == output_name) {
1766+
*datatype = DataTypeToTriton(outputs[idx].DType());
1767+
const std::vector<int64_t>& oshape = outputs[idx].Shape();
1768+
*shape = &oshape[0];
1769+
*dim_count = oshape.size();
1770+
return nullptr; // success
1771+
}
1772+
}
1773+
return TRITONSERVER_ErrorNew(
1774+
TRITONSERVER_ERROR_NOT_FOUND,
1775+
("Output name " + output_name + "not found.").c_str());
1776+
}
1777+
1778+
TRITONAPI_DECLSPEC TRITONSERVER_Error*
1779+
TRITONBACKEND_InferenceResponseOutput(
1780+
TRITONBACKEND_Response* response, const uint32_t index, const char** name,
1781+
TRITONSERVER_DataType* datatype, const int64_t** shape, uint64_t* dim_count)
1782+
{
1783+
InferenceResponse* tr = reinterpret_cast<InferenceResponse*>(response);
1784+
1785+
const auto& outputs = tr->Outputs();
1786+
if (index >= outputs.size()) {
1787+
return TRITONSERVER_ErrorNew(
1788+
TRITONSERVER_ERROR_INVALID_ARG,
1789+
("out of bounds index " + std::to_string(index) +
1790+
std::string(": response has ") + std::to_string(outputs.size()) +
1791+
" outputs")
1792+
.c_str());
1793+
}
1794+
1795+
const InferenceResponse::Output& output = outputs[index];
1796+
1797+
*name = output.Name().c_str();
1798+
*datatype = DataTypeToTriton(output.DType());
1799+
1800+
const std::vector<int64_t>& oshape = output.Shape();
1801+
*shape = &oshape[0];
1802+
*dim_count = oshape.size();
1803+
1804+
return nullptr; // success
1805+
}
1806+
17531807
} // extern C
17541808

17551809
}} // namespace triton::core

src/test/CMakeLists.txt

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -534,3 +534,43 @@ install(
534534
TARGETS register_api_test
535535
RUNTIME DESTINATION bin
536536
)
537+
538+
#
539+
# Backend Output Detail Unittest
540+
#
541+
add_executable(
542+
backend_output_detail_test
543+
backend_output_detail_test.cc
544+
)
545+
546+
set_target_properties(
547+
backend_output_detail_test
548+
PROPERTIES
549+
SKIP_BUILD_RPATH TRUE
550+
BUILD_WITH_INSTALL_RPATH TRUE
551+
INSTALL_RPATH_USE_LINK_PATH FALSE
552+
INSTALL_RPATH ""
553+
)
554+
555+
target_include_directories(
556+
backend_output_detail_test
557+
PRIVATE
558+
${CMAKE_CURRENT_SOURCE_DIR}/..
559+
${CMAKE_CURRENT_SOURCE_DIR}/../../include
560+
${GTEST_INCLUDE_DIRS}
561+
)
562+
563+
target_link_libraries(
564+
backend_output_detail_test
565+
PRIVATE
566+
triton-common-error # from repo-common
567+
triton-common-logging # from repo-common
568+
triton-core
569+
GTest::gtest
570+
GTest::gtest_main
571+
)
572+
573+
install(
574+
TARGETS backend_output_detail_test
575+
RUNTIME DESTINATION bin
576+
)

0 commit comments

Comments
 (0)