Important changes to repositories hosted on mbed.com
Mbed hosted mercurial repositories are deprecated and are due to be permanently deleted in July 2026.
To keep a copy of this software download the repository Zip archive or clone locally using Mercurial.
It is also possible to export all your personal repositories from the account settings page.
add.cc
00001 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 00002 00003 Licensed under the Apache License, Version 2.0 (the "License"); 00004 you may not use this file except in compliance with the License. 00005 You may obtain a copy of the License at 00006 00007 http://www.apache.org/licenses/LICENSE-2.0 00008 00009 Unless required by applicable law or agreed to in writing, software 00010 distributed under the License is distributed on an "AS IS" BASIS, 00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 00012 See the License for the specific language governing permissions and 00013 limitations under the License. 00014 ==============================================================================*/ 00015 00016 #include "tensorflow/lite/kernels/internal/reference/add.h" 00017 00018 #include "tensorflow/lite/c/builtin_op_data.h" 00019 #include "tensorflow/lite/c/c_api_internal.h" 00020 #include "tensorflow/lite/kernels/internal/quantization_util.h" 00021 #include "tensorflow/lite/kernels/internal/reference/integer_ops/add.h" 00022 #include "tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h" 00023 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" 00024 #include "tensorflow/lite/kernels/kernel_util.h" 00025 #include "tensorflow/lite/kernels/op_macros.h" 00026 00027 namespace tflite { 00028 namespace ops { 00029 namespace micro { 00030 namespace add { 00031 00032 constexpr int kInputTensor1 = 0; 00033 constexpr int kInputTensor2 = 1; 00034 constexpr int kOutputTensor = 0; 00035 00036 struct OpData { 00037 bool requires_broadcast; 00038 00039 // These fields are used in both the general 8-bit -> 8bit quantized path, 00040 // and the special 16-bit -> 16bit quantized path 00041 int input1_shift; 00042 int input2_shift; 00043 int32 output_activation_min; 00044 int32 output_activation_max; 00045 00046 // These fields are used only in the general 8-bit -> 8bit quantized path 00047 int32 input1_multiplier; 00048 int32 input2_multiplier; 00049 int32 output_multiplier; 00050 int output_shift; 00051 int left_shift; 00052 int32 input1_offset; 00053 int32 input2_offset; 00054 int32 output_offset; 00055 }; 00056 00057 void* Init(TfLiteContext* context, const char* buffer, size_t length) { 00058 return nullptr; 00059 } 00060 00061 void Free(TfLiteContext* context, void* buffer) {} 00062 00063 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { 00064 return kTfLiteOk; 00065 } 00066 00067 TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteAddParams* params, 00068 const TfLiteTensor* input1, 00069 const TfLiteTensor* input2, TfLiteTensor* output, 00070 OpData* data) { 00071 data->requires_broadcast = !HaveSameShapes(input1, input2); 00072 00073 if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) { 00074 // 8bit -> 8bit general quantized path, with general rescalings 00075 data->input1_offset = -input1->params.zero_point; 00076 data->input2_offset = -input2->params.zero_point; 00077 data->output_offset = output->params.zero_point; 00078 data->left_shift = 20; 00079 const double twice_max_input_scale = 00080 2 * std::max(input1->params.scale, input2->params.scale); 00081 const double real_input1_multiplier = 00082 input1->params.scale / twice_max_input_scale; 00083 const double real_input2_multiplier = 00084 input2->params.scale / twice_max_input_scale; 00085 const double real_output_multiplier = 00086 twice_max_input_scale / 00087 ((1 << data->left_shift) * output->params.scale); 00088 00089 QuantizeMultiplierSmallerThanOneExp( 00090 real_input1_multiplier, &data->input1_multiplier, &data->input1_shift); 00091 00092 QuantizeMultiplierSmallerThanOneExp( 00093 real_input2_multiplier, &data->input2_multiplier, &data->input2_shift); 00094 00095 QuantizeMultiplierSmallerThanOneExp( 00096 real_output_multiplier, &data->output_multiplier, &data->output_shift); 00097 00098 if (output->type == kTfLiteUInt8) { 00099 CalculateActivationRangeUint8(params->activation, output, 00100 &data->output_activation_min, 00101 &data->output_activation_max); 00102 } else { 00103 CalculateActivationRangeInt8(params->activation, output, 00104 &data->output_activation_min, 00105 &data->output_activation_max); 00106 } 00107 } 00108 00109 return kTfLiteOk; 00110 } 00111 00112 void EvalAdd(TfLiteContext* context, TfLiteNode* node, TfLiteAddParams* params, 00113 const OpData* data, const TfLiteTensor* input1, 00114 const TfLiteTensor* input2, TfLiteTensor* output) { 00115 float output_activation_min, output_activation_max; 00116 CalculateActivationRange(params->activation, &output_activation_min, 00117 &output_activation_max); 00118 tflite::ArithmeticParams op_params; 00119 SetActivationParams(output_activation_min, output_activation_max, &op_params); 00120 #define TF_LITE_ADD(opname) \ 00121 reference_ops::opname(op_params, GetTensorShape(input1), \ 00122 GetTensorData<float>(input1), GetTensorShape(input2), \ 00123 GetTensorData<float>(input2), GetTensorShape(output), \ 00124 GetTensorData<float>(output)) 00125 if (data->requires_broadcast) { 00126 TF_LITE_ADD(BroadcastAdd4DSlow); 00127 } else { 00128 TF_LITE_ADD(Add); 00129 } 00130 #undef TF_LITE_ADD 00131 } 00132 00133 TfLiteStatus EvalAddQuantized(TfLiteContext* context, TfLiteNode* node, 00134 TfLiteAddParams* params, const OpData* data, 00135 const TfLiteTensor* input1, 00136 const TfLiteTensor* input2, 00137 TfLiteTensor* output) { 00138 if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) { 00139 tflite::ArithmeticParams op_params; 00140 op_params.left_shift = data->left_shift; 00141 op_params.input1_offset = data->input1_offset; 00142 op_params.input1_multiplier = data->input1_multiplier; 00143 op_params.input1_shift = data->input1_shift; 00144 op_params.input2_offset = data->input2_offset; 00145 op_params.input2_multiplier = data->input2_multiplier; 00146 op_params.input2_shift = data->input2_shift; 00147 op_params.output_offset = data->output_offset; 00148 op_params.output_multiplier = data->output_multiplier; 00149 op_params.output_shift = data->output_shift; 00150 SetActivationParams(data->output_activation_min, 00151 data->output_activation_max, &op_params); 00152 bool need_broadcast = reference_ops::ProcessBroadcastShapes( 00153 GetTensorShape(input1), GetTensorShape(input2), &op_params); 00154 #define TF_LITE_ADD(type, opname, dtype) \ 00155 type::opname(op_params, GetTensorShape(input1), \ 00156 GetTensorData<dtype>(input1), GetTensorShape(input2), \ 00157 GetTensorData<dtype>(input2), GetTensorShape(output), \ 00158 GetTensorData<dtype>(output)); 00159 if (output->type == kTfLiteInt8) { 00160 if (need_broadcast) { 00161 TF_LITE_ADD(reference_integer_ops, BroadcastAdd4DSlow, int8_t); 00162 } else { 00163 TF_LITE_ADD(reference_integer_ops, Add, int8_t); 00164 } 00165 } else { 00166 if (need_broadcast) { 00167 TF_LITE_ADD(reference_ops, BroadcastAdd4DSlow, uint8_t); 00168 } else { 00169 TF_LITE_ADD(reference_ops, Add, uint8_t); 00170 } 00171 } 00172 #undef TF_LITE_ADD 00173 } 00174 00175 return kTfLiteOk; 00176 } 00177 00178 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { 00179 auto* params = reinterpret_cast<TfLiteAddParams*>(node->builtin_data); 00180 00181 const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); 00182 const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); 00183 TfLiteTensor* output = GetOutput(context, node, kOutputTensor); 00184 00185 OpData data; 00186 TF_LITE_ENSURE_STATUS( 00187 CalculateOpData(context, params, input1, input2, output, &data)); 00188 00189 if (output->type == kTfLiteFloat32) { 00190 EvalAdd(context, node, params, &data, input1, input2, output); 00191 } else if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) { 00192 TF_LITE_ENSURE_OK(context, EvalAddQuantized(context, node, params, &data, 00193 input1, input2, output)); 00194 } else { 00195 context->ReportError(context, 00196 "Inputs and outputs not all float|uint8|int8 types."); 00197 return kTfLiteError; 00198 } 00199 00200 return kTfLiteOk; 00201 } 00202 00203 } // namespace add 00204 00205 TfLiteRegistration* Register_ADD() { 00206 static TfLiteRegistration r = {add::Init, add::Free, add::Prepare, add::Eval}; 00207 return &r; 00208 } 00209 00210 } // namespace micro 00211 } // namespace ops 00212 } // namespace tflite
Generated on Wed Jul 13 2022 16:03:34 by
1.7.2