Daniel Konegen / MNIST_example

Dependencies:   mbed-os

Embed: (wiki syntax)

« Back to documentation index

Show/hide line numbers elementwise.cc Source File

elementwise.cc

00001 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
00002 
00003 Licensed under the Apache License, Version 2.0 (the "License");
00004 you may not use this file except in compliance with the License.
00005 You may obtain a copy of the License at
00006 
00007     http://www.apache.org/licenses/LICENSE-2.0
00008 
00009 Unless required by applicable law or agreed to in writing, software
00010 distributed under the License is distributed on an "AS IS" BASIS,
00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00012 See the License for the specific language governing permissions and
00013 limitations under the License.
00014 ==============================================================================*/
00015 
00016 #include <cmath>
00017 
00018 #include "tensorflow/lite/c/c_api_internal.h"
00019 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
00020 #include "tensorflow/lite/kernels/kernel_util.h"
00021 
00022 namespace tflite {
00023 namespace ops {
00024 namespace micro {
00025 namespace elementwise {
00026 namespace {
00027 
00028 bool IsNumericSupportedType(const TfLiteType type) {
00029   return type == kTfLiteFloat32;
00030 }
00031 
00032 bool IsLogicalSupportedType(const TfLiteType type) {
00033   return type == kTfLiteBool;
00034 }
00035 
00036 typedef bool (*IsSupportedType)(TfLiteType);
00037 template <IsSupportedType>
00038 TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
00039   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
00040   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
00041   const TfLiteTensor* input = GetInput(context, node, 0);
00042   TfLiteTensor* output = GetOutput(context, node, 0);
00043   TF_LITE_ENSURE_EQ(context, input->type, output->type);
00044   if (!IsSupportedType(input->type)) {
00045     context->ReportError(context, "Input data type %s (%d) is not supported.",
00046                          TfLiteTypeGetName(input->type), input->type);
00047     return kTfLiteError;
00048   }
00049   return kTfLiteOk;
00050 }
00051 
00052 template <typename T>
00053 inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
00054                              T func(T), TfLiteType expected_type) {
00055   const TfLiteTensor* input = GetInput(context, node, 0);
00056   TfLiteTensor* output = GetOutput(context, node, 0);
00057   TF_LITE_ENSURE_EQ(context, input->type, expected_type);
00058   const int64_t num_elements = NumElements(input);
00059   const T* in_data = GetTensorData<T>(input);
00060   T* out_data = GetTensorData<T>(output);
00061   for (int64_t i = 0; i < num_elements; ++i) {
00062     out_data[i] = func(in_data[i]);
00063   }
00064   return kTfLiteOk;
00065 }
00066 
00067 inline TfLiteStatus EvalNumeric(TfLiteContext* context, TfLiteNode* node,
00068                                 float float_func(float)) {
00069   return EvalImpl<float>(context, node, float_func, kTfLiteFloat32);
00070 }
00071 
00072 inline TfLiteStatus EvalLogical(TfLiteContext* context, TfLiteNode* node,
00073                                 bool bool_func(bool)) {
00074   return EvalImpl<bool>(context, node, bool_func, kTfLiteBool);
00075 }
00076 
00077 TfLiteStatus AbsEval(TfLiteContext* context, TfLiteNode* node) {
00078   return EvalNumeric(context, node, std::abs);
00079 }
00080 
00081 TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) {
00082   return EvalNumeric(context, node, std::sin);
00083 }
00084 
00085 TfLiteStatus CosEval(TfLiteContext* context, TfLiteNode* node) {
00086   return EvalNumeric(context, node, std::cos);
00087 }
00088 
00089 TfLiteStatus LogEval(TfLiteContext* context, TfLiteNode* node) {
00090   return EvalNumeric(context, node, std::log);
00091 }
00092 
00093 TfLiteStatus SqrtEval(TfLiteContext* context, TfLiteNode* node) {
00094   return EvalNumeric(context, node, std::sqrt);
00095 }
00096 
00097 TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) {
00098   return EvalNumeric(context, node, [](float f) { return 1.f / std::sqrt(f); });
00099 }
00100 
00101 TfLiteStatus SquareEval(TfLiteContext* context, TfLiteNode* node) {
00102   return EvalNumeric(context, node, [](float f) { return f * f; });
00103 }
00104 
00105 TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) {
00106   return EvalLogical(context, node, [](bool v) { return !v; });
00107 }
00108 
00109 }  // namespace
00110 }  // namespace elementwise
00111 
00112 TfLiteRegistration* Register_ABS() {
00113   static TfLiteRegistration r = {
00114       /* init */ nullptr, /* free */ nullptr,
00115       elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
00116       elementwise::AbsEval};
00117   return &r;
00118 }
00119 
00120 TfLiteRegistration* Register_SIN() {
00121   static TfLiteRegistration r = {
00122       /* init */ nullptr, /* free */ nullptr,
00123       elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
00124       elementwise::SinEval};
00125   return &r;
00126 }
00127 
00128 TfLiteRegistration* Register_COS() {
00129   static TfLiteRegistration r = {
00130       /* init */ nullptr, /* free */ nullptr,
00131       elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
00132       elementwise::CosEval};
00133   return &r;
00134 }
00135 
00136 TfLiteRegistration* Register_LOG() {
00137   static TfLiteRegistration r = {
00138       /* init */ nullptr, /* free */ nullptr,
00139       elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
00140       elementwise::LogEval};
00141   return &r;
00142 }
00143 
00144 TfLiteRegistration* Register_SQRT() {
00145   static TfLiteRegistration r = {
00146       /* init */ nullptr, /* free */ nullptr,
00147       elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
00148       elementwise::SqrtEval};
00149   return &r;
00150 }
00151 
00152 TfLiteRegistration* Register_RSQRT() {
00153   static TfLiteRegistration r = {
00154       /* init */ nullptr, /* free */ nullptr,
00155       elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
00156       elementwise::RsqrtEval};
00157   return &r;
00158 }
00159 
00160 TfLiteRegistration* Register_SQUARE() {
00161   static TfLiteRegistration r = {
00162       /* init */ nullptr, /* free */ nullptr,
00163       elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
00164       elementwise::SquareEval};
00165   return &r;
00166 }
00167 
00168 TfLiteRegistration* Register_LOGICAL_NOT() {
00169   static TfLiteRegistration r = {
00170       /*init=*/nullptr, /*free=*/nullptr,
00171       elementwise::GenericPrepare<elementwise::IsLogicalSupportedType>,
00172       elementwise::LogicalNotEval};
00173   return &r;
00174 }
00175 
00176 }  // namespace micro
00177 }  // namespace ops
00178 }  // namespace tflite