Daniel Konegen / MNIST_example

Dependencies:   mbed-os

Embed: (wiki syntax)

« Back to documentation index

Show/hide line numbers common.h Source File

common.h

00001 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
00002 
00003 Licensed under the Apache License, Version 2.0 (the "License");
00004 you may not use this file except in compliance with the License.
00005 You may obtain a copy of the License at
00006 
00007     http://www.apache.org/licenses/LICENSE-2.0
00008 
00009 Unless required by applicable law or agreed to in writing, software
00010 distributed under the License is distributed on an "AS IS" BASIS,
00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00012 See the License for the specific language governing permissions and
00013 limitations under the License.
00014 ==============================================================================*/
00015 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_COMMON_H_
00016 #define TENSORFLOW_LITE_KERNELS_INTERNAL_COMMON_H_
00017 
00018 #ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
00019 #ifdef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
00020 #define ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
00021 #endif
00022 #endif
00023 
00024 #include "fixedpoint/fixedpoint.h"
00025 #include "tensorflow/lite/kernels/internal/optimized/neon_check.h"
00026 #include "tensorflow/lite/kernels/internal/types.h"
00027 
00028 namespace tflite {
00029 
00030 inline void GetActivationMinMax(FusedActivationFunctionType ac,
00031                                 float* output_activation_min,
00032                                 float* output_activation_max) {
00033   switch (ac) {
00034     case FusedActivationFunctionType::kNone:
00035       *output_activation_min = std::numeric_limits<float>::lowest();
00036       *output_activation_max = std::numeric_limits<float>::max();
00037       break;
00038     case FusedActivationFunctionType::kRelu:
00039       *output_activation_min = 0.f;
00040       *output_activation_max = std::numeric_limits<float>::max();
00041       break;
00042     case FusedActivationFunctionType::kRelu1:
00043       *output_activation_min = -1.f;
00044       *output_activation_max = 1.f;
00045       break;
00046     case FusedActivationFunctionType::kRelu6:
00047       *output_activation_min = 0.f;
00048       *output_activation_max = 6.f;
00049       break;
00050   }
00051 }
00052 
00053 inline float ActivationFunctionWithMinMax(float x, float output_activation_min,
00054                                           float output_activation_max) {
00055   return std::min(std::max(x, output_activation_min), output_activation_max);
00056 }
00057 
00058 // Legacy function, left for compatibility only.
00059 template <FusedActivationFunctionType Ac>
00060 float ActivationFunction(float x) {
00061   float output_activation_min, output_activation_max;
00062   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
00063   return ActivationFunctionWithMinMax(x, output_activation_min,
00064                                       output_activation_max);
00065 }
00066 
00067 inline void BiasAndClamp(float clamp_min, float clamp_max, int bias_size,
00068                          const float* bias_data, int array_size,
00069                          float* array_data) {
00070   // Note: see b/132215220: in May 2019 we thought it would be OK to replace
00071   // this with the Eigen one-liner:
00072   //   return (array.colwise() + bias).cwiseMin(clamp_max).cwiseMin(clamp_max).
00073   // This turned out to severely regress performance: +4ms (i.e. 8%) on
00074   // MobileNet v2 / 1.0 / 224. So we keep custom NEON code for now.
00075   TFLITE_DCHECK_EQ((array_size % bias_size), 0);
00076 #ifdef USE_NEON
00077   float* array_ptr = array_data;
00078   float* array_end_ptr = array_ptr + array_size;
00079   const auto clamp_min_vec = vdupq_n_f32(clamp_min);
00080   const auto clamp_max_vec = vdupq_n_f32(clamp_max);
00081   for (; array_ptr != array_end_ptr; array_ptr += bias_size) {
00082     int i = 0;
00083     for (; i <= bias_size - 16; i += 16) {
00084       auto b0 = vld1q_f32(bias_data + i);
00085       auto b1 = vld1q_f32(bias_data + i + 4);
00086       auto b2 = vld1q_f32(bias_data + i + 8);
00087       auto b3 = vld1q_f32(bias_data + i + 12);
00088       auto a0 = vld1q_f32(array_ptr + i);
00089       auto a1 = vld1q_f32(array_ptr + i + 4);
00090       auto a2 = vld1q_f32(array_ptr + i + 8);
00091       auto a3 = vld1q_f32(array_ptr + i + 12);
00092       auto x0 = vaddq_f32(a0, b0);
00093       auto x1 = vaddq_f32(a1, b1);
00094       auto x2 = vaddq_f32(a2, b2);
00095       auto x3 = vaddq_f32(a3, b3);
00096       x0 = vmaxq_f32(clamp_min_vec, x0);
00097       x1 = vmaxq_f32(clamp_min_vec, x1);
00098       x2 = vmaxq_f32(clamp_min_vec, x2);
00099       x3 = vmaxq_f32(clamp_min_vec, x3);
00100       x0 = vminq_f32(clamp_max_vec, x0);
00101       x1 = vminq_f32(clamp_max_vec, x1);
00102       x2 = vminq_f32(clamp_max_vec, x2);
00103       x3 = vminq_f32(clamp_max_vec, x3);
00104       vst1q_f32(array_ptr + i, x0);
00105       vst1q_f32(array_ptr + i + 4, x1);
00106       vst1q_f32(array_ptr + i + 8, x2);
00107       vst1q_f32(array_ptr + i + 12, x3);
00108     }
00109     for (; i <= bias_size - 4; i += 4) {
00110       auto b = vld1q_f32(bias_data + i);
00111       auto a = vld1q_f32(array_ptr + i);
00112       auto x = vaddq_f32(a, b);
00113       x = vmaxq_f32(clamp_min_vec, x);
00114       x = vminq_f32(clamp_max_vec, x);
00115       vst1q_f32(array_ptr + i, x);
00116     }
00117     for (; i < bias_size; i++) {
00118       array_ptr[i] = ActivationFunctionWithMinMax(array_ptr[i] + bias_data[i],
00119                                                   clamp_min, clamp_max);
00120     }
00121   }
00122 #else  // not NEON
00123   for (int array_offset = 0; array_offset < array_size;
00124        array_offset += bias_size) {
00125     for (int i = 0; i < bias_size; i++) {
00126       array_data[array_offset + i] = ActivationFunctionWithMinMax(
00127           array_data[array_offset + i] + bias_data[i], clamp_min, clamp_max);
00128     }
00129   }
00130 #endif
00131 }
00132 
00133 inline int32 MultiplyByQuantizedMultiplierSmallerThanOneExp(
00134     int32 x, int32 quantized_multiplier, int left_shift) {
00135   using gemmlowp::RoundingDivideByPOT;
00136   using gemmlowp::SaturatingRoundingDoublingHighMul;
00137   return RoundingDivideByPOT(
00138       SaturatingRoundingDoublingHighMul(x, quantized_multiplier), -left_shift);
00139 }
00140 
00141 inline int32 MultiplyByQuantizedMultiplierGreaterThanOne(
00142     int32 x, int32 quantized_multiplier, int left_shift) {
00143   using gemmlowp::SaturatingRoundingDoublingHighMul;
00144   return SaturatingRoundingDoublingHighMul(x * (1 << left_shift),
00145                                            quantized_multiplier);
00146 }
00147 
00148 inline int32 MultiplyByQuantizedMultiplier(int32 x, int32 quantized_multiplier,
00149                                            int shift) {
00150   using gemmlowp::RoundingDivideByPOT;
00151   using gemmlowp::SaturatingRoundingDoublingHighMul;
00152   int left_shift = shift > 0 ? shift : 0;
00153   int right_shift = shift > 0 ? 0 : -shift;
00154   return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(
00155                                  x * (1 << left_shift), quantized_multiplier),
00156                              right_shift);
00157 }
00158 
00159 template <typename T>
00160 int CountLeadingZeros(T integer_input) {
00161   static_assert(std::is_unsigned<T>::value,
00162                 "Only unsigned integer types handled.");
00163 #if defined(__GNUC__)
00164   return integer_input ? __builtin_clz(integer_input)
00165                        : std::numeric_limits<T>::digits;
00166 #else
00167   if (integer_input == 0) {
00168     return std::numeric_limits<T>::digits;
00169   }
00170 
00171   const T one_in_leading_positive = static_cast<T>(1)
00172                                     << (std::numeric_limits<T>::digits - 1);
00173   int leading_zeros = 0;
00174   while (integer_input < one_in_leading_positive) {
00175     integer_input <<= 1;
00176     ++leading_zeros;
00177   }
00178   return leading_zeros;
00179 #endif
00180 }
00181 
00182 template <typename T>
00183 inline int CountLeadingSignBits(T integer_input) {
00184   static_assert(std::is_signed<T>::value, "Only signed integer types handled.");
00185 #if defined(__GNUC__) && !defined(__clang__)
00186   return integer_input ? __builtin_clrsb(integer_input)
00187                        : std::numeric_limits<T>::digits;
00188 #else
00189   using U = typename std::make_unsigned<T>::type;
00190   return integer_input >= 0
00191              ? CountLeadingZeros(static_cast<U>(integer_input)) - 1
00192              : integer_input != std::numeric_limits<T>::min()
00193                    ? CountLeadingZeros(2 * static_cast<U>(-integer_input) - 1)
00194                    : 0;
00195 #endif
00196 }
00197 
00198 // TODO(b/77858996): Add these to gemmlowp.
00199 template <typename IntegerType>
00200 IntegerType SaturatingAddNonGemmlowp(IntegerType a, IntegerType b) {
00201   static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
00202   return a;
00203 }
00204 
00205 template <>
00206 inline std::int32_t SaturatingAddNonGemmlowp(std::int32_t a, std::int32_t b) {
00207   std::int64_t a64 = a;
00208   std::int64_t b64 = b;
00209   std::int64_t sum = a64 + b64;
00210   return static_cast<std::int32_t>(std::min(
00211       static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::max()),
00212       std::max(
00213           static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::min()),
00214           sum)));
00215 }
00216 
00217 template <typename tRawType, int tIntegerBits>
00218 gemmlowp::FixedPoint<tRawType, tIntegerBits> SaturatingAddNonGemmlowp(
00219     gemmlowp::FixedPoint<tRawType, tIntegerBits> a,
00220     gemmlowp::FixedPoint<tRawType, tIntegerBits> b) {
00221   return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
00222       SaturatingAddNonGemmlowp(a.raw(), b.raw()));
00223 }
00224 
00225 template <typename IntegerType>
00226 IntegerType SaturatingSub(IntegerType a, IntegerType b) {
00227   static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
00228   return a;
00229 }
00230 
00231 template <>
00232 inline std::int16_t SaturatingSub(std::int16_t a, std::int16_t b) {
00233   std::int32_t a32 = a;
00234   std::int32_t b32 = b;
00235   std::int32_t diff = a32 - b32;
00236   return static_cast<std::int16_t>(
00237       std::min(static_cast<int32_t>(32767),
00238                std::max(static_cast<int32_t>(-32768), diff)));
00239 }
00240 
00241 template <>
00242 inline std::int32_t SaturatingSub(std::int32_t a, std::int32_t b) {
00243   std::int64_t a64 = a;
00244   std::int64_t b64 = b;
00245   std::int64_t diff = a64 - b64;
00246   return static_cast<std::int32_t>(std::min(
00247       static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::max()),
00248       std::max(
00249           static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::min()),
00250           diff)));
00251 }
00252 
00253 template <typename tRawType, int tIntegerBits>
00254 gemmlowp::FixedPoint<tRawType, tIntegerBits> SaturatingSub(
00255     gemmlowp::FixedPoint<tRawType, tIntegerBits> a,
00256     gemmlowp::FixedPoint<tRawType, tIntegerBits> b) {
00257   return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
00258       SaturatingSub(a.raw(), b.raw()));
00259 }
00260 // End section to be moved to gemmlowp.
00261 
00262 template <typename IntegerType>
00263 IntegerType SaturatingRoundingMultiplyByPOTParam(IntegerType x, int exponent) {
00264   if (exponent == 0) {
00265     return x;
00266   }
00267   using ScalarIntegerType =
00268       typename gemmlowp::FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
00269   const IntegerType min =
00270       gemmlowp::Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::min());
00271   const IntegerType max =
00272       gemmlowp::Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::max());
00273   const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType);
00274 
00275   const std::int32_t threshold =
00276       ((1 << (ScalarIntegerTypeBits - 1 - exponent)) - 1);
00277   const IntegerType positive_mask =
00278       gemmlowp::MaskIfGreaterThan(x, gemmlowp::Dup<IntegerType>(threshold));
00279   const IntegerType negative_mask =
00280       gemmlowp::MaskIfLessThan(x, gemmlowp::Dup<IntegerType>(-threshold));
00281 
00282   IntegerType result = gemmlowp::ShiftLeft(x, exponent);
00283   result = gemmlowp::SelectUsingMask(positive_mask, max, result);
00284   result = gemmlowp::SelectUsingMask(negative_mask, min, result);
00285   return result;
00286 }
00287 
00288 // If we want to leave IntegerBits fixed, then multiplication
00289 // by a power of two has to be saturating/rounding, not exact anymore.
00290 template <typename tRawType, int tIntegerBits>
00291 gemmlowp::FixedPoint<tRawType, tIntegerBits>
00292 SaturatingRoundingMultiplyByPOTParam(
00293     gemmlowp::FixedPoint<tRawType, tIntegerBits> a, int exponent) {
00294   return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
00295       SaturatingRoundingMultiplyByPOTParam(a.raw(), exponent));
00296 }
00297 
00298 // Minimum output bits to accommodate log of maximum input range.  It actually
00299 // does not matter if one considers, say, [-64,64] or [-64,64).
00300 //
00301 // For example, run this through Octave:
00302 // [0:127; ...
00303 //  ceil(log(abs( log(2.^(0:127))+1 ))/log(2)); ...
00304 //  ceil(log(abs( log(2.^(0:127))+1 ))/log(2))]
00305 constexpr int min_log_x_output_bits(int input_bits) {
00306   return input_bits > 90
00307              ? 7
00308              : input_bits > 44
00309                    ? 6
00310                    : input_bits > 21
00311                          ? 5
00312                          : input_bits > 10
00313                                ? 4
00314                                : input_bits > 4 ? 3 : input_bits > 1 ? 2 : 1;
00315 }
00316 
00317 // Although currently the name of this function says that it cannot handle
00318 // values less than 1, in practice it can handle as low as 1/x_max, where
00319 // x_max is the largest representable input.  In other words, the output range
00320 // is symmetric.
00321 template <int OutputIntegerBits, int InputIntegerBits>
00322 inline gemmlowp::FixedPoint<int32, OutputIntegerBits>
00323 log_x_for_x_greater_than_or_equal_to_1_impl(
00324     gemmlowp::FixedPoint<int32, InputIntegerBits> input_val) {
00325   // assert(__builtin_clz(0u) >= std::numeric_limits<uint32>::digits - 1);
00326   // assert(__builtin_clz(0u) <= std::numeric_limits<uint32>::digits);
00327   using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
00328   // The reason for accumulating the result with an extra bit of headroom is
00329   // that z_pow_2_adj * log_2 might be saturated, and adding num_scaled *
00330   // recip_denom will otherwise introduce an error.
00331   static constexpr int kAccumIntegerBits = OutputIntegerBits + 1;
00332   using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumIntegerBits>;
00333 
00334   const FixedPoint0 log_2 = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
00335       FixedPoint0, 1488522236, std::log(2.0));
00336   const FixedPoint0 sqrt_sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
00337       FixedPoint0, 1805811301, std::sqrt(std::sqrt(0.5)));
00338   const FixedPoint0 sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
00339       FixedPoint0, 1518500250, std::sqrt(0.5));
00340   const FixedPoint0 one_quarter =
00341       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPoint0, 536870912, 1.0 / 4.0);
00342 
00343   const FixedPoint0 alpha_n = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
00344       FixedPoint0, 117049297, 11.0 / 240.0 * std::sqrt(std::sqrt(2.0)));
00345   const FixedPoint0 alpha_d = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
00346       FixedPoint0, 127690142, 1.0 / 20.0 * std::sqrt(std::sqrt(2.0)));
00347   const FixedPoint0 alpha_i = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
00348       FixedPoint0, 1057819769,
00349       2.0 / std::sqrt(std::sqrt(2.0)) - std::sqrt(std::sqrt(2.0)));
00350   const FixedPoint0 alpha_f = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
00351       FixedPoint0, 638450708, 1.0 / 4.0 * std::sqrt(std::sqrt(2.0)));
00352 
00353   const FixedPointAccum shifted_quarter =
00354       gemmlowp::Rescale<kAccumIntegerBits>(one_quarter);
00355 
00356   // Reinterpret the input value as Q0.31, because we will figure out the
00357   // required shift "ourselves" instead of using, say, Rescale.
00358   FixedPoint0 z_a = FixedPoint0::FromRaw(input_val.raw());
00359   // z_a_pow_2 = input_integer_bits - z_a_headroom;
00360   int z_a_headroom_plus_1 = CountLeadingZeros(static_cast<uint32>(z_a.raw()));
00361   FixedPoint0 r_a_tmp =
00362       SaturatingRoundingMultiplyByPOTParam(z_a, (z_a_headroom_plus_1 - 1));
00363   const int32 r_a_raw =
00364       SaturatingRoundingMultiplyByPOTParam((r_a_tmp * sqrt_half).raw(), 1);
00365   // z_pow_2_adj = max(z_pow_2_a - 0.75, z_pow_2_b - 0.25);
00366   // z_pow_2_adj = max(InputIntegerBits - z_a_headroom_plus_1 + 0.25,
00367   //                   InputIntegerBits - z_b_headroom - 0.25);
00368   const FixedPointAccum z_a_pow_2_adj = SaturatingAddNonGemmlowp(
00369       FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam(
00370           InputIntegerBits - z_a_headroom_plus_1, 31 - kAccumIntegerBits)),
00371       shifted_quarter);
00372 
00373   // z_b is treated like z_a, but premultiplying by sqrt(0.5).
00374   FixedPoint0 z_b = z_a * sqrt_half;
00375   int z_b_headroom = CountLeadingZeros(static_cast<uint32>(z_b.raw())) - 1;
00376   const int32 r_b_raw =
00377       SaturatingRoundingMultiplyByPOTParam(z_a.raw(), z_b_headroom);
00378   const FixedPointAccum z_b_pow_2_adj = SaturatingSub(
00379       FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam(
00380           InputIntegerBits - z_b_headroom, 31 - kAccumIntegerBits)),
00381       shifted_quarter);
00382 
00383   const FixedPoint0 r = FixedPoint0::FromRaw(std::min(r_a_raw, r_b_raw));
00384   const FixedPointAccum z_pow_2_adj = FixedPointAccum::FromRaw(
00385       std::max(z_a_pow_2_adj.raw(), z_b_pow_2_adj.raw()));
00386 
00387   const FixedPoint0 p = gemmlowp::RoundingHalfSum(r, sqrt_sqrt_half);
00388   FixedPoint0 q = r - sqrt_sqrt_half;
00389   q = q + q;
00390 
00391   const FixedPoint0 common_sq = q * q;
00392   const FixedPoint0 num = q * r + q * common_sq * alpha_n;
00393   const FixedPoint0 denom_minus_one_0 =
00394       p * (alpha_i + q + alpha_d * common_sq) + alpha_f * q;
00395   const FixedPoint0 recip_denom =
00396       one_over_one_plus_x_for_x_in_0_1(denom_minus_one_0);
00397 
00398   const FixedPointAccum num_scaled = gemmlowp::Rescale<kAccumIntegerBits>(num);
00399   return gemmlowp::Rescale<OutputIntegerBits>(z_pow_2_adj * log_2 +
00400                                               num_scaled * recip_denom);
00401 }
00402 
00403 template <int OutputIntegerBits, int InputIntegerBits>
00404 inline gemmlowp::FixedPoint<int32, OutputIntegerBits>
00405 log_x_for_x_greater_than_or_equal_to_1(
00406     gemmlowp::FixedPoint<int32, InputIntegerBits> input_val) {
00407   static_assert(
00408       OutputIntegerBits >= min_log_x_output_bits(InputIntegerBits),
00409       "Output integer bits must be sufficent to accommodate logs of inputs.");
00410   return log_x_for_x_greater_than_or_equal_to_1_impl<OutputIntegerBits,
00411                                                      InputIntegerBits>(
00412       input_val);
00413 }
00414 
00415 inline int32 GetReciprocal(int32 x, int x_integer_digits,
00416                            int* num_bits_over_unit) {
00417   int headroom_plus_one = CountLeadingZeros(static_cast<uint32>(x));
00418   // This is the number of bits to the left of the binary point above 1.0.
00419   // Consider x=1.25.  In that case shifted_scale=0.8 and
00420   // no later adjustment will be needed.
00421   *num_bits_over_unit = x_integer_digits - headroom_plus_one;
00422   const int32 shifted_sum_minus_one =
00423       static_cast<int32>((static_cast<uint32>(x) << headroom_plus_one) -
00424                          (static_cast<uint32>(1) << 31));
00425 
00426   gemmlowp::FixedPoint<int32, 0> shifted_scale =
00427       gemmlowp::one_over_one_plus_x_for_x_in_0_1(
00428           gemmlowp::FixedPoint<int32, 0>::FromRaw(shifted_sum_minus_one));
00429   return shifted_scale.raw();
00430 }
00431 
00432 inline void GetInvSqrtQuantizedMultiplierExp(int32 input, int reverse_shift,
00433                                              int32* output_inv_sqrt,
00434                                              int* output_shift) {
00435   *output_shift = 11;
00436   while (input >= (1 << 29)) {
00437     input /= 4;
00438     ++*output_shift;
00439   }
00440   TFLITE_DCHECK_GT(input, 0);
00441   const unsigned max_left_shift_bits =
00442       CountLeadingZeros(static_cast<uint32>(input)) - 1;
00443   const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2;
00444   const unsigned left_shift_bit_pairs = max_left_shift_bit_pairs - 1;
00445   *output_shift -= left_shift_bit_pairs;
00446   input <<= 2 * left_shift_bit_pairs;
00447   TFLITE_DCHECK_GE(input, (1 << 27));
00448   TFLITE_DCHECK_LT(input, (1 << 29));
00449   using gemmlowp::FixedPoint;
00450   using gemmlowp::Rescale;
00451   using gemmlowp::SaturatingRoundingMultiplyByPOT;
00452   // Using 3 integer bits gives us enough room for the internal arithmetic in
00453   // this Newton-Raphson iteration.
00454   using F3 = FixedPoint<int32, 3>;
00455   using F0 = FixedPoint<int32, 0>;
00456   const F3 fixedpoint_input = F3::FromRaw(input >> 1);
00457   const F3 fixedpoint_half_input =
00458       SaturatingRoundingMultiplyByPOT<-1>(fixedpoint_input);
00459   const F3 fixedpoint_half_three =
00460       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F3, (1 << 28) + (1 << 27), 1.5);
00461   // Newton-Raphson iteration
00462   // Naive unoptimized starting guess: x = 1
00463   F3 x = F3::One();
00464   // Naive unoptimized number of iterations: 5
00465   for (int i = 0; i < 5; i++) {
00466     const F3 x3 = Rescale<3>(x * x * x);
00467     x = Rescale<3>(fixedpoint_half_three * x - fixedpoint_half_input * x3);
00468   }
00469   const F0 fixedpoint_half_sqrt_2 =
00470       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F0, 1518500250, std::sqrt(2.) / 2.);
00471   x = x * fixedpoint_half_sqrt_2;
00472   *output_inv_sqrt = x.raw();
00473   if (*output_shift < 0) {
00474     *output_inv_sqrt <<= -*output_shift;
00475     *output_shift = 0;
00476   }
00477   // Convert right shift (right is positive) to left shift.
00478   *output_shift *= reverse_shift;
00479 }
00480 
00481 // DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING
00482 // BROADCASTING.
00483 //
00484 // NdArrayDesc<N> describes the shape and memory layout of an N-dimensional
00485 // rectangular array of numbers.
00486 //
00487 // NdArrayDesc<N> is basically identical to Dims<N> defined in types.h.
00488 // However, as Dims<N> is to be deprecated, this class exists as an adaptor
00489 // to enable simple unoptimized implementations of element-wise broadcasting
00490 // operations.
00491 template <int N>
00492 struct NdArrayDesc {
00493   // The "extent" of each dimension. Indices along dimension d must be in the
00494   // half-open interval [0, extents[d]).
00495   int extents[N];
00496 
00497   // The number of *elements* (not bytes) between consecutive indices of each
00498   // dimension.
00499   int strides[N];
00500 };
00501 
00502 // DO NOT USE THIS FUNCTION FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING
00503 // BROADCASTING.
00504 //
00505 // Same as Offset(), except takes as NdArrayDesc<N> instead of Dims<N>.
00506 inline int SubscriptToIndex(const NdArrayDesc<4>& desc, int i0, int i1, int i2,
00507                             int i3) {
00508   TFLITE_DCHECK(i0 >= 0 && i0 < desc.extents[0]);
00509   TFLITE_DCHECK(i1 >= 0 && i1 < desc.extents[1]);
00510   TFLITE_DCHECK(i2 >= 0 && i2 < desc.extents[2]);
00511   TFLITE_DCHECK(i3 >= 0 && i3 < desc.extents[3]);
00512   return i0 * desc.strides[0] + i1 * desc.strides[1] + i2 * desc.strides[2] +
00513          i3 * desc.strides[3];
00514 }
00515 
00516 // Given the dimensions of the operands for an element-wise binary broadcast,
00517 // adjusts them so that they can be directly iterated over with simple loops.
00518 // Returns the adjusted dims as instances of NdArrayDesc in 'desc0_out' and
00519 // 'desc1_out'. 'desc0_out' and 'desc1_out' cannot be nullptr.
00520 //
00521 // This function assumes that the two input shapes are compatible up to
00522 // broadcasting and the shorter one has already been prepended with 1s to be the
00523 // same length. E.g., if shape0 is (1, 16, 16, 64) and shape1 is (1, 64),
00524 // shape1 must already have been prepended to be (1, 1, 1, 64). Recall that
00525 // Dims<N> refer to shapes in reverse order. In this case, input0_dims will be
00526 // (64, 16, 16, 1) and input1_dims will be (64, 1, 1, 1).
00527 //
00528 // When two shapes are compatible up to broadcasting, for each dimension d,
00529 // the input extents are either equal, or one of them is 1.
00530 //
00531 // This function performs the following for each dimension d:
00532 // - If the extents are equal, then do nothing since the loop that walks over
00533 //   both of the input arrays is correct.
00534 // - Otherwise, one (and only one) of the extents must be 1. Say extent0 is 1
00535 //   and extent1 is e1. Then set extent0 to e1 and stride0 *to 0*. This allows
00536 //   array0 to be referenced *at any index* in dimension d and still access the
00537 //   same slice.
00538 template <int N>
00539 inline void NdArrayDescsForElementwiseBroadcast(const Dims<N>& input0_dims,
00540                                                 const Dims<N>& input1_dims,
00541                                                 NdArrayDesc<N>* desc0_out,
00542                                                 NdArrayDesc<N>* desc1_out) {
00543   TFLITE_DCHECK(desc0_out != nullptr);
00544   TFLITE_DCHECK(desc1_out != nullptr);
00545 
00546   // Copy dims to desc.
00547   for (int i = 0; i < N; ++i) {
00548     desc0_out->extents[i] = input0_dims.sizes[i];
00549     desc0_out->strides[i] = input0_dims.strides[i];
00550     desc1_out->extents[i] = input1_dims.sizes[i];
00551     desc1_out->strides[i] = input1_dims.strides[i];
00552   }
00553 
00554   // Walk over each dimension. If the extents are equal do nothing.
00555   // Otherwise, set the desc with extent 1 to have extent equal to the other and
00556   // stride 0.
00557   for (int i = 0; i < N; ++i) {
00558     const int extent0 = ArraySize(input0_dims, i);
00559     const int extent1 = ArraySize(input1_dims, i);
00560     if (extent0 != extent1) {
00561       if (extent0 == 1) {
00562         desc0_out->strides[i] = 0;
00563         desc0_out->extents[i] = extent1;
00564       } else {
00565         TFLITE_DCHECK_EQ(extent1, 1);
00566         desc1_out->strides[i] = 0;
00567         desc1_out->extents[i] = extent0;
00568       }
00569     }
00570   }
00571 }
00572 
00573 template <int N>
00574 inline void NdArrayDescsForElementwiseBroadcast(
00575     const RuntimeShape& input0_shape, const RuntimeShape& input1_shape,
00576     NdArrayDesc<N>* desc0_out, NdArrayDesc<N>* desc1_out) {
00577   TFLITE_DCHECK(desc0_out != nullptr);
00578   TFLITE_DCHECK(desc1_out != nullptr);
00579 
00580   auto extended_input0_shape = RuntimeShape::ExtendedShape(N, input0_shape);
00581   auto extended_input1_shape = RuntimeShape::ExtendedShape(N, input1_shape);
00582 
00583   // Copy dims to desc, calculating strides.
00584   int desc0_stride = 1;
00585   int desc1_stride = 1;
00586   for (int i = N - 1; i >= 0; --i) {
00587     desc0_out->extents[i] = extended_input0_shape.Dims(i);
00588     desc0_out->strides[i] = desc0_stride;
00589     desc0_stride *= extended_input0_shape.Dims(i);
00590     desc1_out->extents[i] = extended_input1_shape.Dims(i);
00591     desc1_out->strides[i] = desc1_stride;
00592     desc1_stride *= extended_input1_shape.Dims(i);
00593   }
00594 
00595   // Walk over each dimension. If the extents are equal do nothing.
00596   // Otherwise, set the desc with extent 1 to have extent equal to the other and
00597   // stride 0.
00598   for (int i = 0; i < N; ++i) {
00599     const int extent0 = extended_input0_shape.Dims(i);
00600     const int extent1 = extended_input1_shape.Dims(i);
00601     if (extent0 != extent1) {
00602       if (extent0 == 1) {
00603         desc0_out->strides[i] = 0;
00604         desc0_out->extents[i] = extent1;
00605       } else {
00606         TFLITE_DCHECK_EQ(extent1, 1);
00607         desc1_out->strides[i] = 0;
00608         desc1_out->extents[i] = extent0;
00609       }
00610     }
00611   }
00612 }
00613 
00614 // Copied from gemmlowp::RoundDown when we dropped direct dependency on
00615 // gemmlowp.
00616 //
00617 // Returns the runtime argument rounded down to the nearest multiple of
00618 // the fixed Modulus.
00619 template <unsigned Modulus, typename Integer>
00620 Integer RoundDown(Integer i) {
00621   return i - (i % Modulus);
00622 }
00623 
00624 // Copied from gemmlowp::RoundUp when we dropped direct dependency on
00625 // gemmlowp.
00626 //
00627 // Returns the runtime argument rounded up to the nearest multiple of
00628 // the fixed Modulus.
00629 template <unsigned Modulus, typename Integer>
00630 Integer RoundUp(Integer i) {
00631   return RoundDown<Modulus>(i + Modulus - 1);
00632 }
00633 
00634 // Copied from gemmlowp::CeilQuotient when we dropped direct dependency on
00635 // gemmlowp.
00636 //
00637 // Returns the quotient a / b rounded up ('ceil') to the nearest integer.
00638 template <typename Integer>
00639 Integer CeilQuotient(Integer a, Integer b) {
00640   return (a + b - 1) / b;
00641 }
00642 
00643 // This function is a copy of gemmlowp::HowManyThreads, copied when we dropped
00644 // the direct dependency of internal/optimized/ on gemmlowp.
00645 //
00646 // It computes a reasonable number of threads to use for a GEMM of shape
00647 // (rows, cols, depth).
00648 //
00649 // TODO(b/131910176): get rid of this function by switching each call site
00650 // to its own more sensible logic for its own workload.
00651 template <int KernelRows>
00652 inline int LegacyHowManyThreads(int max_num_threads, int rows, int cols,
00653                                 int depth) {
00654   // Early-exit in the default case where multi-threading is disabled.
00655   if (max_num_threads == 1) {
00656     return 1;
00657   }
00658 
00659   // Ensure that each thread has KernelRows rows to process, if at all possible.
00660   int thread_count = std::min(max_num_threads, rows / KernelRows);
00661 
00662   // Limit the number of threads according to the overall size of the problem.
00663   if (thread_count > 1) {
00664     // Empirically determined value.
00665     static constexpr std::uint64_t min_cubic_size_per_thread = 64 * 1024;
00666 
00667     // We can only multiply two out of three sizes without risking overflow
00668     const std::uint64_t cubic_size =
00669         std::uint64_t(rows) * std::uint64_t(cols) * std::uint64_t(depth);
00670 
00671     thread_count = std::min(
00672         thread_count, static_cast<int>(cubic_size / min_cubic_size_per_thread));
00673   }
00674 
00675   if (thread_count < 1) {
00676     thread_count = 1;
00677   }
00678 
00679   assert(thread_count > 0 && thread_count <= max_num_threads);
00680   return thread_count;
00681 }
00682 
00683 template <typename T>
00684 void optimized_ops_preload_l1_stream(const T* ptr) {
00685 #ifdef __GNUC__
00686   // builtin offered by GCC-compatible compilers including clang
00687   __builtin_prefetch(ptr, /* 0 means read */ 0, /* 0 means no locality */ 0);
00688 #else
00689   (void)ptr;
00690 #endif
00691 }
00692 
00693 template <typename T>
00694 void optimized_ops_preload_l1_keep(const T* ptr) {
00695 #ifdef __GNUC__
00696   // builtin offered by GCC-compatible compilers including clang
00697   __builtin_prefetch(ptr, /* 0 means read */ 0, /* 3 means high locality */ 3);
00698 #else
00699   (void)ptr;
00700 #endif
00701 }
00702 
00703 template <typename T>
00704 void optimized_ops_prefetch_write_l1_keep(const T* ptr) {
00705 #ifdef __GNUC__
00706   // builtin offered by GCC-compatible compilers including clang
00707   __builtin_prefetch(ptr, /* 1 means write */ 1, /* 3 means high locality */ 3);
00708 #else
00709   (void)ptr;
00710 #endif
00711 }
00712 
00713 }  // namespace tflite
00714 
00715 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_COMMON_H_