Important changes to repositories hosted on mbed.com
Mbed hosted mercurial repositories are deprecated and are due to be permanently deleted in July 2026.
To keep a copy of this software download the repository Zip archive or clone locally using Mercurial.
It is also possible to export all your personal repositories from the account settings page.
comparisons.cc
00001 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 00002 00003 Licensed under the Apache License, Version 2.0 (the "License"); 00004 you may not use this file except in compliance with the License. 00005 You may obtain a copy of the License at 00006 00007 http://www.apache.org/licenses/LICENSE-2.0 00008 00009 Unless required by applicable law or agreed to in writing, software 00010 distributed under the License is distributed on an "AS IS" BASIS, 00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 00012 See the License for the specific language governing permissions and 00013 limitations under the License. 00014 ==============================================================================*/ 00015 #include "tensorflow/lite/kernels/internal/reference/comparisons.h" 00016 00017 #include "tensorflow/lite/c/c_api_internal.h" 00018 #include "tensorflow/lite/kernels/internal/quantization_util.h" 00019 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" 00020 #include "tensorflow/lite/kernels/kernel_util.h" 00021 00022 namespace tflite { 00023 namespace ops { 00024 namespace micro { 00025 namespace comparisons { 00026 namespace { 00027 00028 constexpr int kInputTensor1 = 0; 00029 constexpr int kInputTensor2 = 1; 00030 constexpr int kOutputTensor = 0; 00031 00032 // TODO(ruic): optimize macros below to using template functions. 00033 #define TF_LITE_QUANTIZE_COMPARISON(opname) \ 00034 template <typename input_dtype> \ 00035 void EvalQuantized##opname(TfLiteContext* context, TfLiteNode* node, \ 00036 const TfLiteTensor* input1, \ 00037 const TfLiteTensor* input2, TfLiteTensor* output, \ 00038 bool requires_broadcast) { \ 00039 if (input1->type == kTfLiteUInt8 || input1->type == kTfLiteInt8) { \ 00040 auto input1_offset = -input1->params.zero_point; \ 00041 auto input2_offset = -input2->params.zero_point; \ 00042 const int left_shift = 8; \ 00043 \ 00044 int32 input1_multiplier; \ 00045 int input1_shift; \ 00046 QuantizeMultiplierSmallerThanOneExp(input1->params.scale, \ 00047 &input1_multiplier, &input1_shift); \ 00048 int32 input2_multiplier; \ 00049 int input2_shift; \ 00050 QuantizeMultiplierSmallerThanOneExp(input2->params.scale, \ 00051 &input2_multiplier, &input2_shift); \ 00052 \ 00053 ComparisonParams op_params; \ 00054 op_params.left_shift = left_shift; \ 00055 op_params.input1_offset = input1_offset; \ 00056 op_params.input1_multiplier = input1_multiplier; \ 00057 op_params.input1_shift = input1_shift; \ 00058 op_params.input2_offset = input2_offset; \ 00059 op_params.input2_multiplier = input2_multiplier; \ 00060 op_params.input2_shift = input2_shift; \ 00061 if (requires_broadcast) { \ 00062 reference_ops::Broadcast4DSlow##opname##WithScaling( \ 00063 op_params, GetTensorShape(input1), \ 00064 GetTensorData<input_dtype>(input1), GetTensorShape(input2), \ 00065 GetTensorData<input_dtype>(input2), GetTensorShape(output), \ 00066 GetTensorData<bool>(output)); \ 00067 } else { \ 00068 reference_ops::opname##WithScaling( \ 00069 op_params, GetTensorShape(input1), \ 00070 GetTensorData<input_dtype>(input1), GetTensorShape(input2), \ 00071 GetTensorData<input_dtype>(input2), GetTensorShape(output), \ 00072 GetTensorData<bool>(output)); \ 00073 } \ 00074 } \ 00075 } 00076 TF_LITE_QUANTIZE_COMPARISON(Equal); 00077 TF_LITE_QUANTIZE_COMPARISON(NotEqual); 00078 TF_LITE_QUANTIZE_COMPARISON(Greater); 00079 TF_LITE_QUANTIZE_COMPARISON(GreaterEqual); 00080 TF_LITE_QUANTIZE_COMPARISON(Less); 00081 TF_LITE_QUANTIZE_COMPARISON(LessEqual); 00082 #undef TF_LITE_QUANTIZE_COMPARISON 00083 00084 #define TF_LITE_COMPARISON(type, opname, requires_broadcast) \ 00085 { \ 00086 ComparisonParams op_params; \ 00087 requires_broadcast \ 00088 ? reference_ops::Broadcast4DSlow##opname##NoScaling( \ 00089 op_params, GetTensorShape(input1), GetTensorData<type>(input1), \ 00090 GetTensorShape(input2), GetTensorData<type>(input2), \ 00091 GetTensorShape(output), GetTensorData<bool>(output)) \ 00092 : reference_ops::opname##NoScaling( \ 00093 op_params, GetTensorShape(input1), GetTensorData<type>(input1), \ 00094 GetTensorShape(input2), GetTensorData<type>(input2), \ 00095 GetTensorShape(output), GetTensorData<bool>(output)); \ 00096 } 00097 00098 TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) { 00099 const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); 00100 const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); 00101 TfLiteTensor* output = GetOutput(context, node, kOutputTensor); 00102 bool requires_broadcast = !HaveSameShapes(input1, input2); 00103 switch (input1->type) { 00104 case kTfLiteBool: 00105 TF_LITE_COMPARISON(bool, Equal, requires_broadcast); 00106 break; 00107 case kTfLiteFloat32: 00108 TF_LITE_COMPARISON(float, Equal, requires_broadcast); 00109 break; 00110 case kTfLiteInt32: 00111 TF_LITE_COMPARISON(int32_t, Equal, requires_broadcast); 00112 break; 00113 case kTfLiteInt64: 00114 TF_LITE_COMPARISON(int64_t, Equal, requires_broadcast); 00115 break; 00116 case kTfLiteUInt8: 00117 EvalQuantizedEqual<uint8_t>(context, node, input1, input2, output, 00118 requires_broadcast); 00119 break; 00120 case kTfLiteInt8: 00121 EvalQuantizedEqual<int8_t>(context, node, input1, input2, output, 00122 requires_broadcast); 00123 break; 00124 default: 00125 context->ReportError( 00126 context, "Does not support type %d, requires bool|float|int|uint8", 00127 input1->type); 00128 return kTfLiteError; 00129 } 00130 return kTfLiteOk; 00131 } 00132 00133 // TODO(renjieliu): Refactor the logic to avoid duplications. 00134 TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) { 00135 const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); 00136 const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); 00137 TfLiteTensor* output = GetOutput(context, node, kOutputTensor); 00138 bool requires_broadcast = !HaveSameShapes(input1, input2); 00139 switch (input1->type) { 00140 case kTfLiteBool: 00141 TF_LITE_COMPARISON(bool, NotEqual, requires_broadcast); 00142 break; 00143 case kTfLiteFloat32: 00144 TF_LITE_COMPARISON(float, NotEqual, requires_broadcast); 00145 break; 00146 case kTfLiteInt32: 00147 TF_LITE_COMPARISON(int32_t, NotEqual, requires_broadcast); 00148 break; 00149 case kTfLiteInt64: 00150 TF_LITE_COMPARISON(int64_t, NotEqual, requires_broadcast); 00151 break; 00152 case kTfLiteUInt8: 00153 EvalQuantizedNotEqual<uint8_t>(context, node, input1, input2, output, 00154 requires_broadcast); 00155 break; 00156 case kTfLiteInt8: 00157 EvalQuantizedNotEqual<int8_t>(context, node, input1, input2, output, 00158 requires_broadcast); 00159 break; 00160 default: 00161 context->ReportError( 00162 context, "Does not support type %d, requires bool|float|int|uint8", 00163 input1->type); 00164 return kTfLiteError; 00165 } 00166 return kTfLiteOk; 00167 } 00168 00169 TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) { 00170 const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); 00171 const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); 00172 TfLiteTensor* output = GetOutput(context, node, kOutputTensor); 00173 bool requires_broadcast = !HaveSameShapes(input1, input2); 00174 switch (input1->type) { 00175 case kTfLiteFloat32: 00176 TF_LITE_COMPARISON(float, Greater, requires_broadcast); 00177 break; 00178 case kTfLiteInt32: 00179 TF_LITE_COMPARISON(int32_t, Greater, requires_broadcast); 00180 break; 00181 case kTfLiteInt64: 00182 TF_LITE_COMPARISON(int64_t, Greater, requires_broadcast); 00183 break; 00184 case kTfLiteUInt8: 00185 EvalQuantizedGreater<uint8_t>(context, node, input1, input2, output, 00186 requires_broadcast); 00187 break; 00188 case kTfLiteInt8: 00189 EvalQuantizedGreater<int8_t>(context, node, input1, input2, output, 00190 requires_broadcast); 00191 break; 00192 default: 00193 context->ReportError(context, 00194 "Does not support type %d, requires float|int|uint8", 00195 input1->type); 00196 return kTfLiteError; 00197 } 00198 return kTfLiteOk; 00199 } 00200 00201 TfLiteStatus GreaterEqualEval(TfLiteContext* context, TfLiteNode* node) { 00202 const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); 00203 const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); 00204 TfLiteTensor* output = GetOutput(context, node, kOutputTensor); 00205 bool requires_broadcast = !HaveSameShapes(input1, input2); 00206 switch (input1->type) { 00207 case kTfLiteFloat32: 00208 TF_LITE_COMPARISON(float, GreaterEqual, requires_broadcast); 00209 break; 00210 case kTfLiteInt32: 00211 TF_LITE_COMPARISON(int32_t, GreaterEqual, requires_broadcast); 00212 break; 00213 case kTfLiteInt64: 00214 TF_LITE_COMPARISON(int64_t, GreaterEqual, requires_broadcast); 00215 break; 00216 case kTfLiteUInt8: 00217 EvalQuantizedGreaterEqual<uint8_t>(context, node, input1, input2, output, 00218 requires_broadcast); 00219 break; 00220 case kTfLiteInt8: 00221 EvalQuantizedGreaterEqual<int8_t>(context, node, input1, input2, output, 00222 requires_broadcast); 00223 break; 00224 default: 00225 context->ReportError(context, 00226 "Does not support type %d, requires float|int|uint8", 00227 input1->type); 00228 return kTfLiteError; 00229 } 00230 return kTfLiteOk; 00231 } 00232 00233 TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) { 00234 const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); 00235 const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); 00236 TfLiteTensor* output = GetOutput(context, node, kOutputTensor); 00237 bool requires_broadcast = !HaveSameShapes(input1, input2); 00238 switch (input1->type) { 00239 case kTfLiteFloat32: 00240 TF_LITE_COMPARISON(float, Less, requires_broadcast); 00241 break; 00242 case kTfLiteInt32: 00243 TF_LITE_COMPARISON(int32_t, Less, requires_broadcast); 00244 break; 00245 case kTfLiteInt64: 00246 TF_LITE_COMPARISON(int64_t, Less, requires_broadcast); 00247 break; 00248 case kTfLiteUInt8: 00249 EvalQuantizedLess<uint8_t>(context, node, input1, input2, output, 00250 requires_broadcast); 00251 break; 00252 case kTfLiteInt8: 00253 EvalQuantizedLess<int8_t>(context, node, input1, input2, output, 00254 requires_broadcast); 00255 break; 00256 default: 00257 context->ReportError(context, 00258 "Does not support type %d, requires float|int|uint8", 00259 input1->type); 00260 return kTfLiteError; 00261 } 00262 return kTfLiteOk; 00263 } 00264 00265 TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) { 00266 const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); 00267 const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); 00268 TfLiteTensor* output = GetOutput(context, node, kOutputTensor); 00269 bool requires_broadcast = !HaveSameShapes(input1, input2); 00270 switch (input1->type) { 00271 case kTfLiteFloat32: 00272 TF_LITE_COMPARISON(float, LessEqual, requires_broadcast); 00273 break; 00274 case kTfLiteInt32: 00275 TF_LITE_COMPARISON(int32_t, LessEqual, requires_broadcast); 00276 break; 00277 case kTfLiteInt64: 00278 TF_LITE_COMPARISON(int64_t, LessEqual, requires_broadcast); 00279 break; 00280 case kTfLiteUInt8: 00281 EvalQuantizedLessEqual<uint8_t>(context, node, input1, input2, output, 00282 requires_broadcast); 00283 break; 00284 case kTfLiteInt8: 00285 EvalQuantizedLessEqual<int8_t>(context, node, input1, input2, output, 00286 requires_broadcast); 00287 break; 00288 default: 00289 context->ReportError(context, 00290 "Does not support type %d, requires float|int|uint8", 00291 input1->type); 00292 return kTfLiteError; 00293 } 00294 return kTfLiteOk; 00295 } 00296 00297 } // namespace 00298 } // namespace comparisons 00299 00300 TfLiteRegistration* Register_EQUAL() { 00301 static TfLiteRegistration r = {nullptr, nullptr, nullptr, 00302 comparisons::EqualEval}; 00303 return &r; 00304 } 00305 00306 TfLiteRegistration* Register_NOT_EQUAL() { 00307 static TfLiteRegistration r = {nullptr, nullptr, nullptr, 00308 comparisons::NotEqualEval}; 00309 return &r; 00310 } 00311 00312 TfLiteRegistration* Register_GREATER() { 00313 static TfLiteRegistration r = {nullptr, nullptr, nullptr, 00314 comparisons::GreaterEval}; 00315 return &r; 00316 } 00317 00318 TfLiteRegistration* Register_GREATER_EQUAL() { 00319 static TfLiteRegistration r = {nullptr, nullptr, nullptr, 00320 comparisons::GreaterEqualEval}; 00321 return &r; 00322 } 00323 00324 TfLiteRegistration* Register_LESS() { 00325 static TfLiteRegistration r = {nullptr, nullptr, nullptr, 00326 comparisons::LessEval}; 00327 return &r; 00328 } 00329 00330 TfLiteRegistration* Register_LESS_EQUAL() { 00331 static TfLiteRegistration r = {nullptr, nullptr, nullptr, 00332 comparisons::LessEqualEval}; 00333 return &r; 00334 } 00335 00336 } // namespace micro 00337 } // namespace ops 00338 } // namespace tflite
Generated on Wed Jul 13 2022 16:03:35 by
1.7.2