Skip to content

Commit d14828c

Browse files
authored
[Backend] Add AdaptivePool2d for TensorRT plugin (#668)
* add adaptivepool2d for tensorrt plugin * update code * update code * update code to fix bug
1 parent cc74fb8 commit d14828c

File tree

9 files changed

+461
-31
lines changed

9 files changed

+461
-31
lines changed

CMakeLists.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ configure_file(${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/pybind/main.cc
164164
file(GLOB_RECURSE ALL_DEPLOY_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/*.cc)
165165
file(GLOB_RECURSE FDTENSOR_FUNC_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/function/*.cc)
166166
file(GLOB_RECURSE FDTENSOR_FUNC_CUDA_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/function/*.cu)
167-
file(GLOB_RECURSE DEPLOY_ORT_CUDA_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/ort/*.cu)
167+
file(GLOB_RECURSE DEPLOY_OP_CUDA_KERNEL_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/op_cuda_kernels/*.cu)
168168
file(GLOB_RECURSE DEPLOY_ORT_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/ort/*.cc)
169169
file(GLOB_RECURSE DEPLOY_PADDLE_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/paddle/*.cc)
170170
file(GLOB_RECURSE DEPLOY_POROS_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/poros/*.cc)
@@ -202,7 +202,7 @@ if(ENABLE_ORT_BACKEND)
202202
include(${PROJECT_SOURCE_DIR}/cmake/onnxruntime.cmake)
203203
list(APPEND DEPEND_LIBS external_onnxruntime)
204204
if(WITH_GPU)
205-
list(APPEND ALL_DEPLOY_SRCS ${DEPLOY_ORT_CUDA_SRCS})
205+
list(APPEND ALL_DEPLOY_SRCS ${DEPLOY_OP_CUDA_KERNEL_SRCS})
206206
endif()
207207
endif()
208208

@@ -361,6 +361,7 @@ if(ENABLE_TRT_BACKEND)
361361
find_library(TRT_ONNX_LIB nvonnxparser ${TRT_LIB_DIR} NO_DEFAULT_PATH)
362362
find_library(TRT_PLUGIN_LIB nvinfer_plugin ${TRT_LIB_DIR} NO_DEFAULT_PATH)
363363
list(APPEND DEPEND_LIBS ${TRT_INFER_LIB} ${TRT_ONNX_LIB} ${TRT_PLUGIN_LIB})
364+
list(APPEND ALL_DEPLOY_SRCS ${DEPLOY_OP_CUDA_KERNEL_SRCS})
364365

365366
if(NOT BUILD_ON_JETSON)
366367
if(NOT EXISTS "${CMAKE_CURRENT_BINARY_DIR}/third_libs/install/tensorrt")

fastdeploy/backends/ort/ops/adaptive_pooling.cu renamed to fastdeploy/backends/op_cuda_kernels/adaptive_pool2d_kernel.cu

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,31 @@
1-
#include "adaptive_pool2d.h"
2-
#include <cuda.h>
3-
#include <cuda_runtime.h>
4-
#include <cstdint>
5-
#include <iostream>
6-
#include <vector>
7-
#include <math.h>
1+
#include "adaptive_pool2d_kernel.h"
2+
83
namespace fastdeploy {
94

105
__global__ void CudaCastKernel(const float* in, float* out, int edge, int out_bc_offset, int in_bc_offset, int ih, int iw, int oh, int ow, bool is_avg) {
116
int position = blockDim.x * blockIdx.x + threadIdx.x;
12-
if (position >= edge) return;
7+
if (position >= edge) {
8+
return;
9+
}
1310
int offset = floorf(float(position) / out_bc_offset);
1411
int h = floorf(float(position % out_bc_offset) / ow);
1512
int w = (position % out_bc_offset) % ow;
1613
int hstart = floorf(static_cast<float>(h * ih) / oh);
1714
int hend = ceilf(static_cast<float>((h + 1) * ih) / oh);
1815
int wstart = floorf(static_cast<float>(w * iw) / ow);
1916
int wend = ceilf(static_cast<float>((w + 1) * iw) / ow);
20-
if(is_avg){
17+
if(is_avg) {
2118
out[position] = 0.0;
22-
}else{
19+
} else {
2320
out[position] = in[offset * in_bc_offset + hstart * iw + wstart];
2421
}
2522
for (int h = hstart; h < hend; ++h) {
2623
for (int w = wstart; w < wend; ++w) {
2724
int input_idx = h * iw + w;
28-
if(is_avg){
25+
if(is_avg) {
2926
out[position] = out[position] + in[offset * in_bc_offset + input_idx];
30-
}else{
31-
out[position] = max(out[position], in[offset * in_bc_offset + input_idx]);
27+
} else {
28+
out[position] = max(out[position], in[offset * in_bc_offset + input_idx]);
3229
}
3330
}
3431
}
@@ -40,7 +37,7 @@ void CudaAdaptivePool(const std::vector<int64_t>& input_dims, const std::vector<
4037
int out_bc_offset = output_dims[2] * output_dims[3];
4138
int in_bc_offset = input_dims[2] * input_dims[3];
4239
int jobs = 1;
43-
for(int i : output_dims){
40+
for(int i : output_dims) {
4441
jobs *= i;
4542
}
4643
bool is_avg = pooling_type == "avg";
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
2+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
16+
#pragma once
17+
18+
#include <cuda.h>
19+
#include <cuda_runtime.h>
20+
#include <cstdint>
21+
#include <iostream>
22+
#include <vector>
23+
#include <math.h>
24+
25+
namespace fastdeploy {
26+
27+
void CudaAdaptivePool(const std::vector<int64_t>& input_dims,
28+
const std::vector<int64_t>& output_dims,
29+
float* output,
30+
const float* input,
31+
void* compute_stream,
32+
const std::string& pooling_type);
33+
34+
35+
} // namespace fastdeploy

fastdeploy/backends/ort/ops/adaptive_pool2d.cc

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,9 @@
1414

1515
#ifndef NON_64_PLATFORM
1616

17-
#include "fastdeploy/backends/ort/ops/adaptive_pool2d.h"
18-
#include <algorithm>
19-
#include <cmath>
20-
#include "fastdeploy/core/fd_tensor.h"
21-
#include "fastdeploy/utils/utils.h"
17+
#include "adaptive_pool2d.h"
2218

2319
namespace fastdeploy {
24-
2520
struct OrtTensorDimensions : std::vector<int64_t> {
2621
OrtTensorDimensions(Ort::CustomOpApi ort, const OrtValue* value) {
2722
OrtTensorTypeAndShapeInfo* info = ort.GetTensorTypeAndShape(value);

fastdeploy/backends/ort/ops/adaptive_pool2d.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,19 @@
1616

1717
#include <map>
1818
#include <string>
19+
#include <algorithm>
20+
#include <cmath>
21+
#include "fastdeploy/core/fd_tensor.h"
22+
#include "fastdeploy/utils/utils.h"
1923

2024
#ifndef NON_64_PLATFORM
2125
#include "onnxruntime_cxx_api.h" // NOLINT
2226

23-
namespace fastdeploy {
2427
#ifdef WITH_GPU
25-
void CudaAdaptivePool(const std::vector<int64_t>& input_dims,
26-
const std::vector<int64_t>& output_dims,
27-
float* output,
28-
const float* input,
29-
void* compute_stream,
30-
const std::string& pooling_type);
28+
#include "fastdeploy/backends/op_cuda_kernels/adaptive_pool2d_kernel.h"
3129
#endif
30+
31+
namespace fastdeploy {
3232
struct AdaptivePool2dKernel {
3333
protected:
3434
std::string pooling_type_ = "avg";
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "adaptive_pool2d.h"
16+
17+
namespace fastdeploy {
18+
19+
nvinfer1::PluginFieldCollection AdaptivePool2dPluginCreator::mFC{};
20+
std::vector<nvinfer1::PluginField> AdaptivePool2dPluginCreator::mPluginAttributes;
21+
22+
pluginStatus_t AdaptivePool2dInference(cudaStream_t stream, int32_t n, const void* input, void* output);
23+
24+
AdaptivePool2d::AdaptivePool2d(std::vector<int32_t> output_size, std::string pooling_type) {
25+
output_size_ = output_size;
26+
pooling_type_ = pooling_type;
27+
}
28+
29+
AdaptivePool2d::AdaptivePool2d(const void* buffer, size_t length) {
30+
const char *d = reinterpret_cast<const char*>(buffer), *a = d;
31+
output_size_.resize(4);
32+
for(int64_t i =0 ; i < 4; i++){
33+
output_size_[i] =read<int32_t>(d);
34+
}
35+
if(read<int32_t>(d) == 0){
36+
pooling_type_ = "avg";
37+
}else{
38+
pooling_type_ = "max";
39+
}
40+
FDASSERT(d == a + length, "deserialize failed.");
41+
}
42+
43+
int AdaptivePool2d::getNbOutputs() const noexcept {
44+
return 1;
45+
}
46+
47+
nvinfer1::DimsExprs AdaptivePool2d::getOutputDimensions(
48+
int outputIndex, const nvinfer1::DimsExprs* inputs,
49+
int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept {
50+
try {
51+
nvinfer1::DimsExprs output(inputs[0]);
52+
output.d[2] = exprBuilder.constant(static_cast<int32_t>(output_size_[2]));
53+
output.d[3] = exprBuilder.constant(static_cast<int32_t>(output_size_[3]));
54+
return output;
55+
}
56+
catch (const std::exception& e) {
57+
FDASSERT(false, "getOutputDimensions failed: %s.",e.what());
58+
}
59+
return nvinfer1::DimsExprs{};
60+
}
61+
62+
int AdaptivePool2d::enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
63+
const nvinfer1::PluginTensorDesc* outputDesc,
64+
const void* const* inputs,
65+
void* const* outputs,
66+
void* workspace,
67+
cudaStream_t stream) noexcept {
68+
if (inputDesc[0].type != nvinfer1::DataType::kFLOAT) {
69+
return -1;
70+
}
71+
auto const* data = static_cast<float const*>(inputs[0]);
72+
auto* result = static_cast<float*>(outputs[0]);
73+
int nums = outputDesc[0].dims.d[0] * outputDesc[0].dims.d[1] * outputDesc[0].dims.d[2]* outputDesc[0].dims.d[3];
74+
std::vector<int64_t> input_size, output_size;
75+
for(int i =0; i< 4; i++){
76+
input_size.push_back(inputDesc[0].dims.d[i]);
77+
output_size.push_back(outputDesc[0].dims.d[i]);
78+
}
79+
CudaAdaptivePool(input_size, output_size, result, data, stream, pooling_type_);
80+
return cudaPeekAtLastError();
81+
}
82+
83+
size_t AdaptivePool2d::getSerializationSize() const noexcept {
84+
return 5 * sizeof(int32_t) ;
85+
}
86+
87+
void AdaptivePool2d::serialize(void* buffer) const noexcept {
88+
char *d = reinterpret_cast<char*>(buffer), *a = d;
89+
for(int64_t i=0; i< 4; i++){
90+
write(d, output_size_[i]);
91+
}
92+
int32_t pooling_type_val = 0;
93+
if(pooling_type_ != "avg"){
94+
pooling_type_val = 1;
95+
}
96+
write(d, pooling_type_val);
97+
FDASSERT(d == a + getSerializationSize(), "d == a + getSerializationSize()");
98+
}
99+
100+
nvinfer1::DataType AdaptivePool2d::getOutputDataType(
101+
int index, const nvinfer1::DataType* inputType, int nbInputs) const noexcept {
102+
return inputType[0];
103+
}
104+
105+
bool AdaptivePool2d::supportsFormatCombination(
106+
int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept {
107+
return (inOut[pos].format == nvinfer1::PluginFormat::kLINEAR);
108+
}
109+
110+
int AdaptivePool2d::initialize() noexcept {
111+
return 0;
112+
}
113+
114+
void AdaptivePool2d::terminate() noexcept {
115+
return;
116+
}
117+
118+
size_t AdaptivePool2d::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
119+
int nbInputs,
120+
const nvinfer1::PluginTensorDesc* outputs,
121+
int nbOutputs) const noexcept {
122+
return 0;
123+
}
124+
125+
const char* AdaptivePool2d::getPluginType() const noexcept {
126+
return "AdaptivePool2d";
127+
}
128+
129+
const char* AdaptivePool2d::getPluginVersion() const noexcept {
130+
return "1";
131+
}
132+
133+
void AdaptivePool2d::destroy() noexcept {
134+
return;
135+
}
136+
void AdaptivePool2d::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs,
137+
const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept {
138+
return;
139+
}
140+
nvinfer1::IPluginV2DynamicExt* AdaptivePool2d::clone() const noexcept {
141+
try{
142+
nvinfer1::IPluginV2DynamicExt* plugin = new AdaptivePool2d(output_size_, pooling_type_);
143+
plugin->setPluginNamespace(mNamespace.c_str());
144+
return plugin;
145+
}
146+
catch (std::exception const& e){
147+
FDASSERT(false, "clone failed: %s.",e.what());
148+
}
149+
return nullptr;
150+
}
151+
152+
AdaptivePool2dPluginCreator::AdaptivePool2dPluginCreator() {
153+
mPluginAttributes.clear();
154+
mPluginAttributes.emplace_back(nvinfer1::PluginField("output_size", nullptr, nvinfer1::PluginFieldType::kINT32, 4));
155+
mPluginAttributes.emplace_back(nvinfer1::PluginField("pooling_type", nullptr, nvinfer1::PluginFieldType::kCHAR, 3));
156+
157+
mFC.nbFields = mPluginAttributes.size();
158+
mFC.fields = mPluginAttributes.data();
159+
}
160+
161+
const char* AdaptivePool2dPluginCreator::getPluginName() const noexcept {
162+
return "AdaptivePool2d";
163+
}
164+
165+
const char* AdaptivePool2dPluginCreator::getPluginVersion() const noexcept {
166+
return "1";
167+
}
168+
169+
const nvinfer1::PluginFieldCollection* AdaptivePool2dPluginCreator::getFieldNames() noexcept {
170+
return &mFC;
171+
}
172+
173+
nvinfer1::IPluginV2DynamicExt* AdaptivePool2dPluginCreator::createPlugin(const char* name,
174+
const nvinfer1::PluginFieldCollection* fc) noexcept {
175+
try{
176+
const nvinfer1::PluginField* fields = fc->fields;
177+
auto const dims = static_cast<int32_t const*>(fields[0].data);
178+
output_size_.resize(4);
179+
for(int64_t i = 0; i < 4; i++){
180+
output_size_[i] = dims[i];
181+
}
182+
183+
const char* pooling_type_ptr = (static_cast<char const*>(fields[1].data));
184+
std::string pooling_type(pooling_type_ptr, 3);
185+
pooling_type_ = pooling_type;
186+
return new AdaptivePool2d(output_size_, pooling_type_);
187+
}
188+
catch (std::exception const& e){
189+
FDASSERT(false, "createPlugin failed: %s.",e.what());
190+
}
191+
return nullptr;
192+
}
193+
194+
nvinfer1::IPluginV2DynamicExt* AdaptivePool2dPluginCreator::deserializePlugin(const char* name,
195+
const void* serialData,
196+
size_t serialLength) noexcept {
197+
try{
198+
return new AdaptivePool2d(serialData, serialLength);
199+
}
200+
catch (std::exception const& e){
201+
FDASSERT(false, "deserializePlugin failed: %s.",e.what());
202+
}
203+
return nullptr;
204+
}
205+
206+
} // namespace fastdeploy

0 commit comments

Comments
 (0)