Daniel Konegen / MNIST_example

Dependencies:   mbed-os

Embed: (wiki syntax)

« Back to documentation index

Show/hide line numbers arg_min_max.cc Source File

arg_min_max.cc

00001 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
00002 
00003 Licensed under the Apache License, Version 2.0 (the "License");
00004 you may not use this file except in compliance with the License.
00005 You may obtain a copy of the License at
00006 
00007     http://www.apache.org/licenses/LICENSE-2.0
00008 
00009 Unless required by applicable law or agreed to in writing, software
00010 distributed under the License is distributed on an "AS IS" BASIS,
00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00012 See the License for the specific language governing permissions and
00013 limitations under the License.
00014 ==============================================================================*/
00015 
00016 #include "tensorflow/lite/kernels/internal/reference/arg_min_max.h"
00017 
00018 #include "tensorflow/lite/c/builtin_op_data.h"
00019 #include "tensorflow/lite/c/c_api_internal.h"
00020 #include "tensorflow/lite/experimental/micro/kernels/micro_utils.h"
00021 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
00022 #include "tensorflow/lite/kernels/kernel_util.h"
00023 
00024 namespace tflite {
00025 namespace ops {
00026 namespace micro {
00027 namespace arg_min_max {
00028 
00029 constexpr int kInputTensor = 0;
00030 constexpr int kAxis = 1;
00031 constexpr int kOutputTensor = 0;
00032 
00033 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
00034   return kTfLiteOk;
00035 }
00036 
00037 template <typename T1, typename T2, typename T3>
00038 inline void ArgMinMaxHelper(const RuntimeShape& input1_shape,
00039                             const T1* input1_data, const T3* input2_data,
00040                             const RuntimeShape& output_shape, T2* output_data,
00041                             bool is_arg_max) {
00042   if (is_arg_max) {
00043     reference_ops::ArgMinMax(input1_shape, input1_data, input2_data,
00044                              output_shape, output_data, micro::Greater());
00045   } else {
00046     reference_ops::ArgMinMax(input1_shape, input1_data, input2_data,
00047                              output_shape, output_data, micro::Less());
00048   }
00049 }
00050 
00051 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) {
00052   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
00053   const TfLiteTensor* axis = GetInput(context, node, kAxis);
00054   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
00055 
00056 #define TF_LITE_ARG_MIN_MAX(data_type, axis_type, output_type)            \
00057   ArgMinMaxHelper(GetTensorShape(input), GetTensorData<data_type>(input), \
00058                   GetTensorData<axis_type>(axis), GetTensorShape(output), \
00059                   GetTensorData<output_type>(output), is_arg_max)
00060   if (axis->type == kTfLiteInt32) {
00061     if (output->type == kTfLiteInt32) {
00062       switch (input->type) {
00063         case kTfLiteFloat32:
00064           TF_LITE_ARG_MIN_MAX(float, int32_t, int32_t);
00065           break;
00066         case kTfLiteUInt8:
00067           TF_LITE_ARG_MIN_MAX(uint8_t, int32_t, int32_t);
00068           break;
00069         case kTfLiteInt8:
00070           TF_LITE_ARG_MIN_MAX(int8_t, int32_t, int32_t);
00071           break;
00072         default:
00073           context->ReportError(context,
00074                                "Only float32, uint8 and int8 are "
00075                                "supported currently, got %s.",
00076                                TfLiteTypeGetName(input->type));
00077           return kTfLiteError;
00078       }
00079     } else {
00080       context->ReportError(context,
00081                            "Only int32 are supported currently, got %s.",
00082                            TfLiteTypeGetName(output->type));
00083       return kTfLiteError;
00084     }
00085   } else {
00086     context->ReportError(context, "Only int32 are supported currently, got %s.",
00087                          TfLiteTypeGetName(axis->type));
00088     return kTfLiteError;
00089   }
00090 
00091 #undef TF_LITE_ARG_MIN_MAX
00092 
00093   return kTfLiteOk;
00094 }
00095 
00096 TfLiteStatus ArgMinEval(TfLiteContext* context, TfLiteNode* node) {
00097   return Eval(context, node, false);
00098 }
00099 
00100 TfLiteStatus ArgMaxEval(TfLiteContext* context, TfLiteNode* node) {
00101   return Eval(context, node, true);
00102 }
00103 
00104 }  // namespace arg_min_max
00105 
00106 TfLiteRegistration* Register_ARG_MAX() {
00107   static TfLiteRegistration r = {nullptr, nullptr, arg_min_max::Prepare,
00108                                  arg_min_max::ArgMaxEval};
00109   return &r;
00110 }
00111 
00112 TfLiteRegistration* Register_ARG_MIN() {
00113   static TfLiteRegistration r = {nullptr, nullptr, arg_min_max::Prepare,
00114                                  arg_min_max::ArgMinEval};
00115   return &r;
00116 }
00117 
00118 }  // namespace micro
00119 }  // namespace ops
00120 }  // namespace tflite