Important changes to repositories hosted on mbed.com
Mbed hosted mercurial repositories are deprecated and are due to be permanently deleted in July 2026.
To keep a copy of this software download the repository Zip archive or clone locally using Mercurial.
It is also possible to export all your personal repositories from the account settings page.
fully_connected.cc
00001 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 00002 00003 Licensed under the Apache License, Version 2.0 (the "License"); 00004 you may not use this file except in compliance with the License. 00005 You may obtain a copy of the License at 00006 00007 http://www.apache.org/licenses/LICENSE-2.0 00008 00009 Unless required by applicable law or agreed to in writing, software 00010 distributed under the License is distributed on an "AS IS" BASIS, 00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 00012 See the License for the specific language governing permissions and 00013 limitations under the License. 00014 ==============================================================================*/ 00015 00016 #include "tensorflow/lite/kernels/internal/reference/fully_connected.h" 00017 00018 #include "tensorflow/lite/c/builtin_op_data.h" 00019 #include "tensorflow/lite/c/c_api_internal.h" 00020 #include "tensorflow/lite/kernels/internal/common.h" 00021 #include "tensorflow/lite/kernels/internal/quantization_util.h" 00022 #include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h" 00023 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" 00024 #include "tensorflow/lite/kernels/kernel_util.h" 00025 00026 namespace tflite { 00027 namespace ops { 00028 namespace micro { 00029 namespace fully_connected { 00030 namespace { 00031 00032 struct OpData { 00033 // The scaling factor from input to output (aka the 'real multiplier') can 00034 // be represented as a fixed point multiplier plus a left shift. 00035 int32_t output_multiplier; 00036 int output_shift; 00037 // The range of the fused activation layer. For example for kNone and 00038 // uint8_t these would be 0 and 255. 00039 int32_t output_activation_min; 00040 int32_t output_activation_max; 00041 // The index of the temporary tensor where the quantized inputs are cached. 00042 int input_quantized_index; 00043 }; 00044 00045 constexpr int kInputTensor = 0; 00046 constexpr int kWeightsTensor = 1; 00047 constexpr int kBiasTensor = 2; 00048 constexpr int kOutputTensor = 0; 00049 00050 TfLiteStatus CalculateOpData(TfLiteContext* context, 00051 TfLiteFullyConnectedParams* params, 00052 TfLiteType data_type, const TfLiteTensor* input, 00053 const TfLiteTensor* filter, 00054 const TfLiteTensor* bias, TfLiteTensor* output, 00055 OpData* data) { 00056 TfLiteStatus status = kTfLiteOk; 00057 if (data_type != kTfLiteFloat32) { 00058 double real_multiplier = 0.0; 00059 TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( 00060 context, input, filter, bias, output, &real_multiplier)); 00061 int exponent; 00062 QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent); 00063 data->output_shift = -exponent; 00064 TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized( 00065 context, params->activation, output, &data->output_activation_min, 00066 &data->output_activation_max)); 00067 } 00068 return status; 00069 } 00070 00071 } // namespace 00072 00073 void* Init(TfLiteContext* context, const char* buffer, size_t length) { 00074 return nullptr; 00075 } 00076 00077 void Free(TfLiteContext* context, void* buffer) {} 00078 00079 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { 00080 return kTfLiteOk; 00081 } 00082 00083 TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node, 00084 TfLiteFullyConnectedParams* params, OpData* data, 00085 const TfLiteTensor* input, 00086 const TfLiteTensor* filter, 00087 const TfLiteTensor* bias, TfLiteTensor* output) { 00088 FullyConnectedParams op_params; 00089 op_params.input_offset = -input->params.zero_point; 00090 op_params.weights_offset = -filter->params.zero_point; 00091 op_params.output_offset = output->params.zero_point; 00092 op_params.output_multiplier = data->output_multiplier; 00093 // TODO(b/138810107): Figure out whether output shift should be inverted 00094 op_params.output_shift = -data->output_shift; 00095 op_params.quantized_activation_min = data->output_activation_min; 00096 op_params.quantized_activation_max = data->output_activation_max; 00097 00098 reference_integer_ops::FullyConnected( 00099 op_params, GetTensorShape(input), GetTensorData<int8_t>(input), 00100 GetTensorShape(filter), GetTensorData<int8_t>(filter), 00101 GetTensorShape(bias), GetTensorData<int32_t>(bias), 00102 GetTensorShape(output), GetTensorData<int8_t>(output)); 00103 return kTfLiteOk; 00104 } 00105 00106 TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, 00107 TfLiteFullyConnectedParams* params, OpData* data, 00108 const TfLiteTensor* input, 00109 const TfLiteTensor* filter, const TfLiteTensor* bias, 00110 TfLiteTensor* output) { 00111 const int32_t input_offset = -input->params.zero_point; 00112 const int32_t filter_offset = -filter->params.zero_point; 00113 const int32_t output_offset = output->params.zero_point; 00114 00115 tflite::FullyConnectedParams op_params; 00116 op_params.input_offset = input_offset; 00117 op_params.weights_offset = filter_offset; 00118 op_params.output_offset = output_offset; 00119 op_params.output_multiplier = data->output_multiplier; 00120 // Legacy ops used mixed left and right shifts. Now all are +ve-means-left. 00121 op_params.output_shift = -data->output_shift; 00122 op_params.quantized_activation_min = data->output_activation_min; 00123 op_params.quantized_activation_max = data->output_activation_max; 00124 00125 #define TF_LITE_FULLY_CONNECTED(output_data_type) \ 00126 reference_ops::FullyConnected( \ 00127 op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), \ 00128 GetTensorShape(filter), GetTensorData<uint8_t>(filter), \ 00129 GetTensorShape(bias), GetTensorData<int32_t>(bias), \ 00130 GetTensorShape(output), GetTensorData<output_data_type>(output)) 00131 switch (output->type) { 00132 case kTfLiteUInt8: 00133 TF_LITE_FULLY_CONNECTED(uint8_t); 00134 break; 00135 case kTfLiteInt16: 00136 TF_LITE_FULLY_CONNECTED(int16_t); 00137 break; 00138 default: 00139 context->ReportError( 00140 context, 00141 "Quantized FullyConnected expects output data type uint8 or int16"); 00142 return kTfLiteError; 00143 } 00144 00145 return kTfLiteOk; 00146 } 00147 00148 TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node, 00149 TfLiteFullyConnectedParams* params, OpData* data, 00150 const TfLiteTensor* input, const TfLiteTensor* filter, 00151 const TfLiteTensor* bias, TfLiteTensor* output) { 00152 float output_activation_min, output_activation_max; 00153 CalculateActivationRange(params->activation, &output_activation_min, 00154 &output_activation_max); 00155 tflite::FullyConnectedParams op_params; 00156 op_params.float_activation_min = output_activation_min; 00157 op_params.float_activation_max = output_activation_max; 00158 tflite::reference_ops::FullyConnected( 00159 op_params, GetTensorShape(input), GetTensorData<float>(input), 00160 GetTensorShape(filter), GetTensorData<float>(filter), 00161 GetTensorShape(bias), GetTensorData<float>(bias), GetTensorShape(output), 00162 GetTensorData<float>(output)); 00163 return kTfLiteOk; 00164 } 00165 00166 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { 00167 auto* params = 00168 reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data); 00169 00170 const TfLiteTensor* input = GetInput(context, node, kInputTensor); 00171 const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor); 00172 const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); 00173 TfLiteTensor* output = GetOutput(context, node, kOutputTensor); 00174 00175 TfLiteType data_type = input->type; 00176 OpData local_data_object; 00177 OpData* data = &local_data_object; 00178 TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, data_type, input, 00179 filter, bias, output, data)); 00180 00181 switch (filter->type) { // Already know in/out types are same. 00182 case kTfLiteFloat32: 00183 return EvalFloat(context, node, params, data, input, filter, bias, 00184 output); 00185 case kTfLiteInt8: 00186 return EvalQuantizedInt8(context, node, params, data, input, filter, bias, 00187 output); 00188 00189 case kTfLiteUInt8: 00190 return EvalQuantized(context, node, params, data, input, filter, bias, 00191 output); 00192 00193 default: 00194 context->ReportError(context, "Type %d not currently supported.", 00195 filter->type); 00196 return kTfLiteError; 00197 } 00198 return kTfLiteOk; 00199 } 00200 00201 } // namespace fully_connected 00202 00203 TfLiteRegistration* Register_FULLY_CONNECTED() { 00204 static TfLiteRegistration r = {fully_connected::Init, fully_connected::Free, 00205 fully_connected::Prepare, 00206 fully_connected::Eval}; 00207 return &r; 00208 } 00209 00210 } // namespace micro 00211 } // namespace ops 00212 } // namespace tflite
Generated on Wed Jul 13 2022 16:03:35 by
1.7.2