Daniel Konegen / MNIST_example

Dependencies:   mbed-os

Embed: (wiki syntax)

« Back to documentation index

Show/hide line numbers comparisons.h Source File

comparisons.h

00001 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
00002 
00003 Licensed under the Apache License, Version 2.0 (the "License");
00004 you may not use this file except in compliance with the License.
00005 You may obtain a copy of the License at
00006 
00007     http://www.apache.org/licenses/LICENSE-2.0
00008 
00009 Unless required by applicable law or agreed to in writing, software
00010 distributed under the License is distributed on an "AS IS" BASIS,
00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00012 See the License for the specific language governing permissions and
00013 limitations under the License.
00014 ==============================================================================*/
00015 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_COMPARISONS_H_
00016 #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_COMPARISONS_H_
00017 
00018 #include "tensorflow/lite/kernels/internal/common.h"
00019 #include "tensorflow/lite/kernels/internal/types.h"
00020 
00021 namespace tflite {
00022 
00023 namespace reference_ops {
00024 
00025 template <typename T>
00026 inline bool EqualFn(T lhs, T rhs) {
00027   return lhs == rhs;
00028 }
00029 
00030 template <typename T>
00031 inline bool NotEqualFn(T lhs, T rhs) {
00032   return lhs != rhs;
00033 }
00034 
00035 template <typename T>
00036 inline bool GreaterFn(T lhs, T rhs) {
00037   return lhs > rhs;
00038 }
00039 template <typename T>
00040 inline bool GreaterEqualFn(T lhs, T rhs) {
00041   return lhs >= rhs;
00042 }
00043 template <typename T>
00044 inline bool LessFn(T lhs, T rhs) {
00045   return lhs < rhs;
00046 }
00047 template <typename T>
00048 inline bool LessEqualFn(T lhs, T rhs) {
00049   return lhs <= rhs;
00050 }
00051 
00052 template <typename T>
00053 using ComparisonFn = bool (*)(T, T);
00054 
00055 template <typename T, ComparisonFn<T> F>
00056 inline void ComparisonImpl(
00057     const ComparisonParams& op_params, const RuntimeShape& input1_shape,
00058     const T* input1_data, const RuntimeShape& input2_shape,
00059     const T* input2_data, const RuntimeShape& output_shape, bool* output_data) {
00060   const int64_t flatsize =
00061       MatchingFlatSize(input1_shape, input2_shape, output_shape);
00062   for (int64_t i = 0; i < flatsize; ++i) {
00063     output_data[i] = F(input1_data[i], input2_data[i]);
00064   }
00065 }
00066 
00067 template <ComparisonFn<float> F>
00068 inline void Comparison(const ComparisonParams& op_params,
00069                        const RuntimeShape& input1_shape,
00070                        const float* input1_data,
00071                        const RuntimeShape& input2_shape,
00072                        const float* input2_data,
00073                        const RuntimeShape& output_shape, bool* output_data) {
00074   ComparisonImpl<float, F>(op_params, input1_shape, input1_data, input2_shape,
00075                            input2_data, output_shape, output_data);
00076 }
00077 
00078 template <typename T, ComparisonFn<int32> F>
00079 inline void ComparisonWithScaling(
00080     const ComparisonParams& op_params, const RuntimeShape& input1_shape,
00081     const T* input1_data, const RuntimeShape& input2_shape,
00082     const T* input2_data, const RuntimeShape& output_shape, bool* output_data) {
00083   int left_shift = op_params.left_shift;
00084   int32 input1_offset = op_params.input1_offset;
00085   int32 input1_multiplier = op_params.input1_multiplier;
00086   int input1_shift = op_params.input1_shift;
00087   int32 input2_offset = op_params.input2_offset;
00088   int32 input2_multiplier = op_params.input2_multiplier;
00089   int input2_shift = op_params.input2_shift;
00090 
00091   const int64_t flatsize =
00092       MatchingFlatSize(input1_shape, input2_shape, output_shape);
00093   for (int64_t i = 0; i < flatsize; ++i) {
00094     const int32 input1_val = input1_offset + input1_data[i];
00095     const int32 input2_val = input2_offset + input2_data[i];
00096     const int32 shifted_input1_val = input1_val * (1 << left_shift);
00097     const int32 shifted_input2_val = input2_val * (1 << left_shift);
00098     const int32 scaled_input1_val =
00099         MultiplyByQuantizedMultiplierSmallerThanOneExp(
00100             shifted_input1_val, input1_multiplier, input1_shift);
00101     const int32 scaled_input2_val =
00102         MultiplyByQuantizedMultiplierSmallerThanOneExp(
00103             shifted_input2_val, input2_multiplier, input2_shift);
00104     output_data[i] = F(scaled_input1_val, scaled_input2_val);
00105   }
00106 }
00107 
00108 template <typename T, ComparisonFn<T> F>
00109 inline void BroadcastComparison4DSlowImpl(
00110     const ComparisonParams& op_params,
00111     const RuntimeShape& unextended_input1_shape, const T* input1_data,
00112     const RuntimeShape& unextended_input2_shape, const T* input2_data,
00113     const RuntimeShape& unextended_output_shape, bool* output_data) {
00114   TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
00115   TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
00116   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
00117   const RuntimeShape output_shape =
00118       RuntimeShape::ExtendedShape(4, unextended_output_shape);
00119 
00120   NdArrayDesc<4> desc1;
00121   NdArrayDesc<4> desc2;
00122   NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
00123                                       unextended_input2_shape, &desc1, &desc2);
00124 
00125   for (int b = 0; b < output_shape.Dims(0); ++b) {
00126     for (int y = 0; y < output_shape.Dims(1); ++y) {
00127       for (int x = 0; x < output_shape.Dims(2); ++x) {
00128         for (int c = 0; c < output_shape.Dims(3); ++c) {
00129           output_data[Offset(output_shape, b, y, x, c)] =
00130               F(input1_data[SubscriptToIndex(desc1, b, y, x, c)],
00131                 input2_data[SubscriptToIndex(desc2, b, y, x, c)]);
00132         }
00133       }
00134     }
00135   }
00136 }
00137 template <ComparisonFn<float> F>
00138 inline void BroadcastComparison4DSlow(const ComparisonParams& op_params,
00139                                       const RuntimeShape& input1_shape,
00140                                       const float* input1_data,
00141                                       const RuntimeShape& input2_shape,
00142                                       const float* input2_data,
00143                                       const RuntimeShape& output_shape,
00144                                       bool* output_data) {
00145   BroadcastComparison4DSlowImpl<float, F>(op_params, input1_shape, input1_data,
00146                                           input2_shape, input2_data,
00147                                           output_shape, output_data);
00148 }
00149 
00150 template <typename T, ComparisonFn<int32> F>
00151 inline void BroadcastComparison4DSlowWithScaling(
00152     const ComparisonParams& op_params,
00153     const RuntimeShape& unextended_input1_shape, const T* input1_data,
00154     const RuntimeShape& unextended_input2_shape, const T* input2_data,
00155     const RuntimeShape& unextended_output_shape, bool* output_data) {
00156   TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
00157   TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
00158   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
00159   const RuntimeShape output_shape =
00160       RuntimeShape::ExtendedShape(4, unextended_output_shape);
00161 
00162   NdArrayDesc<4> desc1;
00163   NdArrayDesc<4> desc2;
00164   NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
00165                                       unextended_input2_shape, &desc1, &desc2);
00166 
00167   int left_shift = op_params.left_shift;
00168   int32 input1_offset = op_params.input1_offset;
00169   int32 input1_multiplier = op_params.input1_multiplier;
00170   int input1_shift = op_params.input1_shift;
00171   int32 input2_offset = op_params.input2_offset;
00172   int32 input2_multiplier = op_params.input2_multiplier;
00173   int input2_shift = op_params.input2_shift;
00174 
00175   for (int b = 0; b < output_shape.Dims(0); ++b) {
00176     for (int y = 0; y < output_shape.Dims(1); ++y) {
00177       for (int x = 0; x < output_shape.Dims(2); ++x) {
00178         for (int c = 0; c < output_shape.Dims(3); ++c) {
00179           const int32 input1_val =
00180               input1_offset + input1_data[SubscriptToIndex(desc1, b, y, x, c)];
00181           const int32 input2_val =
00182               input2_offset + input2_data[SubscriptToIndex(desc2, b, y, x, c)];
00183           const int32 shifted_input1_val = input1_val * (1 << left_shift);
00184           const int32 shifted_input2_val = input2_val * (1 << left_shift);
00185           const int32 scaled_input1_val =
00186               MultiplyByQuantizedMultiplierSmallerThanOneExp(
00187                   shifted_input1_val, input1_multiplier, input1_shift);
00188           const int32 scaled_input2_val =
00189               MultiplyByQuantizedMultiplierSmallerThanOneExp(
00190                   shifted_input2_val, input2_multiplier, input2_shift);
00191           output_data[Offset(output_shape, b, y, x, c)] =
00192               F(scaled_input1_val, scaled_input2_val);
00193         }
00194       }
00195     }
00196   }
00197 }
00198 
00199 #define TFLITE_COMPARISON_OP(name)                                             \
00200   inline void name(const ComparisonParams& op_params,                          \
00201                    const RuntimeShape& input1_shape, const float* input1_data, \
00202                    const RuntimeShape& input2_shape, const float* input2_data, \
00203                    const RuntimeShape& output_shape, bool* output_data) {      \
00204     Comparison<name##Fn>(op_params, input1_shape, input1_data, input2_shape,   \
00205                          input2_data, output_shape, output_data);              \
00206   }                                                                            \
00207   template <typename T>                                                        \
00208   inline void name##NoScaling(                                                 \
00209       const ComparisonParams& op_params, const RuntimeShape& input1_shape,     \
00210       const T* input1_data, const RuntimeShape& input2_shape,                  \
00211       const T* input2_data, const RuntimeShape& output_shape,                  \
00212       bool* output_data) {                                                     \
00213     ComparisonImpl<T, name##Fn>(op_params, input1_shape, input1_data,          \
00214                                 input2_shape, input2_data, output_shape,       \
00215                                 output_data);                                  \
00216   }                                                                            \
00217   template <typename T>                                                        \
00218   inline void name##WithScaling(                                               \
00219       const ComparisonParams& op_params, const RuntimeShape& input1_shape,     \
00220       const T* input1_data, const RuntimeShape& input2_shape,                  \
00221       const T* input2_data, const RuntimeShape& output_shape,                  \
00222       bool* output_data) {                                                     \
00223     ComparisonWithScaling<T, name##Fn>(op_params, input1_shape, input1_data,   \
00224                                        input2_shape, input2_data,              \
00225                                        output_shape, output_data);             \
00226   }                                                                            \
00227   template <typename T>                                                        \
00228   inline void Broadcast4DSlow##name##NoScaling(                                \
00229       const ComparisonParams& op_params, const RuntimeShape& input1_shape,     \
00230       const T* input1_data, const RuntimeShape& input2_shape,                  \
00231       const T* input2_data, const RuntimeShape& output_shape,                  \
00232       bool* output_data) {                                                     \
00233     BroadcastComparison4DSlowImpl<T, name##Fn>(                                \
00234         op_params, input1_shape, input1_data, input2_shape, input2_data,       \
00235         output_shape, output_data);                                            \
00236   }                                                                            \
00237   inline void Broadcast4DSlow##name(                                           \
00238       const ComparisonParams& op_params, const RuntimeShape& input1_shape,     \
00239       const float* input1_data, const RuntimeShape& input2_shape,              \
00240       const float* input2_data, const RuntimeShape& output_shape,              \
00241       bool* output_data) {                                                     \
00242     BroadcastComparison4DSlow<name##Fn>(op_params, input1_shape, input1_data,  \
00243                                         input2_shape, input2_data,             \
00244                                         output_shape, output_data);            \
00245   }                                                                            \
00246   template <typename T>                                                        \
00247   inline void Broadcast4DSlow##name##WithScaling(                              \
00248       const ComparisonParams& op_params, const RuntimeShape& input1_shape,     \
00249       const T* input1_data, const RuntimeShape& input2_shape,                  \
00250       const T* input2_data, const RuntimeShape& output_shape,                  \
00251       bool* output_data) {                                                     \
00252     BroadcastComparison4DSlowWithScaling<T, name##Fn>(                         \
00253         op_params, input1_shape, input1_data, input2_shape, input2_data,       \
00254         output_shape, output_data);                                            \
00255   }
00256 TFLITE_COMPARISON_OP(Equal);
00257 TFLITE_COMPARISON_OP(NotEqual);
00258 TFLITE_COMPARISON_OP(Greater);
00259 TFLITE_COMPARISON_OP(GreaterEqual);
00260 TFLITE_COMPARISON_OP(Less);
00261 TFLITE_COMPARISON_OP(LessEqual);
00262 #undef TFLITE_COMPARISON_OP
00263 
00264 }  // namespace reference_ops
00265 }  // namespace tflite
00266 
00267 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_COMPARISONS_H_