Robert Lopez / CMSIS5
Embed: (wiki syntax)

« Back to documentation index

Show/hide line numbers arm_mat_mult_fast_q31.c Source File

arm_mat_mult_fast_q31.c

00001 /* ----------------------------------------------------------------------
00002  * Project:      CMSIS DSP Library
00003  * Title:        arm_mat_mult_fast_q31.c
00004  * Description:  Q31 matrix multiplication (fast variant)
00005  *
00006  * $Date:        27. January 2017
00007  * $Revision:    V.1.5.1
00008  *
00009  * Target Processor: Cortex-M cores
00010  * -------------------------------------------------------------------- */
00011 /*
00012  * Copyright (C) 2010-2017 ARM Limited or its affiliates. All rights reserved.
00013  *
00014  * SPDX-License-Identifier: Apache-2.0
00015  *
00016  * Licensed under the Apache License, Version 2.0 (the License); you may
00017  * not use this file except in compliance with the License.
00018  * You may obtain a copy of the License at
00019  *
00020  * www.apache.org/licenses/LICENSE-2.0
00021  *
00022  * Unless required by applicable law or agreed to in writing, software
00023  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
00024  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00025  * See the License for the specific language governing permissions and
00026  * limitations under the License.
00027  */
00028 
00029 #include "arm_math.h"
00030 
00031 /**
00032  * @ingroup groupMatrix
00033  */
00034 
00035 /**
00036  * @addtogroup MatrixMult
00037  * @{
00038  */
00039 
00040 /**
00041  * @brief Q31 matrix multiplication (fast variant) for Cortex-M3 and Cortex-M4
00042  * @param[in]       *pSrcA points to the first input matrix structure
00043  * @param[in]       *pSrcB points to the second input matrix structure
00044  * @param[out]      *pDst points to output matrix structure
00045  * @return          The function returns either
00046  * <code>ARM_MATH_SIZE_MISMATCH</code> or <code>ARM_MATH_SUCCESS</code> based on the outcome of size checking.
00047  *
00048  * @details
00049  * <b>Scaling and Overflow Behavior:</b>
00050  *
00051  * \par
00052  * The difference between the function arm_mat_mult_q31() and this fast variant is that
00053  * the fast variant use a 32-bit rather than a 64-bit accumulator.
00054  * The result of each 1.31 x 1.31 multiplication is truncated to
00055  * 2.30 format. These intermediate results are accumulated in a 32-bit register in 2.30
00056  * format. Finally, the accumulator is saturated and converted to a 1.31 result.
00057  *
00058  * \par
00059  * The fast version has the same overflow behavior as the standard version but provides
00060  * less precision since it discards the low 32 bits of each multiplication result.
00061  * In order to avoid overflows completely the input signals must be scaled down.
00062  * Scale down one of the input matrices by log2(numColsA) bits to
00063  * avoid overflows, as a total of numColsA additions are computed internally for each
00064  * output element.
00065  *
00066  * \par
00067  * See <code>arm_mat_mult_q31()</code> for a slower implementation of this function
00068  * which uses 64-bit accumulation to provide higher precision.
00069  */
00070 
00071 arm_status arm_mat_mult_fast_q31(
00072   const arm_matrix_instance_q31 * pSrcA,
00073   const arm_matrix_instance_q31 * pSrcB,
00074   arm_matrix_instance_q31 * pDst)
00075 {
00076   q31_t *pInA = pSrcA->pData;                    /* input data matrix pointer A */
00077   q31_t *pInB = pSrcB->pData;                    /* input data matrix pointer B */
00078   q31_t *px;                                     /* Temporary output data matrix pointer */
00079   q31_t sum;                                     /* Accumulator */
00080   uint16_t numRowsA = pSrcA->numRows;            /* number of rows of input matrix A    */
00081   uint16_t numColsB = pSrcB->numCols;            /* number of columns of input matrix B */
00082   uint16_t numColsA = pSrcA->numCols;            /* number of columns of input matrix A */
00083   uint32_t col, i = 0U, j, row = numRowsA, colCnt;  /* loop counters */
00084   arm_status status;                             /* status of matrix multiplication */
00085   q31_t inA1, inB1;
00086 
00087 #if defined (ARM_MATH_DSP)
00088 
00089   q31_t sum2, sum3, sum4;
00090   q31_t inA2, inB2;
00091   q31_t *pInA2;
00092   q31_t *px2;
00093 
00094 #endif
00095 
00096 #ifdef ARM_MATH_MATRIX_CHECK
00097 
00098   /* Check for matrix mismatch condition */
00099   if ((pSrcA->numCols != pSrcB->numRows) ||
00100      (pSrcA->numRows != pDst->numRows) || (pSrcB->numCols != pDst->numCols))
00101   {
00102     /* Set status as ARM_MATH_SIZE_MISMATCH */
00103     status = ARM_MATH_SIZE_MISMATCH;
00104   }
00105   else
00106 #endif /*      #ifdef ARM_MATH_MATRIX_CHECK    */
00107 
00108   {
00109 
00110     px = pDst->pData;
00111 
00112 #if defined (ARM_MATH_DSP)
00113     row = row >> 1;
00114     px2 = px + numColsB;
00115 #endif
00116 
00117     /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */
00118     /* row loop */
00119     while (row > 0U)
00120     {
00121 
00122       /* For every row wise process, the column loop counter is to be initiated */
00123       col = numColsB;
00124 
00125       /* For every row wise process, the pIn2 pointer is set
00126        ** to the starting address of the pSrcB data */
00127       pInB = pSrcB->pData;
00128 
00129       j = 0U;
00130 
00131 #if defined (ARM_MATH_DSP)
00132       col = col >> 1;
00133 #endif
00134 
00135       /* column loop */
00136       while (col > 0U)
00137       {
00138         /* Set the variable sum, that acts as accumulator, to zero */
00139         sum = 0;
00140 
00141         /* Initiate data pointers */
00142         pInA = pSrcA->pData + i;
00143         pInB  = pSrcB->pData + j;
00144 
00145 #if defined (ARM_MATH_DSP)
00146         sum2 = 0;
00147         sum3 = 0;
00148         sum4 = 0;
00149         pInA2 = pInA + numColsA;
00150         colCnt = numColsA;
00151 #else
00152         colCnt = numColsA >> 2;
00153 #endif
00154 
00155         /* matrix multiplication */
00156         while (colCnt > 0U)
00157         {
00158 
00159 #if defined (ARM_MATH_DSP)
00160           inA1 = *pInA++;
00161           inB1 = pInB[0];
00162           inA2 = *pInA2++;
00163           inB2 = pInB[1];
00164           pInB += numColsB;
00165 
00166           sum  = __SMMLA(inA1, inB1, sum);
00167           sum2 = __SMMLA(inA1, inB2, sum2);
00168           sum3 = __SMMLA(inA2, inB1, sum3);
00169           sum4 = __SMMLA(inA2, inB2, sum4);
00170 #else
00171           /* c(m,n) = a(1,1)*b(1,1) + a(1,2) * b(2,1) + .... + a(m,p)*b(p,n) */
00172           /* Perform the multiply-accumulates */
00173           inB1 = *pInB;
00174           pInB += numColsB;
00175           inA1 = pInA[0];
00176           sum = __SMMLA(inA1, inB1, sum);
00177 
00178           inB1 = *pInB;
00179           pInB += numColsB;
00180           inA1 = pInA[1];
00181           sum = __SMMLA(inA1, inB1, sum);
00182 
00183           inB1 = *pInB;
00184           pInB += numColsB;
00185           inA1 = pInA[2];
00186           sum = __SMMLA(inA1, inB1, sum);
00187 
00188           inB1 = *pInB;
00189           pInB += numColsB;
00190           inA1 = pInA[3];
00191           sum = __SMMLA(inA1, inB1, sum);
00192 
00193           pInA += 4U;
00194 #endif
00195 
00196           /* Decrement the loop counter */
00197           colCnt--;
00198         }
00199 
00200 #ifdef ARM_MATH_CM0_FAMILY
00201         /* If the columns of pSrcA is not a multiple of 4, compute any remaining output samples here. */
00202         colCnt = numColsA % 0x4U;
00203         while (colCnt > 0U)
00204         {
00205           sum = __SMMLA(*pInA++, *pInB, sum);
00206           pInB += numColsB;
00207           colCnt--;
00208         }
00209         j++;
00210 #endif
00211 
00212         /* Convert the result from 2.30 to 1.31 format and store in destination buffer */
00213         *px++  = sum << 1;
00214 
00215 #if defined (ARM_MATH_DSP)
00216         *px++  = sum2 << 1;
00217         *px2++ = sum3 << 1;
00218         *px2++ = sum4 << 1;
00219         j += 2;
00220 #endif
00221 
00222         /* Decrement the column loop counter */
00223         col--;
00224 
00225       }
00226 
00227       i = i + numColsA;
00228 
00229 #if defined (ARM_MATH_DSP)
00230       i = i + numColsA;
00231       px = px2 + (numColsB & 1U);
00232       px2 = px + numColsB;
00233 #endif
00234 
00235       /* Decrement the row loop counter */
00236       row--;
00237 
00238     }
00239 
00240     /* Compute any remaining odd row/column below */
00241 
00242 #if defined (ARM_MATH_DSP)
00243 
00244     /* Compute remaining output column */
00245     if (numColsB & 1U) {
00246 
00247       /* Avoid redundant computation of last element */
00248       row = numRowsA & (~0x1);
00249 
00250       /* Point to remaining unfilled column in output matrix */
00251       px = pDst->pData+numColsB-1;
00252       pInA = pSrcA->pData;
00253 
00254       /* row loop */
00255       while (row > 0)
00256       {
00257 
00258         /* point to last column in matrix B */
00259         pInB  = pSrcB->pData + numColsB-1;
00260 
00261         /* Set the variable sum, that acts as accumulator, to zero */
00262         sum  = 0;
00263 
00264         /* Compute 4 columns at once */
00265         colCnt = numColsA >> 2;
00266 
00267         /* matrix multiplication */
00268         while (colCnt > 0U)
00269         {
00270           inA1 = *pInA++;
00271           inA2 = *pInA++;
00272           inB1 = *pInB;
00273           pInB += numColsB;
00274           inB2 = *pInB;
00275           pInB += numColsB;
00276           sum = __SMMLA(inA1, inB1, sum);
00277           sum = __SMMLA(inA2, inB2, sum);
00278 
00279           inA1 = *pInA++;
00280           inA2 = *pInA++;
00281           inB1 = *pInB;
00282           pInB += numColsB;
00283           inB2 = *pInB;
00284           pInB += numColsB;
00285           sum = __SMMLA(inA1, inB1, sum);
00286           sum = __SMMLA(inA2, inB2, sum);
00287 
00288           /* Decrement the loop counter */
00289           colCnt--;
00290         }
00291 
00292         colCnt = numColsA & 3U;
00293         while (colCnt > 0U) {
00294           sum = __SMMLA(*pInA++, *pInB, sum);
00295           pInB += numColsB;
00296           colCnt--;
00297         }
00298 
00299         /* Convert the result from 2.30 to 1.31 format and store in destination buffer */
00300         *px = sum << 1;
00301         px += numColsB;
00302 
00303         /* Decrement the row loop counter */
00304         row--;
00305       }
00306     }
00307 
00308     /* Compute remaining output row */
00309     if (numRowsA & 1U) {
00310 
00311       /* point to last row in output matrix */
00312       px = pDst->pData+(numColsB)*(numRowsA-1);
00313 
00314       col = numColsB;
00315       i = 0U;
00316 
00317       /* col loop */
00318       while (col > 0)
00319       {
00320 
00321         /* point to last row in matrix A */
00322         pInA = pSrcA->pData + (numRowsA-1)*numColsA;
00323         pInB  = pSrcB->pData + i;
00324 
00325         /* Set the variable sum, that acts as accumulator, to zero */
00326         sum  = 0;
00327 
00328         /* Compute 4 columns at once */
00329         colCnt = numColsA >> 2;
00330 
00331         /* matrix multiplication */
00332         while (colCnt > 0U)
00333         {
00334           inA1 = *pInA++;
00335           inA2 = *pInA++;
00336           inB1 = *pInB;
00337           pInB += numColsB;
00338           inB2 = *pInB;
00339           pInB += numColsB;
00340           sum = __SMMLA(inA1, inB1, sum);
00341           sum = __SMMLA(inA2, inB2, sum);
00342 
00343           inA1 = *pInA++;
00344           inA2 = *pInA++;
00345           inB1 = *pInB;
00346           pInB += numColsB;
00347           inB2 = *pInB;
00348           pInB += numColsB;
00349           sum = __SMMLA(inA1, inB1, sum);
00350           sum = __SMMLA(inA2, inB2, sum);
00351 
00352           /* Decrement the loop counter */
00353           colCnt--;
00354         }
00355 
00356         colCnt = numColsA & 3U;
00357         while (colCnt > 0U) {
00358           sum = __SMMLA(*pInA++, *pInB, sum);
00359           pInB += numColsB;
00360           colCnt--;
00361         }
00362 
00363         /* Saturate and store the result in the destination buffer */
00364         *px++ = sum << 1;
00365         i++;
00366 
00367         /* Decrement the col loop counter */
00368         col--;
00369       }
00370     }
00371 
00372 #endif /* #if defined (ARM_MATH_DSP) */
00373 
00374     /* set status as ARM_MATH_SUCCESS */
00375     status = ARM_MATH_SUCCESS;
00376   }
00377 
00378   /* Return to application */
00379   return (status);
00380 }
00381 
00382 /**
00383  * @} end of MatrixMult group
00384  */
00385