Aded CMSIS5 DSP and NN folder. Needs some work

Embed: (wiki syntax)

« Back to documentation index

Show/hide line numbers arm_fully_connected_mat_q7_vec_q15_opt.c Source File

arm_fully_connected_mat_q7_vec_q15_opt.c

00001 /*
00002  * Copyright (C) 2010-2018 Arm Limited or its affiliates. All rights reserved.
00003  *
00004  * SPDX-License-Identifier: Apache-2.0
00005  *
00006  * Licensed under the Apache License, Version 2.0 (the License); you may
00007  * not use this file except in compliance with the License.
00008  * You may obtain a copy of the License at
00009  *
00010  * www.apache.org/licenses/LICENSE-2.0
00011  *
00012  * Unless required by applicable law or agreed to in writing, software
00013  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
00014  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00015  * See the License for the specific language governing permissions and
00016  * limitations under the License.
00017  */
00018 
00019 /* ----------------------------------------------------------------------
00020  * Project:      CMSIS NN Library
00021  * Title:        arm_fully_connected_mat_q7_vec_q15_opt.c
00022  * Description:  Mixed Q15-Q7 opt fully-connected layer function
00023  *
00024  * $Date:        17. January 2018
00025  * $Revision:    V.1.0.0
00026  *
00027  * Target Processor:  Cortex-M cores
00028  *
00029  * -------------------------------------------------------------------- */
00030 
00031 #include "arm_math.h"
00032 #include "arm_nnfunctions.h"
00033 
00034 /**
00035  *  @ingroup groupNN
00036  */
00037 
00038 /**
00039  * @addtogroup FC
00040  * @{
00041  */
00042 
00043   /**
00044    * @brief Mixed Q15-Q7 opt fully-connected layer function
00045    * @param[in]       pV          pointer to input vector
00046    * @param[in]       pM          pointer to matrix weights
00047    * @param[in]       dim_vec     length of the vector
00048    * @param[in]       num_of_rows number of rows in weight matrix
00049    * @param[in]       bias_shift  amount of left-shift for bias
00050    * @param[in]       out_shift   amount of right-shift for output
00051    * @param[in]       bias        pointer to bias
00052    * @param[in,out]   pOut        pointer to output vector
00053    * @param[in,out]   vec_buffer  pointer to buffer space for input
00054    * @return     The function returns <code>ARM_MATH_SUCCESS</code>
00055    *
00056    * @details
00057    *
00058    * <b>Buffer size:</b>
00059    *
00060    * vec_buffer size: 0
00061    *
00062    *  Q7_Q15 version of the fully connected layer
00063    *
00064    *  Weights are in q7_t and Activations are in q15_t
00065    *
00066    *  Limitation: x4 version requires weight reordering to work
00067    *
00068    *  Here we use only one pointer to read 4 rows in the weight
00069    *  matrix. So if the original q7_t matrix looks like this:
00070    *
00071    *  | a11 | a12 | a13 | a14 | a15 | a16 | a17 |
00072    *
00073    *  | a21 | a22 | a23 | a24 | a25 | a26 | a27 |
00074    *
00075    *  | a31 | a32 | a33 | a34 | a35 | a36 | a37 |
00076    *
00077    *  | a41 | a42 | a43 | a44 | a45 | a46 | a47 |
00078    *
00079    *  | a51 | a52 | a53 | a54 | a55 | a56 | a57 |
00080    *
00081    *  | a61 | a62 | a63 | a64 | a65 | a66 | a67 |
00082    *
00083    *  We operates on multiple-of-4 rows, so the first four rows becomes
00084    *
00085    *  | a11 | a21 | a12 | a22 | a31 | a41 | a32 | a42 |
00086    *
00087    *  | a13 | a23 | a14 | a24 | a33 | a43 | a34 | a44 |
00088    *
00089    *  | a15 | a25 | a16 | a26 | a35 | a45 | a36 | a46 |
00090    *
00091    *  The column left over will be in-order.
00092    *  which is:
00093    *  | a17 | a27 | a37 | a47 |
00094    *
00095    *  For the left-over rows, we do 1x1 computation, so the data remains
00096    *  as its original order. 
00097    *
00098    *  So the stored weight matrix looks like this:
00099    *
00100    *  | a11 | a21 | a12 | a22 | a31 | a41 |
00101    *
00102    *  | a32 | a42 | a13 | a23 | a14 | a24 |
00103    *
00104    *  | a33 | a43 | a34 | a44 | a15 | a25 |
00105    *
00106    *  | a16 | a26 | a35 | a45 | a36 | a46 |
00107    *
00108    *  | a17 | a27 | a37 | a47 | a51 | a52 |
00109    *
00110    *  | a53 | a54 | a55 | a56 | a57 | a61 |
00111    *
00112    *  | a62 | a63 | a64 | a65 | a66 | a67 |
00113    *
00114    */
00115 
00116 arm_status
00117 arm_fully_connected_mat_q7_vec_q15_opt(const q15_t * pV,
00118                                        const q7_t * pM,
00119                                        const uint16_t dim_vec,
00120                                        const uint16_t num_of_rows,
00121                                        const uint16_t bias_shift,
00122                                        const uint16_t out_shift, const q7_t * bias, q15_t * pOut, q15_t * vec_buffer)
00123 {
00124 
00125 #if defined (ARM_MATH_DSP)
00126     /* Run the following code for Cortex-M4 and Cortex-M7 */
00127 
00128     const q7_t *pB = pM;
00129     q15_t    *pO = pOut;
00130     const q7_t *pBias = bias;
00131     const q15_t *pA = pV;
00132 
00133     uint16_t  rowCnt = num_of_rows >> 2;
00134 
00135     while (rowCnt)
00136     {
00137         q31_t     sum =  ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
00138         q31_t     sum2 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
00139         q31_t     sum3 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
00140         q31_t     sum4 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
00141 
00142         uint16_t  colCnt = dim_vec >> 1;
00143 
00144         pA = pV;
00145 
00146 #ifdef USE_INTRINSIC
00147 
00148 #ifndef ARM_MATH_BIG_ENDIAN
00149 
00150         while (colCnt)
00151         {
00152             q31_t     inM11, inM12, inM13, inM14;
00153             q31_t     inV;
00154 
00155             inV = *__SIMD32(pA)++;
00156             inM11 = *__SIMD32(pB)++;
00157             inM12 = __SXTB16(__ROR(inM11, 8));
00158             inM11 = __SXTB16(inM11);
00159             sum = __SMLAD(inM11, inV, sum);
00160             sum2 = __SMLAD(inM12, inV, sum2);
00161             inM13 = *__SIMD32(pB)++;
00162             inM14 = __SXTB16(__ROR(inM13, 8));
00163             inM13 = __SXTB16(inM13);
00164             sum3 = __SMLAD(inM13, inV, sum3);
00165             sum4 = __SMLAD(inM14, inV, sum4);
00166             colCnt--;
00167         }
00168 
00169 #else
00170 
00171         while (colCnt)
00172         {
00173             q31_t     inM11, inM12, inM13, inM14;
00174             q31_t     inV;
00175 
00176             inV = *__SIMD32(pA)++;
00177             inM11 = *__SIMD32(pB)++;
00178             inM12 = __SXTB16(__ROR(inM11, 8));
00179             inM11 = __SXTB16(inM11);
00180             sum = __SMLAD(inM12, inV, sum);
00181             sum2 = __SMLAD(inM11, inV, sum2);
00182             inM13 = *__SIMD32(pB)++;
00183             inM14 = __SXTB16(__ROR(inM13, 8));
00184             inM13 = __SXTB16(inM13);
00185             sum3 = __SMLAD(inM14, inV, sum3);
00186             sum4 = __SMLAD(inM13, inV, sum4);
00187             colCnt--;
00188         }
00189 
00190 #endif                          /* ARM_MATH_BIG_ENDIAN */
00191 
00192 #else
00193 
00194         /*
00195          * register needed:
00196          * loop counter: colCnt
00197          * accumulators: sum, sum2, sum3, sum4
00198          * pointers: pB, pA
00199          * weight data: inM11, inM12, inM13, inM14
00200          * activation data: inV
00201          */
00202 
00203 #ifndef ARM_MATH_BIG_ENDIAN
00204         asm volatile ("COL_LOOP_%=:\n"
00205                       "ldr.w r4, [%[pA]], #4\n"
00206                       "ldr.w r1, [%[pB]], #8\n"
00207                       "mov.w r0, r1, ror #8\n"
00208                       "sxtb16 r0, r0\n"
00209                       "sxtb16 r1, r1\n"
00210                       "smlad %[sum], r4, r1, %[sum]\n"
00211                       "smlad %[sum2], r4, r0, %[sum2]\n"
00212                       "ldr.w r3, [%[pB], #-4]\n"
00213                       "mov.w r2, r3, ror #8\n"
00214                       "sxtb16 r2, r2\n"
00215                       "sxtb16 r3, r3\n"
00216                       "smlad %[sum3], r4, r3, %[sum3]\n"
00217                       "smlad %[sum4], r4, r2, %[sum4]\n"
00218                       "subs %[colCnt], #1\n"
00219                       "bne COL_LOOP_%=\n":[sum] "+r"(sum),
00220                       [sum2] "+r"(sum2),[sum3] "+r"(sum3),
00221                       [sum4] "+r"(sum4),[pB] "+r"(pB),[pA] "+r"(pA):[colCnt] "r"(colCnt):"r0", "r1", "r2", "r3", "r4");
00222 #else
00223         asm volatile ("COL_LOOP_%=:\n"
00224                       "ldr.w r4, [%[pA]], #4\n"
00225                       "ldr.w r1, [%[pB]], #8\n"
00226                       "mov.w r0, r1, ror #8\n"
00227                       "sxtb16 r0, r0\n"
00228                       "sxtb16 r1, r1\n"
00229                       "smlad %[sum], r4, r0, %[sum]\n"
00230                       "smlad %[sum2], r4, r1, %[sum2]\n"
00231                       "ldr.w r3, [%[pB], #-4]\n"
00232                       "mov.w r2, r3, ror #8\n"
00233                       "sxtb16 r2, r2\n"
00234                       "sxtb16 r3, r3\n"
00235                       "smlad %[sum3], r4, r2, %[sum3]\n"
00236                       "smlad %[sum4], r4, r3, %[sum4]\n"
00237                       "subs %[colCnt], #1\n"
00238                       "bne COL_LOOP_%=\n":[sum] "+r"(sum),
00239                       [sum2] "+r"(sum2),[sum3] "+r"(sum3),
00240                       [sum4] "+r"(sum4),[pB] "+r"(pB),[pA] "+r"(pA):[colCnt] "r"(colCnt):"r0", "r1", "r2", "r3", "r4");
00241 #endif                          /* ARM_MATH_BIG_ENDIAN */
00242 
00243 #endif                          /* USE_INTRINSIC */
00244 
00245         colCnt = dim_vec & 0x1;
00246         while (colCnt)
00247         {
00248             q15_t     inV = *pA++;
00249             q7_t      inM = *pB++;
00250             q7_t      inM2 = *pB++;
00251             q7_t      inM3 = *pB++;
00252             q7_t      inM4 = *pB++;
00253 
00254             sum += inV * inM;
00255             sum2 += inV * inM2;
00256             sum3 += inV * inM3;
00257             sum4 += inV * inM4;
00258             colCnt--;
00259         }                       /* while over colCnt */
00260         *pO++ = (q15_t) (__SSAT((sum >> out_shift), 16));
00261         *pO++ = (q15_t) (__SSAT((sum2 >> out_shift), 16));
00262         *pO++ = (q15_t) (__SSAT((sum3 >> out_shift), 16));
00263         *pO++ = (q15_t) (__SSAT((sum4 >> out_shift), 16));
00264 
00265         /* adjust the pointers and counters */
00266         rowCnt--;
00267     }
00268 
00269     /* left-over part of the rows */
00270     rowCnt = num_of_rows & 0x3;
00271 
00272     while (rowCnt)
00273     {
00274         q31_t     sum = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
00275 
00276         uint16_t  colCnt = dim_vec >> 2;
00277 
00278         pA = pV;
00279 
00280         while (colCnt)
00281         {
00282             q31_t     inV1, inV2, inM11, inM12;
00283 
00284             pB = (q7_t *) read_and_pad((void *)pB, &inM11, &inM12);
00285 
00286             inV1 = *__SIMD32(pA)++;
00287             sum = __SMLAD(inV1, inM11, sum);
00288 
00289             inV2 = *__SIMD32(pA)++;
00290             sum = __SMLAD(inV2, inM12, sum);
00291 
00292             colCnt--;
00293         }
00294 
00295         /* left-over of the vector */
00296         colCnt = dim_vec & 0x3;
00297         while (colCnt)
00298         {
00299             q15_t     inV = *pA++;
00300             q7_t      inM = *pB++;
00301             sum += inV * inM;
00302             colCnt--;
00303         }
00304 
00305         *pO++ = (q15_t) (__SSAT((sum >> out_shift), 16));
00306 
00307         rowCnt--;
00308     }
00309 
00310 #else
00311     /* Run the following code as reference implementation for Cortex-M0 and Cortex-M3 */
00312     uint16_t  rowCnt = num_of_rows >> 2;
00313     const q7_t *pB = pM;
00314     const q15_t *pA;
00315     q15_t    *pO = pOut;
00316     const q7_t *pBias = bias;
00317 
00318     while (rowCnt)
00319     {
00320         q31_t     sum =  ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
00321         q31_t     sum2 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
00322         q31_t     sum3 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift); 
00323         q31_t     sum4 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift); 
00324         uint16_t  colCnt = dim_vec >> 1;
00325 
00326         pA = pV;
00327 
00328         while (colCnt)
00329         {
00330             q15_t     inA1 = *pA++;
00331             q15_t     inA2 = *pA++;
00332 
00333             q7_t      inB1 = *pB++;
00334             q7_t      inB3 = *pB++;
00335             q7_t      inB2 = *pB++;
00336             q7_t      inB4 = *pB++;
00337 
00338             sum += inA1 * inB1 + inA2 * inB2;
00339             sum2 += inA1 * inB3 + inA2 * inB4;
00340 
00341             inB1 = *pB++;
00342             inB3 = *pB++;
00343             inB2 = *pB++;
00344             inB4 = *pB++;
00345 
00346             sum3 += inA1 * inB1 + inA2 * inB2;
00347             sum4 += inA1 * inB3 + inA2 * inB4;
00348 
00349             colCnt--;
00350         }
00351 
00352         colCnt = dim_vec & 0x1;
00353         while (colCnt)
00354         {
00355             q15_t     inA = *pA++;
00356             q7_t      inB = *pB++;
00357             sum += inA * inB;
00358             inB = *pB++;
00359             sum2 += inA * inB;
00360             inB = *pB++;
00361             sum3 += inA * inB;
00362             inB = *pB++;
00363             sum4 += inA * inB;
00364 
00365             colCnt--;
00366         }
00367         *pO++ = (q15_t) __SSAT((sum >> out_shift), 16);
00368         *pO++ = (q15_t) __SSAT((sum2 >> out_shift), 16);
00369         *pO++ = (q15_t) __SSAT((sum3 >> out_shift), 16);
00370         *pO++ = (q15_t) __SSAT((sum4 >> out_shift), 16);
00371 
00372         rowCnt--;
00373     }
00374 
00375     rowCnt = num_of_rows & 0x3;
00376 
00377     while (rowCnt)
00378     {
00379         int       ip_out = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
00380         int       j;
00381 
00382         pA = pV;
00383         for (j = 0; j < dim_vec; j++)
00384         {
00385             q15_t     inA = *pA++;
00386             q7_t      inB = *pB++;
00387             ip_out += inA * inB;
00388         }
00389         *pO++ = (q15_t) __SSAT((ip_out >> out_shift), 16);
00390 
00391         rowCnt--;
00392     }
00393 
00394 #endif                          /* ARM_MATH_DSP */
00395 
00396     /* Return to ARM_MATH_SUCCESS */
00397     return (ARM_MATH_SUCCESS);
00398 
00399 }
00400 
00401 /**
00402  * @} end of FC group
00403  */
00404