Skip to content

Commit 5555630

Browse files
committed
Add split functions
1 parent 6e2e646 commit 5555630

File tree

5 files changed

+321
-11
lines changed

5 files changed

+321
-11
lines changed

fastdeploy/function/split.cc

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "fastdeploy/function/split.h"
16+
#include "fastdeploy/utils/utils.h"
17+
#include <cstring>
18+
19+
namespace fastdeploy {
20+
namespace function {
21+
22+
/*
23+
* All tensors' dimension should be the same and the values of
24+
* each dimension must be the same, except the axis dimension.
25+
*/
26+
template <typename T> struct SplitFunctor {
27+
public:
28+
void operator()(const FDTensor& input,
29+
const std::vector<const FDTensor*>& ref_inputs, int axis,
30+
std::vector<FDTensor>* outputs) {
31+
if (input.Numel() == 0) {
32+
return;
33+
}
34+
35+
size_t num = outputs->size();
36+
37+
int input_rows = 1;
38+
auto dim_0 = ref_inputs[0]->Shape();
39+
for (int i = 0; i < axis; ++i) {
40+
input_rows *= dim_0[i];
41+
}
42+
43+
int input_cols = 0;
44+
45+
std::vector<int64_t> output_cols(outputs->size());
46+
for (size_t i = 0; i < num; ++i) {
47+
int t_cols = ref_inputs[i]->Numel() / input_rows;
48+
input_cols += t_cols;
49+
output_cols[i] = t_cols;
50+
}
51+
52+
// computation
53+
for (int k = 0; k < input_rows; ++k) {
54+
const T* src_ptr =
55+
reinterpret_cast<const T*>(input.Data()) + k * input_cols;
56+
int col_idx = 0;
57+
for (size_t j = 0; j < num; ++j) {
58+
int col_len = output_cols[j];
59+
auto* out_tensor = &(outputs->at(j));
60+
if (out_tensor != nullptr) {
61+
T* dst_ptr = reinterpret_cast<T*>(out_tensor->Data()) + k * col_len;
62+
std::memcpy(dst_ptr, src_ptr + col_idx, sizeof(T) * col_len);
63+
}
64+
col_idx += col_len;
65+
}
66+
}
67+
}
68+
};
69+
70+
inline int GetSplitAxisValue(const FDTensor& x, int axis) {
71+
int rank = x.Shape().size();
72+
FDASSERT(axis >= -rank && axis < rank,
73+
"The axis is expected to be in range of [%d, %d), but got %d", -rank,
74+
rank, axis);
75+
if (axis < 0) {
76+
axis = axis + rank;
77+
}
78+
return axis;
79+
}
80+
81+
void CreateSplitOutputs(const FDTensor& x,
82+
const std::vector<int>& sections_data,
83+
std::vector<FDTensor>* outs, int axis) {
84+
axis = GetSplitAxisValue(x, axis);
85+
auto input_axis_dim = x.Shape().at(axis);
86+
std::vector<int> sections_vec;
87+
const int unknow_dim_val = -1;
88+
int unknow_dim_idx = -1;
89+
int num_of_unknow = 0;
90+
int sum_of_section = 0;
91+
92+
for (size_t i = 0; i < sections_data.size(); ++i) {
93+
sections_vec.push_back(sections_data[i]);
94+
if (sections_data[i] == unknow_dim_val) {
95+
num_of_unknow++;
96+
unknow_dim_idx = i;
97+
} else {
98+
sum_of_section += sections_data[i];
99+
}
100+
}
101+
102+
FDASSERT(num_of_unknow <= 1,
103+
"Only one dimension value of Attr(num_or_sections) "
104+
"in SplitOp can be -1. "
105+
"But received Attr(num_or_sections) = [%s].",
106+
Str(sections_data).c_str());
107+
if (unknow_dim_idx != -1) {
108+
// for example, input shape = [4 ,5], axis = 1, sections = [2, 3, -1].
109+
// input_axis_dim = 5, sum_of_sections = 5.
110+
// the following check will fail.
111+
FDASSERT(sum_of_section < input_axis_dim,
112+
"Sum of Attr(num_or_sections) other than unknown section "
113+
"must be less than the input's "
114+
"size "
115+
"along the split dimension. But received Attr(num_or_sections) "
116+
"= [%s], input(X)'s shape = [%s], Attr(dim) = %d.",
117+
Str(sections_data).c_str(), Str(x.Shape()).c_str(), axis);
118+
sections_vec[unknow_dim_idx] = input_axis_dim - sum_of_section;
119+
} else {
120+
FDASSERT(sum_of_section == input_axis_dim,
121+
"Sum of Attr(num_or_sections) must be equal to the input's "
122+
"size "
123+
"along the split dimension. But received Attr(num_or_sections)"
124+
" = [%s], input(X)'s shape = [%s], Attr(dim) = %d.",
125+
Str(sections_data).c_str(), Str(x.Shape()).c_str(), axis);
126+
}
127+
// fill out dims
128+
std::vector<std::vector<int64_t>> out_dims(sections_vec.size(), x.Shape());
129+
for (size_t i = 0; i < sections_vec.size(); ++i) {
130+
out_dims[i][axis] = sections_vec[i];
131+
}
132+
for (size_t i = 0; i < sections_vec.size(); ++i) {
133+
(*outs)[i].Allocate(out_dims[i], x.Dtype());
134+
}
135+
}
136+
137+
template <typename T>
138+
void SplitKernel(const FDTensor& x, const std::vector<int>& section,
139+
std::vector<FDTensor>* outs, int axis) {
140+
size_t out_number = section.size();
141+
outs->resize(out_number);
142+
CreateSplitOutputs(x, section, outs, axis);
143+
144+
std::vector<const FDTensor*> shape_refer;
145+
for (size_t j = 0; j < outs->size(); ++j) {
146+
shape_refer.emplace_back(&((*outs)[j]));
147+
}
148+
SplitFunctor<T> functor;
149+
functor(x, shape_refer, axis, outs);
150+
}
151+
152+
void Split(const FDTensor& x, const std::vector<int>& num_or_sections,
153+
std::vector<FDTensor>* out, int axis) {
154+
FD_VISIT_ALL_TYPES(x.Dtype(), "Split", ([&] {
155+
SplitKernel<data_t>(x, num_or_sections, out, axis);
156+
}));
157+
}
158+
159+
} // namespace function
160+
} // namespace fastdeploy

