Daniel Konegen / MNIST_example

Dependencies:   mbed-os

Embed: (wiki syntax)

« Back to documentation index

Show/hide line numbers pooling.cc Source File

pooling.cc

00001 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
00002 
00003 Licensed under the Apache License, Version 2.0 (the "License");
00004 you may not use this file except in compliance with the License.
00005 You may obtain a copy of the License at
00006 
00007     http://www.apache.org/licenses/LICENSE-2.0
00008 
00009 Unless required by applicable law or agreed to in writing, software
00010 distributed under the License is distributed on an "AS IS" BASIS,
00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00012 See the License for the specific language governing permissions and
00013 limitations under the License.
00014 ==============================================================================*/
00015 #include "tensorflow/lite/kernels/internal/reference/pooling.h"
00016 
00017 #include "tensorflow/lite/c/builtin_op_data.h"
00018 #include "tensorflow/lite/kernels/internal/reference/integer_ops/pooling.h"
00019 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
00020 #include "tensorflow/lite/kernels/kernel_util.h"
00021 #include "tensorflow/lite/kernels/padding.h"
00022 
00023 namespace tflite {
00024 namespace ops {
00025 namespace micro {
00026 namespace pooling {
00027 
00028 namespace {
00029 
00030 constexpr int kInputTensor = 0;
00031 constexpr int kOutputTensor = 0;
00032 
00033 struct OpData {
00034   TfLitePaddingValues padding;
00035 };
00036 
00037 TfLiteStatus CalculateOpData(const TfLiteContext* context,
00038                              const TfLitePoolParams* params,
00039                              const TfLiteTensor* input,
00040                              const TfLiteTensor* output, OpData* data) {
00041   // input: batch, height, width, channel
00042   int height = SizeOfDimension(input, 1);
00043   int width = SizeOfDimension(input, 2);
00044 
00045   int out_height, out_width;
00046 
00047   data->padding = ComputePaddingHeightWidth(
00048       params->stride_height, params->stride_width,
00049       /*dilation_rate_height=*/1,
00050       /*dilation_rate_width=*/1, height, width, params->filter_height,
00051       params->filter_width, params->padding, &out_height, &out_width);
00052 
00053   return kTfLiteOk;
00054 }
00055 
00056 void AverageEvalFloat(const TfLiteContext* context, const TfLiteNode* node,
00057                       const TfLitePoolParams* params, const OpData* data,
00058                       const TfLiteTensor* input, TfLiteTensor* output) {
00059   float activation_min, activation_max;
00060   CalculateActivationRange(params->activation, &activation_min,
00061                            &activation_max);
00062 
00063   PoolParams op_params;
00064   op_params.stride_height = params->stride_height;
00065   op_params.stride_width = params->stride_width;
00066   op_params.filter_height = params->filter_height;
00067   op_params.filter_width = params->filter_width;
00068   op_params.padding_values.height = data->padding.height;
00069   op_params.padding_values.width = data->padding.width;
00070   op_params.float_activation_min = activation_min;
00071   op_params.float_activation_max = activation_max;
00072   reference_ops::AveragePool(
00073       op_params, GetTensorShape(input), GetTensorData<float>(input),
00074       GetTensorShape(output), GetTensorData<float>(output));
00075 }
00076 
00077 void AverageEvalUint8(const TfLiteContext* context, const TfLiteNode* node,
00078                       const TfLitePoolParams* params, const OpData* data,
00079                       const TfLiteTensor* input, TfLiteTensor* output) {
00080   int32_t activation_min, activation_max;
00081   CalculateActivationRangeUint8(params->activation, output, &activation_min,
00082                                 &activation_max);
00083 
00084   PoolParams op_params;
00085   op_params.stride_height = params->stride_height;
00086   op_params.stride_width = params->stride_width;
00087   op_params.filter_height = params->filter_height;
00088   op_params.filter_width = params->filter_width;
00089   op_params.padding_values.height = data->padding.height;
00090   op_params.padding_values.width = data->padding.width;
00091   op_params.quantized_activation_min = activation_min;
00092   op_params.quantized_activation_max = activation_max;
00093   reference_ops::AveragePool(
00094       op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
00095       GetTensorShape(output), GetTensorData<uint8_t>(output));
00096 }
00097 
00098 void AverageEvalInt8(const TfLiteContext* context, const TfLiteNode* node,
00099                      const TfLitePoolParams* params, const OpData* data,
00100                      const TfLiteTensor* input, TfLiteTensor* output) {
00101   int32_t activation_min, activation_max;
00102   CalculateActivationRangeInt8(params->activation, output, &activation_min,
00103                                &activation_max);
00104 
00105   PoolParams op_params;
00106   op_params.stride_height = params->stride_height;
00107   op_params.stride_width = params->stride_width;
00108   op_params.filter_height = params->filter_height;
00109   op_params.filter_width = params->filter_width;
00110   op_params.padding_values.height = data->padding.height;
00111   op_params.padding_values.width = data->padding.width;
00112   op_params.quantized_activation_min = activation_min;
00113   op_params.quantized_activation_max = activation_max;
00114   reference_integer_ops::AveragePool(
00115       op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
00116       GetTensorShape(output), GetTensorData<int8_t>(output));
00117 }
00118 
00119 void MaxEvalFloat(TfLiteContext* context, TfLiteNode* node,
00120                   TfLitePoolParams* params, OpData* data,
00121                   const TfLiteTensor* input, TfLiteTensor* output) {
00122   float activation_min, activation_max;
00123   CalculateActivationRange(params->activation, &activation_min,
00124                            &activation_max);
00125 
00126   tflite::PoolParams op_params;
00127   op_params.stride_height = params->stride_height;
00128   op_params.stride_width = params->stride_width;
00129   op_params.filter_height = params->filter_height;
00130   op_params.filter_width = params->filter_width;
00131   op_params.padding_values.height = data->padding.height;
00132   op_params.padding_values.width = data->padding.width;
00133   op_params.float_activation_min = activation_min;
00134   op_params.float_activation_max = activation_max;
00135   reference_ops::MaxPool(op_params, GetTensorShape(input),
00136                          GetTensorData<float>(input), GetTensorShape(output),
00137                          GetTensorData<float>(output));
00138 }
00139 
00140 void MaxEvalQuantizedUInt8(TfLiteContext* context, TfLiteNode* node,
00141                            TfLitePoolParams* params, OpData* data,
00142                            const TfLiteTensor* input, TfLiteTensor* output) {
00143   int32_t activation_min, activation_max;
00144   CalculateActivationRangeUint8(params->activation, output, &activation_min,
00145                                 &activation_max);
00146 
00147   tflite::PoolParams op_params;
00148   op_params.stride_height = params->stride_height;
00149   op_params.stride_width = params->stride_width;
00150   op_params.filter_height = params->filter_height;
00151   op_params.filter_width = params->filter_width;
00152   op_params.padding_values.height = data->padding.height;
00153   op_params.padding_values.width = data->padding.width;
00154   op_params.quantized_activation_min = activation_min;
00155   op_params.quantized_activation_max = activation_max;
00156   reference_ops::MaxPool(op_params, GetTensorShape(input),
00157                          GetTensorData<uint8_t>(input), GetTensorShape(output),
00158                          GetTensorData<uint8_t>(output));
00159 }
00160 
00161 }  // namespace
00162 
00163 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
00164   return nullptr;
00165 }
00166 
00167 void Free(TfLiteContext* context, void* buffer) {}
00168 
00169 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
00170   return kTfLiteOk;
00171 }
00172 
00173 TfLiteStatus AverageEval(TfLiteContext* context, TfLiteNode* node) {
00174   auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);
00175   OpData data;
00176 
00177   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
00178   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
00179 
00180   TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, input, output, &data));
00181 
00182   // Inputs and outputs share the same type, guarenteed by the converter.
00183   switch (input->type) {
00184     case kTfLiteFloat32:
00185       AverageEvalFloat(context, node, params, &data, input, output);
00186       break;
00187     case kTfLiteUInt8:
00188       AverageEvalUint8(context, node, params, &data, input, output);
00189       break;
00190     case kTfLiteInt8:
00191       AverageEvalInt8(context, node, params, &data, input, output);
00192       break;
00193     default:
00194       context->ReportError(context, "Input type %s is not currently supported",
00195                            TfLiteTypeGetName(input->type));
00196       return kTfLiteError;
00197   }
00198   return kTfLiteOk;
00199 }
00200 
00201 TfLiteStatus MaxEval(TfLiteContext* context, TfLiteNode* node) {
00202   auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);
00203   OpData data;
00204 
00205   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
00206   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
00207 
00208   TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, input, output, &data));
00209 
00210   switch (input->type) {
00211     case kTfLiteFloat32:
00212       MaxEvalFloat(context, node, params, &data, input, output);
00213       break;
00214     case kTfLiteUInt8:
00215       MaxEvalQuantizedUInt8(context, node, params, &data, input, output);
00216       break;
00217     default:
00218       context->ReportError(context, "Type %s not currently supported.",
00219                            TfLiteTypeGetName(input->type));
00220       return kTfLiteError;
00221   }
00222   return kTfLiteOk;
00223 }
00224 
00225 }  // namespace pooling
00226 
00227 TfLiteRegistration* Register_AVERAGE_POOL_2D() {
00228   static TfLiteRegistration r = {
00229       pooling::Init,
00230       pooling::Free,
00231       pooling::Prepare,
00232       pooling::AverageEval,
00233   };
00234   return &r;
00235 }
00236 
00237 TfLiteRegistration* Register_MAX_POOL_2D() {
00238   static TfLiteRegistration r = {pooling::Init, pooling::Free, pooling::Prepare,
00239                                  pooling::MaxEval};
00240   return &r;
00241 }
00242 
00243 }  // namespace micro
00244 }  // namespace ops
00245 }  // namespace tflite