Important changes to repositories hosted on mbed.com
Mbed hosted mercurial repositories are deprecated and are due to be permanently deleted in July 2026.
To keep a copy of this software download the repository Zip archive or clone locally using Mercurial.
It is also possible to export all your personal repositories from the account settings page.
pooling.cc
00001 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 00002 00003 Licensed under the Apache License, Version 2.0 (the "License"); 00004 you may not use this file except in compliance with the License. 00005 You may obtain a copy of the License at 00006 00007 http://www.apache.org/licenses/LICENSE-2.0 00008 00009 Unless required by applicable law or agreed to in writing, software 00010 distributed under the License is distributed on an "AS IS" BASIS, 00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 00012 See the License for the specific language governing permissions and 00013 limitations under the License. 00014 ==============================================================================*/ 00015 #include "tensorflow/lite/kernels/internal/reference/pooling.h" 00016 00017 #include "tensorflow/lite/c/builtin_op_data.h" 00018 #include "tensorflow/lite/kernels/internal/reference/integer_ops/pooling.h" 00019 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" 00020 #include "tensorflow/lite/kernels/kernel_util.h" 00021 #include "tensorflow/lite/kernels/padding.h" 00022 00023 namespace tflite { 00024 namespace ops { 00025 namespace micro { 00026 namespace pooling { 00027 00028 namespace { 00029 00030 constexpr int kInputTensor = 0; 00031 constexpr int kOutputTensor = 0; 00032 00033 struct OpData { 00034 TfLitePaddingValues padding; 00035 }; 00036 00037 TfLiteStatus CalculateOpData(const TfLiteContext* context, 00038 const TfLitePoolParams* params, 00039 const TfLiteTensor* input, 00040 const TfLiteTensor* output, OpData* data) { 00041 // input: batch, height, width, channel 00042 int height = SizeOfDimension(input, 1); 00043 int width = SizeOfDimension(input, 2); 00044 00045 int out_height, out_width; 00046 00047 data->padding = ComputePaddingHeightWidth( 00048 params->stride_height, params->stride_width, 00049 /*dilation_rate_height=*/1, 00050 /*dilation_rate_width=*/1, height, width, params->filter_height, 00051 params->filter_width, params->padding, &out_height, &out_width); 00052 00053 return kTfLiteOk; 00054 } 00055 00056 void AverageEvalFloat(const TfLiteContext* context, const TfLiteNode* node, 00057 const TfLitePoolParams* params, const OpData* data, 00058 const TfLiteTensor* input, TfLiteTensor* output) { 00059 float activation_min, activation_max; 00060 CalculateActivationRange(params->activation, &activation_min, 00061 &activation_max); 00062 00063 PoolParams op_params; 00064 op_params.stride_height = params->stride_height; 00065 op_params.stride_width = params->stride_width; 00066 op_params.filter_height = params->filter_height; 00067 op_params.filter_width = params->filter_width; 00068 op_params.padding_values.height = data->padding.height; 00069 op_params.padding_values.width = data->padding.width; 00070 op_params.float_activation_min = activation_min; 00071 op_params.float_activation_max = activation_max; 00072 reference_ops::AveragePool( 00073 op_params, GetTensorShape(input), GetTensorData<float>(input), 00074 GetTensorShape(output), GetTensorData<float>(output)); 00075 } 00076 00077 void AverageEvalUint8(const TfLiteContext* context, const TfLiteNode* node, 00078 const TfLitePoolParams* params, const OpData* data, 00079 const TfLiteTensor* input, TfLiteTensor* output) { 00080 int32_t activation_min, activation_max; 00081 CalculateActivationRangeUint8(params->activation, output, &activation_min, 00082 &activation_max); 00083 00084 PoolParams op_params; 00085 op_params.stride_height = params->stride_height; 00086 op_params.stride_width = params->stride_width; 00087 op_params.filter_height = params->filter_height; 00088 op_params.filter_width = params->filter_width; 00089 op_params.padding_values.height = data->padding.height; 00090 op_params.padding_values.width = data->padding.width; 00091 op_params.quantized_activation_min = activation_min; 00092 op_params.quantized_activation_max = activation_max; 00093 reference_ops::AveragePool( 00094 op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), 00095 GetTensorShape(output), GetTensorData<uint8_t>(output)); 00096 } 00097 00098 void AverageEvalInt8(const TfLiteContext* context, const TfLiteNode* node, 00099 const TfLitePoolParams* params, const OpData* data, 00100 const TfLiteTensor* input, TfLiteTensor* output) { 00101 int32_t activation_min, activation_max; 00102 CalculateActivationRangeInt8(params->activation, output, &activation_min, 00103 &activation_max); 00104 00105 PoolParams op_params; 00106 op_params.stride_height = params->stride_height; 00107 op_params.stride_width = params->stride_width; 00108 op_params.filter_height = params->filter_height; 00109 op_params.filter_width = params->filter_width; 00110 op_params.padding_values.height = data->padding.height; 00111 op_params.padding_values.width = data->padding.width; 00112 op_params.quantized_activation_min = activation_min; 00113 op_params.quantized_activation_max = activation_max; 00114 reference_integer_ops::AveragePool( 00115 op_params, GetTensorShape(input), GetTensorData<int8_t>(input), 00116 GetTensorShape(output), GetTensorData<int8_t>(output)); 00117 } 00118 00119 void MaxEvalFloat(TfLiteContext* context, TfLiteNode* node, 00120 TfLitePoolParams* params, OpData* data, 00121 const TfLiteTensor* input, TfLiteTensor* output) { 00122 float activation_min, activation_max; 00123 CalculateActivationRange(params->activation, &activation_min, 00124 &activation_max); 00125 00126 tflite::PoolParams op_params; 00127 op_params.stride_height = params->stride_height; 00128 op_params.stride_width = params->stride_width; 00129 op_params.filter_height = params->filter_height; 00130 op_params.filter_width = params->filter_width; 00131 op_params.padding_values.height = data->padding.height; 00132 op_params.padding_values.width = data->padding.width; 00133 op_params.float_activation_min = activation_min; 00134 op_params.float_activation_max = activation_max; 00135 reference_ops::MaxPool(op_params, GetTensorShape(input), 00136 GetTensorData<float>(input), GetTensorShape(output), 00137 GetTensorData<float>(output)); 00138 } 00139 00140 void MaxEvalQuantizedUInt8(TfLiteContext* context, TfLiteNode* node, 00141 TfLitePoolParams* params, OpData* data, 00142 const TfLiteTensor* input, TfLiteTensor* output) { 00143 int32_t activation_min, activation_max; 00144 CalculateActivationRangeUint8(params->activation, output, &activation_min, 00145 &activation_max); 00146 00147 tflite::PoolParams op_params; 00148 op_params.stride_height = params->stride_height; 00149 op_params.stride_width = params->stride_width; 00150 op_params.filter_height = params->filter_height; 00151 op_params.filter_width = params->filter_width; 00152 op_params.padding_values.height = data->padding.height; 00153 op_params.padding_values.width = data->padding.width; 00154 op_params.quantized_activation_min = activation_min; 00155 op_params.quantized_activation_max = activation_max; 00156 reference_ops::MaxPool(op_params, GetTensorShape(input), 00157 GetTensorData<uint8_t>(input), GetTensorShape(output), 00158 GetTensorData<uint8_t>(output)); 00159 } 00160 00161 } // namespace 00162 00163 void* Init(TfLiteContext* context, const char* buffer, size_t length) { 00164 return nullptr; 00165 } 00166 00167 void Free(TfLiteContext* context, void* buffer) {} 00168 00169 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { 00170 return kTfLiteOk; 00171 } 00172 00173 TfLiteStatus AverageEval(TfLiteContext* context, TfLiteNode* node) { 00174 auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data); 00175 OpData data; 00176 00177 const TfLiteTensor* input = GetInput(context, node, kInputTensor); 00178 TfLiteTensor* output = GetOutput(context, node, kOutputTensor); 00179 00180 TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, input, output, &data)); 00181 00182 // Inputs and outputs share the same type, guarenteed by the converter. 00183 switch (input->type) { 00184 case kTfLiteFloat32: 00185 AverageEvalFloat(context, node, params, &data, input, output); 00186 break; 00187 case kTfLiteUInt8: 00188 AverageEvalUint8(context, node, params, &data, input, output); 00189 break; 00190 case kTfLiteInt8: 00191 AverageEvalInt8(context, node, params, &data, input, output); 00192 break; 00193 default: 00194 context->ReportError(context, "Input type %s is not currently supported", 00195 TfLiteTypeGetName(input->type)); 00196 return kTfLiteError; 00197 } 00198 return kTfLiteOk; 00199 } 00200 00201 TfLiteStatus MaxEval(TfLiteContext* context, TfLiteNode* node) { 00202 auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data); 00203 OpData data; 00204 00205 const TfLiteTensor* input = GetInput(context, node, kInputTensor); 00206 TfLiteTensor* output = GetOutput(context, node, kOutputTensor); 00207 00208 TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, input, output, &data)); 00209 00210 switch (input->type) { 00211 case kTfLiteFloat32: 00212 MaxEvalFloat(context, node, params, &data, input, output); 00213 break; 00214 case kTfLiteUInt8: 00215 MaxEvalQuantizedUInt8(context, node, params, &data, input, output); 00216 break; 00217 default: 00218 context->ReportError(context, "Type %s not currently supported.", 00219 TfLiteTypeGetName(input->type)); 00220 return kTfLiteError; 00221 } 00222 return kTfLiteOk; 00223 } 00224 00225 } // namespace pooling 00226 00227 TfLiteRegistration* Register_AVERAGE_POOL_2D() { 00228 static TfLiteRegistration r = { 00229 pooling::Init, 00230 pooling::Free, 00231 pooling::Prepare, 00232 pooling::AverageEval, 00233 }; 00234 return &r; 00235 } 00236 00237 TfLiteRegistration* Register_MAX_POOL_2D() { 00238 static TfLiteRegistration r = {pooling::Init, pooling::Free, pooling::Prepare, 00239 pooling::MaxEval}; 00240 return &r; 00241 } 00242 00243 } // namespace micro 00244 } // namespace ops 00245 } // namespace tflite
Generated on Wed Jul 13 2022 16:03:35 by
