Important changes to repositories hosted on mbed.com
Mbed hosted mercurial repositories are deprecated and are due to be permanently deleted in July 2026.
To keep a copy of this software download the repository Zip archive or clone locally using Mercurial.
It is also possible to export all your personal repositories from the account settings page.
fixedpoint.h
00001 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved. 00002 // 00003 // Licensed under the Apache License, Version 2.0 (the "License"); 00004 // you may not use this file except in compliance with the License. 00005 // You may obtain a copy of the License at 00006 // 00007 // http://www.apache.org/licenses/LICENSE-2.0 00008 // 00009 // Unless required by applicable law or agreed to in writing, software 00010 // distributed under the License is distributed on an "AS IS" BASIS, 00011 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 00012 // See the License for the specific language governing permissions and 00013 // limitations under the License. 00014 00015 // fixedpoint.h: fixed-point arithmetic, with basic operations and 00016 // a few math functions such as tanh. 00017 00018 #ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_H_ 00019 #define GEMMLOWP_INTERNAL_FIXEDPOINT_H_ 00020 00021 #include <algorithm> 00022 #include <cassert> 00023 #include <cmath> 00024 #include <cstdint> 00025 #include <limits> 00026 00027 #include "../internal/detect_platform.h" 00028 00029 namespace gemmlowp { 00030 00031 // Part 1: Low-level integer-arithmetic primitives. 00032 // The implementations here are generic implementations valid for 00033 // scalar types (e.g. std::int32_t). Architecture-specific SIMD types 00034 // (e.g. NEON int32x4_t) may be supported by providing 00035 // specializations for them in separate files. 00036 // 00037 // The purpose of these primitives is two-fold: 00038 // - They will be used to implement higher-level fixed-point 00039 // abstractions, namely the FixedPoint class and its arithmetic 00040 // operators. 00041 // - They will be directly used to implement some more involved 00042 // fixed-point computations, e.g. the fixed-point implementation 00043 // of math functions such as tanh. 00044 00045 // Some compile-time traits around raw types to handle SIMD aspects: 00046 // number of lanes, underlying scalar type. 00047 template <typename tIntegerType> 00048 struct FixedPointRawTypeTraits {}; 00049 00050 template <> 00051 struct FixedPointRawTypeTraits<std::int32_t> { 00052 typedef std::int32_t ScalarRawType; 00053 static constexpr int kLanes = 1; 00054 }; 00055 00056 template <> 00057 struct FixedPointRawTypeTraits<std::int16_t> { 00058 typedef std::int16_t ScalarRawType; 00059 static constexpr int kLanes = 1; 00060 }; 00061 00062 // Returns a SIMD value duplicating a scalar value across all lanes. 00063 template <typename tRawType> 00064 tRawType Dup(typename FixedPointRawTypeTraits<tRawType>::ScalarRawType x) { 00065 return x; 00066 } 00067 00068 // Plain bit-wise AND 00069 template <typename tIntegerType> 00070 tIntegerType BitAnd(tIntegerType a, tIntegerType b) { 00071 return a & b; 00072 } 00073 00074 // Plain bit-wise OR 00075 template <typename tIntegerType> 00076 tIntegerType BitOr(tIntegerType a, tIntegerType b) { 00077 return a | b; 00078 } 00079 00080 // Plain bit-wise XOR 00081 template <typename tIntegerType> 00082 tIntegerType BitXor(tIntegerType a, tIntegerType b) { 00083 return a ^ b; 00084 } 00085 00086 // Plain bit-wise NOT 00087 template <typename tIntegerType> 00088 tIntegerType BitNot(tIntegerType a) { 00089 return ~a; 00090 } 00091 00092 // Integer addition. Not saturating. Overflow is undefined behavior. 00093 template <typename tIntegerType> 00094 tIntegerType Add(tIntegerType a, tIntegerType b) { 00095 return a + b; 00096 } 00097 00098 // Integer subtraction. Not saturating. Overflow is undefined behavior. 00099 template <typename tIntegerType> 00100 tIntegerType Mul(tIntegerType a, tIntegerType b) { 00101 return a * b; 00102 } 00103 00104 template <typename tIntegerType> 00105 tIntegerType Sub(tIntegerType a, tIntegerType b) { 00106 return a - b; 00107 } 00108 00109 // Integer unary negative. Not saturating. Overflow is undefined behavior. 00110 template <typename tIntegerType> 00111 tIntegerType Neg(tIntegerType a) { 00112 return -a; 00113 } 00114 00115 // Integer arithmetic left-shift, equivalent to multiplying with a power of two. 00116 // Negative values are OK. In case of overflow, no Undefined 00117 // Behavior, but the results are implementation-defined (in practice, 00118 // they currently are saturated, but we make no commitment to that). The idea 00119 // is that the caller will want to implement the overflowing cases with 00120 // saturation with compare-and-mask, so we don't care about the results 00121 // in the overflow case, we just want to avoid undefined behavior. 00122 // 00123 // tIntegerType may be int32 or any narrower signed type. 00124 template <typename tIntegerType> 00125 tIntegerType ShiftLeft(tIntegerType a, int offset) { 00126 const std::int64_t wide_a = static_cast<std::int64_t>(a); 00127 const std::int64_t wide_shifted = wide_a * (1 << offset); 00128 const auto min = std::numeric_limits<tIntegerType>::min(); 00129 const auto max = std::numeric_limits<tIntegerType>::max(); 00130 return wide_shifted < min 00131 ? min 00132 : wide_shifted > max ? max 00133 : static_cast<tIntegerType>(wide_shifted); 00134 } 00135 00136 // Integer arithmetic right-shift. Not rounding. 00137 // Relying on implementation-defined, but in-practice-consistent, 00138 // C++ compiler behavior. 00139 template <typename tIntegerType> 00140 tIntegerType ShiftRight(tIntegerType a, int offset) { 00141 return a >> offset; 00142 } 00143 00144 // Each bit of the result is set to the corresponding bit of either then_val or 00145 // else_val depending on whether the corresponding bit of if_mask is set. 00146 // Equivalent to the VBSL instruction in ARM NEON. 00147 template <typename tIntegerType> 00148 tIntegerType SelectUsingMask(tIntegerType if_mask, tIntegerType then_val, 00149 tIntegerType else_val) { 00150 return BitXor(BitAnd(if_mask, then_val), BitAnd(BitNot(if_mask), else_val)); 00151 } 00152 00153 // For each input scalar, the corresponding bits of the result are set if the 00154 // input scalar is non-zero. 00155 template <typename tIntegerType> 00156 tIntegerType MaskIfNonZero(tIntegerType a) { 00157 static constexpr tIntegerType zero = 0; 00158 return a ? BitNot(zero) : zero; 00159 } 00160 00161 // For each input scalar, the corresponding bits of the result are set if the 00162 // input scalar is zero. 00163 template <typename tIntegerType> 00164 tIntegerType MaskIfZero(tIntegerType a) { 00165 return MaskIfNonZero<tIntegerType>(!a); 00166 } 00167 00168 // For each pair of input scalars, the corresponding bits of the result are 00169 // set if the input scalars are equal. 00170 template <typename tIntegerType> 00171 tIntegerType MaskIfEqual(tIntegerType a, tIntegerType b) { 00172 return MaskIfNonZero<tIntegerType>(a == b); 00173 } 00174 00175 // For each pair of input scalars, the corresponding bits of the result are 00176 // set if the input scalars are not equal. 00177 template <typename tIntegerType> 00178 tIntegerType MaskIfNotEqual(tIntegerType a, tIntegerType b) { 00179 return MaskIfNonZero<tIntegerType>(a != b); 00180 } 00181 00182 // For each pair of input scalars, the corresponding bits of the result are 00183 // set if the input scalars a, b satisfy a > b. 00184 template <typename tIntegerType> 00185 tIntegerType MaskIfGreaterThan(tIntegerType a, tIntegerType b) { 00186 return MaskIfNonZero<tIntegerType>(a > b); 00187 } 00188 00189 // For each pair of input scalars, the corresponding bits of the result are 00190 // set if the input scalars a, b satisfy a >= b. 00191 template <typename tIntegerType> 00192 tIntegerType MaskIfGreaterThanOrEqual(tIntegerType a, tIntegerType b) { 00193 return MaskIfNonZero<tIntegerType>(a >= b); 00194 } 00195 00196 // For each pair of input scalars, the corresponding bits of the result are 00197 // set if the input scalars a, b satisfy a < b. 00198 template <typename tIntegerType> 00199 tIntegerType MaskIfLessThan(tIntegerType a, tIntegerType b) { 00200 return MaskIfNonZero<tIntegerType>(a < b); 00201 } 00202 00203 // For each pair of input scalars, the corresponding bits of the result are 00204 // set if the input scalars a, b satisfy a <= b. 00205 template <typename tIntegerType> 00206 tIntegerType MaskIfLessThanOrEqual(tIntegerType a, tIntegerType b) { 00207 return MaskIfNonZero<tIntegerType>(a <= b); 00208 } 00209 00210 // Returns true if all of the input scalars are nonzero. 00211 // This function may currently assume that each of the input scalars has either 00212 // all or none of its bits set. Otherwise, its behavior is currently undefined. 00213 template <typename tIntegerType> 00214 bool All(tIntegerType a) { 00215 return a; 00216 } 00217 00218 // Returns true if any of the input scalars are nonzero. 00219 // This function may currently assume that each of the input scalars has either 00220 // all or none of its bits set. Otherwise, its behavior is currently undefined. 00221 template <typename tIntegerType> 00222 bool Any(tIntegerType a) { 00223 return a; 00224 } 00225 00226 // Returns (a+b)/2, rounded to the nearest integer. 00227 // Equivalent to VRHADD in the ARM NEON instruction set. 00228 template <typename IntegerType> 00229 IntegerType RoundingHalfSum(IntegerType a, IntegerType b) { 00230 static_assert(std::is_same<IntegerType, void>::value, "unimplemented"); 00231 (void)b; 00232 return a; 00233 } 00234 00235 template <> 00236 inline std::int32_t RoundingHalfSum(std::int32_t a, std::int32_t b) { 00237 std::int64_t a64 = a; 00238 std::int64_t b64 = b; 00239 std::int64_t sum = a64 + b64; 00240 std::int64_t sign = sum >= 0 ? 1 : -1; 00241 return static_cast<std::int32_t>((sum + sign) / 2); 00242 } 00243 00244 template <> 00245 inline std::int16_t RoundingHalfSum(std::int16_t a, std::int16_t b) { 00246 std::int32_t a32 = a; 00247 std::int32_t b32 = b; 00248 std::int32_t sum = a32 + b32; 00249 std::int32_t sign = sum >= 0 ? 1 : -1; 00250 return static_cast<std::int16_t>((sum + sign) / 2); 00251 } 00252 00253 template <typename IntegerType> 00254 IntegerType SaturatingAdd(IntegerType a, IntegerType b) { 00255 static_assert(std::is_same<IntegerType, void>::value, "unimplemented"); 00256 (void)b; 00257 return a; 00258 } 00259 00260 // So far this is only needed for int16. 00261 template <> 00262 inline std::int16_t SaturatingAdd(std::int16_t a, std::int16_t b) { 00263 std::int32_t a32 = a; 00264 std::int32_t b32 = b; 00265 std::int32_t sum = a32 + b32; 00266 return static_cast<std::int16_t>( 00267 std::min(static_cast<std::int32_t>(32767), 00268 std::max(static_cast<std::int32_t>(-32768), sum))); 00269 } 00270 00271 // Returns a+b, saturating if the integers are 16bit or narrower, 00272 // otherwise just a plain addition. 00273 template <typename IntegerType, bool Is16Bit> 00274 struct AddSaturatingIf16BitImpl { 00275 static IntegerType Run(IntegerType a, IntegerType b) { return Add(a, b); } 00276 }; 00277 template <typename IntegerType> 00278 struct AddSaturatingIf16BitImpl<IntegerType, true> { 00279 static IntegerType Run(IntegerType a, IntegerType b) { 00280 return SaturatingAdd(a, b); 00281 } 00282 }; 00283 template <typename IntegerType> 00284 IntegerType AddSaturatingIf16Bit(IntegerType a, IntegerType b) { 00285 using ScalarType = 00286 typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType; 00287 return AddSaturatingIf16BitImpl<IntegerType, sizeof(ScalarType) == 2>::Run(a, 00288 b); 00289 } 00290 00291 // Returns the integer that represents the product of two fixed-point 00292 // numbers, interpreting all integers as fixed-point values in the 00293 // interval [-1, 1), rounding to the nearest value, and saturating 00294 // -1 * -1 to the maximum value (since 1 is not in the half-open 00295 // interval [-1, 1)). 00296 // 00297 // [The explanation below specializes to std::int32_t for example purpose.] 00298 // 00299 // The mapping between IntegerType and the interval [-1, 1) is unique and 00300 // implied by IntegerType, which is assumed to be signed. For example, 00301 // for IntegerType==std::int32_t, the mapping is 00302 // real_value = integer_value / 2^31. 00303 // So in this case, and leaving aside rounding and saturating, this 00304 // function computes ((a / 2^31) * (b / 2^31)) * 2^31, which simplifies to 00305 // (a * b) / 2^31. 00306 // 00307 // The 'doubling' part in the name of this function comes from the fact that 00308 // this operation is very close to a "multiply-high" operation, keeping only 00309 // the top half bits, except that that would be effectively computing 00310 // (a * b) / 2^32, 00311 // so here we are computing 2x that, since 00312 // 1/2^31 = 2 * 1/2^32. 00313 // The idea is to use all of the available 32 bits in the destination int32 00314 // value. 00315 // 00316 // [End of the explanation specializing to int32.] 00317 // 00318 // This is equivalent to the VQRDMULH instruction in ARM NEON. 00319 template <typename IntegerType> 00320 IntegerType SaturatingRoundingDoublingHighMul(IntegerType a, IntegerType b) { 00321 static_assert(std::is_same<IntegerType, void>::value, "unimplemented"); 00322 (void)b; 00323 return a; 00324 } 00325 00326 // This function implements the same computation as the ARMv7 NEON VQRDMULH 00327 // instruction. 00328 template <> 00329 inline std::int32_t SaturatingRoundingDoublingHighMul(std::int32_t a, 00330 std::int32_t b) { 00331 bool overflow = a == b && a == std::numeric_limits<std::int32_t>::min(); 00332 std::int64_t a_64(a); 00333 std::int64_t b_64(b); 00334 std::int64_t ab_64 = a_64 * b_64; 00335 std::int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30)); 00336 std::int32_t ab_x2_high32 = 00337 static_cast<std::int32_t>((ab_64 + nudge) / (1ll << 31)); 00338 return overflow ? std::numeric_limits<std::int32_t>::max() : ab_x2_high32; 00339 } 00340 00341 template <> 00342 inline std::int16_t SaturatingRoundingDoublingHighMul(std::int16_t a, 00343 std::int16_t b) { 00344 bool overflow = a == b && a == std::numeric_limits<std::int16_t>::min(); 00345 std::int32_t a_32(a); 00346 std::int32_t b_32(b); 00347 std::int32_t ab_32 = a_32 * b_32; 00348 std::int16_t nudge = ab_32 >= 0 ? (1 << 14) : (1 - (1 << 14)); 00349 std::int16_t ab_x2_high16 = 00350 static_cast<std::int16_t>((ab_32 + nudge) / (1 << 15)); 00351 return overflow ? std::numeric_limits<std::int16_t>::max() : ab_x2_high16; 00352 } 00353 00354 // Correctly-rounded-to-nearest division by a power-of-two. 00355 // Also known as a rounding arithmetic right shift. 00356 template <typename IntegerType> 00357 inline IntegerType RoundingDivideByPOT(IntegerType x, int exponent) { 00358 assert(exponent >= 0); 00359 assert(exponent <= 31); 00360 const IntegerType mask = Dup<IntegerType>((1ll << exponent) - 1); 00361 const IntegerType zero = Dup<IntegerType>(0); 00362 const IntegerType one = Dup<IntegerType>(1); 00363 const IntegerType remainder = BitAnd(x, mask); 00364 const IntegerType threshold = 00365 Add(ShiftRight(mask, 1), BitAnd(MaskIfLessThan(x, zero), one)); 00366 return Add(ShiftRight(x, exponent), 00367 BitAnd(MaskIfGreaterThan(remainder, threshold), one)); 00368 } 00369 00370 // Returns the product of a run-time integer value by a compile-time power 00371 // of two, with either a positive exponent (equivalent to an arithmetic 00372 // left shift, saturating) or a negative exponent (equivalent to an arithmetic 00373 // right shift, rounding to nearest). 00374 template <int Exponent, typename IntegerType, 00375 int ExponentSign = (Exponent > 0 ? 1 : Exponent < 0 ? -1 : 0)> 00376 struct ImplSaturatingRoundingMultiplyByPOT {}; 00377 00378 template <int Exponent, typename IntegerType> 00379 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 0> { 00380 static IntegerType eval(IntegerType x) { return x; } 00381 }; 00382 00383 template <int Exponent, typename IntegerType> 00384 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 1> { 00385 static IntegerType eval(IntegerType x) { 00386 using ScalarIntegerType = 00387 typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType; 00388 const IntegerType min = 00389 Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::min()); 00390 const IntegerType max = 00391 Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::max()); 00392 const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType); 00393 00394 const std::int32_t threshold = 00395 ((1 << (ScalarIntegerTypeBits - 1 - Exponent)) - 1); 00396 const IntegerType positive_mask = 00397 MaskIfGreaterThan(x, Dup<IntegerType>(threshold)); 00398 const IntegerType negative_mask = 00399 MaskIfLessThan(x, Dup<IntegerType>(-threshold)); 00400 00401 IntegerType result = ShiftLeft(x, Exponent); 00402 result = SelectUsingMask(positive_mask, max, result); 00403 result = SelectUsingMask(negative_mask, min, result); 00404 return result; 00405 } 00406 }; 00407 00408 template <int Exponent, typename IntegerType> 00409 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, -1> { 00410 static IntegerType eval(IntegerType x) { 00411 return RoundingDivideByPOT<IntegerType>(x, -Exponent); 00412 } 00413 }; 00414 00415 template <int Exponent, typename IntegerType> 00416 IntegerType SaturatingRoundingMultiplyByPOT(IntegerType x) { 00417 return ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType>::eval(x); 00418 } 00419 00420 // Part 2: the FixedPoint class. 00421 00422 // A FixedPoint object represents a fixed-point value stored in the underlying 00423 // integer type tRawType, if tRawType is a plain scalar integer type. 00424 // Alternatively, tRawType may be a SIMD type (e.g. NEON int32x4_t) in which 00425 // case a FixedPoint object represents a corresponding SIMD vector of fixed 00426 // point values. 00427 // 00428 // tIntegerBits describes the range of the fixed-point format: if 00429 // tIntegerBits == m then the range of representable values is the half-open 00430 // interval [-2^m; 2^m) where the open boundary on the right side means that 00431 // 2^m is not representable (how close the maximum representable value is to 00432 // it, depends on bit-depth of tRawType). 00433 // 00434 // In "Q format notation", 00435 // https://en.wikipedia.org/wiki/Q_(number_format) 00436 // we are describing the format 00437 // Qm.n 00438 // where 00439 // m = tIntegerBits 00440 // and 00441 // n = NumberOfBits(tRawType) - (m + 1) 00442 // Note that the (m + 1) in the above line is because we adopt the convention 00443 // that we count the integer bits exclusively of the sign bit; so (m + 1) is 00444 // the total number of integer bits inclusive of the sign bit. 00445 // 00446 // Accordingly, the number of integral representable values in our range 00447 // [-2^m ; 2^m) 00448 // is equal to 2^(m+1). 00449 template <typename tRawType, int tIntegerBits> 00450 class FixedPoint { 00451 public: 00452 typedef tRawType RawType; 00453 00454 typedef FixedPointRawTypeTraits<RawType> RawTypeTraits; 00455 typedef typename RawTypeTraits::ScalarRawType ScalarRawType; 00456 00457 static constexpr int kTotalBits = 8 * sizeof(ScalarRawType); 00458 static constexpr int kIntegerBits = tIntegerBits; 00459 static constexpr int kFractionalBits = kTotalBits - 1 - kIntegerBits; 00460 static_assert(kIntegerBits >= 0 && kIntegerBits < kTotalBits, 00461 "bad IntegerBits"); 00462 00463 typedef FixedPoint<ScalarRawType, kIntegerBits> ScalarFixedPointType; 00464 00465 static const ScalarRawType ScalarRawMin() { 00466 return std::numeric_limits<ScalarRawType>::min(); 00467 } 00468 00469 static const ScalarRawType ScalarRawMax() { 00470 return std::numeric_limits<ScalarRawType>::max(); 00471 } 00472 00473 static const ScalarRawType RawMin() { 00474 return VectorFromScalar(ScalarRawMin()); 00475 } 00476 00477 static const ScalarRawType RawMax() { 00478 return VectorFromScalar(ScalarRawMax()); 00479 } 00480 00481 static FixedPoint FromRaw(RawType x) { 00482 FixedPoint retval; 00483 retval.raw() = x; 00484 return retval; 00485 } 00486 00487 static FixedPoint FromScalarRaw(ScalarRawType x) { 00488 FixedPoint retval; 00489 retval.raw() = Dup<RawType>(x); 00490 return retval; 00491 } 00492 00493 static FixedPoint FromScalarFixedPoint(ScalarFixedPointType x) { 00494 return FromScalarRaw(x.raw()); 00495 } 00496 00497 template <int Exponent> 00498 static FixedPoint ConstantPOT() { 00499 static constexpr int kOffset = kFractionalBits + Exponent; 00500 static_assert( 00501 kOffset < 31, 00502 "Constant not exactly representable in this fixed-point format"); 00503 return FromScalarRaw(ScalarRawType(1) << kOffset); 00504 } 00505 00506 static FixedPoint Zero() { return FromScalarRaw(0); } 00507 00508 static FixedPoint One() { 00509 return FromScalarRaw( 00510 kIntegerBits == 0 00511 ? ScalarRawMax() 00512 : (ScalarRawType(1) << (kIntegerBits == 0 ? 0 : kFractionalBits))); 00513 } 00514 00515 static FixedPoint FromDouble(double x) { 00516 const double min_bound = static_cast<double>(ScalarRawMin()); 00517 const double max_bound = static_cast<double>(ScalarRawMax()); 00518 return FromScalarRaw(static_cast<ScalarRawType>(std::min( 00519 std::max(round(x * static_cast<double>(1ll << kFractionalBits)), 00520 min_bound), 00521 max_bound))); 00522 } 00523 00524 RawType raw() const { return i_; } 00525 RawType& raw() { return i_; } 00526 00527 private: 00528 RawType i_; 00529 }; 00530 00531 // Part 3: implementation of arithmetic operators for the 00532 // FixedPoint class, and a few related functions. 00533 00534 // A FixedPoint multiplication is just a 00535 // SaturatingRoundingDoublingHighMul operation on the underlying 00536 // raw integer values. The IntegerBits simply add up, as is obvious 00537 // from the fact that the range is [-2^IntegerBits, 2^IntegerBits). 00538 template <typename tRawType, int tIntegerBits_a, int tIntegerBits_b> 00539 FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> operator*( 00540 FixedPoint<tRawType, tIntegerBits_a> a, 00541 FixedPoint<tRawType, tIntegerBits_b> b) { 00542 FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> c; 00543 c.raw() = SaturatingRoundingDoublingHighMul(a.raw(), b.raw()); 00544 return c; 00545 } 00546 00547 // Tweaking IntegerBits gives exact multiplication by a power of two. 00548 template <int tExponent, typename tRawType, int tIntegerBits> 00549 FixedPoint<tRawType, tExponent + tIntegerBits> ExactMulByPot( 00550 FixedPoint<tRawType, tIntegerBits> a) { 00551 FixedPoint<tRawType, tExponent + tIntegerBits> c; 00552 c.raw() = a.raw(); 00553 return c; 00554 } 00555 00556 // If we want to leave IntegerBits fixed, then multiplication 00557 // by a power of two has to be saturating/rounding, not exact anymore. 00558 template <int tExponent, typename tRawType, int tIntegerBits> 00559 FixedPoint<tRawType, tIntegerBits> SaturatingRoundingMultiplyByPOT( 00560 FixedPoint<tRawType, tIntegerBits> a) { 00561 return FixedPoint<tRawType, tIntegerBits>::FromRaw( 00562 SaturatingRoundingMultiplyByPOT<tExponent>(a.raw())); 00563 } 00564 00565 // Generic arithmetic operators. 00566 00567 #define MAKE_FIXEDPOINT_UNARY_FUNC(FuncName, ImplFuncName) \ 00568 template <typename tRawType, int tIntegerBits> \ 00569 FixedPoint<tRawType, tIntegerBits> FuncName( \ 00570 FixedPoint<tRawType, tIntegerBits> a) { \ 00571 return FixedPoint<tRawType, tIntegerBits>::FromRaw(ImplFuncName(a.raw())); \ 00572 } 00573 00574 #define MAKE_FIXEDPOINT_BINARY_FUNC(FuncName, ImplFuncName) \ 00575 template <typename tRawType, int tIntegerBits> \ 00576 FixedPoint<tRawType, tIntegerBits> FuncName( \ 00577 FixedPoint<tRawType, tIntegerBits> a, \ 00578 FixedPoint<tRawType, tIntegerBits> b) { \ 00579 return FixedPoint<tRawType, tIntegerBits>::FromRaw( \ 00580 ImplFuncName(a.raw(), b.raw())); \ 00581 } 00582 00583 MAKE_FIXEDPOINT_UNARY_FUNC(operator-, Neg) 00584 MAKE_FIXEDPOINT_UNARY_FUNC(operator~, BitNot) 00585 MAKE_FIXEDPOINT_BINARY_FUNC(operator+, Add) 00586 MAKE_FIXEDPOINT_BINARY_FUNC(operator-, Sub) 00587 MAKE_FIXEDPOINT_BINARY_FUNC(operator&, BitAnd) 00588 MAKE_FIXEDPOINT_BINARY_FUNC(operator^, BitXor) 00589 MAKE_FIXEDPOINT_BINARY_FUNC(operator|, BitOr) 00590 MAKE_FIXEDPOINT_BINARY_FUNC(RoundingHalfSum, RoundingHalfSum) 00591 00592 #undef MAKE_FIXEDPOINT_UNARY_FUNC 00593 #undef MAKE_FIXEDPOINT_BINARY_FUNC 00594 00595 #define MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(FuncName) \ 00596 template <typename tRawType, int tIntegerBits> \ 00597 tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a) { \ 00598 return FuncName(a.raw()); \ 00599 } 00600 00601 #define MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(FuncName) \ 00602 template <typename tRawType, int tIntegerBits> \ 00603 tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a, \ 00604 FixedPoint<tRawType, tIntegerBits> b) { \ 00605 return FuncName(a.raw(), b.raw()); \ 00606 } 00607 00608 MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfZero) 00609 MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfNonZero) 00610 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfEqual) 00611 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfNotEqual) 00612 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThan) 00613 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThanOrEqual) 00614 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThan) 00615 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThanOrEqual) 00616 00617 #undef MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW 00618 #undef MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW 00619 00620 template <typename tRawType, int tIntegerBits> 00621 FixedPoint<tRawType, tIntegerBits> SelectUsingMask( 00622 tRawType if_mask, FixedPoint<tRawType, tIntegerBits> then_val, 00623 FixedPoint<tRawType, tIntegerBits> else_val) { 00624 return FixedPoint<tRawType, tIntegerBits>::FromRaw( 00625 SelectUsingMask(if_mask, then_val.raw(), else_val.raw())); 00626 } 00627 00628 template <typename tRawType, int tIntegerBits> 00629 bool operator==(FixedPoint<tRawType, tIntegerBits> a, 00630 FixedPoint<tRawType, tIntegerBits> b) { 00631 return All(MaskIfEqual(a.raw(), b.raw())); 00632 } 00633 00634 template <typename tRawType, int tIntegerBits> 00635 bool operator!=(FixedPoint<tRawType, tIntegerBits> a, 00636 FixedPoint<tRawType, tIntegerBits> b) { 00637 return !(a == b); 00638 } 00639 00640 template <typename tRawType, int tIntegerBits> 00641 FixedPoint<tRawType, tIntegerBits> SaturatingAdd( 00642 FixedPoint<tRawType, tIntegerBits> a, 00643 FixedPoint<tRawType, tIntegerBits> b) { 00644 return FixedPoint<tRawType, tIntegerBits>::FromRaw( 00645 SaturatingAdd(a.raw(), b.raw())); 00646 } 00647 00648 template <typename tRawType, int tIntegerBits> 00649 FixedPoint<tRawType, tIntegerBits> AddSaturatingIf16Bit( 00650 FixedPoint<tRawType, tIntegerBits> a, 00651 FixedPoint<tRawType, tIntegerBits> b) { 00652 return FixedPoint<tRawType, tIntegerBits>::FromRaw( 00653 AddSaturatingIf16Bit(a.raw(), b.raw())); 00654 } 00655 00656 // Conversion to floating-point. 00657 template <typename tRawType, int tIntegerBits> 00658 double ToDouble(FixedPoint<tRawType, tIntegerBits> x) { 00659 static_assert(FixedPointRawTypeTraits<tRawType>::kLanes == 1, 00660 "not applicable to SIMD types"); 00661 typedef FixedPoint<tRawType, tIntegerBits> F; 00662 return x.raw() / static_cast<double>(1ll << F::kFractionalBits); 00663 } 00664 00665 // Rescale changes the number of IntegerBits and updates the underlying 00666 // raw integer value accordingly. 00667 template <int tIntegerBitsDst, typename tRawType, int tIntegerBitsSrc> 00668 FixedPoint<tRawType, tIntegerBitsDst> Rescale( 00669 FixedPoint<tRawType, tIntegerBitsSrc> x) { 00670 static constexpr int kExponent = tIntegerBitsSrc - tIntegerBitsDst; 00671 FixedPoint<tRawType, tIntegerBitsDst> result; 00672 result.raw() = SaturatingRoundingMultiplyByPOT<kExponent>(x.raw()); 00673 return result; 00674 } 00675 00676 // CheckedFixedPointConstant allows to specify fixed-point constants 00677 // initialized as real numbers, in a way that does not compile floating-point 00678 // arithmetic in production code, yet still checks agreement with the 00679 // floating-point expressions when asserts are enabled. 00680 // 00681 // The raw integer value provided is always a int32, encoding a 32-bit 00682 // fixed-point value, regardless of the actual Scalar type. This allows 00683 // writing generic code that applies just as well to the 32-bit and 16-bit 00684 // cases. In the 16-bit case, the raw integer value is internally 00685 // rounding-shifted by 16 bits to the right. 00686 template <typename FixedPointType> 00687 inline typename FixedPointType::ScalarRawType RescaleConstantInitializer( 00688 std::int32_t int32_value) { 00689 typedef typename FixedPointType::ScalarRawType ScalarRawType; 00690 static constexpr int ScalarTypeBits = 8 * sizeof(ScalarRawType); 00691 return static_cast<ScalarRawType>( 00692 RoundingDivideByPOT<std::int32_t>(int32_value, 32 - ScalarTypeBits)); 00693 } 00694 #ifdef GEMMLOWP_ENABLE_FIXEDPOINT_CONSTANTS_CHECKS 00695 template <typename FixedPointType> 00696 FixedPointType CheckedFixedPointConstant(std::int32_t raw_value, 00697 double double_value) { 00698 const FixedPointType result = FixedPointType::FromScalarRaw(raw_value); 00699 assert(result == FixedPointType::FromDouble(double_value)); 00700 return result; 00701 } 00702 #define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, \ 00703 ScalarRawInt32Value, DoubleValue) \ 00704 (gemmlowp::CheckedFixedPointConstant<FixedPointType>( \ 00705 gemmlowp::RescaleConstantInitializer<FixedPointType>( \ 00706 ScalarRawInt32Value), \ 00707 DoubleValue)) 00708 00709 #else 00710 #define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, \ 00711 ScalarRawInt32Value, DoubleValue) \ 00712 (FixedPointType::FromScalarRaw( \ 00713 gemmlowp::RescaleConstantInitializer<FixedPointType>( \ 00714 ScalarRawInt32Value))) 00715 #endif 00716 00717 // Implementation of exponential function. 00718 00719 // Returns exp(x) for x in [-1/4, 0). 00720 template <typename tRawType> 00721 FixedPoint<tRawType, 0> exp_on_interval_between_negative_one_quarter_and_0_excl( 00722 FixedPoint<tRawType, 0> a) { 00723 typedef FixedPoint<tRawType, 0> F; 00724 const F constant_term = 00725 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 1895147668, std::exp(-1.0 / 8.0)); 00726 const F constant_1_over_3 = 00727 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 715827883, 1.0 / 3.0); 00728 // We're evaluating a Taylor expansion around -1/8, so we do the change of 00729 // variable: x = a + 1/8. 00730 // In fixed-point with 0 integer bits, 1/8 is represented by 1 << 28. 00731 F x = a + F::template ConstantPOT<-3>(); 00732 F x2 = x * x; 00733 F x3 = x2 * x; 00734 F x4 = x2 * x2; 00735 F x4_over_4 = SaturatingRoundingMultiplyByPOT<-2>(x4); 00736 F x4_over_24_plus_x3_over_6_plus_x2_over_2 = 00737 SaturatingRoundingMultiplyByPOT<-1>( 00738 ((x4_over_4 + x3) * constant_1_over_3) + x2); 00739 return AddSaturatingIf16Bit( 00740 constant_term, 00741 constant_term * (x + x4_over_24_plus_x3_over_6_plus_x2_over_2)); 00742 } 00743 00744 // Returns exp(x) for x < 0. 00745 template <typename tRawType, int tIntegerBits> 00746 FixedPoint<tRawType, 0> exp_on_negative_values( 00747 FixedPoint<tRawType, tIntegerBits> a) { 00748 typedef FixedPoint<tRawType, tIntegerBits> InputF; 00749 typedef FixedPoint<tRawType, 0> ResultF; 00750 static constexpr int kFractionalBits = InputF::kFractionalBits; 00751 static constexpr int kIntegerBits = InputF::kIntegerBits; 00752 const InputF kOneQuarter = InputF::template ConstantPOT<-2>(); 00753 InputF mask = kOneQuarter - InputF::FromScalarRaw(1); 00754 InputF a_mod_quarter_minus_one_quarter = (a & mask) - kOneQuarter; 00755 ResultF result = exp_on_interval_between_negative_one_quarter_and_0_excl( 00756 Rescale<0>(a_mod_quarter_minus_one_quarter)); 00757 tRawType remainder = (a_mod_quarter_minus_one_quarter - a).raw(); 00758 00759 #define GEMMLOWP_EXP_BARREL_SHIFTER(Exponent, FixedPointMultiplier) \ 00760 if (kIntegerBits > Exponent) { \ 00761 const ResultF kMultiplier = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( \ 00762 ResultF, FixedPointMultiplier, std::exp(-std::pow(2.0, Exponent))); \ 00763 static constexpr int kShiftAmount = \ 00764 kIntegerBits > Exponent ? kFractionalBits + Exponent : 0; \ 00765 result = SelectUsingMask( \ 00766 MaskIfNonZero(BitAnd(remainder, Dup<tRawType>(1 << kShiftAmount))), \ 00767 result * kMultiplier, result); \ 00768 } 00769 00770 GEMMLOWP_EXP_BARREL_SHIFTER(-2, 1672461947); 00771 GEMMLOWP_EXP_BARREL_SHIFTER(-1, 1302514674); 00772 GEMMLOWP_EXP_BARREL_SHIFTER(+0, 790015084); 00773 GEMMLOWP_EXP_BARREL_SHIFTER(+1, 290630308); 00774 GEMMLOWP_EXP_BARREL_SHIFTER(+2, 39332535); 00775 GEMMLOWP_EXP_BARREL_SHIFTER(+3, 720401); 00776 GEMMLOWP_EXP_BARREL_SHIFTER(+4, 242); 00777 00778 #undef GEMMLOWP_EXP_BARREL_SHIFTER 00779 00780 static constexpr int clampB = kIntegerBits > 5 ? 36 - kIntegerBits : 0; 00781 if (kIntegerBits > 5) { 00782 const InputF clamp = 00783 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(InputF, -(1 << clampB), -32.0); 00784 result = SelectUsingMask(MaskIfLessThan(a, clamp), ResultF::Zero(), result); 00785 } 00786 00787 result = SelectUsingMask(MaskIfZero(a), ResultF::One(), result); 00788 return result; 00789 } 00790 00791 // Implementation of tanh: (1 - exp(-2x)) / (1 + exp(-2x)). 00792 00793 // Returns (1 - x) / (1 + x) for x in (0, 1). 00794 template <typename tRawType> 00795 FixedPoint<tRawType, 0> one_minus_x_over_one_plus_x_for_x_in_0_1( 00796 FixedPoint<tRawType, 0> a) { 00797 typedef FixedPoint<tRawType, 0> F0; 00798 typedef FixedPoint<tRawType, 2> F2; 00799 F0 half_denominator = RoundingHalfSum(a, F0::One()); 00800 // Newton-Raphson division 00801 // https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division 00802 // Refer to that page for the logic behind the 48/17 and 32/17 constants. 00803 const F2 constant_48_over_17 = 00804 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, 1515870810, 48.0 / 17.0); 00805 const F2 constant_neg_32_over_17 = 00806 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, -1010580540, -32.0 / 17.0); 00807 F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17; 00808 for (int i = 0; i < 3; i++) { 00809 F2 half_denominator_times_x = half_denominator * x; 00810 F2 one_minus_half_denominator_times_x = 00811 F2::One() - half_denominator_times_x; 00812 x = x + Rescale<2>(x * one_minus_half_denominator_times_x); 00813 } 00814 return Rescale<0>(x - F2::One()); 00815 } 00816 00817 // Returns -tanh(x) for x < 0. 00818 template <typename tRawType, int tIntegerBits> 00819 FixedPoint<tRawType, 0> neg_tanh_on_negative_values( 00820 FixedPoint<tRawType, tIntegerBits> a) { 00821 return one_minus_x_over_one_plus_x_for_x_in_0_1( 00822 exp_on_negative_values(ExactMulByPot<1>(a))); 00823 } 00824 00825 // Returns tanh(x) for any x. 00826 template <typename tRawType, int tIntegerBits> 00827 FixedPoint<tRawType, 0> tanh(FixedPoint<tRawType, tIntegerBits> a) { 00828 typedef FixedPoint<tRawType, tIntegerBits> InputF; 00829 typedef FixedPoint<tRawType, 0> ResultF; 00830 tRawType mask_if_negative = MaskIfLessThan(a, InputF::Zero()); 00831 tRawType mask_if_zero = MaskIfZero(a); 00832 InputF n = SelectUsingMask(mask_if_negative, a, -a); 00833 ResultF t = neg_tanh_on_negative_values(n); 00834 return SelectUsingMask(mask_if_zero, ResultF::Zero(), 00835 SelectUsingMask(mask_if_negative, -t, t)); 00836 } 00837 00838 // Implementation of logistic function. 00839 00840 // Returns 1 / (1 + x) for x in (0, 1). 00841 template <typename tRawType> 00842 FixedPoint<tRawType, 0> one_over_one_plus_x_for_x_in_0_1( 00843 FixedPoint<tRawType, 0> a) { 00844 typedef FixedPoint<tRawType, 0> F0; 00845 typedef FixedPoint<tRawType, 2> F2; 00846 F0 half_denominator = RoundingHalfSum(a, F0::One()); 00847 // Newton-Raphson division 00848 // https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division 00849 // Refer to that page for the logic behind the 48/17 and 32/17 constants. 00850 const F2 constant_48_over_17 = 00851 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, 1515870810, 48.0 / 17.0); 00852 const F2 constant_neg_32_over_17 = 00853 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, -1010580540, -32.0 / 17.0); 00854 F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17; 00855 for (int i = 0; i < 3; i++) { 00856 F2 half_denominator_times_x = half_denominator * x; 00857 F2 one_minus_half_denominator_times_x = 00858 F2::One() - half_denominator_times_x; 00859 x = x + Rescale<2>(x * one_minus_half_denominator_times_x); 00860 } 00861 return Rescale<0>(ExactMulByPot<-1>(x)); 00862 } 00863 00864 // Returns logistic(x) = 1 / (1 + exp(-x)) for x > 0. 00865 template <typename tRawType, int tIntegerBits> 00866 FixedPoint<tRawType, 0> logistic_on_positive_values( 00867 FixedPoint<tRawType, tIntegerBits> a) { 00868 return one_over_one_plus_x_for_x_in_0_1(exp_on_negative_values(-a)); 00869 } 00870 00871 // Returns logistic(x) = 1 / (1 + exp(-x)) for any x. 00872 template <typename tRawType, int tIntegerBits> 00873 FixedPoint<tRawType, 0> logistic(FixedPoint<tRawType, tIntegerBits> a) { 00874 typedef FixedPoint<tRawType, tIntegerBits> InputF; 00875 typedef FixedPoint<tRawType, 0> ResultF; 00876 tRawType mask_if_positive = MaskIfGreaterThan(a, InputF::Zero()); 00877 tRawType mask_if_zero = MaskIfZero(a); 00878 InputF abs_input = SelectUsingMask(mask_if_positive, a, -a); 00879 ResultF result_if_positive = logistic_on_positive_values(abs_input); 00880 ResultF result_if_negative = ResultF::One() - result_if_positive; 00881 const ResultF one_half = 00882 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(ResultF, 1 << 30, 0.5); 00883 return SelectUsingMask(mask_if_zero, one_half, 00884 SelectUsingMask(mask_if_positive, result_if_positive, 00885 result_if_negative)); 00886 } 00887 00888 } // end namespace gemmlowp 00889 00890 #ifdef GEMMLOWP_NEON 00891 #include "./fixedpoint_neon.h" 00892 #elif defined(GEMMLOWP_AVX2) 00893 #include "./fixedpoint_avx.h" 00894 #elif defined(GEMMLOWP_SSE4) 00895 #include "./fixedpoint_sse.h" 00896 #elif defined(GEMMLOWP_MSA) 00897 #include "./fixedpoint_msa.h" 00898 #endif 00899 00900 #endif // GEMMLOWP_INTERNAL_FIXEDPOINT_H_
Generated on Wed Jul 13 2022 16:03:35 by
