Daniel Konegen / MNIST_example

Dependencies:   mbed-os

Embed: (wiki syntax)

« Back to documentation index

Show/hide line numbers activations.cc Source File

activations.cc

00001 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
00002 
00003 Licensed under the Apache License, Version 2.0 (the "License");
00004 you may not use this file except in compliance with the License.
00005 You may obtain a copy of the License at
00006 
00007     http://www.apache.org/licenses/LICENSE-2.0
00008 
00009 Unless required by applicable law or agreed to in writing, software
00010 distributed under the License is distributed on an "AS IS" BASIS,
00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00012 See the License for the specific language governing permissions and
00013 limitations under the License.
00014 ==============================================================================*/
00015 
00016 #include "tensorflow/lite/c/builtin_op_data.h"
00017 #include "tensorflow/lite/c/c_api_internal.h"
00018 #include "tensorflow/lite/experimental/micro/micro_utils.h"
00019 #include "tensorflow/lite/kernels/internal/common.h"
00020 #include "tensorflow/lite/kernels/internal/quantization_util.h"
00021 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
00022 #include "tensorflow/lite/kernels/kernel_util.h"
00023 #include "tensorflow/lite/kernels/op_macros.h"
00024 
00025 namespace tflite {
00026 namespace ops {
00027 namespace micro {
00028 namespace activations {
00029 
00030 constexpr int kInputTensor = 0;
00031 constexpr int kOutputTensor = 0;
00032 
00033 template <typename Q>
00034 inline void ReluQuantized(int32_t lower, const RuntimeShape& input_shape,
00035                           const Q* input_data, const RuntimeShape& output_shape,
00036                           Q* output_data) {
00037   const int flat_size = MatchingFlatSize(input_shape, output_shape);
00038   for (int i = 0; i < flat_size; ++i) {
00039     const Q val = input_data[i];
00040     const Q clamped = val < lower ? lower : val;
00041     output_data[i] = clamped;
00042   }
00043 }
00044 
00045 inline void ReluFloat(const RuntimeShape& input_shape, const float* input_data,
00046                       const RuntimeShape& output_shape, float* output_data) {
00047   const int flat_size = MatchingFlatSize(input_shape, output_shape);
00048   for (int i = 0; i < flat_size; ++i) {
00049     const float val = input_data[i];
00050     const float lower = 0.0f;
00051     const float clamped = val < lower ? lower : val;
00052     output_data[i] = clamped;
00053   }
00054 }
00055 
00056 inline void Relu6Float(const RuntimeShape& input_shape, const float* input_data,
00057                        const RuntimeShape& output_shape, float* output_data) {
00058   const int flat_size = MatchingFlatSize(input_shape, output_shape);
00059   for (int i = 0; i < flat_size; ++i) {
00060     const float val = input_data[i];
00061     const float upper = 6.0f;
00062     const float lower = 0.0f;
00063     const float clamped = val > upper ? upper : val < lower ? lower : val;
00064     output_data[i] = clamped;
00065   }
00066 }
00067 
00068 template <typename Q>
00069 inline void Relu6Quantized(Q lower, Q upper, const RuntimeShape& input_shape,
00070                            const Q* input_data,
00071                            const RuntimeShape& output_shape, Q* output_data) {
00072   const int flat_size = MatchingFlatSize(input_shape, output_shape);
00073   for (int i = 0; i < flat_size; ++i) {
00074     const Q val = input_data[i];
00075     const Q clamped = val > upper ? upper : val < lower ? lower : val;
00076     output_data[i] = clamped;
00077   }
00078 }
00079 
00080 TfLiteStatus ReluPrepare(TfLiteContext* context, TfLiteNode* node) {
00081   return kTfLiteOk;
00082 }
00083 
00084 TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) {
00085   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
00086   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
00087 
00088   switch (input->type) {
00089     case kTfLiteFloat32: {
00090       ReluFloat(GetTensorShape(input), GetTensorData<float>(input),
00091                 GetTensorShape(output), GetTensorData<float>(output));
00092 
00093       return kTfLiteOk;
00094     }
00095     case kTfLiteInt8: {
00096       ReluQuantized<int8_t>(input->params.zero_point, GetTensorShape(input),
00097                             GetTensorData<int8_t>(input),
00098                             GetTensorShape(output),
00099                             GetTensorData<int8_t>(output));
00100       return kTfLiteOk;
00101     }
00102     case kTfLiteUInt8: {
00103       ReluQuantized<uint8_t>(input->params.zero_point, GetTensorShape(input),
00104                              GetTensorData<uint8_t>(input),
00105                              GetTensorShape(output),
00106                              GetTensorData<uint8_t>(output));
00107       return kTfLiteOk;
00108     }
00109     default: {
00110       context->ReportError(context,
00111                            "Only float32 is supported currently, got %s",
00112                            TfLiteTypeGetName(input->type));
00113       return kTfLiteError;
00114     }
00115   }
00116 }
00117 
00118 TfLiteStatus Relu6Prepare(TfLiteContext* context, TfLiteNode* node) {
00119   return kTfLiteOk;
00120 }
00121 
00122 TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) {
00123   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
00124   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
00125 
00126   switch (input->type) {
00127     case kTfLiteFloat32: {
00128       Relu6Float(GetTensorShape(input), GetTensorData<float>(input),
00129                  GetTensorShape(output), GetTensorData<float>(output));
00130 
00131       return kTfLiteOk;
00132     }
00133     case kTfLiteInt8: {
00134       const int8_t six = FloatToAsymmetricQuantizedInt8(
00135           6.0f, input->params.scale, input->params.zero_point);
00136       const int8_t zero = input->params.zero_point;
00137       Relu6Quantized<int8_t>(
00138           zero, six, GetTensorShape(input), GetTensorData<int8_t>(input),
00139           GetTensorShape(output), GetTensorData<int8_t>(output));
00140       return kTfLiteOk;
00141     }
00142     case kTfLiteUInt8: {
00143       const uint8_t six = FloatToAsymmetricQuantizedUInt8(
00144           6.0f, input->params.scale, input->params.zero_point);
00145       const uint8_t zero = input->params.zero_point;
00146       Relu6Quantized<uint8_t>(
00147           zero, six, GetTensorShape(input), GetTensorData<uint8_t>(input),
00148           GetTensorShape(output), GetTensorData<uint8_t>(output));
00149       return kTfLiteOk;
00150     }
00151     default: {
00152       context->ReportError(context,
00153                            "Only float32 is supported currently, got %s",
00154                            TfLiteTypeGetName(input->type));
00155       return kTfLiteError;
00156     }
00157   }
00158 }
00159 
00160 }  // namespace activations
00161 
00162 TfLiteRegistration* Register_RELU() {
00163   static TfLiteRegistration r = {/*init=*/nullptr,
00164                                  /*free=*/nullptr, activations::ReluPrepare,
00165                                  activations::ReluEval};
00166   return &r;
00167 }
00168 
00169 TfLiteRegistration* Register_RELU6() {
00170   static TfLiteRegistration r = {/*init=*/nullptr,
00171                                  /*free=*/nullptr, activations::Relu6Prepare,
00172                                  activations::Relu6Eval};
00173   return &r;
00174 }
00175 
00176 }  // namespace micro
00177 }  // namespace ops
00178 }  // namespace tflite