Important changes to repositories hosted on mbed.com
Mbed hosted mercurial repositories are deprecated and are due to be permanently deleted in July 2026.
To keep a copy of this software download the repository Zip archive or clone locally using Mercurial.
It is also possible to export all your personal repositories from the account settings page.
arg_min_max.h
00001 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 00002 00003 Licensed under the Apache License, Version 2.0 (the "License"); 00004 you may not use this file except in compliance with the License. 00005 You may obtain a copy of the License at 00006 00007 http://www.apache.org/licenses/LICENSE-2.0 00008 00009 Unless required by applicable law or agreed to in writing, software 00010 distributed under the License is distributed on an "AS IS" BASIS, 00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 00012 See the License for the specific language governing permissions and 00013 limitations under the License. 00014 ==============================================================================*/ 00015 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ARG_MIN_MAX_H_ 00016 #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ARG_MIN_MAX_H_ 00017 00018 #include "tensorflow/lite/kernels/internal/types.h" 00019 00020 namespace tflite { 00021 00022 namespace reference_ops { 00023 00024 template <typename T1, typename T2, typename T3, typename Cmp> 00025 void ArgMinMax(const RuntimeShape& input1_shape, const T1* input1_data, 00026 const T3* input2_data, const RuntimeShape& output_shape, 00027 T2* output_data, const Cmp& cmp) { 00028 TFLITE_DCHECK_GT(input1_shape.DimensionsCount(), 0); 00029 TFLITE_DCHECK_EQ(input1_shape.DimensionsCount() - 1, 00030 output_shape.DimensionsCount()); 00031 int axis = input2_data[0]; 00032 if (axis < 0) { 00033 axis += input1_shape.DimensionsCount(); 00034 } 00035 const int axis_size = input1_shape.Dims(axis); 00036 00037 int outer_size = 1; 00038 for (int i = 0; i < axis; ++i) { 00039 TFLITE_DCHECK_EQ(input1_shape.Dims(i), output_shape.Dims(i)); 00040 outer_size *= input1_shape.Dims(i); 00041 } 00042 00043 int inner_size = 1; 00044 const int dims_count = input1_shape.DimensionsCount(); 00045 for (int i = axis + 1; i < dims_count; ++i) { 00046 TFLITE_DCHECK_EQ(input1_shape.Dims(i), output_shape.Dims(i - 1)); 00047 inner_size *= input1_shape.Dims(i); 00048 } 00049 for (int outer = 0; outer < outer_size; ++outer) { 00050 for (int inner = 0; inner < inner_size; ++inner) { 00051 auto min_max_value = input1_data[outer * axis_size * inner_size + inner]; 00052 T2 min_max_index = 0; 00053 for (int i = 1; i < axis_size; ++i) { 00054 const auto& curr_value = 00055 input1_data[(outer * axis_size + i) * inner_size + inner]; 00056 if (cmp(curr_value, min_max_value)) { 00057 min_max_value = curr_value; 00058 min_max_index = static_cast<T2>(i); 00059 } 00060 } 00061 output_data[outer * inner_size + inner] = min_max_index; 00062 } 00063 } 00064 } 00065 } // namespace reference_ops 00066 } // namespace tflite 00067 00068 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ARG_MIN_MAX_H_
Generated on Wed Jul 13 2022 16:03:34 by
