Important changes to repositories hosted on mbed.com
Mbed hosted mercurial repositories are deprecated and are due to be permanently deleted in July 2026.
To keep a copy of this software download the repository Zip archive or clone locally using Mercurial.
It is also possible to export all your personal repositories from the account settings page.
arg_min_max.cc
00001 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 00002 00003 Licensed under the Apache License, Version 2.0 (the "License"); 00004 you may not use this file except in compliance with the License. 00005 You may obtain a copy of the License at 00006 00007 http://www.apache.org/licenses/LICENSE-2.0 00008 00009 Unless required by applicable law or agreed to in writing, software 00010 distributed under the License is distributed on an "AS IS" BASIS, 00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 00012 See the License for the specific language governing permissions and 00013 limitations under the License. 00014 ==============================================================================*/ 00015 00016 #include "tensorflow/lite/kernels/internal/reference/arg_min_max.h" 00017 00018 #include "tensorflow/lite/c/builtin_op_data.h" 00019 #include "tensorflow/lite/c/c_api_internal.h" 00020 #include "tensorflow/lite/experimental/micro/kernels/micro_utils.h" 00021 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" 00022 #include "tensorflow/lite/kernels/kernel_util.h" 00023 00024 namespace tflite { 00025 namespace ops { 00026 namespace micro { 00027 namespace arg_min_max { 00028 00029 constexpr int kInputTensor = 0; 00030 constexpr int kAxis = 1; 00031 constexpr int kOutputTensor = 0; 00032 00033 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { 00034 return kTfLiteOk; 00035 } 00036 00037 template <typename T1, typename T2, typename T3> 00038 inline void ArgMinMaxHelper(const RuntimeShape& input1_shape, 00039 const T1* input1_data, const T3* input2_data, 00040 const RuntimeShape& output_shape, T2* output_data, 00041 bool is_arg_max) { 00042 if (is_arg_max) { 00043 reference_ops::ArgMinMax(input1_shape, input1_data, input2_data, 00044 output_shape, output_data, micro::Greater()); 00045 } else { 00046 reference_ops::ArgMinMax(input1_shape, input1_data, input2_data, 00047 output_shape, output_data, micro::Less()); 00048 } 00049 } 00050 00051 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) { 00052 const TfLiteTensor* input = GetInput(context, node, kInputTensor); 00053 const TfLiteTensor* axis = GetInput(context, node, kAxis); 00054 TfLiteTensor* output = GetOutput(context, node, kOutputTensor); 00055 00056 #define TF_LITE_ARG_MIN_MAX(data_type, axis_type, output_type) \ 00057 ArgMinMaxHelper(GetTensorShape(input), GetTensorData<data_type>(input), \ 00058 GetTensorData<axis_type>(axis), GetTensorShape(output), \ 00059 GetTensorData<output_type>(output), is_arg_max) 00060 if (axis->type == kTfLiteInt32) { 00061 if (output->type == kTfLiteInt32) { 00062 switch (input->type) { 00063 case kTfLiteFloat32: 00064 TF_LITE_ARG_MIN_MAX(float, int32_t, int32_t); 00065 break; 00066 case kTfLiteUInt8: 00067 TF_LITE_ARG_MIN_MAX(uint8_t, int32_t, int32_t); 00068 break; 00069 case kTfLiteInt8: 00070 TF_LITE_ARG_MIN_MAX(int8_t, int32_t, int32_t); 00071 break; 00072 default: 00073 context->ReportError(context, 00074 "Only float32, uint8 and int8 are " 00075 "supported currently, got %s.", 00076 TfLiteTypeGetName(input->type)); 00077 return kTfLiteError; 00078 } 00079 } else { 00080 context->ReportError(context, 00081 "Only int32 are supported currently, got %s.", 00082 TfLiteTypeGetName(output->type)); 00083 return kTfLiteError; 00084 } 00085 } else { 00086 context->ReportError(context, "Only int32 are supported currently, got %s.", 00087 TfLiteTypeGetName(axis->type)); 00088 return kTfLiteError; 00089 } 00090 00091 #undef TF_LITE_ARG_MIN_MAX 00092 00093 return kTfLiteOk; 00094 } 00095 00096 TfLiteStatus ArgMinEval(TfLiteContext* context, TfLiteNode* node) { 00097 return Eval(context, node, false); 00098 } 00099 00100 TfLiteStatus ArgMaxEval(TfLiteContext* context, TfLiteNode* node) { 00101 return Eval(context, node, true); 00102 } 00103 00104 } // namespace arg_min_max 00105 00106 TfLiteRegistration* Register_ARG_MAX() { 00107 static TfLiteRegistration r = {nullptr, nullptr, arg_min_max::Prepare, 00108 arg_min_max::ArgMaxEval}; 00109 return &r; 00110 } 00111 00112 TfLiteRegistration* Register_ARG_MIN() { 00113 static TfLiteRegistration r = {nullptr, nullptr, arg_min_max::Prepare, 00114 arg_min_max::ArgMinEval}; 00115 return &r; 00116 } 00117 00118 } // namespace micro 00119 } // namespace ops 00120 } // namespace tflite
Generated on Wed Jul 13 2022 16:03:34 by
1.7.2