fastdeploy/function/split.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "fastdeploy/core/fd_tensor.h"
18+
19+
namespace fastdeploy {
20+
namespace function {
21+
22+
/** Split the input tensor into multiple sub-Tensors.
23+
@param x The input tensor.
24+
@param num_or_sections f num_or_sections is an int, then num_or_sections
25+
indicates the number of equal sized sub-Tensors that the x will
26+
be divided into.
27+
@param out The output vector tensor which stores the result.
28+
@param axis Axis which will be splitted.
29+
*/
30+
31+
FASTDEPLOY_DECL void Split(const FDTensor& x,
32+
const std::vector<int>& num_or_sections,
33+
std::vector<FDTensor>* out, int axis = 0);
34+
35+
} // namespace function
36+
} // namespace fastdeploy

fastdeploy/utils/utils.cc

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,4 @@ std::vector<int64_t> GetStride(const std::vector<int64_t>& dims) {
5656
return result;
5757
}
5858

59-
std::string Str(const std::vector<int64_t>& shape) {
60-
std::ostringstream oss;
61-
oss << "[ " << shape[0];
62-
for (int i = 1; i < shape.size(); ++i) {
63-
oss << " ," << shape[i];
64-
}
65-
oss << " ]";
66-
return oss.str();
67-
}
68-
6959
} // namespace fastdeploy

fastdeploy/utils/utils.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include <numeric>
2323
#include <sstream>
2424
#include <string>
25+
#include <type_traits>
2526
#include <vector>
2627

2728
#if defined(_WIN32)
@@ -186,6 +187,16 @@ FASTDEPLOY_DECL bool ReadBinaryFromFile(const std::string& file,
186187
FASTDEPLOY_DECL std::vector<int64_t>
187188
GetStride(const std::vector<int64_t>& dims);
188189

189-
FASTDEPLOY_DECL std::string Str(const std::vector<int64_t>& shape);
190+
template <typename T, typename std::enable_if<std::is_integral<T>::value,
191+
bool>::type = true>
192+
std::string Str(const std::vector<T>& shape) {
193+
std::ostringstream oss;
194+
oss << "[ " << shape[0];
195+
for (int i = 1; i < shape.size(); ++i) {
196+
oss << " ," << shape[i];
197+
}
198+
oss << " ]";
199+
return oss.str();
200+
}
190201

191202
} // namespace fastdeploy

