Important changes to repositories hosted on mbed.com
Mbed hosted mercurial repositories are deprecated and are due to be permanently deleted in July 2026.
To keep a copy of this software download the repository Zip archive or clone locally using Mercurial.
It is also possible to export all your personal repositories from the account settings page.
elementwise.cc
00001 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 00002 00003 Licensed under the Apache License, Version 2.0 (the "License"); 00004 you may not use this file except in compliance with the License. 00005 You may obtain a copy of the License at 00006 00007 http://www.apache.org/licenses/LICENSE-2.0 00008 00009 Unless required by applicable law or agreed to in writing, software 00010 distributed under the License is distributed on an "AS IS" BASIS, 00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 00012 See the License for the specific language governing permissions and 00013 limitations under the License. 00014 ==============================================================================*/ 00015 00016 #include <cmath> 00017 00018 #include "tensorflow/lite/c/c_api_internal.h" 00019 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" 00020 #include "tensorflow/lite/kernels/kernel_util.h" 00021 00022 namespace tflite { 00023 namespace ops { 00024 namespace micro { 00025 namespace elementwise { 00026 namespace { 00027 00028 bool IsNumericSupportedType(const TfLiteType type) { 00029 return type == kTfLiteFloat32; 00030 } 00031 00032 bool IsLogicalSupportedType(const TfLiteType type) { 00033 return type == kTfLiteBool; 00034 } 00035 00036 typedef bool (*IsSupportedType)(TfLiteType); 00037 template <IsSupportedType> 00038 TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) { 00039 TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); 00040 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); 00041 const TfLiteTensor* input = GetInput(context, node, 0); 00042 TfLiteTensor* output = GetOutput(context, node, 0); 00043 TF_LITE_ENSURE_EQ(context, input->type, output->type); 00044 if (!IsSupportedType(input->type)) { 00045 context->ReportError(context, "Input data type %s (%d) is not supported.", 00046 TfLiteTypeGetName(input->type), input->type); 00047 return kTfLiteError; 00048 } 00049 return kTfLiteOk; 00050 } 00051 00052 template <typename T> 00053 inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node, 00054 T func(T), TfLiteType expected_type) { 00055 const TfLiteTensor* input = GetInput(context, node, 0); 00056 TfLiteTensor* output = GetOutput(context, node, 0); 00057 TF_LITE_ENSURE_EQ(context, input->type, expected_type); 00058 const int64_t num_elements = NumElements(input); 00059 const T* in_data = GetTensorData<T>(input); 00060 T* out_data = GetTensorData<T>(output); 00061 for (int64_t i = 0; i < num_elements; ++i) { 00062 out_data[i] = func(in_data[i]); 00063 } 00064 return kTfLiteOk; 00065 } 00066 00067 inline TfLiteStatus EvalNumeric(TfLiteContext* context, TfLiteNode* node, 00068 float float_func(float)) { 00069 return EvalImpl<float>(context, node, float_func, kTfLiteFloat32); 00070 } 00071 00072 inline TfLiteStatus EvalLogical(TfLiteContext* context, TfLiteNode* node, 00073 bool bool_func(bool)) { 00074 return EvalImpl<bool>(context, node, bool_func, kTfLiteBool); 00075 } 00076 00077 TfLiteStatus AbsEval(TfLiteContext* context, TfLiteNode* node) { 00078 return EvalNumeric(context, node, std::abs); 00079 } 00080 00081 TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) { 00082 return EvalNumeric(context, node, std::sin); 00083 } 00084 00085 TfLiteStatus CosEval(TfLiteContext* context, TfLiteNode* node) { 00086 return EvalNumeric(context, node, std::cos); 00087 } 00088 00089 TfLiteStatus LogEval(TfLiteContext* context, TfLiteNode* node) { 00090 return EvalNumeric(context, node, std::log); 00091 } 00092 00093 TfLiteStatus SqrtEval(TfLiteContext* context, TfLiteNode* node) { 00094 return EvalNumeric(context, node, std::sqrt); 00095 } 00096 00097 TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) { 00098 return EvalNumeric(context, node, [](float f) { return 1.f / std::sqrt(f); }); 00099 } 00100 00101 TfLiteStatus SquareEval(TfLiteContext* context, TfLiteNode* node) { 00102 return EvalNumeric(context, node, [](float f) { return f * f; }); 00103 } 00104 00105 TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) { 00106 return EvalLogical(context, node, [](bool v) { return !v; }); 00107 } 00108 00109 } // namespace 00110 } // namespace elementwise 00111 00112 TfLiteRegistration* Register_ABS() { 00113 static TfLiteRegistration r = { 00114 /* init */ nullptr, /* free */ nullptr, 00115 elementwise::GenericPrepare<elementwise::IsNumericSupportedType>, 00116 elementwise::AbsEval}; 00117 return &r; 00118 } 00119 00120 TfLiteRegistration* Register_SIN() { 00121 static TfLiteRegistration r = { 00122 /* init */ nullptr, /* free */ nullptr, 00123 elementwise::GenericPrepare<elementwise::IsNumericSupportedType>, 00124 elementwise::SinEval}; 00125 return &r; 00126 } 00127 00128 TfLiteRegistration* Register_COS() { 00129 static TfLiteRegistration r = { 00130 /* init */ nullptr, /* free */ nullptr, 00131 elementwise::GenericPrepare<elementwise::IsNumericSupportedType>, 00132 elementwise::CosEval}; 00133 return &r; 00134 } 00135 00136 TfLiteRegistration* Register_LOG() { 00137 static TfLiteRegistration r = { 00138 /* init */ nullptr, /* free */ nullptr, 00139 elementwise::GenericPrepare<elementwise::IsNumericSupportedType>, 00140 elementwise::LogEval}; 00141 return &r; 00142 } 00143 00144 TfLiteRegistration* Register_SQRT() { 00145 static TfLiteRegistration r = { 00146 /* init */ nullptr, /* free */ nullptr, 00147 elementwise::GenericPrepare<elementwise::IsNumericSupportedType>, 00148 elementwise::SqrtEval}; 00149 return &r; 00150 } 00151 00152 TfLiteRegistration* Register_RSQRT() { 00153 static TfLiteRegistration r = { 00154 /* init */ nullptr, /* free */ nullptr, 00155 elementwise::GenericPrepare<elementwise::IsNumericSupportedType>, 00156 elementwise::RsqrtEval}; 00157 return &r; 00158 } 00159 00160 TfLiteRegistration* Register_SQUARE() { 00161 static TfLiteRegistration r = { 00162 /* init */ nullptr, /* free */ nullptr, 00163 elementwise::GenericPrepare<elementwise::IsNumericSupportedType>, 00164 elementwise::SquareEval}; 00165 return &r; 00166 } 00167 00168 TfLiteRegistration* Register_LOGICAL_NOT() { 00169 static TfLiteRegistration r = { 00170 /*init=*/nullptr, /*free=*/nullptr, 00171 elementwise::GenericPrepare<elementwise::IsLogicalSupportedType>, 00172 elementwise::LogicalNotEval}; 00173 return &r; 00174 } 00175 00176 } // namespace micro 00177 } // namespace ops 00178 } // namespace tflite
Generated on Wed Jul 13 2022 16:03:35 by
