Daniel Konegen / MNIST_example

Dependencies:   mbed-os

Embed: (wiki syntax)

« Back to documentation index

Show/hide line numbers arg_min_max.h Source File

arg_min_max.h

00001 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
00002 
00003 Licensed under the Apache License, Version 2.0 (the "License");
00004 you may not use this file except in compliance with the License.
00005 You may obtain a copy of the License at
00006 
00007     http://www.apache.org/licenses/LICENSE-2.0
00008 
00009 Unless required by applicable law or agreed to in writing, software
00010 distributed under the License is distributed on an "AS IS" BASIS,
00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00012 See the License for the specific language governing permissions and
00013 limitations under the License.
00014 ==============================================================================*/
00015 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ARG_MIN_MAX_H_
00016 #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ARG_MIN_MAX_H_
00017 
00018 #include "tensorflow/lite/kernels/internal/types.h"
00019 
00020 namespace tflite {
00021 
00022 namespace reference_ops {
00023 
00024 template <typename T1, typename T2, typename T3, typename Cmp>
00025 void ArgMinMax(const RuntimeShape& input1_shape, const T1* input1_data,
00026                const T3* input2_data, const RuntimeShape& output_shape,
00027                T2* output_data, const Cmp& cmp) {
00028   TFLITE_DCHECK_GT(input1_shape.DimensionsCount(), 0);
00029   TFLITE_DCHECK_EQ(input1_shape.DimensionsCount() - 1,
00030                    output_shape.DimensionsCount());
00031   int axis = input2_data[0];
00032   if (axis < 0) {
00033     axis += input1_shape.DimensionsCount();
00034   }
00035   const int axis_size = input1_shape.Dims(axis);
00036 
00037   int outer_size = 1;
00038   for (int i = 0; i < axis; ++i) {
00039     TFLITE_DCHECK_EQ(input1_shape.Dims(i), output_shape.Dims(i));
00040     outer_size *= input1_shape.Dims(i);
00041   }
00042 
00043   int inner_size = 1;
00044   const int dims_count = input1_shape.DimensionsCount();
00045   for (int i = axis + 1; i < dims_count; ++i) {
00046     TFLITE_DCHECK_EQ(input1_shape.Dims(i), output_shape.Dims(i - 1));
00047     inner_size *= input1_shape.Dims(i);
00048   }
00049   for (int outer = 0; outer < outer_size; ++outer) {
00050     for (int inner = 0; inner < inner_size; ++inner) {
00051       auto min_max_value = input1_data[outer * axis_size * inner_size + inner];
00052       T2 min_max_index = 0;
00053       for (int i = 1; i < axis_size; ++i) {
00054         const auto& curr_value =
00055             input1_data[(outer * axis_size + i) * inner_size + inner];
00056         if (cmp(curr_value, min_max_value)) {
00057           min_max_value = curr_value;
00058           min_max_index = static_cast<T2>(i);
00059         }
00060       }
00061       output_data[outer * inner_size + inner] = min_max_index;
00062     }
00063   }
00064 }
00065 }  // namespace reference_ops
00066 }  // namespace tflite
00067 
00068 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ARG_MIN_MAX_H_