Daniel Konegen / MNIST_example

Dependencies:   mbed-os

Embed: (wiki syntax)

« Back to documentation index

Show/hide line numbers fully_connected.cc Source File

fully_connected.cc

00001 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
00002 
00003 Licensed under the Apache License, Version 2.0 (the "License");
00004 you may not use this file except in compliance with the License.
00005 You may obtain a copy of the License at
00006 
00007     http://www.apache.org/licenses/LICENSE-2.0
00008 
00009 Unless required by applicable law or agreed to in writing, software
00010 distributed under the License is distributed on an "AS IS" BASIS,
00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00012 See the License for the specific language governing permissions and
00013 limitations under the License.
00014 ==============================================================================*/
00015 
00016 #include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
00017 
00018 #include "tensorflow/lite/c/builtin_op_data.h"
00019 #include "tensorflow/lite/c/c_api_internal.h"
00020 #include "tensorflow/lite/kernels/internal/common.h"
00021 #include "tensorflow/lite/kernels/internal/quantization_util.h"
00022 #include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h"
00023 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
00024 #include "tensorflow/lite/kernels/kernel_util.h"
00025 
00026 namespace tflite {
00027 namespace ops {
00028 namespace micro {
00029 namespace fully_connected {
00030 namespace {
00031 
00032 struct OpData {
00033   // The scaling factor from input to output (aka the 'real multiplier') can
00034   // be represented as a fixed point multiplier plus a left shift.
00035   int32_t output_multiplier;
00036   int output_shift;
00037   // The range of the fused activation layer. For example for kNone and
00038   // uint8_t these would be 0 and 255.
00039   int32_t output_activation_min;
00040   int32_t output_activation_max;
00041   // The index of the temporary tensor where the quantized inputs are cached.
00042   int input_quantized_index;
00043 };
00044 
00045 constexpr int kInputTensor = 0;
00046 constexpr int kWeightsTensor = 1;
00047 constexpr int kBiasTensor = 2;
00048 constexpr int kOutputTensor = 0;
00049 
00050 TfLiteStatus CalculateOpData(TfLiteContext* context,
00051                              TfLiteFullyConnectedParams* params,
00052                              TfLiteType data_type, const TfLiteTensor* input,
00053                              const TfLiteTensor* filter,
00054                              const TfLiteTensor* bias, TfLiteTensor* output,
00055                              OpData* data) {
00056   TfLiteStatus status = kTfLiteOk;
00057   if (data_type != kTfLiteFloat32) {
00058     double real_multiplier = 0.0;
00059     TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
00060         context, input, filter, bias, output, &real_multiplier));
00061     int exponent;
00062     QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent);
00063     data->output_shift = -exponent;
00064     TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(
00065         context, params->activation, output, &data->output_activation_min,
00066         &data->output_activation_max));
00067   }
00068   return status;
00069 }
00070 
00071 }  // namespace
00072 
00073 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
00074   return nullptr;
00075 }
00076 
00077 void Free(TfLiteContext* context, void* buffer) {}
00078 
00079 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
00080   return kTfLiteOk;
00081 }
00082 
00083 TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node,
00084                                TfLiteFullyConnectedParams* params, OpData* data,
00085                                const TfLiteTensor* input,
00086                                const TfLiteTensor* filter,
00087                                const TfLiteTensor* bias, TfLiteTensor* output) {
00088   FullyConnectedParams op_params;
00089   op_params.input_offset = -input->params.zero_point;
00090   op_params.weights_offset = -filter->params.zero_point;
00091   op_params.output_offset = output->params.zero_point;
00092   op_params.output_multiplier = data->output_multiplier;
00093   // TODO(b/138810107): Figure out whether output shift should be inverted
00094   op_params.output_shift = -data->output_shift;
00095   op_params.quantized_activation_min = data->output_activation_min;
00096   op_params.quantized_activation_max = data->output_activation_max;
00097 
00098   reference_integer_ops::FullyConnected(
00099       op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
00100       GetTensorShape(filter), GetTensorData<int8_t>(filter),
00101       GetTensorShape(bias), GetTensorData<int32_t>(bias),
00102       GetTensorShape(output), GetTensorData<int8_t>(output));
00103   return kTfLiteOk;
00104 }
00105 
00106 TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
00107                            TfLiteFullyConnectedParams* params, OpData* data,
00108                            const TfLiteTensor* input,
00109                            const TfLiteTensor* filter, const TfLiteTensor* bias,
00110                            TfLiteTensor* output) {
00111   const int32_t input_offset = -input->params.zero_point;
00112   const int32_t filter_offset = -filter->params.zero_point;
00113   const int32_t output_offset = output->params.zero_point;
00114 
00115   tflite::FullyConnectedParams op_params;
00116   op_params.input_offset = input_offset;
00117   op_params.weights_offset = filter_offset;
00118   op_params.output_offset = output_offset;
00119   op_params.output_multiplier = data->output_multiplier;
00120   // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
00121   op_params.output_shift = -data->output_shift;
00122   op_params.quantized_activation_min = data->output_activation_min;
00123   op_params.quantized_activation_max = data->output_activation_max;
00124 
00125 #define TF_LITE_FULLY_CONNECTED(output_data_type)                      \
00126   reference_ops::FullyConnected(                                       \
00127       op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), \
00128       GetTensorShape(filter), GetTensorData<uint8_t>(filter),          \
00129       GetTensorShape(bias), GetTensorData<int32_t>(bias),              \
00130       GetTensorShape(output), GetTensorData<output_data_type>(output))
00131   switch (output->type) {
00132     case kTfLiteUInt8:
00133       TF_LITE_FULLY_CONNECTED(uint8_t);
00134       break;
00135     case kTfLiteInt16:
00136       TF_LITE_FULLY_CONNECTED(int16_t);
00137       break;
00138     default:
00139       context->ReportError(
00140           context,
00141           "Quantized FullyConnected expects output data type uint8 or int16");
00142       return kTfLiteError;
00143   }
00144 
00145   return kTfLiteOk;
00146 }
00147 
00148 TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
00149                        TfLiteFullyConnectedParams* params, OpData* data,
00150                        const TfLiteTensor* input, const TfLiteTensor* filter,
00151                        const TfLiteTensor* bias, TfLiteTensor* output) {
00152   float output_activation_min, output_activation_max;
00153   CalculateActivationRange(params->activation, &output_activation_min,
00154                            &output_activation_max);
00155   tflite::FullyConnectedParams op_params;
00156   op_params.float_activation_min = output_activation_min;
00157   op_params.float_activation_max = output_activation_max;
00158   tflite::reference_ops::FullyConnected(
00159       op_params, GetTensorShape(input), GetTensorData<float>(input),
00160       GetTensorShape(filter), GetTensorData<float>(filter),
00161       GetTensorShape(bias), GetTensorData<float>(bias), GetTensorShape(output),
00162       GetTensorData<float>(output));
00163   return kTfLiteOk;
00164 }
00165 
00166 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
00167   auto* params =
00168       reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
00169 
00170   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
00171   const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
00172   const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
00173   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
00174 
00175   TfLiteType data_type = input->type;
00176   OpData local_data_object;
00177   OpData* data = &local_data_object;
00178   TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, data_type, input,
00179                                         filter, bias, output, data));
00180 
00181   switch (filter->type) {  // Already know in/out types are same.
00182     case kTfLiteFloat32:
00183       return EvalFloat(context, node, params, data, input, filter, bias,
00184                        output);
00185     case kTfLiteInt8:
00186       return EvalQuantizedInt8(context, node, params, data, input, filter, bias,
00187                                output);
00188 
00189     case kTfLiteUInt8:
00190       return EvalQuantized(context, node, params, data, input, filter, bias,
00191                            output);
00192 
00193     default:
00194       context->ReportError(context, "Type %d not currently supported.",
00195                            filter->type);
00196       return kTfLiteError;
00197   }
00198   return kTfLiteOk;
00199 }
00200 
00201 }  // namespace fully_connected
00202 
00203 TfLiteRegistration* Register_FULLY_CONNECTED() {
00204   static TfLiteRegistration r = {fully_connected::Init, fully_connected::Free,
00205                                  fully_connected::Prepare,
00206                                  fully_connected::Eval};
00207   return &r;
00208 }
00209 
00210 }  // namespace micro
00211 }  // namespace ops
00212 }  // namespace tflite