Aded CMSIS5 DSP and NN folder. Needs some work

Embed: (wiki syntax)

« Back to documentation index

Show/hide line numbers arm_nn_mat_mult_kernel_q7_q15_reordered.c Source File

arm_nn_mat_mult_kernel_q7_q15_reordered.c

00001 /*
00002  * Copyright (C) 2010-2018 Arm Limited or its affiliates. All rights reserved.
00003  *
00004  * SPDX-License-Identifier: Apache-2.0
00005  *
00006  * Licensed under the Apache License, Version 2.0 (the License); you may
00007  * not use this file except in compliance with the License.
00008  * You may obtain a copy of the License at
00009  *
00010  * www.apache.org/licenses/LICENSE-2.0
00011  *
00012  * Unless required by applicable law or agreed to in writing, software
00013  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
00014  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00015  * See the License for the specific language governing permissions and
00016  * limitations under the License.
00017  */
00018 
00019 /* ----------------------------------------------------------------------
00020  * Project:      CMSIS NN Library
00021  * Title:        arm_nn_mat_mult_kernel_q7_q15_reordered.c
00022  * Description:  Matrix-multiplication function for convolution with reordered columns
00023  *
00024  * $Date:        17. January 2018
00025  * $Revision:    V.1.0.0
00026  *
00027  * Target Processor:  Cortex-M cores
00028  * -------------------------------------------------------------------- */
00029 
00030 #include "arm_nnfunctions.h"
00031 #include "arm_math.h"
00032 
00033   /**
00034    * @brief Matrix-multiplication function for convolution with reordered columns
00035    * @param[in]       pA          pointer to operand A
00036    * @param[in]       pInBuffer   pointer to operand B, always conssists of 2 vectors
00037    * @param[in]       ch_im_out   numRow of A
00038    * @param[in]       numCol_A    numCol of A
00039    * @param[in]       bias_shift  amount of left-shift for bias
00040    * @param[in]       out_shift   amount of right-shift for output
00041    * @param[in]       bias        the bias
00042    * @param[in,out]   pOut        pointer to output
00043    * @return     The function returns the incremented output pointer
00044    *
00045    * @details
00046    *
00047    * This function assumes that data in pInBuffer are reordered
00048    */
00049 
00050 q7_t     *arm_nn_mat_mult_kernel_q7_q15_reordered(const q7_t * pA,
00051                                                   const q15_t * pInBuffer,
00052                                                   const uint16_t ch_im_out,
00053                                                   const uint16_t numCol_A,
00054                                                   const uint16_t bias_shift,
00055                                                   const uint16_t out_shift, 
00056                                                   const q7_t * bias, 
00057                                                   q7_t * pOut)
00058 {
00059 
00060 #if defined (ARM_MATH_DSP)
00061     /* set up the second output pointers */
00062     q7_t     *pOut2 = pOut + ch_im_out;
00063     int       i;
00064 
00065     /* this loop over rows in A */
00066     for (i = 0; i < ch_im_out; i += 2)
00067     {
00068         /* setup pointers for B */
00069         const q15_t *pB = pInBuffer;
00070         const q15_t *pB2 = pB + numCol_A;
00071 
00072         /* align the second pointer for A */
00073         const q7_t *pA2 = pA + numCol_A;
00074 
00075         /* init the sum with bias */
00076         q31_t     sum =  ((q31_t)(bias[i]) << bias_shift) + NN_ROUND(out_shift);
00077         q31_t     sum2 = ((q31_t)(bias[i]) << bias_shift) + NN_ROUND(out_shift);
00078         q31_t     sum3 = ((q31_t)(bias[i + 1]) << bias_shift) + NN_ROUND(out_shift);
00079         q31_t     sum4 = ((q31_t)(bias[i + 1]) << bias_shift) + NN_ROUND(out_shift);
00080 
00081         uint16_t  colCnt = numCol_A >> 2;
00082         /* accumulate over the vector */
00083         while (colCnt)
00084         {
00085             q31_t     inA11, inA12, inA21, inA22;
00086             q31_t     inB1 = *__SIMD32(pB)++;
00087             q31_t     inB2 = *__SIMD32(pB2)++;
00088 
00089             pA = (q7_t *) read_and_pad_reordered((void *)pA, &inA11, &inA12);
00090             pA2 = (q7_t *) read_and_pad_reordered((void *)pA2, &inA21, &inA22);
00091 
00092             sum = __SMLAD(inA11, inB1, sum);
00093             sum2 = __SMLAD(inA11, inB2, sum2);
00094             sum3 = __SMLAD(inA21, inB1, sum3);
00095             sum4 = __SMLAD(inA21, inB2, sum4);
00096 
00097             inB1 = *__SIMD32(pB)++;
00098             inB2 = *__SIMD32(pB2)++;
00099 
00100             sum = __SMLAD(inA12, inB1, sum);
00101             sum2 = __SMLAD(inA12, inB2, sum2);
00102             sum3 = __SMLAD(inA22, inB1, sum3);
00103             sum4 = __SMLAD(inA22, inB2, sum4);
00104 
00105             colCnt--;
00106         }                       /* while over colCnt */
00107         colCnt = numCol_A & 0x3;
00108         while (colCnt)
00109         {
00110             q7_t      inA1 = *pA++;
00111             q15_t     inB1 = *pB++;
00112             q7_t      inA2 = *pA2++;
00113             q15_t     inB2 = *pB2++;
00114 
00115             sum += inA1 * inB1;
00116             sum2 += inA1 * inB2;
00117             sum3 += inA2 * inB1;
00118             sum4 += inA2 * inB2;
00119             colCnt--;
00120         }                       /* while over colCnt */
00121         *pOut++ = (q7_t) __SSAT((sum >> out_shift), 8);
00122         *pOut++ = (q7_t) __SSAT((sum3 >> out_shift), 8);
00123         *pOut2++ = (q7_t) __SSAT((sum2 >> out_shift), 8);
00124         *pOut2++ = (q7_t) __SSAT((sum4 >> out_shift), 8);
00125 
00126         /* skip the row computed with A2 */
00127         pA += numCol_A;
00128     }                           /* for over ch_im_out */
00129 
00130     pOut += ch_im_out;
00131 
00132     /* return the new output pointer with offset */
00133     return pOut;
00134 #else
00135     /* To be completed */
00136     return NULL;
00137 #endif                          /* ARM_MATH_DSP */
00138 }
00139