Daniel Konegen / MNIST_example

Dependencies:   mbed-os

Embed: (wiki syntax)

« Back to documentation index

Show/hide line numbers fixedpoint_sse.h Source File

fixedpoint_sse.h

00001 // Copyright 2015 Google Inc. All Rights Reserved.
00002 //
00003 // Licensed under the Apache License, Version 2.0 (the "License");
00004 // you may not use this file except in compliance with the License.
00005 // You may obtain a copy of the License at
00006 //
00007 //     http://www.apache.org/licenses/LICENSE-2.0
00008 //
00009 // Unless required by applicable law or agreed to in writing, software
00010 // distributed under the License is distributed on an "AS IS" BASIS,
00011 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00012 // See the License for the specific language governing permissions and
00013 // limitations under the License.
00014 
00015 // fixedpoint_SSE.h: optimized SSE specializations of the templates
00016 // in fixedpoint.h.
00017 
00018 #ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_SSE_H_
00019 #define GEMMLOWP_INTERNAL_FIXEDPOINT_SSE_H_
00020 
00021 #include <smmintrin.h>
00022 #include "fixedpoint.h"
00023 
00024 namespace gemmlowp {
00025 
00026 // SSE intrinsics are not finely typed: there is a single __m128i vector
00027 // type that does not distinguish between "int32x4" and "int16x8" use
00028 // cases, unlike the NEON equivalents. Because we had initially focused
00029 // on int32x4, we did not pay attention and specialized these fixedpoint
00030 // templates directly for __m128i hardcoding the int32x4 semantics,
00031 // not leaving room for int16x8 semantics. Amending that by adding a separate
00032 // data type, int16x8_m128i, that wraps __m128i while being a separate
00033 // type.
00034 struct int16x8_m128i {
00035   int16x8_m128i() {}
00036   explicit int16x8_m128i(__m128i w) : v(w) {}
00037   ~int16x8_m128i() {}
00038 
00039   __m128i v;
00040 };
00041 
00042 template <>
00043 struct FixedPointRawTypeTraits<__m128i> {
00044   typedef std::int32_t ScalarRawType;
00045   static constexpr int kLanes = 4;
00046 };
00047 
00048 template <>
00049 struct FixedPointRawTypeTraits<int16x8_m128i> {
00050   typedef std::int16_t ScalarRawType;
00051   static constexpr int kLanes = 8;
00052 };
00053 
00054 template <>
00055 inline __m128i BitAnd(__m128i a, __m128i b) {
00056   return _mm_and_si128(a, b);
00057 }
00058 
00059 template <>
00060 inline int16x8_m128i BitAnd(int16x8_m128i a, int16x8_m128i b) {
00061   return int16x8_m128i(_mm_and_si128(a.v, b.v));
00062 }
00063 
00064 template <>
00065 inline __m128i BitOr(__m128i a, __m128i b) {
00066   return _mm_or_si128(a, b);
00067 }
00068 
00069 template <>
00070 inline int16x8_m128i BitOr(int16x8_m128i a, int16x8_m128i b) {
00071   return int16x8_m128i(_mm_or_si128(a.v, b.v));
00072 }
00073 
00074 template <>
00075 inline __m128i BitXor(__m128i a, __m128i b) {
00076   return _mm_xor_si128(a, b);
00077 }
00078 
00079 template <>
00080 inline int16x8_m128i BitXor(int16x8_m128i a, int16x8_m128i b) {
00081   return int16x8_m128i(_mm_xor_si128(a.v, b.v));
00082 }
00083 
00084 template <>
00085 inline __m128i BitNot(__m128i a) {
00086   return _mm_andnot_si128(a, _mm_set1_epi32(-1));
00087 }
00088 
00089 template <>
00090 inline int16x8_m128i BitNot(int16x8_m128i a) {
00091   return int16x8_m128i(_mm_andnot_si128(a.v, _mm_set1_epi16(-1)));
00092 }
00093 
00094 template <>
00095 inline __m128i Add(__m128i a, __m128i b) {
00096   return _mm_add_epi32(a, b);
00097 }
00098 
00099 template <>
00100 inline int16x8_m128i Add(int16x8_m128i a, int16x8_m128i b) {
00101   return int16x8_m128i(_mm_add_epi16(a.v, b.v));
00102 }
00103 
00104 template <>
00105 inline __m128i Mul(__m128i a, __m128i b) {
00106   return _mm_mullo_epi32(a, b);
00107 }
00108 
00109 template <>
00110 inline int16x8_m128i Mul(int16x8_m128i a, int16x8_m128i b) {
00111   return int16x8_m128i(_mm_mullo_epi16(a.v, b.v));
00112 }
00113 
00114 template <>
00115 inline __m128i Sub(__m128i a, __m128i b) {
00116   return _mm_sub_epi32(a, b);
00117 }
00118 
00119 template <>
00120 inline int16x8_m128i Sub(int16x8_m128i a, int16x8_m128i b) {
00121   return int16x8_m128i(_mm_sub_epi16(a.v, b.v));
00122 }
00123 
00124 template <>
00125 inline __m128i Neg(__m128i a) {
00126   return _mm_sign_epi32(a, _mm_set1_epi32(-1));
00127 }
00128 
00129 template <>
00130 inline int16x8_m128i Neg(int16x8_m128i a) {
00131   return int16x8_m128i(_mm_sign_epi16(a.v, _mm_set1_epi16(-1)));
00132 }
00133 
00134 template <>
00135 inline __m128i ShiftLeft(__m128i a, int offset) {
00136   return _mm_slli_epi32(a, offset);
00137 }
00138 
00139 template <>
00140 inline int16x8_m128i ShiftLeft(int16x8_m128i a, int offset) {
00141   return int16x8_m128i(_mm_slli_epi16(a.v, offset));
00142 }
00143 
00144 template <>
00145 inline __m128i ShiftRight(__m128i a, int offset) {
00146   return _mm_srai_epi32(a, offset);
00147 }
00148 
00149 template <>
00150 inline int16x8_m128i ShiftRight(int16x8_m128i a, int offset) {
00151   return int16x8_m128i(_mm_srai_epi16(a.v, offset));
00152 }
00153 
00154 template <>
00155 inline __m128i SelectUsingMask(__m128i if_mask, __m128i then_val,
00156                                __m128i else_val) {
00157   // borrowed from Intel's arm_neon_sse.h header.
00158   return _mm_or_si128(_mm_and_si128(if_mask, then_val),
00159                       _mm_andnot_si128(if_mask, else_val));
00160 }
00161 
00162 template <>
00163 inline int16x8_m128i SelectUsingMask(int16x8_m128i if_mask,
00164                                      int16x8_m128i then_val,
00165                                      int16x8_m128i else_val) {
00166   // borrowed from Intel's arm_neon_sse.h header.
00167   return int16x8_m128i(SelectUsingMask(if_mask.v, then_val.v, else_val.v));
00168 }
00169 
00170 template <>
00171 inline __m128i MaskIfEqual(__m128i a, __m128i b) {
00172   return _mm_cmpeq_epi32(a, b);
00173 }
00174 
00175 template <>
00176 inline int16x8_m128i MaskIfEqual(int16x8_m128i a, int16x8_m128i b) {
00177   return int16x8_m128i(_mm_cmpeq_epi16(a.v, b.v));
00178 }
00179 
00180 template <>
00181 inline __m128i MaskIfNotEqual(__m128i a, __m128i b) {
00182   return BitNot(MaskIfEqual(a, b));
00183 }
00184 
00185 template <>
00186 inline int16x8_m128i MaskIfNotEqual(int16x8_m128i a, int16x8_m128i b) {
00187   return BitNot(MaskIfEqual(a, b));
00188 }
00189 
00190 template <>
00191 inline __m128i MaskIfZero(__m128i a) {
00192   return MaskIfEqual(a, _mm_set1_epi32(0));
00193 }
00194 
00195 template <>
00196 inline int16x8_m128i MaskIfZero(int16x8_m128i a) {
00197   return MaskIfEqual(a, int16x8_m128i(_mm_set1_epi16(0)));
00198 }
00199 
00200 template <>
00201 inline __m128i MaskIfNonZero(__m128i a) {
00202   return MaskIfNotEqual(a, _mm_set1_epi32(0));
00203 }
00204 
00205 template <>
00206 inline int16x8_m128i MaskIfNonZero(int16x8_m128i a) {
00207   return MaskIfNotEqual(a, int16x8_m128i(_mm_set1_epi16(0)));
00208 }
00209 
00210 template <>
00211 inline __m128i MaskIfGreaterThan(__m128i a, __m128i b) {
00212   return _mm_cmpgt_epi32(a, b);
00213 }
00214 
00215 template <>
00216 inline int16x8_m128i MaskIfGreaterThan(int16x8_m128i a, int16x8_m128i b) {
00217   return int16x8_m128i(_mm_cmpgt_epi16(a.v, b.v));
00218 }
00219 
00220 template <>
00221 inline __m128i MaskIfLessThan(__m128i a, __m128i b) {
00222   return _mm_cmplt_epi32(a, b);
00223 }
00224 
00225 template <>
00226 inline int16x8_m128i MaskIfLessThan(int16x8_m128i a, int16x8_m128i b) {
00227   return int16x8_m128i(_mm_cmplt_epi16(a.v, b.v));
00228 }
00229 
00230 template <>
00231 inline __m128i MaskIfGreaterThanOrEqual(__m128i a, __m128i b) {
00232   return BitNot(MaskIfLessThan(a, b));
00233 }
00234 
00235 template <>
00236 inline int16x8_m128i MaskIfGreaterThanOrEqual(int16x8_m128i a,
00237                                               int16x8_m128i b) {
00238   return BitNot(MaskIfLessThan(a, b));
00239 }
00240 
00241 template <>
00242 inline __m128i MaskIfLessThanOrEqual(__m128i a, __m128i b) {
00243   return BitNot(MaskIfGreaterThan(a, b));
00244 }
00245 
00246 template <>
00247 inline int16x8_m128i MaskIfLessThanOrEqual(int16x8_m128i a, int16x8_m128i b) {
00248   return BitNot(MaskIfGreaterThan(a, b));
00249 }
00250 
00251 /* Assumptions:
00252    - All and Any are used on masks.
00253    - masks are all_ones for true lanes, all_zeroes otherwise.
00254 Hence, All means all 128bits set, and Any means any bit set.
00255 */
00256 
00257 template <>
00258 inline bool All(__m128i a) {
00259   return _mm_testc_si128(a, a);
00260 }
00261 
00262 template <>
00263 inline bool All(int16x8_m128i a) {
00264   return _mm_testc_si128(a.v, a.v);
00265 }
00266 
00267 template <>
00268 inline bool Any(__m128i a) {
00269   return !_mm_testz_si128(a, a);
00270 }
00271 
00272 template <>
00273 inline bool Any(int16x8_m128i a) {
00274   return !_mm_testz_si128(a.v, a.v);
00275 }
00276 
00277 template <>
00278 inline __m128i RoundingHalfSum(__m128i a, __m128i b) {
00279   /* __m128i round_bit_mask, a_over_2, b_over_2, round_bit, sum; */
00280   /* We divide the inputs before the add to avoid the overflow and costly test
00281    */
00282   /* of checking if an overflow occured on signed add */
00283   /* round_bit_mask = _mm_set1_epi32(1); */
00284   /* a_over_2 = _mm_srai_epi32(a, 1); */
00285   /* b_over_2 = _mm_srai_epi32(b, 1); */
00286   /* sum = Add(a_over_2, b_over_2); */
00287   /* round_bit = _mm_sign_epi32(BitAnd(BitOr(a,b), round_bit_mask), sum); */
00288   /* return Add(sum, round_bit); */
00289 
00290   /* Other possibility detecting overflow and xor the sign if an overflow
00291    * happened*/
00292   __m128i one, sign_bit_mask, sum, rounded_half_sum, overflow, result;
00293   one = _mm_set1_epi32(1);
00294   sign_bit_mask = _mm_set1_epi32(0x80000000);
00295   sum = Add(a, b);
00296   rounded_half_sum = _mm_srai_epi32(Add(sum, one), 1);
00297   overflow =
00298       BitAnd(BitAnd(BitXor(a, rounded_half_sum), BitXor(b, rounded_half_sum)),
00299              sign_bit_mask);
00300   result = BitXor(rounded_half_sum, overflow);
00301   return result;
00302 }
00303 
00304 template <>
00305 inline int16x8_m128i RoundingHalfSum(int16x8_m128i a, int16x8_m128i b) {
00306   // Idea: go to unsigned to use _mm_avg_epu16,
00307   // borrowed from Intel's arm_neon_sse.h header.
00308   __m128i constant_neg_32768 = _mm_set1_epi16(-32768);
00309   __m128i a_unsigned = _mm_sub_epi16(a.v, constant_neg_32768);
00310   __m128i b_unsigned = _mm_sub_epi16(b.v, constant_neg_32768);
00311   __m128i avg_unsigned = _mm_avg_epu16(a_unsigned, b_unsigned);
00312   __m128i avg = _mm_add_epi16(avg_unsigned, constant_neg_32768);
00313   return int16x8_m128i(avg);
00314 }
00315 
00316 template <>
00317 inline __m128i SaturatingRoundingDoublingHighMul(__m128i a, __m128i b) {
00318   __m128i min, saturation_mask, a0_a2, a1_a3, b0_b2, b1_b3;
00319   __m128i a0b0_a2b2, a1b1_a3b3, a0b0_a2b2_rounded, a1b1_a3b3_rounded;
00320   __m128i a0b0_a2b2_rounded_2x, a1b1_a3b3_rounded_2x, result;
00321   __m128i nudge;
00322 
00323   // saturation only happen if a == b == INT_MIN
00324   min = _mm_set1_epi32(std::numeric_limits<std::int32_t>::min());
00325   saturation_mask = BitAnd(MaskIfEqual(a, b), MaskIfEqual(a, min));
00326 
00327   // a = a0 | a1 | a2 | a3
00328   // b = b0 | b1 | b2 | b3
00329   a0_a2 = a;
00330   a1_a3 = _mm_srli_si128(a, 4);
00331   b0_b2 = b;
00332   b1_b3 = _mm_srli_si128(b, 4);
00333 
00334   a0b0_a2b2 = _mm_mul_epi32(a0_a2, b0_b2);
00335   a1b1_a3b3 = _mm_mul_epi32(a1_a3, b1_b3);
00336 
00337   // do the rounding and take into account that it will be doubled
00338   nudge = _mm_set1_epi64x(1 << 30);
00339   a0b0_a2b2_rounded = _mm_add_epi64(a0b0_a2b2, nudge);
00340   a1b1_a3b3_rounded = _mm_add_epi64(a1b1_a3b3, nudge);
00341 
00342   // do the doubling
00343   a0b0_a2b2_rounded_2x = _mm_slli_epi64(a0b0_a2b2_rounded, 1);
00344   a1b1_a3b3_rounded_2x = _mm_slli_epi64(a1b1_a3b3_rounded, 1);
00345 
00346   // get the high part of the products
00347   result = _mm_blend_epi16(_mm_srli_si128(a0b0_a2b2_rounded_2x, 4),
00348                            a1b1_a3b3_rounded_2x, 0xcc);
00349 
00350   // saturate those which overflowed
00351   return SelectUsingMask(saturation_mask, min, result);
00352 }
00353 
00354 template <>
00355 inline int16x8_m128i SaturatingRoundingDoublingHighMul(int16x8_m128i a,
00356                                                        int16x8_m128i b) {
00357   // Idea: use _mm_mulhrs_epi16 then saturate with a bit-operation,
00358   // borrowed from Intel's arm_neon_sse.h header.
00359   __m128i result_unsaturated = _mm_mulhrs_epi16(a.v, b.v);
00360   __m128i saturation_mask =
00361       _mm_cmpeq_epi16(result_unsaturated, _mm_set1_epi16(0x8000));
00362   __m128i result = _mm_xor_si128(result_unsaturated, saturation_mask);
00363   return int16x8_m128i(result);
00364 }
00365 
00366 template <>
00367 inline __m128i Dup<__m128i>(std::int32_t x) {
00368   return _mm_set1_epi32(x);
00369 }
00370 
00371 template <>
00372 inline int16x8_m128i Dup<int16x8_m128i>(std::int16_t x) {
00373   return int16x8_m128i(_mm_set1_epi16(x));
00374 }
00375 
00376 // So far this is only needed for int16.
00377 template <>
00378 inline int16x8_m128i SaturatingAdd(int16x8_m128i a, int16x8_m128i b) {
00379   return int16x8_m128i(_mm_adds_epi16(a.v, b.v));
00380 }
00381 
00382 }  // end namespace gemmlowp
00383 
00384 #endif  // GEMMLOWP_INTERNAL_FIXEDPOINT_SSE_H_