Important changes to repositories hosted on mbed.com
Mbed hosted mercurial repositories are deprecated and are due to be permanently deleted in July 2026.
To keep a copy of this software download the repository Zip archive or clone locally using Mercurial.
It is also possible to export all your personal repositories from the account settings page.
common.h
00001 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 00002 00003 Licensed under the Apache License, Version 2.0 (the "License"); 00004 you may not use this file except in compliance with the License. 00005 You may obtain a copy of the License at 00006 00007 http://www.apache.org/licenses/LICENSE-2.0 00008 00009 Unless required by applicable law or agreed to in writing, software 00010 distributed under the License is distributed on an "AS IS" BASIS, 00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 00012 See the License for the specific language governing permissions and 00013 limitations under the License. 00014 ==============================================================================*/ 00015 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_COMMON_H_ 00016 #define TENSORFLOW_LITE_KERNELS_INTERNAL_COMMON_H_ 00017 00018 #ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK 00019 #ifdef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK 00020 #define ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK 00021 #endif 00022 #endif 00023 00024 #include "fixedpoint/fixedpoint.h" 00025 #include "tensorflow/lite/kernels/internal/optimized/neon_check.h" 00026 #include "tensorflow/lite/kernels/internal/types.h" 00027 00028 namespace tflite { 00029 00030 inline void GetActivationMinMax(FusedActivationFunctionType ac, 00031 float* output_activation_min, 00032 float* output_activation_max) { 00033 switch (ac) { 00034 case FusedActivationFunctionType::kNone: 00035 *output_activation_min = std::numeric_limits<float>::lowest(); 00036 *output_activation_max = std::numeric_limits<float>::max(); 00037 break; 00038 case FusedActivationFunctionType::kRelu: 00039 *output_activation_min = 0.f; 00040 *output_activation_max = std::numeric_limits<float>::max(); 00041 break; 00042 case FusedActivationFunctionType::kRelu1: 00043 *output_activation_min = -1.f; 00044 *output_activation_max = 1.f; 00045 break; 00046 case FusedActivationFunctionType::kRelu6: 00047 *output_activation_min = 0.f; 00048 *output_activation_max = 6.f; 00049 break; 00050 } 00051 } 00052 00053 inline float ActivationFunctionWithMinMax(float x, float output_activation_min, 00054 float output_activation_max) { 00055 return std::min(std::max(x, output_activation_min), output_activation_max); 00056 } 00057 00058 // Legacy function, left for compatibility only. 00059 template <FusedActivationFunctionType Ac> 00060 float ActivationFunction(float x) { 00061 float output_activation_min, output_activation_max; 00062 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); 00063 return ActivationFunctionWithMinMax(x, output_activation_min, 00064 output_activation_max); 00065 } 00066 00067 inline void BiasAndClamp(float clamp_min, float clamp_max, int bias_size, 00068 const float* bias_data, int array_size, 00069 float* array_data) { 00070 // Note: see b/132215220: in May 2019 we thought it would be OK to replace 00071 // this with the Eigen one-liner: 00072 // return (array.colwise() + bias).cwiseMin(clamp_max).cwiseMin(clamp_max). 00073 // This turned out to severely regress performance: +4ms (i.e. 8%) on 00074 // MobileNet v2 / 1.0 / 224. So we keep custom NEON code for now. 00075 TFLITE_DCHECK_EQ((array_size % bias_size), 0); 00076 #ifdef USE_NEON 00077 float* array_ptr = array_data; 00078 float* array_end_ptr = array_ptr + array_size; 00079 const auto clamp_min_vec = vdupq_n_f32(clamp_min); 00080 const auto clamp_max_vec = vdupq_n_f32(clamp_max); 00081 for (; array_ptr != array_end_ptr; array_ptr += bias_size) { 00082 int i = 0; 00083 for (; i <= bias_size - 16; i += 16) { 00084 auto b0 = vld1q_f32(bias_data + i); 00085 auto b1 = vld1q_f32(bias_data + i + 4); 00086 auto b2 = vld1q_f32(bias_data + i + 8); 00087 auto b3 = vld1q_f32(bias_data + i + 12); 00088 auto a0 = vld1q_f32(array_ptr + i); 00089 auto a1 = vld1q_f32(array_ptr + i + 4); 00090 auto a2 = vld1q_f32(array_ptr + i + 8); 00091 auto a3 = vld1q_f32(array_ptr + i + 12); 00092 auto x0 = vaddq_f32(a0, b0); 00093 auto x1 = vaddq_f32(a1, b1); 00094 auto x2 = vaddq_f32(a2, b2); 00095 auto x3 = vaddq_f32(a3, b3); 00096 x0 = vmaxq_f32(clamp_min_vec, x0); 00097 x1 = vmaxq_f32(clamp_min_vec, x1); 00098 x2 = vmaxq_f32(clamp_min_vec, x2); 00099 x3 = vmaxq_f32(clamp_min_vec, x3); 00100 x0 = vminq_f32(clamp_max_vec, x0); 00101 x1 = vminq_f32(clamp_max_vec, x1); 00102 x2 = vminq_f32(clamp_max_vec, x2); 00103 x3 = vminq_f32(clamp_max_vec, x3); 00104 vst1q_f32(array_ptr + i, x0); 00105 vst1q_f32(array_ptr + i + 4, x1); 00106 vst1q_f32(array_ptr + i + 8, x2); 00107 vst1q_f32(array_ptr + i + 12, x3); 00108 } 00109 for (; i <= bias_size - 4; i += 4) { 00110 auto b = vld1q_f32(bias_data + i); 00111 auto a = vld1q_f32(array_ptr + i); 00112 auto x = vaddq_f32(a, b); 00113 x = vmaxq_f32(clamp_min_vec, x); 00114 x = vminq_f32(clamp_max_vec, x); 00115 vst1q_f32(array_ptr + i, x); 00116 } 00117 for (; i < bias_size; i++) { 00118 array_ptr[i] = ActivationFunctionWithMinMax(array_ptr[i] + bias_data[i], 00119 clamp_min, clamp_max); 00120 } 00121 } 00122 #else // not NEON 00123 for (int array_offset = 0; array_offset < array_size; 00124 array_offset += bias_size) { 00125 for (int i = 0; i < bias_size; i++) { 00126 array_data[array_offset + i] = ActivationFunctionWithMinMax( 00127 array_data[array_offset + i] + bias_data[i], clamp_min, clamp_max); 00128 } 00129 } 00130 #endif 00131 } 00132 00133 inline int32 MultiplyByQuantizedMultiplierSmallerThanOneExp( 00134 int32 x, int32 quantized_multiplier, int left_shift) { 00135 using gemmlowp::RoundingDivideByPOT; 00136 using gemmlowp::SaturatingRoundingDoublingHighMul; 00137 return RoundingDivideByPOT( 00138 SaturatingRoundingDoublingHighMul(x, quantized_multiplier), -left_shift); 00139 } 00140 00141 inline int32 MultiplyByQuantizedMultiplierGreaterThanOne( 00142 int32 x, int32 quantized_multiplier, int left_shift) { 00143 using gemmlowp::SaturatingRoundingDoublingHighMul; 00144 return SaturatingRoundingDoublingHighMul(x * (1 << left_shift), 00145 quantized_multiplier); 00146 } 00147 00148 inline int32 MultiplyByQuantizedMultiplier(int32 x, int32 quantized_multiplier, 00149 int shift) { 00150 using gemmlowp::RoundingDivideByPOT; 00151 using gemmlowp::SaturatingRoundingDoublingHighMul; 00152 int left_shift = shift > 0 ? shift : 0; 00153 int right_shift = shift > 0 ? 0 : -shift; 00154 return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul( 00155 x * (1 << left_shift), quantized_multiplier), 00156 right_shift); 00157 } 00158 00159 template <typename T> 00160 int CountLeadingZeros(T integer_input) { 00161 static_assert(std::is_unsigned<T>::value, 00162 "Only unsigned integer types handled."); 00163 #if defined(__GNUC__) 00164 return integer_input ? __builtin_clz(integer_input) 00165 : std::numeric_limits<T>::digits; 00166 #else 00167 if (integer_input == 0) { 00168 return std::numeric_limits<T>::digits; 00169 } 00170 00171 const T one_in_leading_positive = static_cast<T>(1) 00172 << (std::numeric_limits<T>::digits - 1); 00173 int leading_zeros = 0; 00174 while (integer_input < one_in_leading_positive) { 00175 integer_input <<= 1; 00176 ++leading_zeros; 00177 } 00178 return leading_zeros; 00179 #endif 00180 } 00181 00182 template <typename T> 00183 inline int CountLeadingSignBits(T integer_input) { 00184 static_assert(std::is_signed<T>::value, "Only signed integer types handled."); 00185 #if defined(__GNUC__) && !defined(__clang__) 00186 return integer_input ? __builtin_clrsb(integer_input) 00187 : std::numeric_limits<T>::digits; 00188 #else 00189 using U = typename std::make_unsigned<T>::type; 00190 return integer_input >= 0 00191 ? CountLeadingZeros(static_cast<U>(integer_input)) - 1 00192 : integer_input != std::numeric_limits<T>::min() 00193 ? CountLeadingZeros(2 * static_cast<U>(-integer_input) - 1) 00194 : 0; 00195 #endif 00196 } 00197 00198 // TODO(b/77858996): Add these to gemmlowp. 00199 template <typename IntegerType> 00200 IntegerType SaturatingAddNonGemmlowp(IntegerType a, IntegerType b) { 00201 static_assert(std::is_same<IntegerType, void>::value, "unimplemented"); 00202 return a; 00203 } 00204 00205 template <> 00206 inline std::int32_t SaturatingAddNonGemmlowp(std::int32_t a, std::int32_t b) { 00207 std::int64_t a64 = a; 00208 std::int64_t b64 = b; 00209 std::int64_t sum = a64 + b64; 00210 return static_cast<std::int32_t>(std::min( 00211 static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::max()), 00212 std::max( 00213 static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::min()), 00214 sum))); 00215 } 00216 00217 template <typename tRawType, int tIntegerBits> 00218 gemmlowp::FixedPoint<tRawType, tIntegerBits> SaturatingAddNonGemmlowp( 00219 gemmlowp::FixedPoint<tRawType, tIntegerBits> a, 00220 gemmlowp::FixedPoint<tRawType, tIntegerBits> b) { 00221 return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw( 00222 SaturatingAddNonGemmlowp(a.raw(), b.raw())); 00223 } 00224 00225 template <typename IntegerType> 00226 IntegerType SaturatingSub(IntegerType a, IntegerType b) { 00227 static_assert(std::is_same<IntegerType, void>::value, "unimplemented"); 00228 return a; 00229 } 00230 00231 template <> 00232 inline std::int16_t SaturatingSub(std::int16_t a, std::int16_t b) { 00233 std::int32_t a32 = a; 00234 std::int32_t b32 = b; 00235 std::int32_t diff = a32 - b32; 00236 return static_cast<std::int16_t>( 00237 std::min(static_cast<int32_t>(32767), 00238 std::max(static_cast<int32_t>(-32768), diff))); 00239 } 00240 00241 template <> 00242 inline std::int32_t SaturatingSub(std::int32_t a, std::int32_t b) { 00243 std::int64_t a64 = a; 00244 std::int64_t b64 = b; 00245 std::int64_t diff = a64 - b64; 00246 return static_cast<std::int32_t>(std::min( 00247 static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::max()), 00248 std::max( 00249 static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::min()), 00250 diff))); 00251 } 00252 00253 template <typename tRawType, int tIntegerBits> 00254 gemmlowp::FixedPoint<tRawType, tIntegerBits> SaturatingSub( 00255 gemmlowp::FixedPoint<tRawType, tIntegerBits> a, 00256 gemmlowp::FixedPoint<tRawType, tIntegerBits> b) { 00257 return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw( 00258 SaturatingSub(a.raw(), b.raw())); 00259 } 00260 // End section to be moved to gemmlowp. 00261 00262 template <typename IntegerType> 00263 IntegerType SaturatingRoundingMultiplyByPOTParam(IntegerType x, int exponent) { 00264 if (exponent == 0) { 00265 return x; 00266 } 00267 using ScalarIntegerType = 00268 typename gemmlowp::FixedPointRawTypeTraits<IntegerType>::ScalarRawType; 00269 const IntegerType min = 00270 gemmlowp::Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::min()); 00271 const IntegerType max = 00272 gemmlowp::Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::max()); 00273 const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType); 00274 00275 const std::int32_t threshold = 00276 ((1 << (ScalarIntegerTypeBits - 1 - exponent)) - 1); 00277 const IntegerType positive_mask = 00278 gemmlowp::MaskIfGreaterThan(x, gemmlowp::Dup<IntegerType>(threshold)); 00279 const IntegerType negative_mask = 00280 gemmlowp::MaskIfLessThan(x, gemmlowp::Dup<IntegerType>(-threshold)); 00281 00282 IntegerType result = gemmlowp::ShiftLeft(x, exponent); 00283 result = gemmlowp::SelectUsingMask(positive_mask, max, result); 00284 result = gemmlowp::SelectUsingMask(negative_mask, min, result); 00285 return result; 00286 } 00287 00288 // If we want to leave IntegerBits fixed, then multiplication 00289 // by a power of two has to be saturating/rounding, not exact anymore. 00290 template <typename tRawType, int tIntegerBits> 00291 gemmlowp::FixedPoint<tRawType, tIntegerBits> 00292 SaturatingRoundingMultiplyByPOTParam( 00293 gemmlowp::FixedPoint<tRawType, tIntegerBits> a, int exponent) { 00294 return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw( 00295 SaturatingRoundingMultiplyByPOTParam(a.raw(), exponent)); 00296 } 00297 00298 // Minimum output bits to accommodate log of maximum input range. It actually 00299 // does not matter if one considers, say, [-64,64] or [-64,64). 00300 // 00301 // For example, run this through Octave: 00302 // [0:127; ... 00303 // ceil(log(abs( log(2.^(0:127))+1 ))/log(2)); ... 00304 // ceil(log(abs( log(2.^(0:127))+1 ))/log(2))] 00305 constexpr int min_log_x_output_bits(int input_bits) { 00306 return input_bits > 90 00307 ? 7 00308 : input_bits > 44 00309 ? 6 00310 : input_bits > 21 00311 ? 5 00312 : input_bits > 10 00313 ? 4 00314 : input_bits > 4 ? 3 : input_bits > 1 ? 2 : 1; 00315 } 00316 00317 // Although currently the name of this function says that it cannot handle 00318 // values less than 1, in practice it can handle as low as 1/x_max, where 00319 // x_max is the largest representable input. In other words, the output range 00320 // is symmetric. 00321 template <int OutputIntegerBits, int InputIntegerBits> 00322 inline gemmlowp::FixedPoint<int32, OutputIntegerBits> 00323 log_x_for_x_greater_than_or_equal_to_1_impl( 00324 gemmlowp::FixedPoint<int32, InputIntegerBits> input_val) { 00325 // assert(__builtin_clz(0u) >= std::numeric_limits<uint32>::digits - 1); 00326 // assert(__builtin_clz(0u) <= std::numeric_limits<uint32>::digits); 00327 using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>; 00328 // The reason for accumulating the result with an extra bit of headroom is 00329 // that z_pow_2_adj * log_2 might be saturated, and adding num_scaled * 00330 // recip_denom will otherwise introduce an error. 00331 static constexpr int kAccumIntegerBits = OutputIntegerBits + 1; 00332 using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumIntegerBits>; 00333 00334 const FixedPoint0 log_2 = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( 00335 FixedPoint0, 1488522236, std::log(2.0)); 00336 const FixedPoint0 sqrt_sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( 00337 FixedPoint0, 1805811301, std::sqrt(std::sqrt(0.5))); 00338 const FixedPoint0 sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( 00339 FixedPoint0, 1518500250, std::sqrt(0.5)); 00340 const FixedPoint0 one_quarter = 00341 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPoint0, 536870912, 1.0 / 4.0); 00342 00343 const FixedPoint0 alpha_n = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( 00344 FixedPoint0, 117049297, 11.0 / 240.0 * std::sqrt(std::sqrt(2.0))); 00345 const FixedPoint0 alpha_d = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( 00346 FixedPoint0, 127690142, 1.0 / 20.0 * std::sqrt(std::sqrt(2.0))); 00347 const FixedPoint0 alpha_i = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( 00348 FixedPoint0, 1057819769, 00349 2.0 / std::sqrt(std::sqrt(2.0)) - std::sqrt(std::sqrt(2.0))); 00350 const FixedPoint0 alpha_f = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( 00351 FixedPoint0, 638450708, 1.0 / 4.0 * std::sqrt(std::sqrt(2.0))); 00352 00353 const FixedPointAccum shifted_quarter = 00354 gemmlowp::Rescale<kAccumIntegerBits>(one_quarter); 00355 00356 // Reinterpret the input value as Q0.31, because we will figure out the 00357 // required shift "ourselves" instead of using, say, Rescale. 00358 FixedPoint0 z_a = FixedPoint0::FromRaw(input_val.raw()); 00359 // z_a_pow_2 = input_integer_bits - z_a_headroom; 00360 int z_a_headroom_plus_1 = CountLeadingZeros(static_cast<uint32>(z_a.raw())); 00361 FixedPoint0 r_a_tmp = 00362 SaturatingRoundingMultiplyByPOTParam(z_a, (z_a_headroom_plus_1 - 1)); 00363 const int32 r_a_raw = 00364 SaturatingRoundingMultiplyByPOTParam((r_a_tmp * sqrt_half).raw(), 1); 00365 // z_pow_2_adj = max(z_pow_2_a - 0.75, z_pow_2_b - 0.25); 00366 // z_pow_2_adj = max(InputIntegerBits - z_a_headroom_plus_1 + 0.25, 00367 // InputIntegerBits - z_b_headroom - 0.25); 00368 const FixedPointAccum z_a_pow_2_adj = SaturatingAddNonGemmlowp( 00369 FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam( 00370 InputIntegerBits - z_a_headroom_plus_1, 31 - kAccumIntegerBits)), 00371 shifted_quarter); 00372 00373 // z_b is treated like z_a, but premultiplying by sqrt(0.5). 00374 FixedPoint0 z_b = z_a * sqrt_half; 00375 int z_b_headroom = CountLeadingZeros(static_cast<uint32>(z_b.raw())) - 1; 00376 const int32 r_b_raw = 00377 SaturatingRoundingMultiplyByPOTParam(z_a.raw(), z_b_headroom); 00378 const FixedPointAccum z_b_pow_2_adj = SaturatingSub( 00379 FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam( 00380 InputIntegerBits - z_b_headroom, 31 - kAccumIntegerBits)), 00381 shifted_quarter); 00382 00383 const FixedPoint0 r = FixedPoint0::FromRaw(std::min(r_a_raw, r_b_raw)); 00384 const FixedPointAccum z_pow_2_adj = FixedPointAccum::FromRaw( 00385 std::max(z_a_pow_2_adj.raw(), z_b_pow_2_adj.raw())); 00386 00387 const FixedPoint0 p = gemmlowp::RoundingHalfSum(r, sqrt_sqrt_half); 00388 FixedPoint0 q = r - sqrt_sqrt_half; 00389 q = q + q; 00390 00391 const FixedPoint0 common_sq = q * q; 00392 const FixedPoint0 num = q * r + q * common_sq * alpha_n; 00393 const FixedPoint0 denom_minus_one_0 = 00394 p * (alpha_i + q + alpha_d * common_sq) + alpha_f * q; 00395 const FixedPoint0 recip_denom = 00396 one_over_one_plus_x_for_x_in_0_1(denom_minus_one_0); 00397 00398 const FixedPointAccum num_scaled = gemmlowp::Rescale<kAccumIntegerBits>(num); 00399 return gemmlowp::Rescale<OutputIntegerBits>(z_pow_2_adj * log_2 + 00400 num_scaled * recip_denom); 00401 } 00402 00403 template <int OutputIntegerBits, int InputIntegerBits> 00404 inline gemmlowp::FixedPoint<int32, OutputIntegerBits> 00405 log_x_for_x_greater_than_or_equal_to_1( 00406 gemmlowp::FixedPoint<int32, InputIntegerBits> input_val) { 00407 static_assert( 00408 OutputIntegerBits >= min_log_x_output_bits(InputIntegerBits), 00409 "Output integer bits must be sufficent to accommodate logs of inputs."); 00410 return log_x_for_x_greater_than_or_equal_to_1_impl<OutputIntegerBits, 00411 InputIntegerBits>( 00412 input_val); 00413 } 00414 00415 inline int32 GetReciprocal(int32 x, int x_integer_digits, 00416 int* num_bits_over_unit) { 00417 int headroom_plus_one = CountLeadingZeros(static_cast<uint32>(x)); 00418 // This is the number of bits to the left of the binary point above 1.0. 00419 // Consider x=1.25. In that case shifted_scale=0.8 and 00420 // no later adjustment will be needed. 00421 *num_bits_over_unit = x_integer_digits - headroom_plus_one; 00422 const int32 shifted_sum_minus_one = 00423 static_cast<int32>((static_cast<uint32>(x) << headroom_plus_one) - 00424 (static_cast<uint32>(1) << 31)); 00425 00426 gemmlowp::FixedPoint<int32, 0> shifted_scale = 00427 gemmlowp::one_over_one_plus_x_for_x_in_0_1( 00428 gemmlowp::FixedPoint<int32, 0>::FromRaw(shifted_sum_minus_one)); 00429 return shifted_scale.raw(); 00430 } 00431 00432 inline void GetInvSqrtQuantizedMultiplierExp(int32 input, int reverse_shift, 00433 int32* output_inv_sqrt, 00434 int* output_shift) { 00435 *output_shift = 11; 00436 while (input >= (1 << 29)) { 00437 input /= 4; 00438 ++*output_shift; 00439 } 00440 TFLITE_DCHECK_GT(input, 0); 00441 const unsigned max_left_shift_bits = 00442 CountLeadingZeros(static_cast<uint32>(input)) - 1; 00443 const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2; 00444 const unsigned left_shift_bit_pairs = max_left_shift_bit_pairs - 1; 00445 *output_shift -= left_shift_bit_pairs; 00446 input <<= 2 * left_shift_bit_pairs; 00447 TFLITE_DCHECK_GE(input, (1 << 27)); 00448 TFLITE_DCHECK_LT(input, (1 << 29)); 00449 using gemmlowp::FixedPoint; 00450 using gemmlowp::Rescale; 00451 using gemmlowp::SaturatingRoundingMultiplyByPOT; 00452 // Using 3 integer bits gives us enough room for the internal arithmetic in 00453 // this Newton-Raphson iteration. 00454 using F3 = FixedPoint<int32, 3>; 00455 using F0 = FixedPoint<int32, 0>; 00456 const F3 fixedpoint_input = F3::FromRaw(input >> 1); 00457 const F3 fixedpoint_half_input = 00458 SaturatingRoundingMultiplyByPOT<-1>(fixedpoint_input); 00459 const F3 fixedpoint_half_three = 00460 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F3, (1 << 28) + (1 << 27), 1.5); 00461 // Newton-Raphson iteration 00462 // Naive unoptimized starting guess: x = 1 00463 F3 x = F3::One(); 00464 // Naive unoptimized number of iterations: 5 00465 for (int i = 0; i < 5; i++) { 00466 const F3 x3 = Rescale<3>(x * x * x); 00467 x = Rescale<3>(fixedpoint_half_three * x - fixedpoint_half_input * x3); 00468 } 00469 const F0 fixedpoint_half_sqrt_2 = 00470 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F0, 1518500250, std::sqrt(2.) / 2.); 00471 x = x * fixedpoint_half_sqrt_2; 00472 *output_inv_sqrt = x.raw(); 00473 if (*output_shift < 0) { 00474 *output_inv_sqrt <<= -*output_shift; 00475 *output_shift = 0; 00476 } 00477 // Convert right shift (right is positive) to left shift. 00478 *output_shift *= reverse_shift; 00479 } 00480 00481 // DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING 00482 // BROADCASTING. 00483 // 00484 // NdArrayDesc<N> describes the shape and memory layout of an N-dimensional 00485 // rectangular array of numbers. 00486 // 00487 // NdArrayDesc<N> is basically identical to Dims<N> defined in types.h. 00488 // However, as Dims<N> is to be deprecated, this class exists as an adaptor 00489 // to enable simple unoptimized implementations of element-wise broadcasting 00490 // operations. 00491 template <int N> 00492 struct NdArrayDesc { 00493 // The "extent" of each dimension. Indices along dimension d must be in the 00494 // half-open interval [0, extents[d]). 00495 int extents[N]; 00496 00497 // The number of *elements* (not bytes) between consecutive indices of each 00498 // dimension. 00499 int strides[N]; 00500 }; 00501 00502 // DO NOT USE THIS FUNCTION FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING 00503 // BROADCASTING. 00504 // 00505 // Same as Offset(), except takes as NdArrayDesc<N> instead of Dims<N>. 00506 inline int SubscriptToIndex(const NdArrayDesc<4>& desc, int i0, int i1, int i2, 00507 int i3) { 00508 TFLITE_DCHECK(i0 >= 0 && i0 < desc.extents[0]); 00509 TFLITE_DCHECK(i1 >= 0 && i1 < desc.extents[1]); 00510 TFLITE_DCHECK(i2 >= 0 && i2 < desc.extents[2]); 00511 TFLITE_DCHECK(i3 >= 0 && i3 < desc.extents[3]); 00512 return i0 * desc.strides[0] + i1 * desc.strides[1] + i2 * desc.strides[2] + 00513 i3 * desc.strides[3]; 00514 } 00515 00516 // Given the dimensions of the operands for an element-wise binary broadcast, 00517 // adjusts them so that they can be directly iterated over with simple loops. 00518 // Returns the adjusted dims as instances of NdArrayDesc in 'desc0_out' and 00519 // 'desc1_out'. 'desc0_out' and 'desc1_out' cannot be nullptr. 00520 // 00521 // This function assumes that the two input shapes are compatible up to 00522 // broadcasting and the shorter one has already been prepended with 1s to be the 00523 // same length. E.g., if shape0 is (1, 16, 16, 64) and shape1 is (1, 64), 00524 // shape1 must already have been prepended to be (1, 1, 1, 64). Recall that 00525 // Dims<N> refer to shapes in reverse order. In this case, input0_dims will be 00526 // (64, 16, 16, 1) and input1_dims will be (64, 1, 1, 1). 00527 // 00528 // When two shapes are compatible up to broadcasting, for each dimension d, 00529 // the input extents are either equal, or one of them is 1. 00530 // 00531 // This function performs the following for each dimension d: 00532 // - If the extents are equal, then do nothing since the loop that walks over 00533 // both of the input arrays is correct. 00534 // - Otherwise, one (and only one) of the extents must be 1. Say extent0 is 1 00535 // and extent1 is e1. Then set extent0 to e1 and stride0 *to 0*. This allows 00536 // array0 to be referenced *at any index* in dimension d and still access the 00537 // same slice. 00538 template <int N> 00539 inline void NdArrayDescsForElementwiseBroadcast(const Dims<N>& input0_dims, 00540 const Dims<N>& input1_dims, 00541 NdArrayDesc<N>* desc0_out, 00542 NdArrayDesc<N>* desc1_out) { 00543 TFLITE_DCHECK(desc0_out != nullptr); 00544 TFLITE_DCHECK(desc1_out != nullptr); 00545 00546 // Copy dims to desc. 00547 for (int i = 0; i < N; ++i) { 00548 desc0_out->extents[i] = input0_dims.sizes[i]; 00549 desc0_out->strides[i] = input0_dims.strides[i]; 00550 desc1_out->extents[i] = input1_dims.sizes[i]; 00551 desc1_out->strides[i] = input1_dims.strides[i]; 00552 } 00553 00554 // Walk over each dimension. If the extents are equal do nothing. 00555 // Otherwise, set the desc with extent 1 to have extent equal to the other and 00556 // stride 0. 00557 for (int i = 0; i < N; ++i) { 00558 const int extent0 = ArraySize(input0_dims, i); 00559 const int extent1 = ArraySize(input1_dims, i); 00560 if (extent0 != extent1) { 00561 if (extent0 == 1) { 00562 desc0_out->strides[i] = 0; 00563 desc0_out->extents[i] = extent1; 00564 } else { 00565 TFLITE_DCHECK_EQ(extent1, 1); 00566 desc1_out->strides[i] = 0; 00567 desc1_out->extents[i] = extent0; 00568 } 00569 } 00570 } 00571 } 00572 00573 template <int N> 00574 inline void NdArrayDescsForElementwiseBroadcast( 00575 const RuntimeShape& input0_shape, const RuntimeShape& input1_shape, 00576 NdArrayDesc<N>* desc0_out, NdArrayDesc<N>* desc1_out) { 00577 TFLITE_DCHECK(desc0_out != nullptr); 00578 TFLITE_DCHECK(desc1_out != nullptr); 00579 00580 auto extended_input0_shape = RuntimeShape::ExtendedShape(N, input0_shape); 00581 auto extended_input1_shape = RuntimeShape::ExtendedShape(N, input1_shape); 00582 00583 // Copy dims to desc, calculating strides. 00584 int desc0_stride = 1; 00585 int desc1_stride = 1; 00586 for (int i = N - 1; i >= 0; --i) { 00587 desc0_out->extents[i] = extended_input0_shape.Dims(i); 00588 desc0_out->strides[i] = desc0_stride; 00589 desc0_stride *= extended_input0_shape.Dims(i); 00590 desc1_out->extents[i] = extended_input1_shape.Dims(i); 00591 desc1_out->strides[i] = desc1_stride; 00592 desc1_stride *= extended_input1_shape.Dims(i); 00593 } 00594 00595 // Walk over each dimension. If the extents are equal do nothing. 00596 // Otherwise, set the desc with extent 1 to have extent equal to the other and 00597 // stride 0. 00598 for (int i = 0; i < N; ++i) { 00599 const int extent0 = extended_input0_shape.Dims(i); 00600 const int extent1 = extended_input1_shape.Dims(i); 00601 if (extent0 != extent1) { 00602 if (extent0 == 1) { 00603 desc0_out->strides[i] = 0; 00604 desc0_out->extents[i] = extent1; 00605 } else { 00606 TFLITE_DCHECK_EQ(extent1, 1); 00607 desc1_out->strides[i] = 0; 00608 desc1_out->extents[i] = extent0; 00609 } 00610 } 00611 } 00612 } 00613 00614 // Copied from gemmlowp::RoundDown when we dropped direct dependency on 00615 // gemmlowp. 00616 // 00617 // Returns the runtime argument rounded down to the nearest multiple of 00618 // the fixed Modulus. 00619 template <unsigned Modulus, typename Integer> 00620 Integer RoundDown(Integer i) { 00621 return i - (i % Modulus); 00622 } 00623 00624 // Copied from gemmlowp::RoundUp when we dropped direct dependency on 00625 // gemmlowp. 00626 // 00627 // Returns the runtime argument rounded up to the nearest multiple of 00628 // the fixed Modulus. 00629 template <unsigned Modulus, typename Integer> 00630 Integer RoundUp(Integer i) { 00631 return RoundDown<Modulus>(i + Modulus - 1); 00632 } 00633 00634 // Copied from gemmlowp::CeilQuotient when we dropped direct dependency on 00635 // gemmlowp. 00636 // 00637 // Returns the quotient a / b rounded up ('ceil') to the nearest integer. 00638 template <typename Integer> 00639 Integer CeilQuotient(Integer a, Integer b) { 00640 return (a + b - 1) / b; 00641 } 00642 00643 // This function is a copy of gemmlowp::HowManyThreads, copied when we dropped 00644 // the direct dependency of internal/optimized/ on gemmlowp. 00645 // 00646 // It computes a reasonable number of threads to use for a GEMM of shape 00647 // (rows, cols, depth). 00648 // 00649 // TODO(b/131910176): get rid of this function by switching each call site 00650 // to its own more sensible logic for its own workload. 00651 template <int KernelRows> 00652 inline int LegacyHowManyThreads(int max_num_threads, int rows, int cols, 00653 int depth) { 00654 // Early-exit in the default case where multi-threading is disabled. 00655 if (max_num_threads == 1) { 00656 return 1; 00657 } 00658 00659 // Ensure that each thread has KernelRows rows to process, if at all possible. 00660 int thread_count = std::min(max_num_threads, rows / KernelRows); 00661 00662 // Limit the number of threads according to the overall size of the problem. 00663 if (thread_count > 1) { 00664 // Empirically determined value. 00665 static constexpr std::uint64_t min_cubic_size_per_thread = 64 * 1024; 00666 00667 // We can only multiply two out of three sizes without risking overflow 00668 const std::uint64_t cubic_size = 00669 std::uint64_t(rows) * std::uint64_t(cols) * std::uint64_t(depth); 00670 00671 thread_count = std::min( 00672 thread_count, static_cast<int>(cubic_size / min_cubic_size_per_thread)); 00673 } 00674 00675 if (thread_count < 1) { 00676 thread_count = 1; 00677 } 00678 00679 assert(thread_count > 0 && thread_count <= max_num_threads); 00680 return thread_count; 00681 } 00682 00683 template <typename T> 00684 void optimized_ops_preload_l1_stream(const T* ptr) { 00685 #ifdef __GNUC__ 00686 // builtin offered by GCC-compatible compilers including clang 00687 __builtin_prefetch(ptr, /* 0 means read */ 0, /* 0 means no locality */ 0); 00688 #else 00689 (void)ptr; 00690 #endif 00691 } 00692 00693 template <typename T> 00694 void optimized_ops_preload_l1_keep(const T* ptr) { 00695 #ifdef __GNUC__ 00696 // builtin offered by GCC-compatible compilers including clang 00697 __builtin_prefetch(ptr, /* 0 means read */ 0, /* 3 means high locality */ 3); 00698 #else 00699 (void)ptr; 00700 #endif 00701 } 00702 00703 template <typename T> 00704 void optimized_ops_prefetch_write_l1_keep(const T* ptr) { 00705 #ifdef __GNUC__ 00706 // builtin offered by GCC-compatible compilers including clang 00707 __builtin_prefetch(ptr, /* 1 means write */ 1, /* 3 means high locality */ 3); 00708 #else 00709 (void)ptr; 00710 #endif 00711 } 00712 00713 } // namespace tflite 00714 00715 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_COMMON_H_
Generated on Wed Jul 13 2022 16:03:35 by
