Daniel Konegen / MNIST_example

Dependencies:   mbed-os

Embed: (wiki syntax)

« Back to documentation index

Show/hide line numbers fixedpoint.h Source File

fixedpoint.h

00001 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
00002 //
00003 // Licensed under the Apache License, Version 2.0 (the "License");
00004 // you may not use this file except in compliance with the License.
00005 // You may obtain a copy of the License at
00006 //
00007 //     http://www.apache.org/licenses/LICENSE-2.0
00008 //
00009 // Unless required by applicable law or agreed to in writing, software
00010 // distributed under the License is distributed on an "AS IS" BASIS,
00011 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00012 // See the License for the specific language governing permissions and
00013 // limitations under the License.
00014 
00015 // fixedpoint.h: fixed-point arithmetic, with basic operations and
00016 // a few math functions such as tanh.
00017 
00018 #ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_H_
00019 #define GEMMLOWP_INTERNAL_FIXEDPOINT_H_
00020 
00021 #include <algorithm>
00022 #include <cassert>
00023 #include <cmath>
00024 #include <cstdint>
00025 #include <limits>
00026 
00027 #include "../internal/detect_platform.h"
00028 
00029 namespace gemmlowp {
00030 
00031 // Part 1: Low-level integer-arithmetic primitives.
00032 // The implementations here are generic implementations valid for
00033 // scalar types (e.g. std::int32_t). Architecture-specific SIMD types
00034 // (e.g. NEON int32x4_t) may be supported by providing
00035 // specializations for them in separate files.
00036 //
00037 // The purpose of these primitives is two-fold:
00038 //  - They will be used to implement higher-level fixed-point
00039 //    abstractions, namely the FixedPoint class and its arithmetic
00040 //    operators.
00041 //  - They will be directly used to implement some more involved
00042 //    fixed-point computations, e.g. the fixed-point implementation
00043 //    of math functions such as tanh.
00044 
00045 // Some compile-time traits around raw types to handle SIMD aspects:
00046 // number of lanes, underlying scalar type.
00047 template <typename tIntegerType>
00048 struct FixedPointRawTypeTraits {};
00049 
00050 template <>
00051 struct FixedPointRawTypeTraits<std::int32_t> {
00052   typedef std::int32_t ScalarRawType;
00053   static constexpr int kLanes = 1;
00054 };
00055 
00056 template <>
00057 struct FixedPointRawTypeTraits<std::int16_t> {
00058   typedef std::int16_t ScalarRawType;
00059   static constexpr int kLanes = 1;
00060 };
00061 
00062 // Returns a SIMD value duplicating a scalar value across all lanes.
00063 template <typename tRawType>
00064 tRawType Dup(typename FixedPointRawTypeTraits<tRawType>::ScalarRawType x) {
00065   return x;
00066 }
00067 
00068 // Plain bit-wise AND
00069 template <typename tIntegerType>
00070 tIntegerType BitAnd(tIntegerType a, tIntegerType b) {
00071   return a & b;
00072 }
00073 
00074 // Plain bit-wise OR
00075 template <typename tIntegerType>
00076 tIntegerType BitOr(tIntegerType a, tIntegerType b) {
00077   return a | b;
00078 }
00079 
00080 // Plain bit-wise XOR
00081 template <typename tIntegerType>
00082 tIntegerType BitXor(tIntegerType a, tIntegerType b) {
00083   return a ^ b;
00084 }
00085 
00086 // Plain bit-wise NOT
00087 template <typename tIntegerType>
00088 tIntegerType BitNot(tIntegerType a) {
00089   return ~a;
00090 }
00091 
00092 // Integer addition. Not saturating. Overflow is undefined behavior.
00093 template <typename tIntegerType>
00094 tIntegerType Add(tIntegerType a, tIntegerType b) {
00095   return a + b;
00096 }
00097 
00098 // Integer subtraction. Not saturating. Overflow is undefined behavior.
00099 template <typename tIntegerType>
00100 tIntegerType Mul(tIntegerType a, tIntegerType b) {
00101   return a * b;
00102 }
00103 
00104 template <typename tIntegerType>
00105 tIntegerType Sub(tIntegerType a, tIntegerType b) {
00106   return a - b;
00107 }
00108 
00109 // Integer unary negative. Not saturating. Overflow is undefined behavior.
00110 template <typename tIntegerType>
00111 tIntegerType Neg(tIntegerType a) {
00112   return -a;
00113 }
00114 
00115 // Integer arithmetic left-shift, equivalent to multiplying with a power of two.
00116 // Negative values are OK. In case of overflow, no Undefined
00117 // Behavior, but the results are implementation-defined (in practice,
00118 // they currently are saturated, but we make no commitment to that). The idea
00119 // is that the caller will want to implement the overflowing cases with
00120 // saturation with compare-and-mask, so we don't care about the results
00121 // in the overflow case, we just want to avoid undefined behavior.
00122 //
00123 // tIntegerType may be int32 or any narrower signed type.
00124 template <typename tIntegerType>
00125 tIntegerType ShiftLeft(tIntegerType a, int offset) {
00126   const std::int64_t wide_a = static_cast<std::int64_t>(a);
00127   const std::int64_t wide_shifted = wide_a * (1 << offset);
00128   const auto min = std::numeric_limits<tIntegerType>::min();
00129   const auto max = std::numeric_limits<tIntegerType>::max();
00130   return wide_shifted < min
00131              ? min
00132              : wide_shifted > max ? max
00133                                   : static_cast<tIntegerType>(wide_shifted);
00134 }
00135 
00136 // Integer arithmetic right-shift. Not rounding.
00137 // Relying on implementation-defined, but in-practice-consistent,
00138 // C++ compiler behavior.
00139 template <typename tIntegerType>
00140 tIntegerType ShiftRight(tIntegerType a, int offset) {
00141   return a >> offset;
00142 }
00143 
00144 // Each bit of the result is set to the corresponding bit of either then_val or
00145 // else_val depending on whether the corresponding bit of if_mask is set.
00146 // Equivalent to the VBSL instruction in ARM NEON.
00147 template <typename tIntegerType>
00148 tIntegerType SelectUsingMask(tIntegerType if_mask, tIntegerType then_val,
00149                              tIntegerType else_val) {
00150   return BitXor(BitAnd(if_mask, then_val), BitAnd(BitNot(if_mask), else_val));
00151 }
00152 
00153 // For each input scalar, the corresponding bits of the result are set if the
00154 // input scalar is non-zero.
00155 template <typename tIntegerType>
00156 tIntegerType MaskIfNonZero(tIntegerType a) {
00157   static constexpr tIntegerType zero = 0;
00158   return a ? BitNot(zero) : zero;
00159 }
00160 
00161 // For each input scalar, the corresponding bits of the result are set if the
00162 // input scalar is zero.
00163 template <typename tIntegerType>
00164 tIntegerType MaskIfZero(tIntegerType a) {
00165   return MaskIfNonZero<tIntegerType>(!a);
00166 }
00167 
00168 // For each pair of input scalars, the corresponding bits of the result are
00169 // set if the input scalars are equal.
00170 template <typename tIntegerType>
00171 tIntegerType MaskIfEqual(tIntegerType a, tIntegerType b) {
00172   return MaskIfNonZero<tIntegerType>(a == b);
00173 }
00174 
00175 // For each pair of input scalars, the corresponding bits of the result are
00176 // set if the input scalars are not equal.
00177 template <typename tIntegerType>
00178 tIntegerType MaskIfNotEqual(tIntegerType a, tIntegerType b) {
00179   return MaskIfNonZero<tIntegerType>(a != b);
00180 }
00181 
00182 // For each pair of input scalars, the corresponding bits of the result are
00183 // set if the input scalars a, b satisfy a > b.
00184 template <typename tIntegerType>
00185 tIntegerType MaskIfGreaterThan(tIntegerType a, tIntegerType b) {
00186   return MaskIfNonZero<tIntegerType>(a > b);
00187 }
00188 
00189 // For each pair of input scalars, the corresponding bits of the result are
00190 // set if the input scalars a, b satisfy a >= b.
00191 template <typename tIntegerType>
00192 tIntegerType MaskIfGreaterThanOrEqual(tIntegerType a, tIntegerType b) {
00193   return MaskIfNonZero<tIntegerType>(a >= b);
00194 }
00195 
00196 // For each pair of input scalars, the corresponding bits of the result are
00197 // set if the input scalars a, b satisfy a < b.
00198 template <typename tIntegerType>
00199 tIntegerType MaskIfLessThan(tIntegerType a, tIntegerType b) {
00200   return MaskIfNonZero<tIntegerType>(a < b);
00201 }
00202 
00203 // For each pair of input scalars, the corresponding bits of the result are
00204 // set if the input scalars a, b satisfy a <= b.
00205 template <typename tIntegerType>
00206 tIntegerType MaskIfLessThanOrEqual(tIntegerType a, tIntegerType b) {
00207   return MaskIfNonZero<tIntegerType>(a <= b);
00208 }
00209 
00210 // Returns true if all of the input scalars are nonzero.
00211 // This function may currently assume that each of the input scalars has either
00212 // all or none of its bits set. Otherwise, its behavior is currently undefined.
00213 template <typename tIntegerType>
00214 bool All(tIntegerType a) {
00215   return a;
00216 }
00217 
00218 // Returns true if any of the input scalars are nonzero.
00219 // This function may currently assume that each of the input scalars has either
00220 // all or none of its bits set. Otherwise, its behavior is currently undefined.
00221 template <typename tIntegerType>
00222 bool Any(tIntegerType a) {
00223   return a;
00224 }
00225 
00226 // Returns (a+b)/2, rounded to the nearest integer.
00227 // Equivalent to VRHADD in the ARM NEON instruction set.
00228 template <typename IntegerType>
00229 IntegerType RoundingHalfSum(IntegerType a, IntegerType b) {
00230   static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
00231   (void)b;
00232   return a;
00233 }
00234 
00235 template <>
00236 inline std::int32_t RoundingHalfSum(std::int32_t a, std::int32_t b) {
00237   std::int64_t a64 = a;
00238   std::int64_t b64 = b;
00239   std::int64_t sum = a64 + b64;
00240   std::int64_t sign = sum >= 0 ? 1 : -1;
00241   return static_cast<std::int32_t>((sum + sign) / 2);
00242 }
00243 
00244 template <>
00245 inline std::int16_t RoundingHalfSum(std::int16_t a, std::int16_t b) {
00246   std::int32_t a32 = a;
00247   std::int32_t b32 = b;
00248   std::int32_t sum = a32 + b32;
00249   std::int32_t sign = sum >= 0 ? 1 : -1;
00250   return static_cast<std::int16_t>((sum + sign) / 2);
00251 }
00252 
00253 template <typename IntegerType>
00254 IntegerType SaturatingAdd(IntegerType a, IntegerType b) {
00255   static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
00256   (void)b;
00257   return a;
00258 }
00259 
00260 // So far this is only needed for int16.
00261 template <>
00262 inline std::int16_t SaturatingAdd(std::int16_t a, std::int16_t b) {
00263   std::int32_t a32 = a;
00264   std::int32_t b32 = b;
00265   std::int32_t sum = a32 + b32;
00266   return static_cast<std::int16_t>(
00267       std::min(static_cast<std::int32_t>(32767),
00268                std::max(static_cast<std::int32_t>(-32768), sum)));
00269 }
00270 
00271 // Returns a+b, saturating if the integers are 16bit or narrower,
00272 // otherwise just a plain addition.
00273 template <typename IntegerType, bool Is16Bit>
00274 struct AddSaturatingIf16BitImpl {
00275   static IntegerType Run(IntegerType a, IntegerType b) { return Add(a, b); }
00276 };
00277 template <typename IntegerType>
00278 struct AddSaturatingIf16BitImpl<IntegerType, true> {
00279   static IntegerType Run(IntegerType a, IntegerType b) {
00280     return SaturatingAdd(a, b);
00281   }
00282 };
00283 template <typename IntegerType>
00284 IntegerType AddSaturatingIf16Bit(IntegerType a, IntegerType b) {
00285   using ScalarType =
00286       typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
00287   return AddSaturatingIf16BitImpl<IntegerType, sizeof(ScalarType) == 2>::Run(a,
00288                                                                              b);
00289 }
00290 
00291 // Returns the integer that represents the product of two fixed-point
00292 // numbers, interpreting all integers as fixed-point values in the
00293 // interval [-1, 1), rounding to the nearest value, and saturating
00294 // -1 * -1 to the maximum value (since 1 is not in the half-open
00295 // interval [-1, 1)).
00296 //
00297 // [The explanation below specializes to std::int32_t for example purpose.]
00298 //
00299 // The mapping between IntegerType and the interval [-1, 1) is unique and
00300 // implied by IntegerType, which is assumed to be signed. For example,
00301 // for IntegerType==std::int32_t, the mapping is
00302 //   real_value = integer_value / 2^31.
00303 // So in this case, and leaving aside rounding and saturating, this
00304 // function computes ((a / 2^31) * (b / 2^31)) * 2^31, which simplifies to
00305 //   (a * b) / 2^31.
00306 //
00307 // The 'doubling' part in the name of this function comes from the fact that
00308 // this operation is very close to a "multiply-high" operation, keeping only
00309 // the top half bits, except that that would be effectively computing
00310 //   (a * b) / 2^32,
00311 // so here we are computing 2x that, since
00312 //   1/2^31 = 2 * 1/2^32.
00313 // The idea is to use all of the available 32 bits in the destination int32
00314 // value.
00315 //
00316 // [End of the explanation specializing to int32.]
00317 //
00318 // This is equivalent to the VQRDMULH instruction in ARM NEON.
00319 template <typename IntegerType>
00320 IntegerType SaturatingRoundingDoublingHighMul(IntegerType a, IntegerType b) {
00321   static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
00322   (void)b;
00323   return a;
00324 }
00325 
00326 // This function implements the same computation as the ARMv7 NEON VQRDMULH
00327 // instruction.
00328 template <>
00329 inline std::int32_t SaturatingRoundingDoublingHighMul(std::int32_t a,
00330                                                       std::int32_t b) {
00331   bool overflow = a == b && a == std::numeric_limits<std::int32_t>::min();
00332   std::int64_t a_64(a);
00333   std::int64_t b_64(b);
00334   std::int64_t ab_64 = a_64 * b_64;
00335   std::int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30));
00336   std::int32_t ab_x2_high32 =
00337       static_cast<std::int32_t>((ab_64 + nudge) / (1ll << 31));
00338   return overflow ? std::numeric_limits<std::int32_t>::max() : ab_x2_high32;
00339 }
00340 
00341 template <>
00342 inline std::int16_t SaturatingRoundingDoublingHighMul(std::int16_t a,
00343                                                       std::int16_t b) {
00344   bool overflow = a == b && a == std::numeric_limits<std::int16_t>::min();
00345   std::int32_t a_32(a);
00346   std::int32_t b_32(b);
00347   std::int32_t ab_32 = a_32 * b_32;
00348   std::int16_t nudge = ab_32 >= 0 ? (1 << 14) : (1 - (1 << 14));
00349   std::int16_t ab_x2_high16 =
00350       static_cast<std::int16_t>((ab_32 + nudge) / (1 << 15));
00351   return overflow ? std::numeric_limits<std::int16_t>::max() : ab_x2_high16;
00352 }
00353 
00354 // Correctly-rounded-to-nearest division by a power-of-two.
00355 // Also known as a rounding arithmetic right shift.
00356 template <typename IntegerType>
00357 inline IntegerType RoundingDivideByPOT(IntegerType x, int exponent) {
00358   assert(exponent >= 0);
00359   assert(exponent <= 31);
00360   const IntegerType mask = Dup<IntegerType>((1ll << exponent) - 1);
00361   const IntegerType zero = Dup<IntegerType>(0);
00362   const IntegerType one = Dup<IntegerType>(1);
00363   const IntegerType remainder = BitAnd(x, mask);
00364   const IntegerType threshold =
00365       Add(ShiftRight(mask, 1), BitAnd(MaskIfLessThan(x, zero), one));
00366   return Add(ShiftRight(x, exponent),
00367              BitAnd(MaskIfGreaterThan(remainder, threshold), one));
00368 }
00369 
00370 // Returns the product of a run-time integer value by a compile-time power
00371 // of two, with either a positive exponent (equivalent to an arithmetic
00372 // left shift, saturating) or a negative exponent (equivalent to an arithmetic
00373 // right shift, rounding to nearest).
00374 template <int Exponent, typename IntegerType,
00375           int ExponentSign = (Exponent > 0 ? 1 : Exponent < 0 ? -1 : 0)>
00376 struct ImplSaturatingRoundingMultiplyByPOT {};
00377 
00378 template <int Exponent, typename IntegerType>
00379 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 0> {
00380   static IntegerType eval(IntegerType x) { return x; }
00381 };
00382 
00383 template <int Exponent, typename IntegerType>
00384 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 1> {
00385   static IntegerType eval(IntegerType x) {
00386     using ScalarIntegerType =
00387         typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
00388     const IntegerType min =
00389         Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::min());
00390     const IntegerType max =
00391         Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::max());
00392     const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType);
00393 
00394     const std::int32_t threshold =
00395         ((1 << (ScalarIntegerTypeBits - 1 - Exponent)) - 1);
00396     const IntegerType positive_mask =
00397         MaskIfGreaterThan(x, Dup<IntegerType>(threshold));
00398     const IntegerType negative_mask =
00399         MaskIfLessThan(x, Dup<IntegerType>(-threshold));
00400 
00401     IntegerType result = ShiftLeft(x, Exponent);
00402     result = SelectUsingMask(positive_mask, max, result);
00403     result = SelectUsingMask(negative_mask, min, result);
00404     return result;
00405   }
00406 };
00407 
00408 template <int Exponent, typename IntegerType>
00409 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, -1> {
00410   static IntegerType eval(IntegerType x) {
00411     return RoundingDivideByPOT<IntegerType>(x, -Exponent);
00412   }
00413 };
00414 
00415 template <int Exponent, typename IntegerType>
00416 IntegerType SaturatingRoundingMultiplyByPOT(IntegerType x) {
00417   return ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType>::eval(x);
00418 }
00419 
00420 // Part 2: the FixedPoint class.
00421 
00422 // A FixedPoint object represents a fixed-point value stored in the underlying
00423 // integer type tRawType, if tRawType is a plain scalar integer type.
00424 // Alternatively, tRawType may be a SIMD type (e.g. NEON int32x4_t) in which
00425 // case a FixedPoint object represents a corresponding SIMD vector of fixed
00426 // point values.
00427 //
00428 // tIntegerBits describes the range of the fixed-point format: if
00429 // tIntegerBits == m then the range of representable values is the half-open
00430 // interval [-2^m; 2^m) where the open boundary on the right side means that
00431 // 2^m is not representable (how close the maximum representable value is to
00432 // it, depends on bit-depth of tRawType).
00433 //
00434 // In "Q format notation",
00435 //   https://en.wikipedia.org/wiki/Q_(number_format)
00436 // we are describing the format
00437 //   Qm.n
00438 // where
00439 //   m = tIntegerBits
00440 // and
00441 //   n = NumberOfBits(tRawType) - (m + 1)
00442 // Note that the (m + 1) in the above line is because we adopt the convention
00443 // that we count the integer bits exclusively of the sign bit; so (m + 1) is
00444 // the total number of integer bits inclusive of the sign bit.
00445 //
00446 // Accordingly, the number of integral representable values in our range
00447 //   [-2^m ; 2^m)
00448 // is equal to 2^(m+1).
00449 template <typename tRawType, int tIntegerBits>
00450 class FixedPoint {
00451  public:
00452   typedef tRawType RawType;
00453 
00454   typedef FixedPointRawTypeTraits<RawType> RawTypeTraits;
00455   typedef typename RawTypeTraits::ScalarRawType ScalarRawType;
00456 
00457   static constexpr int kTotalBits = 8 * sizeof(ScalarRawType);
00458   static constexpr int kIntegerBits = tIntegerBits;
00459   static constexpr int kFractionalBits = kTotalBits - 1 - kIntegerBits;
00460   static_assert(kIntegerBits >= 0 && kIntegerBits < kTotalBits,
00461                 "bad IntegerBits");
00462 
00463   typedef FixedPoint<ScalarRawType, kIntegerBits> ScalarFixedPointType;
00464 
00465   static const ScalarRawType ScalarRawMin() {
00466     return std::numeric_limits<ScalarRawType>::min();
00467   }
00468 
00469   static const ScalarRawType ScalarRawMax() {
00470     return std::numeric_limits<ScalarRawType>::max();
00471   }
00472 
00473   static const ScalarRawType RawMin() {
00474     return VectorFromScalar(ScalarRawMin());
00475   }
00476 
00477   static const ScalarRawType RawMax() {
00478     return VectorFromScalar(ScalarRawMax());
00479   }
00480 
00481   static FixedPoint FromRaw(RawType x) {
00482     FixedPoint retval;
00483     retval.raw() = x;
00484     return retval;
00485   }
00486 
00487   static FixedPoint FromScalarRaw(ScalarRawType x) {
00488     FixedPoint retval;
00489     retval.raw() = Dup<RawType>(x);
00490     return retval;
00491   }
00492 
00493   static FixedPoint FromScalarFixedPoint(ScalarFixedPointType x) {
00494     return FromScalarRaw(x.raw());
00495   }
00496 
00497   template <int Exponent>
00498   static FixedPoint ConstantPOT() {
00499     static constexpr int kOffset = kFractionalBits + Exponent;
00500     static_assert(
00501         kOffset < 31,
00502         "Constant not exactly representable in this fixed-point format");
00503     return FromScalarRaw(ScalarRawType(1) << kOffset);
00504   }
00505 
00506   static FixedPoint Zero() { return FromScalarRaw(0); }
00507 
00508   static FixedPoint One() {
00509     return FromScalarRaw(
00510         kIntegerBits == 0
00511             ? ScalarRawMax()
00512             : (ScalarRawType(1) << (kIntegerBits == 0 ? 0 : kFractionalBits)));
00513   }
00514 
00515   static FixedPoint FromDouble(double x) {
00516     const double min_bound = static_cast<double>(ScalarRawMin());
00517     const double max_bound = static_cast<double>(ScalarRawMax());
00518     return FromScalarRaw(static_cast<ScalarRawType>(std::min(
00519         std::max(round(x * static_cast<double>(1ll << kFractionalBits)),
00520                  min_bound),
00521         max_bound)));
00522   }
00523 
00524   RawType raw() const { return i_; }
00525   RawType& raw() { return i_; }
00526 
00527  private:
00528   RawType i_;
00529 };
00530 
00531 // Part 3: implementation of arithmetic operators for the
00532 // FixedPoint class, and a few related functions.
00533 
00534 // A FixedPoint multiplication is just a
00535 // SaturatingRoundingDoublingHighMul operation on the underlying
00536 // raw integer values. The IntegerBits simply add up, as is obvious
00537 // from the fact that the range is [-2^IntegerBits, 2^IntegerBits).
00538 template <typename tRawType, int tIntegerBits_a, int tIntegerBits_b>
00539 FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> operator*(
00540     FixedPoint<tRawType, tIntegerBits_a> a,
00541     FixedPoint<tRawType, tIntegerBits_b> b) {
00542   FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> c;
00543   c.raw() = SaturatingRoundingDoublingHighMul(a.raw(), b.raw());
00544   return c;
00545 }
00546 
00547 // Tweaking IntegerBits gives exact multiplication by a power of two.
00548 template <int tExponent, typename tRawType, int tIntegerBits>
00549 FixedPoint<tRawType, tExponent + tIntegerBits> ExactMulByPot(
00550     FixedPoint<tRawType, tIntegerBits> a) {
00551   FixedPoint<tRawType, tExponent + tIntegerBits> c;
00552   c.raw() = a.raw();
00553   return c;
00554 }
00555 
00556 // If we want to leave IntegerBits fixed, then multiplication
00557 // by a power of two has to be saturating/rounding, not exact anymore.
00558 template <int tExponent, typename tRawType, int tIntegerBits>
00559 FixedPoint<tRawType, tIntegerBits> SaturatingRoundingMultiplyByPOT(
00560     FixedPoint<tRawType, tIntegerBits> a) {
00561   return FixedPoint<tRawType, tIntegerBits>::FromRaw(
00562       SaturatingRoundingMultiplyByPOT<tExponent>(a.raw()));
00563 }
00564 
00565 // Generic arithmetic operators.
00566 
00567 #define MAKE_FIXEDPOINT_UNARY_FUNC(FuncName, ImplFuncName)                     \
00568   template <typename tRawType, int tIntegerBits>                               \
00569   FixedPoint<tRawType, tIntegerBits> FuncName(                                 \
00570       FixedPoint<tRawType, tIntegerBits> a) {                                  \
00571     return FixedPoint<tRawType, tIntegerBits>::FromRaw(ImplFuncName(a.raw())); \
00572   }
00573 
00574 #define MAKE_FIXEDPOINT_BINARY_FUNC(FuncName, ImplFuncName) \
00575   template <typename tRawType, int tIntegerBits>            \
00576   FixedPoint<tRawType, tIntegerBits> FuncName(              \
00577       FixedPoint<tRawType, tIntegerBits> a,                 \
00578       FixedPoint<tRawType, tIntegerBits> b) {               \
00579     return FixedPoint<tRawType, tIntegerBits>::FromRaw(     \
00580         ImplFuncName(a.raw(), b.raw()));                    \
00581   }
00582 
00583 MAKE_FIXEDPOINT_UNARY_FUNC(operator-, Neg)
00584 MAKE_FIXEDPOINT_UNARY_FUNC(operator~, BitNot)
00585 MAKE_FIXEDPOINT_BINARY_FUNC(operator+, Add)
00586 MAKE_FIXEDPOINT_BINARY_FUNC(operator-, Sub)
00587 MAKE_FIXEDPOINT_BINARY_FUNC(operator&, BitAnd)
00588 MAKE_FIXEDPOINT_BINARY_FUNC(operator^, BitXor)
00589 MAKE_FIXEDPOINT_BINARY_FUNC(operator|, BitOr)
00590 MAKE_FIXEDPOINT_BINARY_FUNC(RoundingHalfSum, RoundingHalfSum)
00591 
00592 #undef MAKE_FIXEDPOINT_UNARY_FUNC
00593 #undef MAKE_FIXEDPOINT_BINARY_FUNC
00594 
00595 #define MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(FuncName)  \
00596   template <typename tRawType, int tIntegerBits>            \
00597   tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a) { \
00598     return FuncName(a.raw());                               \
00599   }
00600 
00601 #define MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(FuncName) \
00602   template <typename tRawType, int tIntegerBits>            \
00603   tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a,   \
00604                     FixedPoint<tRawType, tIntegerBits> b) { \
00605     return FuncName(a.raw(), b.raw());                      \
00606   }
00607 
00608 MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfZero)
00609 MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfNonZero)
00610 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfEqual)
00611 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfNotEqual)
00612 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThan)
00613 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThanOrEqual)
00614 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThan)
00615 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThanOrEqual)
00616 
00617 #undef MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW
00618 #undef MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW
00619 
00620 template <typename tRawType, int tIntegerBits>
00621 FixedPoint<tRawType, tIntegerBits> SelectUsingMask(
00622     tRawType if_mask, FixedPoint<tRawType, tIntegerBits> then_val,
00623     FixedPoint<tRawType, tIntegerBits> else_val) {
00624   return FixedPoint<tRawType, tIntegerBits>::FromRaw(
00625       SelectUsingMask(if_mask, then_val.raw(), else_val.raw()));
00626 }
00627 
00628 template <typename tRawType, int tIntegerBits>
00629 bool operator==(FixedPoint<tRawType, tIntegerBits> a,
00630                 FixedPoint<tRawType, tIntegerBits> b) {
00631   return All(MaskIfEqual(a.raw(), b.raw()));
00632 }
00633 
00634 template <typename tRawType, int tIntegerBits>
00635 bool operator!=(FixedPoint<tRawType, tIntegerBits> a,
00636                 FixedPoint<tRawType, tIntegerBits> b) {
00637   return !(a == b);
00638 }
00639 
00640 template <typename tRawType, int tIntegerBits>
00641 FixedPoint<tRawType, tIntegerBits> SaturatingAdd(
00642     FixedPoint<tRawType, tIntegerBits> a,
00643     FixedPoint<tRawType, tIntegerBits> b) {
00644   return FixedPoint<tRawType, tIntegerBits>::FromRaw(
00645       SaturatingAdd(a.raw(), b.raw()));
00646 }
00647 
00648 template <typename tRawType, int tIntegerBits>
00649 FixedPoint<tRawType, tIntegerBits> AddSaturatingIf16Bit(
00650     FixedPoint<tRawType, tIntegerBits> a,
00651     FixedPoint<tRawType, tIntegerBits> b) {
00652   return FixedPoint<tRawType, tIntegerBits>::FromRaw(
00653       AddSaturatingIf16Bit(a.raw(), b.raw()));
00654 }
00655 
00656 // Conversion to floating-point.
00657 template <typename tRawType, int tIntegerBits>
00658 double ToDouble(FixedPoint<tRawType, tIntegerBits> x) {
00659   static_assert(FixedPointRawTypeTraits<tRawType>::kLanes == 1,
00660                 "not applicable to SIMD types");
00661   typedef FixedPoint<tRawType, tIntegerBits> F;
00662   return x.raw() / static_cast<double>(1ll << F::kFractionalBits);
00663 }
00664 
00665 // Rescale changes the number of IntegerBits and updates the underlying
00666 // raw integer value accordingly.
00667 template <int tIntegerBitsDst, typename tRawType, int tIntegerBitsSrc>
00668 FixedPoint<tRawType, tIntegerBitsDst> Rescale(
00669     FixedPoint<tRawType, tIntegerBitsSrc> x) {
00670   static constexpr int kExponent = tIntegerBitsSrc - tIntegerBitsDst;
00671   FixedPoint<tRawType, tIntegerBitsDst> result;
00672   result.raw() = SaturatingRoundingMultiplyByPOT<kExponent>(x.raw());
00673   return result;
00674 }
00675 
00676 // CheckedFixedPointConstant allows to specify fixed-point constants
00677 // initialized as real numbers, in a way that does not compile floating-point
00678 // arithmetic in production code, yet still checks agreement with the
00679 // floating-point expressions when asserts are enabled.
00680 //
00681 // The raw integer value provided is always a int32, encoding a 32-bit
00682 // fixed-point value, regardless of the actual Scalar type. This allows
00683 // writing generic code that applies just as well to the 32-bit and 16-bit
00684 // cases. In the 16-bit case, the raw integer value is internally
00685 // rounding-shifted by 16 bits to the right.
00686 template <typename FixedPointType>
00687 inline typename FixedPointType::ScalarRawType RescaleConstantInitializer(
00688     std::int32_t int32_value) {
00689   typedef typename FixedPointType::ScalarRawType ScalarRawType;
00690   static constexpr int ScalarTypeBits = 8 * sizeof(ScalarRawType);
00691   return static_cast<ScalarRawType>(
00692       RoundingDivideByPOT<std::int32_t>(int32_value, 32 - ScalarTypeBits));
00693 }
00694 #ifdef GEMMLOWP_ENABLE_FIXEDPOINT_CONSTANTS_CHECKS
00695 template <typename FixedPointType>
00696 FixedPointType CheckedFixedPointConstant(std::int32_t raw_value,
00697                                          double double_value) {
00698   const FixedPointType result = FixedPointType::FromScalarRaw(raw_value);
00699   assert(result == FixedPointType::FromDouble(double_value));
00700   return result;
00701 }
00702 #define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType,                   \
00703                                              ScalarRawInt32Value, DoubleValue) \
00704   (gemmlowp::CheckedFixedPointConstant<FixedPointType>(                        \
00705       gemmlowp::RescaleConstantInitializer<FixedPointType>(                    \
00706           ScalarRawInt32Value),                                                \
00707       DoubleValue))
00708 
00709 #else
00710 #define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType,                   \
00711                                              ScalarRawInt32Value, DoubleValue) \
00712   (FixedPointType::FromScalarRaw(                                              \
00713       gemmlowp::RescaleConstantInitializer<FixedPointType>(                    \
00714           ScalarRawInt32Value)))
00715 #endif
00716 
00717 // Implementation of exponential function.
00718 
00719 // Returns exp(x) for x in [-1/4, 0).
00720 template <typename tRawType>
00721 FixedPoint<tRawType, 0> exp_on_interval_between_negative_one_quarter_and_0_excl(
00722     FixedPoint<tRawType, 0> a) {
00723   typedef FixedPoint<tRawType, 0> F;
00724   const F constant_term =
00725       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 1895147668, std::exp(-1.0 / 8.0));
00726   const F constant_1_over_3 =
00727       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 715827883, 1.0 / 3.0);
00728   // We're evaluating a Taylor expansion around -1/8, so we do the change of
00729   // variable: x = a + 1/8.
00730   // In fixed-point with 0 integer bits, 1/8 is represented by 1 << 28.
00731   F x = a + F::template ConstantPOT<-3>();
00732   F x2 = x * x;
00733   F x3 = x2 * x;
00734   F x4 = x2 * x2;
00735   F x4_over_4 = SaturatingRoundingMultiplyByPOT<-2>(x4);
00736   F x4_over_24_plus_x3_over_6_plus_x2_over_2 =
00737       SaturatingRoundingMultiplyByPOT<-1>(
00738           ((x4_over_4 + x3) * constant_1_over_3) + x2);
00739   return AddSaturatingIf16Bit(
00740       constant_term,
00741       constant_term * (x + x4_over_24_plus_x3_over_6_plus_x2_over_2));
00742 }
00743 
00744 // Returns exp(x) for x < 0.
00745 template <typename tRawType, int tIntegerBits>
00746 FixedPoint<tRawType, 0> exp_on_negative_values(
00747     FixedPoint<tRawType, tIntegerBits> a) {
00748   typedef FixedPoint<tRawType, tIntegerBits> InputF;
00749   typedef FixedPoint<tRawType, 0> ResultF;
00750   static constexpr int kFractionalBits = InputF::kFractionalBits;
00751   static constexpr int kIntegerBits = InputF::kIntegerBits;
00752   const InputF kOneQuarter = InputF::template ConstantPOT<-2>();
00753   InputF mask = kOneQuarter - InputF::FromScalarRaw(1);
00754   InputF a_mod_quarter_minus_one_quarter = (a & mask) - kOneQuarter;
00755   ResultF result = exp_on_interval_between_negative_one_quarter_and_0_excl(
00756       Rescale<0>(a_mod_quarter_minus_one_quarter));
00757   tRawType remainder = (a_mod_quarter_minus_one_quarter - a).raw();
00758 
00759 #define GEMMLOWP_EXP_BARREL_SHIFTER(Exponent, FixedPointMultiplier)         \
00760   if (kIntegerBits > Exponent) {                                            \
00761     const ResultF kMultiplier = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(       \
00762         ResultF, FixedPointMultiplier, std::exp(-std::pow(2.0, Exponent))); \
00763     static constexpr int kShiftAmount =                                     \
00764         kIntegerBits > Exponent ? kFractionalBits + Exponent : 0;           \
00765     result = SelectUsingMask(                                               \
00766         MaskIfNonZero(BitAnd(remainder, Dup<tRawType>(1 << kShiftAmount))), \
00767         result * kMultiplier, result);                                      \
00768   }
00769 
00770   GEMMLOWP_EXP_BARREL_SHIFTER(-2, 1672461947);
00771   GEMMLOWP_EXP_BARREL_SHIFTER(-1, 1302514674);
00772   GEMMLOWP_EXP_BARREL_SHIFTER(+0, 790015084);
00773   GEMMLOWP_EXP_BARREL_SHIFTER(+1, 290630308);
00774   GEMMLOWP_EXP_BARREL_SHIFTER(+2, 39332535);
00775   GEMMLOWP_EXP_BARREL_SHIFTER(+3, 720401);
00776   GEMMLOWP_EXP_BARREL_SHIFTER(+4, 242);
00777 
00778 #undef GEMMLOWP_EXP_BARREL_SHIFTER
00779 
00780   static constexpr int clampB = kIntegerBits > 5 ? 36 - kIntegerBits : 0;
00781   if (kIntegerBits > 5) {
00782     const InputF clamp =
00783         GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(InputF, -(1 << clampB), -32.0);
00784     result = SelectUsingMask(MaskIfLessThan(a, clamp), ResultF::Zero(), result);
00785   }
00786 
00787   result = SelectUsingMask(MaskIfZero(a), ResultF::One(), result);
00788   return result;
00789 }
00790 
00791 // Implementation of tanh: (1 - exp(-2x)) / (1 + exp(-2x)).
00792 
00793 // Returns (1 - x) / (1 + x) for x in (0, 1).
00794 template <typename tRawType>
00795 FixedPoint<tRawType, 0> one_minus_x_over_one_plus_x_for_x_in_0_1(
00796     FixedPoint<tRawType, 0> a) {
00797   typedef FixedPoint<tRawType, 0> F0;
00798   typedef FixedPoint<tRawType, 2> F2;
00799   F0 half_denominator = RoundingHalfSum(a, F0::One());
00800   // Newton-Raphson division
00801   // https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division
00802   // Refer to that page for the logic behind the 48/17 and 32/17 constants.
00803   const F2 constant_48_over_17 =
00804       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, 1515870810, 48.0 / 17.0);
00805   const F2 constant_neg_32_over_17 =
00806       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, -1010580540, -32.0 / 17.0);
00807   F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17;
00808   for (int i = 0; i < 3; i++) {
00809     F2 half_denominator_times_x = half_denominator * x;
00810     F2 one_minus_half_denominator_times_x =
00811         F2::One() - half_denominator_times_x;
00812     x = x + Rescale<2>(x * one_minus_half_denominator_times_x);
00813   }
00814   return Rescale<0>(x - F2::One());
00815 }
00816 
00817 // Returns -tanh(x) for x < 0.
00818 template <typename tRawType, int tIntegerBits>
00819 FixedPoint<tRawType, 0> neg_tanh_on_negative_values(
00820     FixedPoint<tRawType, tIntegerBits> a) {
00821   return one_minus_x_over_one_plus_x_for_x_in_0_1(
00822       exp_on_negative_values(ExactMulByPot<1>(a)));
00823 }
00824 
00825 // Returns tanh(x) for any x.
00826 template <typename tRawType, int tIntegerBits>
00827 FixedPoint<tRawType, 0> tanh(FixedPoint<tRawType, tIntegerBits> a) {
00828   typedef FixedPoint<tRawType, tIntegerBits> InputF;
00829   typedef FixedPoint<tRawType, 0> ResultF;
00830   tRawType mask_if_negative = MaskIfLessThan(a, InputF::Zero());
00831   tRawType mask_if_zero = MaskIfZero(a);
00832   InputF n = SelectUsingMask(mask_if_negative, a, -a);
00833   ResultF t = neg_tanh_on_negative_values(n);
00834   return SelectUsingMask(mask_if_zero, ResultF::Zero(),
00835                          SelectUsingMask(mask_if_negative, -t, t));
00836 }
00837 
00838 // Implementation of logistic function.
00839 
00840 // Returns 1 / (1 + x) for x in (0, 1).
00841 template <typename tRawType>
00842 FixedPoint<tRawType, 0> one_over_one_plus_x_for_x_in_0_1(
00843     FixedPoint<tRawType, 0> a) {
00844   typedef FixedPoint<tRawType, 0> F0;
00845   typedef FixedPoint<tRawType, 2> F2;
00846   F0 half_denominator = RoundingHalfSum(a, F0::One());
00847   // Newton-Raphson division
00848   // https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division
00849   // Refer to that page for the logic behind the 48/17 and 32/17 constants.
00850   const F2 constant_48_over_17 =
00851       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, 1515870810, 48.0 / 17.0);
00852   const F2 constant_neg_32_over_17 =
00853       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, -1010580540, -32.0 / 17.0);
00854   F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17;
00855   for (int i = 0; i < 3; i++) {
00856     F2 half_denominator_times_x = half_denominator * x;
00857     F2 one_minus_half_denominator_times_x =
00858         F2::One() - half_denominator_times_x;
00859     x = x + Rescale<2>(x * one_minus_half_denominator_times_x);
00860   }
00861   return Rescale<0>(ExactMulByPot<-1>(x));
00862 }
00863 
00864 // Returns logistic(x) = 1 / (1 + exp(-x)) for x > 0.
00865 template <typename tRawType, int tIntegerBits>
00866 FixedPoint<tRawType, 0> logistic_on_positive_values(
00867     FixedPoint<tRawType, tIntegerBits> a) {
00868   return one_over_one_plus_x_for_x_in_0_1(exp_on_negative_values(-a));
00869 }
00870 
00871 // Returns logistic(x) = 1 / (1 + exp(-x)) for any x.
00872 template <typename tRawType, int tIntegerBits>
00873 FixedPoint<tRawType, 0> logistic(FixedPoint<tRawType, tIntegerBits> a) {
00874   typedef FixedPoint<tRawType, tIntegerBits> InputF;
00875   typedef FixedPoint<tRawType, 0> ResultF;
00876   tRawType mask_if_positive = MaskIfGreaterThan(a, InputF::Zero());
00877   tRawType mask_if_zero = MaskIfZero(a);
00878   InputF abs_input = SelectUsingMask(mask_if_positive, a, -a);
00879   ResultF result_if_positive = logistic_on_positive_values(abs_input);
00880   ResultF result_if_negative = ResultF::One() - result_if_positive;
00881   const ResultF one_half =
00882       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(ResultF, 1 << 30, 0.5);
00883   return SelectUsingMask(mask_if_zero, one_half,
00884                          SelectUsingMask(mask_if_positive, result_if_positive,
00885                                          result_if_negative));
00886 }
00887 
00888 }  // end namespace gemmlowp
00889 
00890 #ifdef GEMMLOWP_NEON
00891 #include "./fixedpoint_neon.h"
00892 #elif defined(GEMMLOWP_AVX2)
00893 #include "./fixedpoint_avx.h"
00894 #elif defined(GEMMLOWP_SSE4)
00895 #include "./fixedpoint_sse.h"
00896 #elif defined(GEMMLOWP_MSA)
00897 #include "./fixedpoint_msa.h"
00898 #endif
00899 
00900 #endif  // GEMMLOWP_INTERNAL_FIXEDPOINT_H_