Daniel Konegen / MNIST_example

Dependencies:   mbed-os

Embed: (wiki syntax)

« Back to documentation index

Show/hide line numbers round.h Source File

round.h

00001 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
00002 
00003 Licensed under the Apache License, Version 2.0 (the "License");
00004 you may not use this file except in compliance with the License.
00005 You may obtain a copy of the License at
00006 
00007     http://www.apache.org/licenses/LICENSE-2.0
00008 
00009 Unless required by applicable law or agreed to in writing, software
00010 distributed under the License is distributed on an "AS IS" BASIS,
00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00012 See the License for the specific language governing permissions and
00013 limitations under the License.
00014 ==============================================================================*/
00015 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ROUND_H_
00016 #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ROUND_H_
00017 
00018 #include <cmath>
00019 
00020 #include "tensorflow/lite/kernels/internal/types.h"
00021 
00022 namespace tflite {
00023 
00024 namespace reference_ops {
00025 
00026 inline float RoundToNearest(float value) {
00027   auto floor_val = std::floor(value);
00028   auto diff = value - floor_val;
00029   if ((diff < 0.5f) ||
00030       ((diff == 0.5f) && (static_cast<int>(floor_val) % 2 == 0))) {
00031     return floor_val;
00032   } else {
00033     return floor_val = floor_val + 1.0f;
00034   }
00035 }
00036 
00037 inline void Round(const RuntimeShape& input_shape, const float* input_data,
00038                   const RuntimeShape& output_shape, float* output_data) {
00039   const int flat_size = MatchingFlatSize(input_shape, output_shape);
00040   for (int i = 0; i < flat_size; ++i) {
00041     // Note that this implementation matches that of tensorFlow tf.round
00042     // and corresponds to the bankers rounding method.
00043     // cfenv (for fesetround) is not yet supported universally on Android, so
00044     // using a work around.
00045     output_data[i] = RoundToNearest(input_data[i]);
00046   }
00047 }
00048 
00049 }  // namespace reference_ops
00050 }  // namespace tflite
00051 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ROUND_H_