Skip to content

Commit 12af6b4

Browse files
committed
Add abs functions
1 parent a14c550 commit 12af6b4

File tree

4 files changed

+34
-0
lines changed

4 files changed

+34
-0
lines changed

fastdeploy/function/math.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ DEFINE_ACTIVATION_KERNEL(Sqrt, SqrtFunctor)
3939
DEFINE_ACTIVATION_KERNEL(Log, LogFunctor)
4040
DEFINE_ACTIVATION_KERNEL(Round, RoundFunctor)
4141
DEFINE_ACTIVATION_KERNEL(Exp, ExpFunctor)
42+
DEFINE_ACTIVATION_KERNEL(Abs, AbsFunctor)
4243

4344
void Sqrt(const FDTensor& x, FDTensor* out) {
4445
FD_VISIT_FLOAT_TYPES(x.dtype, "SqrtKernel",
@@ -60,5 +61,10 @@ void Exp(const FDTensor& x, FDTensor* out) {
6061
([&] { ExpKernel<data_t>(x, out); }));
6162
}
6263

64+
void Abs(const FDTensor& x, FDTensor* out) {
65+
FD_VISIT_FLOAT_TYPES(x.dtype, "AbsKernel",
66+
([&] { AbsKernel<data_t>(x, out); }));
67+
}
68+
6369
} // namespace function
6470
} // namespace fastdeploy

fastdeploy/function/math.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,5 +43,11 @@ FASTDEPLOY_DECL void Round(const FDTensor& x, FDTensor* out);
4343
*/
4444
FASTDEPLOY_DECL void Exp(const FDTensor& x, FDTensor* out);
4545

46+
/** This operator is used to perform elementwise abs for input X. Only for float type FDTensor
47+
@param x The input tensor.
48+
@param out The output tensor which stores the result.
49+
*/
50+
FASTDEPLOY_DECL void Abs(const FDTensor& x, FDTensor* out);
51+
4652
} // namespace function
4753
} // namespace fastdeploy

fastdeploy/function/math_functor.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,5 +52,14 @@ template <typename T> struct SqrtFunctor {
5252
}
5353
};
5454

55+
// abs(x) = x if x > 0 else -x
56+
template <typename T> struct AbsFunctor {
57+
template <typename Device, typename X, typename Out>
58+
void operator()(Device d, X x, Out out) const {
59+
out.device(d) =
60+
x.unaryExpr([](T v) { return v > static_cast<T>(0) ? v : -v; });
61+
}
62+
};
63+
5564
} // namespace function
5665
} // namespace fastdeploy

tests/function/test_math.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,5 +83,18 @@ TEST(fastdeploy, exp_sqrt_round_log) {
8383
log_result.size());
8484
}
8585

86+
TEST(fastdeploy, abs) {
87+
CheckShape check_shape;
88+
CheckData check_data;
89+
FDTensor x, y;
90+
std::vector<float> test_data = {-1, 2, 3, -5, -4, -6};
91+
x.SetExternalData({2, 3}, FDDataType::FP32, test_data.data());
92+
std::vector<float> result = {1, 2, 3, 5, 4, 6};
93+
Abs(x, &y);
94+
check_shape(y.shape, {2, 3});
95+
check_data(reinterpret_cast<const float*>(y.Data()), result.data(),
96+
result.size());
97+
}
98+
8699
} // namespace function
87100
} // namespace fastdeploy

0 commit comments

Comments
 (0)