Daniel Konegen / MNIST_example

Dependencies:   mbed-os

Embed: (wiki syntax)

« Back to documentation index

Show/hide line numbers comparisons.cc Source File

comparisons.cc

00001 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
00002 
00003 Licensed under the Apache License, Version 2.0 (the "License");
00004 you may not use this file except in compliance with the License.
00005 You may obtain a copy of the License at
00006 
00007     http://www.apache.org/licenses/LICENSE-2.0
00008 
00009 Unless required by applicable law or agreed to in writing, software
00010 distributed under the License is distributed on an "AS IS" BASIS,
00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00012 See the License for the specific language governing permissions and
00013 limitations under the License.
00014 ==============================================================================*/
00015 #include "tensorflow/lite/kernels/internal/reference/comparisons.h"
00016 
00017 #include "tensorflow/lite/c/c_api_internal.h"
00018 #include "tensorflow/lite/kernels/internal/quantization_util.h"
00019 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
00020 #include "tensorflow/lite/kernels/kernel_util.h"
00021 
00022 namespace tflite {
00023 namespace ops {
00024 namespace micro {
00025 namespace comparisons {
00026 namespace {
00027 
00028 constexpr int kInputTensor1 = 0;
00029 constexpr int kInputTensor2 = 1;
00030 constexpr int kOutputTensor = 0;
00031 
00032 // TODO(ruic): optimize macros below to using template functions.
00033 #define TF_LITE_QUANTIZE_COMPARISON(opname)                                    \
00034   template <typename input_dtype>                                              \
00035   void EvalQuantized##opname(TfLiteContext* context, TfLiteNode* node,         \
00036                              const TfLiteTensor* input1,                       \
00037                              const TfLiteTensor* input2, TfLiteTensor* output, \
00038                              bool requires_broadcast) {                        \
00039     if (input1->type == kTfLiteUInt8 || input1->type == kTfLiteInt8) {         \
00040       auto input1_offset = -input1->params.zero_point;                         \
00041       auto input2_offset = -input2->params.zero_point;                         \
00042       const int left_shift = 8;                                                \
00043                                                                                \
00044       int32 input1_multiplier;                                                 \
00045       int input1_shift;                                                        \
00046       QuantizeMultiplierSmallerThanOneExp(input1->params.scale,                \
00047                                           &input1_multiplier, &input1_shift);  \
00048       int32 input2_multiplier;                                                 \
00049       int input2_shift;                                                        \
00050       QuantizeMultiplierSmallerThanOneExp(input2->params.scale,                \
00051                                           &input2_multiplier, &input2_shift);  \
00052                                                                                \
00053       ComparisonParams op_params;                                              \
00054       op_params.left_shift = left_shift;                                       \
00055       op_params.input1_offset = input1_offset;                                 \
00056       op_params.input1_multiplier = input1_multiplier;                         \
00057       op_params.input1_shift = input1_shift;                                   \
00058       op_params.input2_offset = input2_offset;                                 \
00059       op_params.input2_multiplier = input2_multiplier;                         \
00060       op_params.input2_shift = input2_shift;                                   \
00061       if (requires_broadcast) {                                                \
00062         reference_ops::Broadcast4DSlow##opname##WithScaling(                   \
00063             op_params, GetTensorShape(input1),                                 \
00064             GetTensorData<input_dtype>(input1), GetTensorShape(input2),        \
00065             GetTensorData<input_dtype>(input2), GetTensorShape(output),        \
00066             GetTensorData<bool>(output));                                      \
00067       } else {                                                                 \
00068         reference_ops::opname##WithScaling(                                    \
00069             op_params, GetTensorShape(input1),                                 \
00070             GetTensorData<input_dtype>(input1), GetTensorShape(input2),        \
00071             GetTensorData<input_dtype>(input2), GetTensorShape(output),        \
00072             GetTensorData<bool>(output));                                      \
00073       }                                                                        \
00074     }                                                                          \
00075   }
00076 TF_LITE_QUANTIZE_COMPARISON(Equal);
00077 TF_LITE_QUANTIZE_COMPARISON(NotEqual);
00078 TF_LITE_QUANTIZE_COMPARISON(Greater);
00079 TF_LITE_QUANTIZE_COMPARISON(GreaterEqual);
00080 TF_LITE_QUANTIZE_COMPARISON(Less);
00081 TF_LITE_QUANTIZE_COMPARISON(LessEqual);
00082 #undef TF_LITE_QUANTIZE_COMPARISON
00083 
00084 #define TF_LITE_COMPARISON(type, opname, requires_broadcast)                  \
00085   {                                                                           \
00086     ComparisonParams op_params;                                               \
00087     requires_broadcast                                                        \
00088         ? reference_ops::Broadcast4DSlow##opname##NoScaling(                  \
00089               op_params, GetTensorShape(input1), GetTensorData<type>(input1), \
00090               GetTensorShape(input2), GetTensorData<type>(input2),            \
00091               GetTensorShape(output), GetTensorData<bool>(output))            \
00092         : reference_ops::opname##NoScaling(                                   \
00093               op_params, GetTensorShape(input1), GetTensorData<type>(input1), \
00094               GetTensorShape(input2), GetTensorData<type>(input2),            \
00095               GetTensorShape(output), GetTensorData<bool>(output));           \
00096   }
00097 
00098 TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) {
00099   const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
00100   const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
00101   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
00102   bool requires_broadcast = !HaveSameShapes(input1, input2);
00103   switch (input1->type) {
00104     case kTfLiteBool:
00105       TF_LITE_COMPARISON(bool, Equal, requires_broadcast);
00106       break;
00107     case kTfLiteFloat32:
00108       TF_LITE_COMPARISON(float, Equal, requires_broadcast);
00109       break;
00110     case kTfLiteInt32:
00111       TF_LITE_COMPARISON(int32_t, Equal, requires_broadcast);
00112       break;
00113     case kTfLiteInt64:
00114       TF_LITE_COMPARISON(int64_t, Equal, requires_broadcast);
00115       break;
00116     case kTfLiteUInt8:
00117       EvalQuantizedEqual<uint8_t>(context, node, input1, input2, output,
00118                                   requires_broadcast);
00119       break;
00120     case kTfLiteInt8:
00121       EvalQuantizedEqual<int8_t>(context, node, input1, input2, output,
00122                                  requires_broadcast);
00123       break;
00124     default:
00125       context->ReportError(
00126           context, "Does not support type %d, requires bool|float|int|uint8",
00127           input1->type);
00128       return kTfLiteError;
00129   }
00130   return kTfLiteOk;
00131 }
00132 
00133 // TODO(renjieliu): Refactor the logic to avoid duplications.
00134 TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) {
00135   const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
00136   const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
00137   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
00138   bool requires_broadcast = !HaveSameShapes(input1, input2);
00139   switch (input1->type) {
00140     case kTfLiteBool:
00141       TF_LITE_COMPARISON(bool, NotEqual, requires_broadcast);
00142       break;
00143     case kTfLiteFloat32:
00144       TF_LITE_COMPARISON(float, NotEqual, requires_broadcast);
00145       break;
00146     case kTfLiteInt32:
00147       TF_LITE_COMPARISON(int32_t, NotEqual, requires_broadcast);
00148       break;
00149     case kTfLiteInt64:
00150       TF_LITE_COMPARISON(int64_t, NotEqual, requires_broadcast);
00151       break;
00152     case kTfLiteUInt8:
00153       EvalQuantizedNotEqual<uint8_t>(context, node, input1, input2, output,
00154                                      requires_broadcast);
00155       break;
00156     case kTfLiteInt8:
00157       EvalQuantizedNotEqual<int8_t>(context, node, input1, input2, output,
00158                                     requires_broadcast);
00159       break;
00160     default:
00161       context->ReportError(
00162           context, "Does not support type %d, requires bool|float|int|uint8",
00163           input1->type);
00164       return kTfLiteError;
00165   }
00166   return kTfLiteOk;
00167 }
00168 
00169 TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) {
00170   const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
00171   const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
00172   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
00173   bool requires_broadcast = !HaveSameShapes(input1, input2);
00174   switch (input1->type) {
00175     case kTfLiteFloat32:
00176       TF_LITE_COMPARISON(float, Greater, requires_broadcast);
00177       break;
00178     case kTfLiteInt32:
00179       TF_LITE_COMPARISON(int32_t, Greater, requires_broadcast);
00180       break;
00181     case kTfLiteInt64:
00182       TF_LITE_COMPARISON(int64_t, Greater, requires_broadcast);
00183       break;
00184     case kTfLiteUInt8:
00185       EvalQuantizedGreater<uint8_t>(context, node, input1, input2, output,
00186                                     requires_broadcast);
00187       break;
00188     case kTfLiteInt8:
00189       EvalQuantizedGreater<int8_t>(context, node, input1, input2, output,
00190                                    requires_broadcast);
00191       break;
00192     default:
00193       context->ReportError(context,
00194                            "Does not support type %d, requires float|int|uint8",
00195                            input1->type);
00196       return kTfLiteError;
00197   }
00198   return kTfLiteOk;
00199 }
00200 
00201 TfLiteStatus GreaterEqualEval(TfLiteContext* context, TfLiteNode* node) {
00202   const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
00203   const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
00204   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
00205   bool requires_broadcast = !HaveSameShapes(input1, input2);
00206   switch (input1->type) {
00207     case kTfLiteFloat32:
00208       TF_LITE_COMPARISON(float, GreaterEqual, requires_broadcast);
00209       break;
00210     case kTfLiteInt32:
00211       TF_LITE_COMPARISON(int32_t, GreaterEqual, requires_broadcast);
00212       break;
00213     case kTfLiteInt64:
00214       TF_LITE_COMPARISON(int64_t, GreaterEqual, requires_broadcast);
00215       break;
00216     case kTfLiteUInt8:
00217       EvalQuantizedGreaterEqual<uint8_t>(context, node, input1, input2, output,
00218                                          requires_broadcast);
00219       break;
00220     case kTfLiteInt8:
00221       EvalQuantizedGreaterEqual<int8_t>(context, node, input1, input2, output,
00222                                         requires_broadcast);
00223       break;
00224     default:
00225       context->ReportError(context,
00226                            "Does not support type %d, requires float|int|uint8",
00227                            input1->type);
00228       return kTfLiteError;
00229   }
00230   return kTfLiteOk;
00231 }
00232 
00233 TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) {
00234   const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
00235   const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
00236   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
00237   bool requires_broadcast = !HaveSameShapes(input1, input2);
00238   switch (input1->type) {
00239     case kTfLiteFloat32:
00240       TF_LITE_COMPARISON(float, Less, requires_broadcast);
00241       break;
00242     case kTfLiteInt32:
00243       TF_LITE_COMPARISON(int32_t, Less, requires_broadcast);
00244       break;
00245     case kTfLiteInt64:
00246       TF_LITE_COMPARISON(int64_t, Less, requires_broadcast);
00247       break;
00248     case kTfLiteUInt8:
00249       EvalQuantizedLess<uint8_t>(context, node, input1, input2, output,
00250                                  requires_broadcast);
00251       break;
00252     case kTfLiteInt8:
00253       EvalQuantizedLess<int8_t>(context, node, input1, input2, output,
00254                                 requires_broadcast);
00255       break;
00256     default:
00257       context->ReportError(context,
00258                            "Does not support type %d, requires float|int|uint8",
00259                            input1->type);
00260       return kTfLiteError;
00261   }
00262   return kTfLiteOk;
00263 }
00264 
00265 TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) {
00266   const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
00267   const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
00268   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
00269   bool requires_broadcast = !HaveSameShapes(input1, input2);
00270   switch (input1->type) {
00271     case kTfLiteFloat32:
00272       TF_LITE_COMPARISON(float, LessEqual, requires_broadcast);
00273       break;
00274     case kTfLiteInt32:
00275       TF_LITE_COMPARISON(int32_t, LessEqual, requires_broadcast);
00276       break;
00277     case kTfLiteInt64:
00278       TF_LITE_COMPARISON(int64_t, LessEqual, requires_broadcast);
00279       break;
00280     case kTfLiteUInt8:
00281       EvalQuantizedLessEqual<uint8_t>(context, node, input1, input2, output,
00282                                       requires_broadcast);
00283       break;
00284     case kTfLiteInt8:
00285       EvalQuantizedLessEqual<int8_t>(context, node, input1, input2, output,
00286                                      requires_broadcast);
00287       break;
00288     default:
00289       context->ReportError(context,
00290                            "Does not support type %d, requires float|int|uint8",
00291                            input1->type);
00292       return kTfLiteError;
00293   }
00294   return kTfLiteOk;
00295 }
00296 
00297 }  // namespace
00298 }  // namespace comparisons
00299 
00300 TfLiteRegistration* Register_EQUAL() {
00301   static TfLiteRegistration r = {nullptr, nullptr, nullptr,
00302                                  comparisons::EqualEval};
00303   return &r;
00304 }
00305 
00306 TfLiteRegistration* Register_NOT_EQUAL() {
00307   static TfLiteRegistration r = {nullptr, nullptr, nullptr,
00308                                  comparisons::NotEqualEval};
00309   return &r;
00310 }
00311 
00312 TfLiteRegistration* Register_GREATER() {
00313   static TfLiteRegistration r = {nullptr, nullptr, nullptr,
00314                                  comparisons::GreaterEval};
00315   return &r;
00316 }
00317 
00318 TfLiteRegistration* Register_GREATER_EQUAL() {
00319   static TfLiteRegistration r = {nullptr, nullptr, nullptr,
00320                                  comparisons::GreaterEqualEval};
00321   return &r;
00322 }
00323 
00324 TfLiteRegistration* Register_LESS() {
00325   static TfLiteRegistration r = {nullptr, nullptr, nullptr,
00326                                  comparisons::LessEval};
00327   return &r;
00328 }
00329 
00330 TfLiteRegistration* Register_LESS_EQUAL() {
00331   static TfLiteRegistration r = {nullptr, nullptr, nullptr,
00332                                  comparisons::LessEqualEval};
00333   return &r;
00334 }
00335 
00336 }  // namespace micro
00337 }  // namespace ops
00338 }  // namespace tflite