Eigne Matrix Class Library
Dependents: Eigen_test Odometry_test AttitudeEstimation_usingTicker MPU9250_Quaternion_Binary_Serial ... more
TriangularMatrixMatrix_MKL.h
00001 /* 00002 Copyright (c) 2011, Intel Corporation. All rights reserved. 00003 00004 Redistribution and use in source and binary forms, with or without modification, 00005 are permitted provided that the following conditions are met: 00006 00007 * Redistributions of source code must retain the above copyright notice, this 00008 list of conditions and the following disclaimer. 00009 * Redistributions in binary form must reproduce the above copyright notice, 00010 this list of conditions and the following disclaimer in the documentation 00011 and/or other materials provided with the distribution. 00012 * Neither the name of Intel Corporation nor the names of its contributors may 00013 be used to endorse or promote products derived from this software without 00014 specific prior written permission. 00015 00016 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 00017 ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 00018 WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 00019 DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 00020 ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 00021 (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 00022 LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 00023 ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 00024 (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 00025 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00026 00027 ******************************************************************************** 00028 * Content : Eigen bindings to Intel(R) MKL 00029 * Triangular matrix * matrix product functionality based on ?TRMM. 00030 ******************************************************************************** 00031 */ 00032 00033 #ifndef EIGEN_TRIANGULAR_MATRIX_MATRIX_MKL_H 00034 #define EIGEN_TRIANGULAR_MATRIX_MATRIX_MKL_H 00035 00036 namespace Eigen { 00037 00038 namespace internal { 00039 00040 00041 template <typename Scalar, typename Index, 00042 int Mode, bool LhsIsTriangular, 00043 int LhsStorageOrder, bool ConjugateLhs, 00044 int RhsStorageOrder, bool ConjugateRhs, 00045 int ResStorageOrder> 00046 struct product_triangular_matrix_matrix_trmm : 00047 product_triangular_matrix_matrix<Scalar,Index,Mode, 00048 LhsIsTriangular,LhsStorageOrder,ConjugateLhs, 00049 RhsStorageOrder, ConjugateRhs, ResStorageOrder, BuiltIn> {}; 00050 00051 00052 // try to go to BLAS specialization 00053 #define EIGEN_MKL_TRMM_SPECIALIZE(Scalar, LhsIsTriangular) \ 00054 template <typename Index, int Mode, \ 00055 int LhsStorageOrder, bool ConjugateLhs, \ 00056 int RhsStorageOrder, bool ConjugateRhs> \ 00057 struct product_triangular_matrix_matrix<Scalar,Index, Mode, LhsIsTriangular, \ 00058 LhsStorageOrder,ConjugateLhs, RhsStorageOrder,ConjugateRhs,ColMajor,Specialized> { \ 00059 static inline void run(Index _rows, Index _cols, Index _depth, const Scalar* _lhs, Index lhsStride,\ 00060 const Scalar* _rhs, Index rhsStride, Scalar* res, Index resStride, Scalar alpha, level3_blocking<Scalar,Scalar>& blocking) { \ 00061 product_triangular_matrix_matrix_trmm<Scalar,Index,Mode, \ 00062 LhsIsTriangular,LhsStorageOrder,ConjugateLhs, \ 00063 RhsStorageOrder, ConjugateRhs, ColMajor>::run( \ 00064 _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \ 00065 } \ 00066 }; 00067 00068 EIGEN_MKL_TRMM_SPECIALIZE(double, true) 00069 EIGEN_MKL_TRMM_SPECIALIZE(double, false) 00070 EIGEN_MKL_TRMM_SPECIALIZE(dcomplex, true) 00071 EIGEN_MKL_TRMM_SPECIALIZE(dcomplex, false) 00072 EIGEN_MKL_TRMM_SPECIALIZE(float, true) 00073 EIGEN_MKL_TRMM_SPECIALIZE(float, false) 00074 EIGEN_MKL_TRMM_SPECIALIZE(scomplex, true) 00075 EIGEN_MKL_TRMM_SPECIALIZE(scomplex, false) 00076 00077 // implements col-major += alpha * op(triangular) * op(general) 00078 #define EIGEN_MKL_TRMM_L(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \ 00079 template <typename Index, int Mode, \ 00080 int LhsStorageOrder, bool ConjugateLhs, \ 00081 int RhsStorageOrder, bool ConjugateRhs> \ 00082 struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \ 00083 LhsStorageOrder,ConjugateLhs,RhsStorageOrder,ConjugateRhs,ColMajor> \ 00084 { \ 00085 enum { \ 00086 IsLower = (Mode&Lower) == Lower, \ 00087 SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \ 00088 IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \ 00089 IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \ 00090 LowUp = IsLower ? Lower : Upper, \ 00091 conjA = ((LhsStorageOrder==ColMajor) && ConjugateLhs) ? 1 : 0 \ 00092 }; \ 00093 \ 00094 static void run( \ 00095 Index _rows, Index _cols, Index _depth, \ 00096 const EIGTYPE* _lhs, Index lhsStride, \ 00097 const EIGTYPE* _rhs, Index rhsStride, \ 00098 EIGTYPE* res, Index resStride, \ 00099 EIGTYPE alpha, level3_blocking<EIGTYPE,EIGTYPE>& blocking) \ 00100 { \ 00101 Index diagSize = (std::min)(_rows,_depth); \ 00102 Index rows = IsLower ? _rows : diagSize; \ 00103 Index depth = IsLower ? diagSize : _depth; \ 00104 Index cols = _cols; \ 00105 \ 00106 typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \ 00107 typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \ 00108 \ 00109 /* Non-square case - doesn't fit to MKL ?TRMM. Fall to default triangular product or call MKL ?GEMM*/ \ 00110 if (rows != depth) { \ 00111 \ 00112 int nthr = mkl_domain_get_max_threads(EIGEN_MKL_DOMAIN_BLAS); \ 00113 \ 00114 if (((nthr==1) && (((std::max)(rows,depth)-diagSize)/(double)diagSize < 0.5))) { \ 00115 /* Most likely no benefit to call TRMM or GEMM from MKL*/ \ 00116 product_triangular_matrix_matrix<EIGTYPE,Index,Mode,true, \ 00117 LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, BuiltIn>::run( \ 00118 _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \ 00119 /*std::cout << "TRMM_L: A is not square! Go to Eigen TRMM implementation!\n";*/ \ 00120 } else { \ 00121 /* Make sense to call GEMM */ \ 00122 Map<const MatrixLhs, 0, OuterStride<> > lhsMap(_lhs,rows,depth,OuterStride<>(lhsStride)); \ 00123 MatrixLhs aa_tmp=lhsMap.template triangularView<Mode>(); \ 00124 MKL_INT aStride = aa_tmp.outerStride(); \ 00125 gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth); \ 00126 general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \ 00127 rows, cols, depth, aa_tmp.data(), aStride, _rhs, rhsStride, res, resStride, alpha, gemm_blocking, 0); \ 00128 \ 00129 /*std::cout << "TRMM_L: A is not square! Go to MKL GEMM implementation! " << nthr<<" \n";*/ \ 00130 } \ 00131 return; \ 00132 } \ 00133 char side = 'L', transa, uplo, diag = 'N'; \ 00134 EIGTYPE *b; \ 00135 const EIGTYPE *a; \ 00136 MKL_INT m, n, lda, ldb; \ 00137 MKLTYPE alpha_; \ 00138 \ 00139 /* Set alpha_*/ \ 00140 assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); \ 00141 \ 00142 /* Set m, n */ \ 00143 m = (MKL_INT)diagSize; \ 00144 n = (MKL_INT)cols; \ 00145 \ 00146 /* Set trans */ \ 00147 transa = (LhsStorageOrder==RowMajor) ? ((ConjugateLhs) ? 'C' : 'T') : 'N'; \ 00148 \ 00149 /* Set b, ldb */ \ 00150 Map<const MatrixRhs, 0, OuterStride<> > rhs(_rhs,depth,cols,OuterStride<>(rhsStride)); \ 00151 MatrixX##EIGPREFIX b_tmp; \ 00152 \ 00153 if (ConjugateRhs) b_tmp = rhs.conjugate(); else b_tmp = rhs; \ 00154 b = b_tmp.data(); \ 00155 ldb = b_tmp.outerStride(); \ 00156 \ 00157 /* Set uplo */ \ 00158 uplo = IsLower ? 'L' : 'U'; \ 00159 if (LhsStorageOrder==RowMajor) uplo = (uplo == 'L') ? 'U' : 'L'; \ 00160 /* Set a, lda */ \ 00161 Map<const MatrixLhs, 0, OuterStride<> > lhs(_lhs,rows,depth,OuterStride<>(lhsStride)); \ 00162 MatrixLhs a_tmp; \ 00163 \ 00164 if ((conjA!=0) || (SetDiag==0)) { \ 00165 if (conjA) a_tmp = lhs.conjugate(); else a_tmp = lhs; \ 00166 if (IsZeroDiag) \ 00167 a_tmp.diagonal().setZero(); \ 00168 else if (IsUnitDiag) \ 00169 a_tmp.diagonal().setOnes();\ 00170 a = a_tmp.data(); \ 00171 lda = a_tmp.outerStride(); \ 00172 } else { \ 00173 a = _lhs; \ 00174 lda = lhsStride; \ 00175 } \ 00176 /*std::cout << "TRMM_L: A is square! Go to MKL TRMM implementation! \n";*/ \ 00177 /* call ?trmm*/ \ 00178 MKLPREFIX##trmm(&side, &uplo, &transa, &diag, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (MKLTYPE*)b, &ldb); \ 00179 \ 00180 /* Add op(a_triangular)*b into res*/ \ 00181 Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \ 00182 res_tmp=res_tmp+b_tmp; \ 00183 } \ 00184 }; 00185 00186 EIGEN_MKL_TRMM_L(double, double, d, d) 00187 EIGEN_MKL_TRMM_L(dcomplex, MKL_Complex16, cd, z) 00188 EIGEN_MKL_TRMM_L(float, float, f, s) 00189 EIGEN_MKL_TRMM_L(scomplex, MKL_Complex8, cf, c) 00190 00191 // implements col-major += alpha * op(general) * op(triangular) 00192 #define EIGEN_MKL_TRMM_R(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \ 00193 template <typename Index, int Mode, \ 00194 int LhsStorageOrder, bool ConjugateLhs, \ 00195 int RhsStorageOrder, bool ConjugateRhs> \ 00196 struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \ 00197 LhsStorageOrder,ConjugateLhs,RhsStorageOrder,ConjugateRhs,ColMajor> \ 00198 { \ 00199 enum { \ 00200 IsLower = (Mode&Lower) == Lower, \ 00201 SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \ 00202 IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \ 00203 IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \ 00204 LowUp = IsLower ? Lower : Upper, \ 00205 conjA = ((RhsStorageOrder==ColMajor) && ConjugateRhs) ? 1 : 0 \ 00206 }; \ 00207 \ 00208 static void run( \ 00209 Index _rows, Index _cols, Index _depth, \ 00210 const EIGTYPE* _lhs, Index lhsStride, \ 00211 const EIGTYPE* _rhs, Index rhsStride, \ 00212 EIGTYPE* res, Index resStride, \ 00213 EIGTYPE alpha, level3_blocking<EIGTYPE,EIGTYPE>& blocking) \ 00214 { \ 00215 Index diagSize = (std::min)(_cols,_depth); \ 00216 Index rows = _rows; \ 00217 Index depth = IsLower ? _depth : diagSize; \ 00218 Index cols = IsLower ? diagSize : _cols; \ 00219 \ 00220 typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \ 00221 typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \ 00222 \ 00223 /* Non-square case - doesn't fit to MKL ?TRMM. Fall to default triangular product or call MKL ?GEMM*/ \ 00224 if (cols != depth) { \ 00225 \ 00226 int nthr = mkl_domain_get_max_threads(EIGEN_MKL_DOMAIN_BLAS); \ 00227 \ 00228 if ((nthr==1) && (((std::max)(cols,depth)-diagSize)/(double)diagSize < 0.5)) { \ 00229 /* Most likely no benefit to call TRMM or GEMM from MKL*/ \ 00230 product_triangular_matrix_matrix<EIGTYPE,Index,Mode,false, \ 00231 LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, BuiltIn>::run( \ 00232 _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \ 00233 /*std::cout << "TRMM_R: A is not square! Go to Eigen TRMM implementation!\n";*/ \ 00234 } else { \ 00235 /* Make sense to call GEMM */ \ 00236 Map<const MatrixRhs, 0, OuterStride<> > rhsMap(_rhs,depth,cols, OuterStride<>(rhsStride)); \ 00237 MatrixRhs aa_tmp=rhsMap.template triangularView<Mode>(); \ 00238 MKL_INT aStride = aa_tmp.outerStride(); \ 00239 gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth); \ 00240 general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \ 00241 rows, cols, depth, _lhs, lhsStride, aa_tmp.data(), aStride, res, resStride, alpha, gemm_blocking, 0); \ 00242 \ 00243 /*std::cout << "TRMM_R: A is not square! Go to MKL GEMM implementation! " << nthr<<" \n";*/ \ 00244 } \ 00245 return; \ 00246 } \ 00247 char side = 'R', transa, uplo, diag = 'N'; \ 00248 EIGTYPE *b; \ 00249 const EIGTYPE *a; \ 00250 MKL_INT m, n, lda, ldb; \ 00251 MKLTYPE alpha_; \ 00252 \ 00253 /* Set alpha_*/ \ 00254 assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); \ 00255 \ 00256 /* Set m, n */ \ 00257 m = (MKL_INT)rows; \ 00258 n = (MKL_INT)diagSize; \ 00259 \ 00260 /* Set trans */ \ 00261 transa = (RhsStorageOrder==RowMajor) ? ((ConjugateRhs) ? 'C' : 'T') : 'N'; \ 00262 \ 00263 /* Set b, ldb */ \ 00264 Map<const MatrixLhs, 0, OuterStride<> > lhs(_lhs,rows,depth,OuterStride<>(lhsStride)); \ 00265 MatrixX##EIGPREFIX b_tmp; \ 00266 \ 00267 if (ConjugateLhs) b_tmp = lhs.conjugate(); else b_tmp = lhs; \ 00268 b = b_tmp.data(); \ 00269 ldb = b_tmp.outerStride(); \ 00270 \ 00271 /* Set uplo */ \ 00272 uplo = IsLower ? 'L' : 'U'; \ 00273 if (RhsStorageOrder==RowMajor) uplo = (uplo == 'L') ? 'U' : 'L'; \ 00274 /* Set a, lda */ \ 00275 Map<const MatrixRhs, 0, OuterStride<> > rhs(_rhs,depth,cols, OuterStride<>(rhsStride)); \ 00276 MatrixRhs a_tmp; \ 00277 \ 00278 if ((conjA!=0) || (SetDiag==0)) { \ 00279 if (conjA) a_tmp = rhs.conjugate(); else a_tmp = rhs; \ 00280 if (IsZeroDiag) \ 00281 a_tmp.diagonal().setZero(); \ 00282 else if (IsUnitDiag) \ 00283 a_tmp.diagonal().setOnes();\ 00284 a = a_tmp.data(); \ 00285 lda = a_tmp.outerStride(); \ 00286 } else { \ 00287 a = _rhs; \ 00288 lda = rhsStride; \ 00289 } \ 00290 /*std::cout << "TRMM_R: A is square! Go to MKL TRMM implementation! \n";*/ \ 00291 /* call ?trmm*/ \ 00292 MKLPREFIX##trmm(&side, &uplo, &transa, &diag, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (MKLTYPE*)b, &ldb); \ 00293 \ 00294 /* Add op(a_triangular)*b into res*/ \ 00295 Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \ 00296 res_tmp=res_tmp+b_tmp; \ 00297 } \ 00298 }; 00299 00300 EIGEN_MKL_TRMM_R(double, double, d, d) 00301 EIGEN_MKL_TRMM_R(dcomplex, MKL_Complex16, cd, z) 00302 EIGEN_MKL_TRMM_R(float, float, f, s) 00303 EIGEN_MKL_TRMM_R(scomplex, MKL_Complex8, cf, c) 00304 00305 } // end namespace internal 00306 00307 } // end namespace Eigen 00308 00309 #endif // EIGEN_TRIANGULAR_MATRIX_MATRIX_MKL_H
Generated on Tue Jul 12 2022 17:47:01 by 1.7.2