Important changes to repositories hosted on mbed.com
Mbed hosted mercurial repositories are deprecated and are due to be permanently deleted in July 2026.
To keep a copy of this software download the repository Zip archive or clone locally using Mercurial.
It is also possible to export all your personal repositories from the account settings page.
conv.cc
00001 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 00002 00003 Licensed under the Apache License, Version 2.0 (the "License"); 00004 you may not use this file except in compliance with the License. 00005 You may obtain a copy of the License at 00006 00007 http://www.apache.org/licenses/LICENSE-2.0 00008 00009 Unless required by applicable law or agreed to in writing, software 00010 distributed under the License is distributed on an "AS IS" BASIS, 00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 00012 See the License for the specific language governing permissions and 00013 limitations under the License. 00014 ==============================================================================*/ 00015 00016 #include "tensorflow/lite/kernels/internal/reference/conv.h" 00017 00018 #include "tensorflow/lite/c/builtin_op_data.h" 00019 #include "tensorflow/lite/c/c_api_internal.h" 00020 #include "tensorflow/lite/kernels/internal/common.h" 00021 #include "tensorflow/lite/kernels/internal/quantization_util.h" 00022 #include "tensorflow/lite/kernels/internal/reference/integer_ops/conv.h" 00023 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" 00024 #include "tensorflow/lite/kernels/kernel_util.h" 00025 #include "tensorflow/lite/kernels/padding.h" 00026 00027 namespace tflite { 00028 namespace ops { 00029 namespace micro { 00030 namespace conv { 00031 00032 constexpr int kInputTensor = 0; 00033 constexpr int kFilterTensor = 1; 00034 constexpr int kBiasTensor = 2; 00035 constexpr int kOutputTensor = 0; 00036 constexpr int kMaxChannels = 256; 00037 00038 // This file has 2 implementation of Conv. 00039 00040 const int kTensorNotAllocated = -1; 00041 00042 struct OpData { 00043 TfLitePaddingValues padding; 00044 // The scaling factor from input to output (aka the 'real multiplier') can 00045 // be represented as a fixed point multiplier plus a left shift. 00046 int32_t output_multiplier; 00047 int output_shift; 00048 00049 // Per channel output multiplier and shift. 00050 // TODO(b/141139247): Allocate these dynamically when possible. 00051 int32_t per_channel_output_multiplier[kMaxChannels]; 00052 int32_t per_channel_output_shift[kMaxChannels]; 00053 00054 // The range of the fused activation layer. For example for kNone and 00055 // uint8_t these would be 0 and 255. 00056 int32_t output_activation_min; 00057 int32_t output_activation_max; 00058 }; 00059 00060 inline PaddingType RuntimePaddingType(TfLitePadding padding) { 00061 switch (padding) { 00062 case TfLitePadding::kTfLitePaddingSame: 00063 return PaddingType::kSame; 00064 case TfLitePadding::kTfLitePaddingValid: 00065 return PaddingType::kValid; 00066 case TfLitePadding::kTfLitePaddingUnknown: 00067 default: 00068 return PaddingType::kNone; 00069 } 00070 } 00071 00072 TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node, 00073 TfLiteConvParams* params, int width, int height, 00074 int filter_width, int filter_height, int out_width, 00075 int out_height, const TfLiteType data_type, 00076 OpData* data) { 00077 bool has_bias = node->inputs->size == 3; 00078 // Check number of inputs/outputs 00079 TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2); 00080 TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); 00081 00082 // Matching GetWindowedOutputSize in TensorFlow. 00083 auto padding = params->padding; 00084 data->padding = ComputePaddingHeightWidth( 00085 params->stride_height, params->stride_width, 00086 params->dilation_height_factor, params->dilation_width_factor, height, 00087 width, filter_height, filter_width, padding, &out_height, &out_width); 00088 00089 // Note that quantized inference requires that all tensors have their 00090 // parameters set. This is usually done during quantized training. 00091 if (data_type != kTfLiteFloat32) { 00092 const TfLiteTensor* input = GetInput(context, node, kInputTensor); 00093 const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); 00094 const TfLiteTensor* bias = 00095 GetOptionalInputTensor(context, node, kBiasTensor); 00096 TfLiteTensor* output = GetOutput(context, node, kOutputTensor); 00097 00098 TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams( 00099 context, input, filter, bias, output, params->activation, 00100 &data->output_multiplier, &data->output_shift, 00101 &data->output_activation_min, &data->output_activation_max, 00102 data->per_channel_output_multiplier, 00103 reinterpret_cast<int*>(data->per_channel_output_shift))); 00104 } 00105 return kTfLiteOk; 00106 } 00107 00108 void* Init(TfLiteContext* context, const char* buffer, size_t length) { 00109 return nullptr; 00110 } 00111 00112 void Free(TfLiteContext* context, void* buffer) {} 00113 00114 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { 00115 return kTfLiteOk; 00116 } 00117 00118 void EvalQuantized(TfLiteContext* context, TfLiteNode* node, 00119 TfLiteConvParams* params, OpData* data, 00120 const TfLiteTensor* input, const TfLiteTensor* filter, 00121 const TfLiteTensor* bias, TfLiteTensor* im2col, 00122 TfLiteTensor* hwcn_weights, TfLiteTensor* output) { 00123 const int32_t input_offset = -input->params.zero_point; 00124 const int32_t filter_offset = -filter->params.zero_point; 00125 const int32_t output_offset = output->params.zero_point; 00126 00127 ConvParams op_params; 00128 op_params.padding_type = RuntimePaddingType(params->padding); 00129 op_params.padding_values.width = data->padding.width; 00130 op_params.padding_values.height = data->padding.height; 00131 op_params.stride_width = params->stride_width; 00132 op_params.stride_height = params->stride_height; 00133 op_params.dilation_width_factor = params->dilation_width_factor; 00134 op_params.dilation_height_factor = params->dilation_height_factor; 00135 op_params.input_offset = input_offset; 00136 op_params.weights_offset = filter_offset; 00137 op_params.output_offset = output_offset; 00138 op_params.output_multiplier = data->output_multiplier; 00139 op_params.output_shift = -data->output_shift; 00140 op_params.quantized_activation_min = data->output_activation_min; 00141 op_params.quantized_activation_max = data->output_activation_max; 00142 reference_ops::Conv(op_params, GetTensorShape(input), 00143 GetTensorData<uint8_t>(input), GetTensorShape(filter), 00144 GetTensorData<uint8_t>(filter), GetTensorShape(bias), 00145 GetTensorData<int32_t>(bias), GetTensorShape(output), 00146 GetTensorData<uint8_t>(output), GetTensorShape(im2col), 00147 GetTensorData<uint8_t>(im2col), nullptr); 00148 } 00149 00150 void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node, 00151 TfLiteConvParams* params, OpData* data, 00152 const TfLiteTensor* input, 00153 const TfLiteTensor* filter, 00154 const TfLiteTensor* bias, TfLiteTensor* output, 00155 TfLiteTensor* im2col) { 00156 ConvParams op_params; 00157 op_params.input_offset = -input->params.zero_point; 00158 op_params.output_offset = output->params.zero_point; 00159 op_params.stride_height = params->stride_height; 00160 op_params.stride_width = params->stride_width; 00161 op_params.dilation_height_factor = params->dilation_height_factor; 00162 op_params.dilation_width_factor = params->dilation_width_factor; 00163 op_params.padding_values.height = data->padding.height; 00164 op_params.padding_values.width = data->padding.width; 00165 00166 reference_integer_ops::ConvPerChannel( 00167 op_params, data->per_channel_output_multiplier, 00168 data->per_channel_output_shift, GetTensorShape(input), 00169 GetTensorData<int8>(input), GetTensorShape(filter), 00170 GetTensorData<int8>(filter), GetTensorShape(bias), 00171 GetTensorData<int32>(bias), GetTensorShape(output), 00172 GetTensorData<int8>(output)); 00173 } 00174 00175 void EvalFloat(TfLiteContext* context, TfLiteNode* node, 00176 TfLiteConvParams* params, OpData* data, 00177 const TfLiteTensor* input, const TfLiteTensor* filter, 00178 const TfLiteTensor* bias, TfLiteTensor* im2col, 00179 TfLiteTensor* hwcn_weights, TfLiteTensor* output) { 00180 float output_activation_min, output_activation_max; 00181 CalculateActivationRange(params->activation, &output_activation_min, 00182 &output_activation_max); 00183 00184 ConvParams op_params; 00185 op_params.padding_type = RuntimePaddingType(params->padding); 00186 op_params.padding_values.width = data->padding.width; 00187 op_params.padding_values.height = data->padding.height; 00188 op_params.stride_width = params->stride_width; 00189 op_params.stride_height = params->stride_height; 00190 op_params.dilation_width_factor = params->dilation_width_factor; 00191 op_params.dilation_height_factor = params->dilation_height_factor; 00192 op_params.float_activation_min = output_activation_min; 00193 op_params.float_activation_max = output_activation_max; 00194 00195 reference_ops::Conv(op_params, GetTensorShape(input), 00196 GetTensorData<float>(input), GetTensorShape(filter), 00197 GetTensorData<float>(filter), GetTensorShape(bias), 00198 GetTensorData<float>(bias), GetTensorShape(output), 00199 GetTensorData<float>(output), GetTensorShape(im2col), 00200 GetTensorData<float>(im2col)); 00201 } 00202 00203 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { 00204 auto* params = reinterpret_cast<TfLiteConvParams*>(node->builtin_data); 00205 00206 TfLiteTensor* output = GetOutput(context, node, kOutputTensor); 00207 const TfLiteTensor* input = GetInput(context, node, kInputTensor); 00208 const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); 00209 const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); 00210 00211 int input_width = input->dims->data[2]; 00212 int input_height = input->dims->data[1]; 00213 int filter_width = filter->dims->data[2]; 00214 int filter_height = filter->dims->data[1]; 00215 int output_width = output->dims->data[2]; 00216 int output_height = output->dims->data[1]; 00217 00218 OpData data; 00219 00220 // All per-channel quantized tensors need valid zero point and scale arrays. 00221 if (input->type == kTfLiteInt8) { 00222 TF_LITE_ENSURE_EQ(context, filter->quantization.type, 00223 kTfLiteAffineQuantization); 00224 00225 const auto* affine_quantization = 00226 reinterpret_cast<TfLiteAffineQuantization*>( 00227 filter->quantization.params); 00228 TF_LITE_ENSURE(context, affine_quantization); 00229 TF_LITE_ENSURE(context, affine_quantization->scale); 00230 TF_LITE_ENSURE(context, affine_quantization->zero_point); 00231 // Conv is quantized along dimension 0: 00232 // https://www.tensorflow.org/lite/performance/quantization_spec 00233 TF_LITE_ENSURE_EQ(context, filter->dims->data[0], 00234 affine_quantization->scale->size); 00235 TF_LITE_ENSURE_EQ(context, filter->dims->data[0], 00236 affine_quantization->zero_point->size); 00237 } 00238 00239 TF_LITE_ENSURE_STATUS(CalculateOpData( 00240 context, node, params, input_width, input_height, filter_width, 00241 filter_height, output_width, output_height, input->type, &data)); 00242 00243 switch (input->type) { // Already know in/out types are same. 00244 case kTfLiteFloat32: 00245 EvalFloat(context, node, params, &data, input, filter, bias, nullptr, 00246 nullptr, output); 00247 break; 00248 case kTfLiteInt8: 00249 EvalQuantizedPerChannel(context, node, params, &data, input, filter, bias, 00250 output, nullptr); 00251 break; 00252 case kTfLiteUInt8: 00253 EvalQuantized(context, node, params, &data, input, filter, bias, nullptr, 00254 nullptr, output); 00255 break; 00256 default: 00257 context->ReportError(context, "Type %s (%d) not supported.", 00258 TfLiteTypeGetName(input->type), input->type); 00259 return kTfLiteError; 00260 } 00261 return kTfLiteOk; 00262 } 00263 00264 } // namespace conv 00265 00266 TfLiteRegistration* Register_CONV_2D() { 00267 static TfLiteRegistration r = {conv::Init, conv::Free, conv::Prepare, 00268 conv::Eval}; 00269 return &r; 00270 } 00271 00272 } // namespace micro 00273 } // namespace ops 00274 } // namespace tflite
Generated on Wed Jul 13 2022 16:03:35 by
1.7.2