Aded CMSIS5 DSP and NN folder. Needs some work

Embed: (wiki syntax)

« Back to documentation index

Show/hide line numbers arm_fully_connected_q7_opt.c Source File

arm_fully_connected_q7_opt.c

00001 /*
00002  * Copyright (C) 2010-2018 Arm Limited or its affiliates. All rights reserved.
00003  *
00004  * SPDX-License-Identifier: Apache-2.0
00005  *
00006  * Licensed under the Apache License, Version 2.0 (the License); you may
00007  * not use this file except in compliance with the License.
00008  * You may obtain a copy of the License at
00009  *
00010  * www.apache.org/licenses/LICENSE-2.0
00011  *
00012  * Unless required by applicable law or agreed to in writing, software
00013  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
00014  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00015  * See the License for the specific language governing permissions and
00016  * limitations under the License.
00017  */
00018 
00019 /* ----------------------------------------------------------------------
00020  * Project:      CMSIS NN Library
00021  * Title:        arm_fully_connected_q7_opt.c
00022  * Description:  Q7 basic fully-connected layer function
00023  *
00024  * $Date:        17. January 2018
00025  * $Revision:    V.1.0.0
00026  *
00027  * Target Processor:  Cortex-M cores
00028  *
00029  * -------------------------------------------------------------------- */
00030 
00031 #include "arm_math.h"
00032 #include "arm_nnfunctions.h"
00033 
00034 /**
00035  *  @ingroup groupNN
00036  */
00037 
00038 /**
00039  * @addtogroup FC
00040  * @{
00041  */
00042 
00043   /**
00044    * @brief Q7 opt fully-connected layer function
00045    * @param[in]       pV          pointer to input vector
00046    * @param[in]       pM          pointer to matrix weights
00047    * @param[in]       dim_vec     length of the vector
00048    * @param[in]       num_of_rows number of rows in weight matrix
00049    * @param[in]       bias_shift  amount of left-shift for bias
00050    * @param[in]       out_shift   amount of right-shift for output
00051    * @param[in]       bias        pointer to bias
00052    * @param[in,out]   pOut        pointer to output vector
00053    * @param[in,out]   vec_buffer  pointer to buffer space for input
00054    * @return     The function returns <code>ARM_MATH_SUCCESS</code>
00055    *
00056    * @details
00057    *
00058    * <b>Buffer size:</b>
00059    *
00060    * vec_buffer size: dim_vec
00061    *
00062    * This opt function is designed to work with interleaved weight
00063    * matrix. The vector input is assumed in q7_t format, we call
00064    *  arm_q7_to_q15_no_shift_shuffle function to expand into
00065    *  q15_t format with certain weight re-ordering, refer to the function
00066    *  comments for more details.
00067    *  Here we use only one pointer to read 4 rows in the weight
00068    *  matrix. So if the original q7_t matrix looks like this:
00069    *
00070    *  | a11 | a12 | a13 | a14 | a15 | a16 | a17 |
00071    *
00072    *  | a21 | a22 | a23 | a24 | a25 | a26 | a27 |
00073    *
00074    *  | a31 | a32 | a33 | a34 | a35 | a36 | a37 |
00075    *
00076    *  | a41 | a42 | a43 | a44 | a45 | a46 | a47 |
00077    *
00078    *  | a51 | a52 | a53 | a54 | a55 | a56 | a57 |
00079    *
00080    *  | a61 | a62 | a63 | a64 | a65 | a66 | a67 |
00081    *
00082    *
00083    *  We operates on multiple-of-4 rows, so the first four rows becomes
00084    *
00085    *  | a11 | a21 | a13 | a23 | a31 | a41 | a33 | a43 |
00086    *
00087    *  | a12 | a22 | a14 | a24 | a32 | a42 | a34 | a44 |
00088    *
00089    *  | a15 | a25 | a35 | a45 | a16 | a26 | a36 | a46 |
00090    *
00091    *  So within the kernel, we first read the re-ordered vector in as:
00092    *
00093    *  | b1  | b3  | and | b2  | b4  |
00094    *
00095    *  the four q31_t weights will look like
00096    *
00097    *  | a11 | a13 |, | a21 | a23 |, | a31 | a33 |, | a41 | a43 |
00098    *
00099    *  | a12 | a14 |, | a22 | a24 |, | a32 | a34 |, | a42 | a44 |
00100    *
00101    *  The column left over will be in-order.
00102    *  which is:
00103    *
00104    *  | a17 | a27 | a37 | a47 |
00105    *
00106    *  For the left-over rows, we do 1x1 computation, so the data remains
00107    *  as its original order. 
00108    *
00109    *  So the stored weight matrix looks like this:
00110    *
00111    *  | a11 | a21 | a13 | a23 | a31 | a41 |
00112    *
00113    *  | a33 | a43 | a12 | a22 | a14 | a24 |
00114    *
00115    *  | a32 | a42 | a34 | a44 | a15 | a25 |
00116    *
00117    *  | a35 | a45 | a16 | a26 | a36 | a46 |
00118    *
00119    *  | a17 | a27 | a37 | a47 | a51 | a52 |
00120    *
00121    *  | a53 | a54 | a55 | a56 | a57 | a61 |
00122    *
00123    *  | a62 | a63 | a64 | a65 | a66 | a67 |
00124    *
00125    *
00126    */
00127 
00128 arm_status
00129 arm_fully_connected_q7_opt(const q7_t * pV,
00130                            const q7_t * pM,
00131                            const uint16_t dim_vec,
00132                            const uint16_t num_of_rows,
00133                            const uint16_t bias_shift,
00134                            const uint16_t out_shift, 
00135                            const q7_t * bias, 
00136                            q7_t * pOut, 
00137                            q15_t * vec_buffer)
00138 {
00139 
00140 #if defined (ARM_MATH_DSP)
00141     /* Run the following code for Cortex-M4 and Cortex-M7 */
00142 
00143     const q7_t *pB = pM;
00144     q7_t     *pO = pOut;
00145     const q7_t *pBias = bias;
00146     q15_t    *pA;
00147     uint16_t  rowCnt = num_of_rows >> 2;
00148 
00149     arm_q7_to_q15_reordered_no_shift(pV, vec_buffer, dim_vec);
00150 
00151     while (rowCnt)
00152     {
00153 
00154         q31_t     sum =  ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
00155         q31_t     sum2 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
00156         q31_t     sum3 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
00157         q31_t     sum4 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
00158 
00159         uint16_t  colCnt = dim_vec >> 2;
00160 
00161         pA = vec_buffer;
00162 
00163 #ifdef USE_INTRINSIC
00164 
00165 #ifndef ARM_MATH_BIG_ENDIAN
00166         while (colCnt)
00167         {
00168             q31_t     inM11, inM12, inM13, inM14;
00169             q31_t     inV;
00170 
00171             inV = *__SIMD32(pA)++;
00172             inM11 = *__SIMD32(pB)++;
00173             inM12 = __SXTB16(__ROR(inM11, 8));
00174             inM11 = __SXTB16(inM11);
00175             sum = __SMLAD(inM11, inV, sum);
00176             sum2 = __SMLAD(inM12, inV, sum2);
00177             inM13 = *__SIMD32(pB)++;
00178             inM14 = __SXTB16(__ROR(inM13, 8));
00179             inM13 = __SXTB16(inM13);
00180             sum3 = __SMLAD(inM13, inV, sum3);
00181             sum4 = __SMLAD(inM14, inV, sum4);
00182 
00183             inV = *__SIMD32(pA)++;
00184             inM11 = *__SIMD32(pB)++;
00185             inM12 = __SXTB16(__ROR(inM11, 8));
00186             inM11 = __SXTB16(inM11);
00187             sum = __SMLAD(inM11, inV, sum);
00188             sum2 = __SMLAD(inM12, inV, sum2);
00189             inM13 = *__SIMD32(pB)++;
00190             inM14 = __SXTB16(__ROR(inM13, 8));
00191             inM13 = __SXTB16(inM13);
00192             sum3 = __SMLAD(inM13, inV, sum3);
00193             sum4 = __SMLAD(inM14, inV, sum4);
00194             colCnt--;
00195         }
00196 #else
00197         while (colCnt)
00198         {
00199             q31_t     inM11, inM12, inM13, inM14;
00200             q31_t     inV;
00201 
00202             inV = *__SIMD32(pA)++;
00203             inM11 = *__SIMD32(pB)++;
00204             inM12 = __SXTB16(__ROR(inM11, 8));
00205             inM11 = __SXTB16(inM11);
00206             sum = __SMLAD(inM12, inV, sum);
00207             sum2 = __SMLAD(inM11, inV, sum2);
00208             inM13 = *__SIMD32(pB)++;
00209             inM14 = __SXTB16(__ROR(inM13, 8));
00210             inM13 = __SXTB16(inM13);
00211             sum3 = __SMLAD(inM14, inV, sum3);
00212             sum4 = __SMLAD(inM13, inV, sum4);
00213 
00214             inV = *__SIMD32(pA)++;
00215             inM11 = *__SIMD32(pB)++;
00216             inM12 = __SXTB16(__ROR(inM11, 8));
00217             inM11 = __SXTB16(inM11);
00218             sum = __SMLAD(inM12, inV, sum);
00219             sum2 = __SMLAD(inM11, inV, sum2);
00220             inM13 = *__SIMD32(pB)++;
00221             inM14 = __SXTB16(__ROR(inM13, 8));
00222             inM13 = __SXTB16(inM13);
00223             sum3 = __SMLAD(inM14, inV, sum3);
00224             sum4 = __SMLAD(inM13, inV, sum4);
00225             colCnt--;
00226         }
00227 #endif                          /* ARM_MATH_BIG_ENDIAN */
00228 
00229 #else
00230 
00231         /*
00232          * register needed:
00233          * loop counter: colCnt
00234          * accumulators: sum, sum2, sum3, sum4
00235          * pointers: pB, pA
00236          * weight data: inM11, inM12, inM13, inM14
00237          * activation data: inV
00238          */
00239 
00240 #ifndef ARM_MATH_BIG_ENDIAN
00241         asm volatile ("COL_LOOP_%=:\n"
00242                       "ldr.w r4, [%[pA]], #8\n"
00243                       "ldr.w r1, [%[pB]], #16\n"
00244                       "mov.w r0, r1, ror #8\n"
00245                       "sxtb16 r0, r0\n"
00246                       "sxtb16 r1, r1\n"
00247                       "smlad %[sum], r4, r1, %[sum]\n"
00248                       "smlad %[sum2], r4, r0, %[sum2]\n"
00249                       "ldr.w r3, [%[pB], #-12]\n"
00250                       "mov.w r2, r3, ror #8\n"
00251                       "sxtb16 r2, r2\n"
00252                       "sxtb16 r3, r3\n"
00253                       "smlad %[sum3], r4, r3, %[sum3]\n"
00254                       "smlad %[sum4], r4, r2, %[sum4]\n"
00255                       "ldr.w r4, [%[pA], #-4]\n"
00256                       "ldr.w r1, [%[pB], #-8]\n"
00257                       "mov.w r0, r1, ror #8\n"
00258                       "sxtb16 r0, r0\n"
00259                       "sxtb16 r1, r1\n"
00260                       "smlad %[sum], r4, r1, %[sum]\n"
00261                       "smlad %[sum2], r4, r0, %[sum2]\n"
00262                       "ldr.w r3, [%[pB], #-4]\n"
00263                       "mov.w r2, r3, ror #8\n"
00264                       "sxtb16 r2, r2\n"
00265                       "sxtb16 r3, r3\n"
00266                       "smlad %[sum3], r4, r3, %[sum3]\n"
00267                       "smlad %[sum4], r4, r2, %[sum4]\n"
00268                       "subs %[colCnt], #1\n"
00269                       "bne COL_LOOP_%=\n":[sum] "+r"(sum),
00270                       [sum2] "+r"(sum2),[sum3] "+r"(sum3),
00271                       [sum4] "+r"(sum4),[pB] "+r"(pB),[pA] "+r"(pA):[colCnt] "r"(colCnt):"r0", "r1", "r2", "r3", "r4");
00272 #else
00273         asm volatile ("COL_LOOP_%=:\n"
00274                       "ldr.w r4, [%[pA]], #8\n"
00275                       "ldr.w r1, [%[pB]], #16\n"
00276                       "mov.w r0, r1, ror #8\n"
00277                       "sxtb16 r0, r0\n"
00278                       "sxtb16 r1, r1\n"
00279                       "smlad %[sum], r4, r0, %[sum]\n"
00280                       "smlad %[sum2], r4, r1, %[sum2]\n"
00281                       "ldr.w r3, [%[pB], #-12]\n"
00282                       "mov.w r2, r3, ror #8\n"
00283                       "sxtb16 r2, r2\n"
00284                       "sxtb16 r3, r3\n"
00285                       "smlad %[sum3], r4, r2, %[sum3]\n"
00286                       "smlad %[sum4], r4, r3, %[sum4]\n"
00287                       "ldr.w r4, [%[pA], #-4]\n"
00288                       "ldr.w r1, [%[pB], #-8]\n"
00289                       "mov.w r0, r1, ror #8\n"
00290                       "sxtb16 r0, r0\n"
00291                       "sxtb16 r1, r1\n"
00292                       "smlad %[sum], r4, r0, %[sum]\n"
00293                       "smlad %[sum2], r4, r1, %[sum2]\n"
00294                       "ldr.w r3, [%[pB], #-4]\n"
00295                       "mov.w r2, r3, ror #8\n"
00296                       "sxtb16 r2, r2\n"
00297                       "sxtb16 r3, r3\n"
00298                       "smlad %[sum3], r4, r2, %[sum3]\n"
00299                       "smlad %[sum4], r4, r3, %[sum4]\n"
00300                       "subs %[colCnt], #1\n"
00301                       "bne COL_LOOP_%=\n":[sum] "+r"(sum),
00302                       [sum2] "+r"(sum2),[sum3] "+r"(sum3),
00303                       [sum4] "+r"(sum4),[pB] "+r"(pB),[pA] "+r"(pA):[colCnt] "r"(colCnt):"r0", "r1", "r2", "r3", "r4");
00304 #endif                          /* ARM_MATH_BIG_ENDIAN */
00305 
00306 #endif                          /* USE_INTRINSIC */
00307 
00308         colCnt = dim_vec & 0x3;
00309         while (colCnt)
00310         {
00311             q15_t     inV = *pA++;
00312             q7_t      inM = *pB++;
00313             q7_t      inM2 = *pB++;
00314             q7_t      inM3 = *pB++;
00315             q7_t      inM4 = *pB++;
00316 
00317             sum += inV * inM;
00318             sum2 += inV * inM2;
00319             sum3 += inV * inM3;
00320             sum4 += inV * inM4;
00321             colCnt--;
00322         }                       /* while over colCnt */
00323         *pO++ = (q7_t) (__SSAT((sum >> out_shift), 8));
00324         *pO++ = (q7_t) (__SSAT((sum2 >> out_shift), 8));
00325         *pO++ = (q7_t) (__SSAT((sum3 >> out_shift), 8));
00326         *pO++ = (q7_t) (__SSAT((sum4 >> out_shift), 8));
00327 
00328         /* adjust the pointers and counters */
00329         rowCnt--;
00330     }
00331 
00332     /* left-over part of the rows */
00333     rowCnt = num_of_rows & 0x3;
00334 
00335     while (rowCnt)
00336     {
00337         q31_t     sum = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
00338         uint16_t  colCnt = dim_vec >> 2;
00339 
00340         pA = vec_buffer;
00341 
00342         while (colCnt)
00343         {
00344             q31_t     inV1, inV2, inM11, inM12;
00345 
00346             pB = (q7_t *) read_and_pad_reordered((void *)pB, &inM11, &inM12);
00347 
00348             inV1 = *__SIMD32(pA)++;
00349             sum = __SMLAD(inV1, inM11, sum);
00350 
00351             inV2 = *__SIMD32(pA)++;
00352             sum = __SMLAD(inV2, inM12, sum);
00353 
00354             colCnt--;
00355         }
00356 
00357         /* left-over of the vector */
00358         colCnt = dim_vec & 0x3;
00359         while (colCnt)
00360         {
00361             q15_t     inV = *pA++;
00362             q7_t      inM = *pB++;
00363             sum += inV * inM;
00364             colCnt--;
00365         }
00366 
00367         *pO++ = (q7_t) (__SSAT((sum >> out_shift), 8));
00368 
00369         rowCnt--;
00370     }
00371 
00372 #else
00373     /* Run the following code as reference implementation for Cortex-M0 and Cortex-M3 */
00374     uint16_t  rowCnt = num_of_rows >> 2;
00375     const q7_t *pB = pM;
00376     const q7_t *pA;
00377     q7_t     *pO = pOut;
00378     const q7_t *pBias = bias;
00379 
00380     while (rowCnt)
00381     {
00382         q31_t     sum =  ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
00383         q31_t     sum2 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
00384         q31_t     sum3 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
00385         q31_t     sum4 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
00386 
00387         uint16_t  colCnt = dim_vec >> 2;
00388 
00389         pA = pV;
00390 
00391         while (colCnt)
00392         {
00393             q7_t      inA1 = *pA++;
00394             q7_t      inA3 = *pA++;
00395             q7_t      inA2 = *pA++;
00396             q7_t      inA4 = *pA++;
00397 
00398             q7_t      inB1 = *pB++;
00399             q7_t      inB3 = *pB++;
00400             q7_t      inB2 = *pB++;
00401             q7_t      inB4 = *pB++;
00402 
00403             sum += inA1 * inB1 + inA2 * inB2;
00404             sum2 += inA1 * inB3 + inA2 * inB4;
00405 
00406             inB1 = *pB++;
00407             inB3 = *pB++;
00408             inB2 = *pB++;
00409             inB4 = *pB++;
00410 
00411             sum3 += inA1 * inB1 + inA2 * inB2;
00412             sum4 += inA1 * inB3 + inA2 * inB4;
00413 
00414             inB1 = *pB++;
00415             inB3 = *pB++;
00416             inB2 = *pB++;
00417             inB4 = *pB++;
00418 
00419             sum += inA3 * inB1 + inA4 * inB2;
00420             sum2 += inA3 * inB3 + inA4 * inB4;
00421 
00422             inB1 = *pB++;
00423             inB3 = *pB++;
00424             inB2 = *pB++;
00425             inB4 = *pB++;
00426 
00427             sum3 += inA3 * inB1 + inA4 * inB2;
00428             sum4 += inA3 * inB3 + inA4 * inB4;
00429 
00430             colCnt--;
00431         }
00432         colCnt = dim_vec & 0x3;
00433         while (colCnt)
00434         {
00435             q7_t      inA = *pA++;
00436             q7_t      inB = *pB++;
00437             sum += inA * inB;
00438             inB = *pB++;
00439             sum2 += inA * inB;
00440             inB = *pB++;
00441             sum3 += inA * inB;
00442             inB = *pB++;
00443             sum4 += inA * inB;
00444 
00445             colCnt--;
00446         }
00447         *pO++ = (q7_t) __SSAT((sum >> out_shift), 8);
00448         *pO++ = (q7_t) __SSAT((sum2 >> out_shift), 8);
00449         *pO++ = (q7_t) __SSAT((sum3 >> out_shift), 8);
00450         *pO++ = (q7_t) __SSAT((sum4 >> out_shift), 8);
00451 
00452         rowCnt--;
00453     }
00454 
00455     rowCnt = num_of_rows & 0x3;
00456 
00457     while (rowCnt)
00458     {
00459         int       ip_out = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
00460 
00461         int       j;
00462 
00463         pA = pV;
00464         for (j = 0; j < dim_vec; j++)
00465         {
00466             q7_t      inA = *pA++;
00467             q7_t      inB = *pB++;
00468             ip_out += inA * inB;
00469         }
00470         *pO++ = (q7_t) __SSAT((ip_out >> out_shift), 8);
00471 
00472         rowCnt--;
00473     }
00474 
00475 #endif                          /* ARM_MATH_DSP */
00476 
00477     /* Return to ARM_MATH_SUCCESS */
00478     return (ARM_MATH_SUCCESS);
00479 
00480 }
00481 
00482 /**
00483  * @} end of FC group
00484  */
00485