Skip to content

Commit 9d35a74

Browse files
authored
Automated sync from github.com/tensorflow/tensorflow (#3100)
BUG=automated sync from upstream NO_CHECK_TFLITE_FILES=automated sync from upstream
1 parent 15b55b7 commit 9d35a74

File tree

4 files changed

+18
-15
lines changed

4 files changed

+18
-15
lines changed

tensorflow/lite/kernels/internal/reference/floor.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,15 @@ namespace tflite {
2323

2424
namespace reference_ops {
2525

26-
inline void Floor(const RuntimeShape& input_shape, const float* input_data,
27-
const RuntimeShape& output_shape, float* output_data) {
26+
template <typename T>
27+
inline void Floor(const RuntimeShape& input_shape, const T* input_data,
28+
const RuntimeShape& output_shape, T* output_data) {
2829
const int flat_size = MatchingFlatSize(input_shape, output_shape);
2930

3031
for (int i = 0; i < flat_size; i++) {
3132
int offset = i;
32-
output_data[offset] = std::floor(input_data[offset]);
33+
output_data[offset] =
34+
static_cast<T>(std::floor(static_cast<float>(input_data[offset])));
3335
}
3436
}
3537

tensorflow/lite/kernels/internal/reference/logistic.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,9 @@ limitations under the License.
2727
namespace tflite {
2828
namespace reference_ops {
2929

30-
inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
31-
const RuntimeShape& output_shape, float* output_data) {
30+
template <typename T>
31+
inline void Logistic(const RuntimeShape& input_shape, const T* input_data,
32+
const RuntimeShape& output_shape, T* output_data) {
3233
const float cutoff_upper = 16.619047164916992188f;
3334
const float cutoff_lower = -9.f;
3435

@@ -43,7 +44,7 @@ inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
4344
// optimized kernels. (check the definition of scalar_logistic_op<float>)
4445

4546
for (int i = 0; i < flat_size; i++) {
46-
float val = input_data[i];
47+
T val = input_data[i];
4748
float result;
4849
if (val > cutoff_upper) {
4950
result = 1.0f;
@@ -52,7 +53,7 @@ inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
5253
} else {
5354
result = 1.f / (1.f + std::exp(-val));
5455
}
55-
output_data[i] = result;
56+
output_data[i] = static_cast<T>(result);
5657
}
5758
}
5859

tensorflow/lite/kernels/internal/reference/round.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,16 @@ inline float RoundToNearest(float value) {
3434
}
3535
}
3636

37-
inline void Round(const RuntimeShape& input_shape, const float* input_data,
38-
const RuntimeShape& output_shape, float* output_data) {
37+
template <typename Scalar>
38+
inline void Round(const RuntimeShape& input_shape, const Scalar* input_data,
39+
const RuntimeShape& output_shape, Scalar* output_data) {
3940
const int flat_size = MatchingFlatSize(input_shape, output_shape);
4041
for (int i = 0; i < flat_size; ++i) {
4142
// Note that this implementation matches that of tensorFlow tf.round
4243
// and corresponds to the bankers rounding method.
4344
// cfenv (for fesetround) is not yet supported universally on Android, so
4445
// using a work around.
45-
output_data[i] = RoundToNearest(input_data[i]);
46+
output_data[i] = static_cast<Scalar>(RoundToNearest(input_data[i]));
4647
}
4748
}
4849

tensorflow/lite/kernels/internal/reference/tanh.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,13 @@ limitations under the License.
2626
namespace tflite {
2727
namespace reference_ops {
2828

29-
inline void Tanh(const RuntimeShape& input_shape, const float* input_data,
30-
const RuntimeShape& output_shape, float* output_data) {
29+
template <typename T>
30+
inline void Tanh(const RuntimeShape& input_shape, const T* input_data,
31+
const RuntimeShape& output_shape, T* output_data) {
3132
const int flat_size = MatchingFlatSize(input_shape, output_shape);
3233

3334
for (int i = 0; i < flat_size; i++) {
34-
float val = input_data[i];
35-
float result = std::tanh(val);
36-
output_data[i] = result;
35+
output_data[i] = static_cast<T>(std::tanh(input_data[i]));
3736
}
3837
}
3938

0 commit comments

Comments
 (0)