Daniel Konegen / MNIST_example

Dependencies:   mbed-os

Embed: (wiki syntax)

« Back to documentation index

Show/hide line numbers add.cc Source File

add.cc

00001 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
00002 
00003 Licensed under the Apache License, Version 2.0 (the "License");
00004 you may not use this file except in compliance with the License.
00005 You may obtain a copy of the License at
00006 
00007     http://www.apache.org/licenses/LICENSE-2.0
00008 
00009 Unless required by applicable law or agreed to in writing, software
00010 distributed under the License is distributed on an "AS IS" BASIS,
00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00012 See the License for the specific language governing permissions and
00013 limitations under the License.
00014 ==============================================================================*/
00015 
00016 #include "tensorflow/lite/kernels/internal/reference/add.h"
00017 
00018 #include "tensorflow/lite/c/builtin_op_data.h"
00019 #include "tensorflow/lite/c/c_api_internal.h"
00020 #include "tensorflow/lite/kernels/internal/quantization_util.h"
00021 #include "tensorflow/lite/kernels/internal/reference/integer_ops/add.h"
00022 #include "tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h"
00023 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
00024 #include "tensorflow/lite/kernels/kernel_util.h"
00025 #include "tensorflow/lite/kernels/op_macros.h"
00026 
00027 namespace tflite {
00028 namespace ops {
00029 namespace micro {
00030 namespace add {
00031 
00032 constexpr int kInputTensor1 = 0;
00033 constexpr int kInputTensor2 = 1;
00034 constexpr int kOutputTensor = 0;
00035 
00036 struct OpData {
00037   bool requires_broadcast;
00038 
00039   // These fields are used in both the general 8-bit -> 8bit quantized path,
00040   // and the special 16-bit -> 16bit quantized path
00041   int input1_shift;
00042   int input2_shift;
00043   int32 output_activation_min;
00044   int32 output_activation_max;
00045 
00046   // These fields are used only in the general 8-bit -> 8bit quantized path
00047   int32 input1_multiplier;
00048   int32 input2_multiplier;
00049   int32 output_multiplier;
00050   int output_shift;
00051   int left_shift;
00052   int32 input1_offset;
00053   int32 input2_offset;
00054   int32 output_offset;
00055 };
00056 
00057 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
00058   return nullptr;
00059 }
00060 
00061 void Free(TfLiteContext* context, void* buffer) {}
00062 
00063 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
00064   return kTfLiteOk;
00065 }
00066 
00067 TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteAddParams* params,
00068                              const TfLiteTensor* input1,
00069                              const TfLiteTensor* input2, TfLiteTensor* output,
00070                              OpData* data) {
00071   data->requires_broadcast = !HaveSameShapes(input1, input2);
00072 
00073   if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) {
00074     // 8bit -> 8bit general quantized path, with general rescalings
00075     data->input1_offset = -input1->params.zero_point;
00076     data->input2_offset = -input2->params.zero_point;
00077     data->output_offset = output->params.zero_point;
00078     data->left_shift = 20;
00079     const double twice_max_input_scale =
00080         2 * std::max(input1->params.scale, input2->params.scale);
00081     const double real_input1_multiplier =
00082         input1->params.scale / twice_max_input_scale;
00083     const double real_input2_multiplier =
00084         input2->params.scale / twice_max_input_scale;
00085     const double real_output_multiplier =
00086         twice_max_input_scale /
00087         ((1 << data->left_shift) * output->params.scale);
00088 
00089     QuantizeMultiplierSmallerThanOneExp(
00090         real_input1_multiplier, &data->input1_multiplier, &data->input1_shift);
00091 
00092     QuantizeMultiplierSmallerThanOneExp(
00093         real_input2_multiplier, &data->input2_multiplier, &data->input2_shift);
00094 
00095     QuantizeMultiplierSmallerThanOneExp(
00096         real_output_multiplier, &data->output_multiplier, &data->output_shift);
00097 
00098     if (output->type == kTfLiteUInt8) {
00099       CalculateActivationRangeUint8(params->activation, output,
00100                                     &data->output_activation_min,
00101                                     &data->output_activation_max);
00102     } else {
00103       CalculateActivationRangeInt8(params->activation, output,
00104                                    &data->output_activation_min,
00105                                    &data->output_activation_max);
00106     }
00107   }
00108 
00109   return kTfLiteOk;
00110 }
00111 
00112 void EvalAdd(TfLiteContext* context, TfLiteNode* node, TfLiteAddParams* params,
00113              const OpData* data, const TfLiteTensor* input1,
00114              const TfLiteTensor* input2, TfLiteTensor* output) {
00115   float output_activation_min, output_activation_max;
00116   CalculateActivationRange(params->activation, &output_activation_min,
00117                            &output_activation_max);
00118   tflite::ArithmeticParams op_params;
00119   SetActivationParams(output_activation_min, output_activation_max, &op_params);
00120 #define TF_LITE_ADD(opname)                                                   \
00121   reference_ops::opname(op_params, GetTensorShape(input1),                    \
00122                         GetTensorData<float>(input1), GetTensorShape(input2), \
00123                         GetTensorData<float>(input2), GetTensorShape(output), \
00124                         GetTensorData<float>(output))
00125   if (data->requires_broadcast) {
00126     TF_LITE_ADD(BroadcastAdd4DSlow);
00127   } else {
00128     TF_LITE_ADD(Add);
00129   }
00130 #undef TF_LITE_ADD
00131 }
00132 
00133 TfLiteStatus EvalAddQuantized(TfLiteContext* context, TfLiteNode* node,
00134                               TfLiteAddParams* params, const OpData* data,
00135                               const TfLiteTensor* input1,
00136                               const TfLiteTensor* input2,
00137                               TfLiteTensor* output) {
00138   if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) {
00139     tflite::ArithmeticParams op_params;
00140     op_params.left_shift = data->left_shift;
00141     op_params.input1_offset = data->input1_offset;
00142     op_params.input1_multiplier = data->input1_multiplier;
00143     op_params.input1_shift = data->input1_shift;
00144     op_params.input2_offset = data->input2_offset;
00145     op_params.input2_multiplier = data->input2_multiplier;
00146     op_params.input2_shift = data->input2_shift;
00147     op_params.output_offset = data->output_offset;
00148     op_params.output_multiplier = data->output_multiplier;
00149     op_params.output_shift = data->output_shift;
00150     SetActivationParams(data->output_activation_min,
00151                         data->output_activation_max, &op_params);
00152     bool need_broadcast = reference_ops::ProcessBroadcastShapes(
00153         GetTensorShape(input1), GetTensorShape(input2), &op_params);
00154 #define TF_LITE_ADD(type, opname, dtype)                             \
00155   type::opname(op_params, GetTensorShape(input1),                    \
00156                GetTensorData<dtype>(input1), GetTensorShape(input2), \
00157                GetTensorData<dtype>(input2), GetTensorShape(output), \
00158                GetTensorData<dtype>(output));
00159     if (output->type == kTfLiteInt8) {
00160       if (need_broadcast) {
00161         TF_LITE_ADD(reference_integer_ops, BroadcastAdd4DSlow, int8_t);
00162       } else {
00163         TF_LITE_ADD(reference_integer_ops, Add, int8_t);
00164       }
00165     } else {
00166       if (need_broadcast) {
00167         TF_LITE_ADD(reference_ops, BroadcastAdd4DSlow, uint8_t);
00168       } else {
00169         TF_LITE_ADD(reference_ops, Add, uint8_t);
00170       }
00171     }
00172 #undef TF_LITE_ADD
00173   }
00174 
00175   return kTfLiteOk;
00176 }
00177 
00178 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
00179   auto* params = reinterpret_cast<TfLiteAddParams*>(node->builtin_data);
00180 
00181   const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
00182   const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
00183   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
00184 
00185   OpData data;
00186   TF_LITE_ENSURE_STATUS(
00187       CalculateOpData(context, params, input1, input2, output, &data));
00188 
00189   if (output->type == kTfLiteFloat32) {
00190     EvalAdd(context, node, params, &data, input1, input2, output);
00191   } else if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) {
00192     TF_LITE_ENSURE_OK(context, EvalAddQuantized(context, node, params, &data,
00193                                                 input1, input2, output));
00194   } else {
00195     context->ReportError(context,
00196                          "Inputs and outputs not all float|uint8|int8 types.");
00197     return kTfLiteError;
00198   }
00199 
00200   return kTfLiteOk;
00201 }
00202 
00203 }  // namespace add
00204 
00205 TfLiteRegistration* Register_ADD() {
00206   static TfLiteRegistration r = {add::Init, add::Free, add::Prepare, add::Eval};
00207   return &r;
00208 }
00209 
00210 }  // namespace micro
00211 }  // namespace ops
00212 }  // namespace tflite