Daniel Konegen / MNIST_example

Dependencies:   mbed-os

Embed: (wiki syntax)

« Back to documentation index

Show/hide line numbers maximum_minimum.cc Source File

maximum_minimum.cc

00001 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
00002 
00003 Licensed under the Apache License, Version 2.0 (the "License");
00004 you may not use this file except in compliance with the License.
00005 You may obtain a copy of the License at
00006 
00007     http://www.apache.org/licenses/LICENSE-2.0
00008 
00009 Unless required by applicable law or agreed to in writing, software
00010 distributed under the License is distributed on an "AS IS" BASIS,
00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00012 See the License for the specific language governing permissions and
00013 limitations under the License.
00014 ==============================================================================*/
00015 
00016 #include "tensorflow/lite/kernels/internal/reference/maximum_minimum.h"
00017 
00018 #include "tensorflow/lite/c/builtin_op_data.h"
00019 #include "tensorflow/lite/c/c_api_internal.h"
00020 #include "tensorflow/lite/kernels/internal/common.h"
00021 #include "tensorflow/lite/kernels/internal/quantization_util.h"
00022 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
00023 #include "tensorflow/lite/kernels/kernel_util.h"
00024 #include "tensorflow/lite/kernels/op_macros.h"
00025 
00026 namespace tflite {
00027 namespace ops {
00028 namespace micro {
00029 namespace maximum_minimum {
00030 namespace {
00031 
00032 // This file has a reference implementation of TFMaximum/TFMinimum.
00033 enum KernelType {
00034   kReference,
00035 };
00036 
00037 constexpr int kInputTensor1 = 0;
00038 constexpr int kInputTensor2 = 1;
00039 constexpr int kOutputTensor = 0;
00040 
00041 struct OpContext {
00042   OpContext(TfLiteContext* context, TfLiteNode* node) {
00043     input1 = GetInput(context, node, kInputTensor1);
00044     input2 = GetInput(context, node, kInputTensor2);
00045     output = GetOutput(context, node, kOutputTensor);
00046   }
00047   const TfLiteTensor* input1;
00048   const TfLiteTensor* input2;
00049   TfLiteTensor* output;
00050 };
00051 
00052 struct MaximumOp {
00053   template <typename data_type>
00054   static data_type op(data_type el1, data_type el2) {
00055     return el1 > el2 ? el1 : el2;
00056   }
00057 };
00058 
00059 struct MinimumOp {
00060   template <typename data_type>
00061   static data_type op(data_type el1, data_type el2) {
00062     return el1 < el2 ? el1 : el2;
00063   }
00064 };
00065 
00066 }  // namespace
00067 
00068 template <typename data_type, typename op_type>
00069 void TFLiteOperation(TfLiteContext* context, TfLiteNode* node,
00070                      const OpContext& op_context) {
00071   reference_ops::MaximumMinimumBroadcast4DSlow(
00072       GetTensorShape(op_context.input1),
00073       GetTensorData<data_type>(op_context.input1),
00074       GetTensorShape(op_context.input2),
00075       GetTensorData<data_type>(op_context.input2),
00076       GetTensorShape(op_context.output),
00077       GetTensorData<data_type>(op_context.output),
00078       op_type::template op<data_type>);
00079 }
00080 
00081 template <KernelType kernel_type, typename OpType>
00082 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
00083   OpContext op_context(context, node);
00084 
00085   if (kernel_type == kReference) {
00086     switch (op_context.output->type) {
00087       case kTfLiteFloat32:
00088         TFLiteOperation<float, OpType>(context, node, op_context);
00089         break;
00090       case kTfLiteUInt8:
00091         TFLiteOperation<uint8_t, OpType>(context, node, op_context);
00092         break;
00093       case kTfLiteInt8:
00094         TFLiteOperation<int8_t, OpType>(context, node, op_context);
00095         break;
00096       case kTfLiteInt32:
00097         TFLiteOperation<int32_t, OpType>(context, node, op_context);
00098         break;
00099       case kTfLiteInt64:
00100         TFLiteOperation<int64_t, OpType>(context, node, op_context);
00101         break;
00102       default:
00103         context->ReportError(
00104             context, "Type %s (%d) is not supported by Maximum/Minimum.",
00105             TfLiteTypeGetName(op_context.output->type),
00106             op_context.output->type);
00107         return kTfLiteError;
00108     }
00109   } else {
00110     context->ReportError(context,
00111                          "Kernel type not supported by Maximum/Minimum.");
00112     return kTfLiteError;
00113   }
00114   return kTfLiteOk;
00115 }
00116 
00117 }  // namespace maximum_minimum
00118 
00119 TfLiteRegistration* Register_MAXIMUM() {
00120   static TfLiteRegistration r = {
00121       /* init */ nullptr,
00122       /* free */ nullptr,
00123       /* prepare */ nullptr,
00124       maximum_minimum::Eval<maximum_minimum::kReference,
00125                             maximum_minimum::MaximumOp>};
00126   return &r;
00127 }
00128 
00129 TfLiteRegistration* Register_MINIMUM() {
00130   static TfLiteRegistration r = {
00131       /* init */ nullptr,
00132       /* free */ nullptr,
00133       /* prepare */ nullptr,
00134       maximum_minimum::Eval<maximum_minimum::kReference,
00135                             maximum_minimum::MinimumOp>};
00136   return &r;
00137 }
00138 
00139 }  // namespace micro
00140 }  // namespace ops
00141 }  // namespace tflite