tests/function/test_split.cc

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "fastdeploy/core/fd_tensor.h"
16+
#include "fastdeploy/function/split.h"
17+
#include "glog/logging.h"
18+
#include "gtest_utils.h"
19+
#include "gtest/gtest.h"
20+
#include <array>
21+
#include <vector>
22+
23+
namespace fastdeploy {
24+
namespace function {
25+
26+
std::vector<float> CreateTestData() {
27+
// Shape: [2, 3, 4]
28+
std::vector<float> x_data = {
29+
0.8428625, 0.6461913, 0.13740455, 0.11430702, 0.659926, 0.535816,
30+
0.7429162, 0.8456049, 0.21228176, 0.29970083, 0.8621713, 0.40894133,
31+
0.12684688, 0.1566195, 0.42884097, 0.8476526, 0.2458633, 0.669046,
32+
0.87888306, 0.6762589, 0.666453, 0.32523027, 0.4139388, 0.8341406};
33+
return x_data;
34+
}
35+
36+
TEST(fastdeploy, split_axis0) {
37+
CheckShape check_shape;
38+
CheckData check_data;
39+
FDTensor x;
40+
std::vector<FDTensor> out;
41+
auto test_data = CreateTestData();
42+
x.SetExternalData({2, 3, 4}, FDDataType::FP32, test_data.data());
43+
44+
Split(x, {1, 1}, &out, 0);
45+
ASSERT_EQ(out.size(), 2);
46+
check_shape(out[0].Shape(), {1, 3, 4});
47+
check_shape(out[1].Shape(), {1, 3, 4});
48+
std::vector<float> result1 = {0.842862, 0.646191, 0.137405, 0.114307,
49+
0.659926, 0.535816, 0.742916, 0.845605,
50+
0.212282, 0.299701, 0.862171, 0.408941};
51+
std::vector<float> result2 = {0.126847, 0.15662, 0.428841, 0.847653,
52+
0.245863, 0.669046, 0.878883, 0.676259,
53+
0.666453, 0.32523, 0.413939, 0.834141};
54+
check_data(reinterpret_cast<const float*>(out[0].Data()), result1.data(),
55+
result1.size());
56+
check_data(reinterpret_cast<const float*>(out[1].Data()), result2.data(),
57+
result2.size());
58+
}
59+
60+
TEST(fastdeploy, split_axis1) {
61+
CheckShape check_shape;
62+
CheckData check_data;
63+
FDTensor x;
64+
std::vector<FDTensor> out;
65+
auto test_data = CreateTestData();
66+
x.SetExternalData({2, 3, 4}, FDDataType::FP32, test_data.data());
67+
68+
Split(x, {2, 1}, &out, 1);
69+
ASSERT_EQ(out.size(), 2);
70+
check_shape(out[0].Shape(), {2, 2, 4});
71+
check_shape(out[1].Shape(), {2, 1, 4});
72+
std::vector<float> result1 = {0.842862, 0.646191, 0.137405, 0.114307,
73+
0.659926, 0.535816, 0.742916, 0.845605,
74+
0.126847, 0.15662, 0.428841, 0.847653,
75+
0.245863, 0.669046, 0.878883, 0.676259};
76+
std::vector<float> result2 = {0.212282, 0.299701, 0.862171, 0.408941,
77+
0.666453, 0.32523, 0.413939, 0.834141};
78+
check_data(reinterpret_cast<const float*>(out[0].Data()), result1.data(),
79+
result1.size());
80+
check_data(reinterpret_cast<const float*>(out[1].Data()), result2.data(),
81+
result2.size());
82+
}
83+
84+
TEST(fastdeploy, split_axis2) {
85+
CheckShape check_shape;
86+
CheckData check_data;
87+
FDTensor x;
88+
std::vector<FDTensor> out;
89+
auto test_data = CreateTestData();
90+
x.SetExternalData({2, 3, 4}, FDDataType::FP32, test_data.data());
91+
92+
Split(x, {1, 2, 1}, &out, 2);
93+
ASSERT_EQ(out.size(), 3);
94+
check_shape(out[0].Shape(), {2, 3, 1});
95+
check_shape(out[1].Shape(), {2, 3, 2});
96+
check_shape(out[2].Shape(), {2, 3, 1});
97+
std::vector<float> result1 = {0.842862, 0.659926, 0.212282,
98+
0.126847, 0.245863, 0.666453};
99+
std::vector<float> result2 = {0.646191, 0.137405, 0.535816, 0.742916,
100+
0.299701, 0.862171, 0.15662, 0.428841,
101+
0.669046, 0.878883, 0.32523, 0.413939};
102+
std::vector<float> result3 = {0.114307, 0.845605, 0.408941,
103+
0.847653, 0.676259, 0.834141};
104+
check_data(reinterpret_cast<const float*>(out[0].Data()), result1.data(),
105+
result1.size());
106+
check_data(reinterpret_cast<const float*>(out[1].Data()), result2.data(),
107+
result2.size());
108+
check_data(reinterpret_cast<const float*>(out[2].Data()), result3.data(),
109+
result3.size());
110+
}
111+
112+
} // namespace function
113+
} // namespace fastdeploy

0 commit comments

Comments
 (0)