Robert Lopez / CMSIS5
Embed: (wiki syntax)

« Back to documentation index

Show/hide line numbers arm_fully_connected_q15_opt.c Source File

arm_fully_connected_q15_opt.c

00001 /*
00002  * Copyright (C) 2010-2018 Arm Limited or its affiliates. All rights reserved.
00003  *
00004  * SPDX-License-Identifier: Apache-2.0
00005  *
00006  * Licensed under the Apache License, Version 2.0 (the License); you may
00007  * not use this file except in compliance with the License.
00008  * You may obtain a copy of the License at
00009  *
00010  * www.apache.org/licenses/LICENSE-2.0
00011  *
00012  * Unless required by applicable law or agreed to in writing, software
00013  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
00014  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00015  * See the License for the specific language governing permissions and
00016  * limitations under the License.
00017  */
00018 
00019 /* ----------------------------------------------------------------------
00020  * Project:      CMSIS NN Library
00021  * Title:        arm_fully_connected_q15_opt.c
00022  * Description:  Q15 opt fully-connected layer function
00023  *
00024  * $Date:        17. January 2018
00025  * $Revision:    V.1.0.0
00026  *
00027  * Target Processor:  Cortex-M cores
00028  *
00029  * -------------------------------------------------------------------- */
00030 
00031 #include "arm_math.h"
00032 #include "arm_nnfunctions.h"
00033 
00034 /**
00035  *  @ingroup groupNN
00036  */
00037 
00038 /**
00039  * @addtogroup FC
00040  * @{
00041  */
00042 
00043   /**
00044    * @brief Q15 opt fully-connected layer function
00045    * @param[in]       pV          pointer to input vector
00046    * @param[in]       pM          pointer to matrix weights
00047    * @param[in]       dim_vec     length of the vector
00048    * @param[in]       num_of_rows number of rows in weight matrix
00049    * @param[in]       bias_shift  amount of left-shift for bias
00050    * @param[in]       out_shift   amount of right-shift for output
00051    * @param[in]       bias        pointer to bias
00052    * @param[in,out]   pOut        pointer to output vector
00053    * @param[in,out]   vec_buffer  pointer to buffer space for input
00054    * @return     The function returns <code>ARM_MATH_SUCCESS</code>
00055    *
00056    *
00057    * @details
00058    *
00059    * <b>Buffer size:</b>
00060    *
00061    * vec_buffer size: 0
00062    *
00063    *  Here we use only one pointer to read 4 rows in the weight
00064    *  matrix. So if the original matrix looks like this:
00065    *
00066    *  | a11 | a12 | a13 |
00067    *
00068    *  | a21 | a22 | a23 |
00069    *
00070    *  | a31 | a32 | a33 |
00071    *
00072    *  | a41 | a42 | a43 |
00073    *
00074    *  | a51 | a52 | a53 |
00075    *
00076    *  | a61 | a62 | a63 |
00077    *
00078    *  We operates on multiple-of-4 rows, so the first four rows becomes
00079    *
00080    *  | a11 | a12 | a21 | a22 | a31 | a32 | a41 | a42 |
00081    *
00082    *  | a13 | a23 | a33 | a43 |
00083    *
00084    *  Remaining rows are kept the same original order.
00085    *
00086    *  So the stored weight matrix looks like this:
00087    *
00088    *
00089    *  | a11 | a12 | a21 | a22 | a31 | a32 | a41 | a42 |
00090    *
00091    *  | a13 | a23 | a33 | a43 | a51 | a52 | a53 | a61 |
00092    *
00093    *  | a62 | a63 |
00094    */
00095 
00096 arm_status
00097 arm_fully_connected_q15_opt(const q15_t * pV,
00098                             const q15_t * pM,
00099                             const uint16_t dim_vec,
00100                             const uint16_t num_of_rows,
00101                             const uint16_t bias_shift,
00102                             const uint16_t out_shift, 
00103                             const q15_t * bias, 
00104                             q15_t * pOut, 
00105                             q15_t * vec_buffer)
00106 {
00107 
00108 #if defined (ARM_MATH_DSP)
00109     /* Run the following code for Cortex-M4 and Cortex-M7 */
00110 
00111     const q15_t *pB = pM;
00112     q15_t    *pO = pOut;
00113     const q15_t *pBias = bias;
00114     const q15_t *pA = pV;
00115 
00116     uint16_t  rowCnt = num_of_rows >> 2;
00117 
00118     while (rowCnt)
00119     {
00120         q31_t     sum =  ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
00121         q31_t     sum2 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift); 
00122         q31_t     sum3 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift); 
00123         q31_t     sum4 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift); 
00124 
00125         uint16_t  colCnt = dim_vec >> 1;
00126 
00127         pA = pV;
00128 
00129 #ifdef USE_INTRINSIC
00130 
00131         while (colCnt)
00132         {
00133             q31_t     inM11, inM12, inM13, inM14;
00134             q31_t     inV;
00135 
00136             inV = *__SIMD32(pA)++;
00137             inM11 = *__SIMD32(pB)++;
00138             sum = __SMLAD(inV, inM11, sum);
00139             inM12 = *__SIMD32(pB)++;
00140             sum2 = __SMLAD(inV, inM12, sum2);
00141             inM13 = *__SIMD32(pB)++;
00142             sum3 = __SMLAD(inV, inM13, sum3);
00143             inM14 = *__SIMD32(pB)++;
00144             sum4 = __SMLAD(inV, inM14, sum4);
00145             colCnt--;
00146         }
00147 
00148 #else
00149 
00150         /*
00151          * register needed:
00152          * loop counter: colCnt
00153          * accumulators: sum, sum2, sum3, sum4
00154          * pointers: pB, pA
00155          * weight data: inM11, inM12, inM13, inM14
00156          * activation data: inV
00157          */
00158 
00159         asm volatile ("COL_LOOP_%=:\n"
00160                       "ldr.w r4, [%[pA]], #4\n"
00161                       "ldr.w r0, [%[pB]], #16\n"
00162                       "smlad %[sum], r4, r0, %[sum]\n"
00163                       "ldr.w r1, [%[pB] , #-12]\n"
00164                       "smlad %[sum2], r4, r1, %[sum2]\n"
00165                       "ldr.w r2, [%[pB] , #-8]\n"
00166                       "smlad %[sum3], r4, r2, %[sum3]\n"
00167                       "ldr.w r3, [%[pB] , #-4]\n"
00168                       "smlad %[sum4], r4, r3, %[sum4]\n"
00169                       "subs %[colCnt], #1\n"
00170                       "bne COL_LOOP_%=\n":[sum] "+r"(sum),
00171                       [sum2] "+r"(sum2),[sum3] "+r"(sum3),
00172                       [sum4] "+r"(sum4),[pB] "+r"(pB),[pA] "+r"(pA):[colCnt] "r"(colCnt):"r0", "r1", "r2", "r3", "r4");
00173 
00174 #endif                          /* USE_INTRINSIC */
00175 
00176         colCnt = dim_vec & 0x1;
00177         while (colCnt)
00178         {
00179 
00180             q15_t     inV = *pA++;
00181             q15_t     inM = *pB++;
00182             q15_t     inM2 = *pB++;
00183             q15_t     inM3 = *pB++;
00184             q15_t     inM4 = *pB++;
00185 
00186             sum += inV * inM;
00187             sum2 += inV * inM2;
00188             sum3 += inV * inM3;
00189             sum4 += inV * inM4;
00190             colCnt--;
00191         }                       /* while over colCnt */
00192         *pO++ = (q15_t) (__SSAT((sum >> out_shift), 16));
00193         *pO++ = (q15_t) (__SSAT((sum2 >> out_shift), 16));
00194         *pO++ = (q15_t) (__SSAT((sum3 >> out_shift), 16));
00195         *pO++ = (q15_t) (__SSAT((sum4 >> out_shift), 16));
00196 
00197         /* adjust the pointers and counters */
00198         rowCnt--;
00199     }
00200 
00201     /* left-over part of the rows */
00202     rowCnt = num_of_rows & 0x3;
00203 
00204     while (rowCnt)
00205     {
00206         q31_t     sum = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
00207 
00208         uint16_t  colCnt = dim_vec >> 2;
00209 
00210         pA = pV;
00211 
00212         while (colCnt)
00213         {
00214             q31_t     inV1, inV2, inM1, inM2;
00215 
00216             inM1 = *__SIMD32(pB)++;
00217             inV1 = *__SIMD32(pA)++;
00218             sum = __SMLAD(inV1, inM1, sum);
00219 
00220             inM2 = *__SIMD32(pB)++;
00221             inV2 = *__SIMD32(pA)++;
00222             sum = __SMLAD(inV2, inM2, sum);
00223 
00224             colCnt--;
00225         }
00226 
00227         /* left-over of the vector */
00228         colCnt = dim_vec & 0x3;
00229         while (colCnt)
00230         {
00231             q15_t     inV = *pA++;
00232             q15_t     inM = *pB++;
00233             sum += inV * inM;
00234             colCnt--;
00235         }
00236 
00237         *pO++ = (q15_t) (__SSAT((sum >> out_shift), 16));
00238 
00239         rowCnt--;
00240     }
00241 
00242 #else
00243     /* Run the following code as reference implementation for Cortex-M0 and Cortex-M3 */
00244     uint16_t  rowCnt = num_of_rows >> 2;
00245     const q15_t *pB = pM;
00246     const q15_t *pA;
00247     q15_t    *pO = pOut;
00248     const q15_t *pBias = bias;
00249 
00250     while (rowCnt)
00251     {
00252         q31_t     sum =  ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
00253         q31_t     sum2 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
00254         q31_t     sum3 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
00255         q31_t     sum4 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
00256 
00257         uint16_t  colCnt = dim_vec >> 1;
00258 
00259         pA = pV;
00260         while (colCnt)
00261         {
00262             q15_t     inA1 = *pA++;
00263             q15_t     inA2 = *pA++;
00264 
00265             q15_t     inB1 = *pB++;
00266             q15_t     inB2 = *pB++;
00267             sum += inA1 * inB1 + inA2 * inB2;
00268 
00269             inB1 = *pB++;
00270             inB2 = *pB++;
00271             sum2 += inA1 * inB1 + inA2 * inB2;
00272 
00273             inB1 = *pB++;
00274             inB2 = *pB++;
00275             sum3 += inA1 * inB1 + inA2 * inB2;
00276 
00277             inB1 = *pB++;
00278             inB2 = *pB++;
00279             sum4 += inA1 * inB1 + inA2 * inB2;
00280 
00281             colCnt--;
00282         }
00283         colCnt = dim_vec & 0x1;
00284         while (colCnt)
00285         {
00286             q15_t     inA = *pA++;
00287             q15_t     inB = *pB++;
00288             sum += inA * inB;
00289             inB = *pB++;
00290             sum2 += inA * inB;
00291             inB = *pB++;
00292             sum3 += inA * inB;
00293             inB = *pB++;
00294             sum4 += inA * inB;
00295             colCnt--;
00296         }
00297         *pO++ = (q15_t) __SSAT((sum >> out_shift), 16);
00298         *pO++ = (q15_t) __SSAT((sum2 >> out_shift), 16);
00299         *pO++ = (q15_t) __SSAT((sum3 >> out_shift), 16);
00300         *pO++ = (q15_t) __SSAT((sum4 >> out_shift), 16);
00301 
00302         rowCnt--;
00303     }
00304     rowCnt = num_of_rows & 0x3;
00305 
00306     while (rowCnt)
00307     {
00308         int       ip_out = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
00309         int       j;
00310 
00311         pA = pV;
00312         for (j = 0; j < dim_vec; j++)
00313         {
00314             q15_t     inA = *pA++;
00315             q15_t     inB = *pB++;
00316             ip_out += inA * inB;
00317         }
00318         *pO++ = (q15_t) __SSAT((ip_out >> out_shift), 16);
00319 
00320         rowCnt--;
00321     }
00322 
00323 #endif                          /* ARM_MATH_DSP */
00324 
00325     /* Return to ARM_MATH_SUCCESS */
00326     return (ARM_MATH_SUCCESS);
00327 
00328 }
00329 
00330 /**
00331  * @} end of FC group
00332  */
00333