Important changes to repositories hosted on mbed.com
Mbed hosted mercurial repositories are deprecated and are due to be permanently deleted in July 2026.
To keep a copy of this software download the repository Zip archive or clone locally using Mercurial.
It is also possible to export all your personal repositories from the account settings page.
comparisons.h
00001 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 00002 00003 Licensed under the Apache License, Version 2.0 (the "License"); 00004 you may not use this file except in compliance with the License. 00005 You may obtain a copy of the License at 00006 00007 http://www.apache.org/licenses/LICENSE-2.0 00008 00009 Unless required by applicable law or agreed to in writing, software 00010 distributed under the License is distributed on an "AS IS" BASIS, 00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 00012 See the License for the specific language governing permissions and 00013 limitations under the License. 00014 ==============================================================================*/ 00015 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_COMPARISONS_H_ 00016 #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_COMPARISONS_H_ 00017 00018 #include "tensorflow/lite/kernels/internal/common.h" 00019 #include "tensorflow/lite/kernels/internal/types.h" 00020 00021 namespace tflite { 00022 00023 namespace reference_ops { 00024 00025 template <typename T> 00026 inline bool EqualFn(T lhs, T rhs) { 00027 return lhs == rhs; 00028 } 00029 00030 template <typename T> 00031 inline bool NotEqualFn(T lhs, T rhs) { 00032 return lhs != rhs; 00033 } 00034 00035 template <typename T> 00036 inline bool GreaterFn(T lhs, T rhs) { 00037 return lhs > rhs; 00038 } 00039 template <typename T> 00040 inline bool GreaterEqualFn(T lhs, T rhs) { 00041 return lhs >= rhs; 00042 } 00043 template <typename T> 00044 inline bool LessFn(T lhs, T rhs) { 00045 return lhs < rhs; 00046 } 00047 template <typename T> 00048 inline bool LessEqualFn(T lhs, T rhs) { 00049 return lhs <= rhs; 00050 } 00051 00052 template <typename T> 00053 using ComparisonFn = bool (*)(T, T); 00054 00055 template <typename T, ComparisonFn<T> F> 00056 inline void ComparisonImpl( 00057 const ComparisonParams& op_params, const RuntimeShape& input1_shape, 00058 const T* input1_data, const RuntimeShape& input2_shape, 00059 const T* input2_data, const RuntimeShape& output_shape, bool* output_data) { 00060 const int64_t flatsize = 00061 MatchingFlatSize(input1_shape, input2_shape, output_shape); 00062 for (int64_t i = 0; i < flatsize; ++i) { 00063 output_data[i] = F(input1_data[i], input2_data[i]); 00064 } 00065 } 00066 00067 template <ComparisonFn<float> F> 00068 inline void Comparison(const ComparisonParams& op_params, 00069 const RuntimeShape& input1_shape, 00070 const float* input1_data, 00071 const RuntimeShape& input2_shape, 00072 const float* input2_data, 00073 const RuntimeShape& output_shape, bool* output_data) { 00074 ComparisonImpl<float, F>(op_params, input1_shape, input1_data, input2_shape, 00075 input2_data, output_shape, output_data); 00076 } 00077 00078 template <typename T, ComparisonFn<int32> F> 00079 inline void ComparisonWithScaling( 00080 const ComparisonParams& op_params, const RuntimeShape& input1_shape, 00081 const T* input1_data, const RuntimeShape& input2_shape, 00082 const T* input2_data, const RuntimeShape& output_shape, bool* output_data) { 00083 int left_shift = op_params.left_shift; 00084 int32 input1_offset = op_params.input1_offset; 00085 int32 input1_multiplier = op_params.input1_multiplier; 00086 int input1_shift = op_params.input1_shift; 00087 int32 input2_offset = op_params.input2_offset; 00088 int32 input2_multiplier = op_params.input2_multiplier; 00089 int input2_shift = op_params.input2_shift; 00090 00091 const int64_t flatsize = 00092 MatchingFlatSize(input1_shape, input2_shape, output_shape); 00093 for (int64_t i = 0; i < flatsize; ++i) { 00094 const int32 input1_val = input1_offset + input1_data[i]; 00095 const int32 input2_val = input2_offset + input2_data[i]; 00096 const int32 shifted_input1_val = input1_val * (1 << left_shift); 00097 const int32 shifted_input2_val = input2_val * (1 << left_shift); 00098 const int32 scaled_input1_val = 00099 MultiplyByQuantizedMultiplierSmallerThanOneExp( 00100 shifted_input1_val, input1_multiplier, input1_shift); 00101 const int32 scaled_input2_val = 00102 MultiplyByQuantizedMultiplierSmallerThanOneExp( 00103 shifted_input2_val, input2_multiplier, input2_shift); 00104 output_data[i] = F(scaled_input1_val, scaled_input2_val); 00105 } 00106 } 00107 00108 template <typename T, ComparisonFn<T> F> 00109 inline void BroadcastComparison4DSlowImpl( 00110 const ComparisonParams& op_params, 00111 const RuntimeShape& unextended_input1_shape, const T* input1_data, 00112 const RuntimeShape& unextended_input2_shape, const T* input2_data, 00113 const RuntimeShape& unextended_output_shape, bool* output_data) { 00114 TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4); 00115 TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4); 00116 TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); 00117 const RuntimeShape output_shape = 00118 RuntimeShape::ExtendedShape(4, unextended_output_shape); 00119 00120 NdArrayDesc<4> desc1; 00121 NdArrayDesc<4> desc2; 00122 NdArrayDescsForElementwiseBroadcast(unextended_input1_shape, 00123 unextended_input2_shape, &desc1, &desc2); 00124 00125 for (int b = 0; b < output_shape.Dims(0); ++b) { 00126 for (int y = 0; y < output_shape.Dims(1); ++y) { 00127 for (int x = 0; x < output_shape.Dims(2); ++x) { 00128 for (int c = 0; c < output_shape.Dims(3); ++c) { 00129 output_data[Offset(output_shape, b, y, x, c)] = 00130 F(input1_data[SubscriptToIndex(desc1, b, y, x, c)], 00131 input2_data[SubscriptToIndex(desc2, b, y, x, c)]); 00132 } 00133 } 00134 } 00135 } 00136 } 00137 template <ComparisonFn<float> F> 00138 inline void BroadcastComparison4DSlow(const ComparisonParams& op_params, 00139 const RuntimeShape& input1_shape, 00140 const float* input1_data, 00141 const RuntimeShape& input2_shape, 00142 const float* input2_data, 00143 const RuntimeShape& output_shape, 00144 bool* output_data) { 00145 BroadcastComparison4DSlowImpl<float, F>(op_params, input1_shape, input1_data, 00146 input2_shape, input2_data, 00147 output_shape, output_data); 00148 } 00149 00150 template <typename T, ComparisonFn<int32> F> 00151 inline void BroadcastComparison4DSlowWithScaling( 00152 const ComparisonParams& op_params, 00153 const RuntimeShape& unextended_input1_shape, const T* input1_data, 00154 const RuntimeShape& unextended_input2_shape, const T* input2_data, 00155 const RuntimeShape& unextended_output_shape, bool* output_data) { 00156 TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4); 00157 TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4); 00158 TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); 00159 const RuntimeShape output_shape = 00160 RuntimeShape::ExtendedShape(4, unextended_output_shape); 00161 00162 NdArrayDesc<4> desc1; 00163 NdArrayDesc<4> desc2; 00164 NdArrayDescsForElementwiseBroadcast(unextended_input1_shape, 00165 unextended_input2_shape, &desc1, &desc2); 00166 00167 int left_shift = op_params.left_shift; 00168 int32 input1_offset = op_params.input1_offset; 00169 int32 input1_multiplier = op_params.input1_multiplier; 00170 int input1_shift = op_params.input1_shift; 00171 int32 input2_offset = op_params.input2_offset; 00172 int32 input2_multiplier = op_params.input2_multiplier; 00173 int input2_shift = op_params.input2_shift; 00174 00175 for (int b = 0; b < output_shape.Dims(0); ++b) { 00176 for (int y = 0; y < output_shape.Dims(1); ++y) { 00177 for (int x = 0; x < output_shape.Dims(2); ++x) { 00178 for (int c = 0; c < output_shape.Dims(3); ++c) { 00179 const int32 input1_val = 00180 input1_offset + input1_data[SubscriptToIndex(desc1, b, y, x, c)]; 00181 const int32 input2_val = 00182 input2_offset + input2_data[SubscriptToIndex(desc2, b, y, x, c)]; 00183 const int32 shifted_input1_val = input1_val * (1 << left_shift); 00184 const int32 shifted_input2_val = input2_val * (1 << left_shift); 00185 const int32 scaled_input1_val = 00186 MultiplyByQuantizedMultiplierSmallerThanOneExp( 00187 shifted_input1_val, input1_multiplier, input1_shift); 00188 const int32 scaled_input2_val = 00189 MultiplyByQuantizedMultiplierSmallerThanOneExp( 00190 shifted_input2_val, input2_multiplier, input2_shift); 00191 output_data[Offset(output_shape, b, y, x, c)] = 00192 F(scaled_input1_val, scaled_input2_val); 00193 } 00194 } 00195 } 00196 } 00197 } 00198 00199 #define TFLITE_COMPARISON_OP(name) \ 00200 inline void name(const ComparisonParams& op_params, \ 00201 const RuntimeShape& input1_shape, const float* input1_data, \ 00202 const RuntimeShape& input2_shape, const float* input2_data, \ 00203 const RuntimeShape& output_shape, bool* output_data) { \ 00204 Comparison<name##Fn>(op_params, input1_shape, input1_data, input2_shape, \ 00205 input2_data, output_shape, output_data); \ 00206 } \ 00207 template <typename T> \ 00208 inline void name##NoScaling( \ 00209 const ComparisonParams& op_params, const RuntimeShape& input1_shape, \ 00210 const T* input1_data, const RuntimeShape& input2_shape, \ 00211 const T* input2_data, const RuntimeShape& output_shape, \ 00212 bool* output_data) { \ 00213 ComparisonImpl<T, name##Fn>(op_params, input1_shape, input1_data, \ 00214 input2_shape, input2_data, output_shape, \ 00215 output_data); \ 00216 } \ 00217 template <typename T> \ 00218 inline void name##WithScaling( \ 00219 const ComparisonParams& op_params, const RuntimeShape& input1_shape, \ 00220 const T* input1_data, const RuntimeShape& input2_shape, \ 00221 const T* input2_data, const RuntimeShape& output_shape, \ 00222 bool* output_data) { \ 00223 ComparisonWithScaling<T, name##Fn>(op_params, input1_shape, input1_data, \ 00224 input2_shape, input2_data, \ 00225 output_shape, output_data); \ 00226 } \ 00227 template <typename T> \ 00228 inline void Broadcast4DSlow##name##NoScaling( \ 00229 const ComparisonParams& op_params, const RuntimeShape& input1_shape, \ 00230 const T* input1_data, const RuntimeShape& input2_shape, \ 00231 const T* input2_data, const RuntimeShape& output_shape, \ 00232 bool* output_data) { \ 00233 BroadcastComparison4DSlowImpl<T, name##Fn>( \ 00234 op_params, input1_shape, input1_data, input2_shape, input2_data, \ 00235 output_shape, output_data); \ 00236 } \ 00237 inline void Broadcast4DSlow##name( \ 00238 const ComparisonParams& op_params, const RuntimeShape& input1_shape, \ 00239 const float* input1_data, const RuntimeShape& input2_shape, \ 00240 const float* input2_data, const RuntimeShape& output_shape, \ 00241 bool* output_data) { \ 00242 BroadcastComparison4DSlow<name##Fn>(op_params, input1_shape, input1_data, \ 00243 input2_shape, input2_data, \ 00244 output_shape, output_data); \ 00245 } \ 00246 template <typename T> \ 00247 inline void Broadcast4DSlow##name##WithScaling( \ 00248 const ComparisonParams& op_params, const RuntimeShape& input1_shape, \ 00249 const T* input1_data, const RuntimeShape& input2_shape, \ 00250 const T* input2_data, const RuntimeShape& output_shape, \ 00251 bool* output_data) { \ 00252 BroadcastComparison4DSlowWithScaling<T, name##Fn>( \ 00253 op_params, input1_shape, input1_data, input2_shape, input2_data, \ 00254 output_shape, output_data); \ 00255 } 00256 TFLITE_COMPARISON_OP(Equal); 00257 TFLITE_COMPARISON_OP(NotEqual); 00258 TFLITE_COMPARISON_OP(Greater); 00259 TFLITE_COMPARISON_OP(GreaterEqual); 00260 TFLITE_COMPARISON_OP(Less); 00261 TFLITE_COMPARISON_OP(LessEqual); 00262 #undef TFLITE_COMPARISON_OP 00263 00264 } // namespace reference_ops 00265 } // namespace tflite 00266 00267 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_COMPARISONS_H_
Generated on Wed Jul 13 2022 16:03:35 by
