Important changes to repositories hosted on mbed.com
Mbed hosted mercurial repositories are deprecated and are due to be permanently deleted in July 2026.
To keep a copy of this software download the repository Zip archive or clone locally using Mercurial.
It is also possible to export all your personal repositories from the account settings page.
TriangularMatrixMatrix.h
00001 // This file is part of Eigen, a lightweight C++ template library 00002 // for linear algebra. 00003 // 00004 // Copyright (C) 2009 Gael Guennebaud <gael.guennebaud@inria.fr> 00005 // 00006 // This Source Code Form is subject to the terms of the Mozilla 00007 // Public License v. 2.0. If a copy of the MPL was not distributed 00008 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 00009 00010 #ifndef EIGEN_TRIANGULAR_MATRIX_MATRIX_H 00011 #define EIGEN_TRIANGULAR_MATRIX_MATRIX_H 00012 00013 namespace Eigen { 00014 00015 namespace internal { 00016 00017 // template<typename Scalar, int mr, int StorageOrder, bool Conjugate, int Mode> 00018 // struct gemm_pack_lhs_triangular 00019 // { 00020 // Matrix<Scalar,mr,mr, 00021 // void operator()(Scalar* blockA, const EIGEN_RESTRICT Scalar* _lhs, int lhsStride, int depth, int rows) 00022 // { 00023 // conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj; 00024 // const_blas_data_mapper<Scalar, StorageOrder> lhs(_lhs,lhsStride); 00025 // int count = 0; 00026 // const int peeled_mc = (rows/mr)*mr; 00027 // for(int i=0; i<peeled_mc; i+=mr) 00028 // { 00029 // for(int k=0; k<depth; k++) 00030 // for(int w=0; w<mr; w++) 00031 // blockA[count++] = cj(lhs(i+w, k)); 00032 // } 00033 // for(int i=peeled_mc; i<rows; i++) 00034 // { 00035 // for(int k=0; k<depth; k++) 00036 // blockA[count++] = cj(lhs(i, k)); 00037 // } 00038 // } 00039 // }; 00040 00041 /* Optimized triangular matrix * matrix (_TRMM++) product built on top of 00042 * the general matrix matrix product. 00043 */ 00044 template <typename Scalar, typename Index, 00045 int Mode, bool LhsIsTriangular, 00046 int LhsStorageOrder, bool ConjugateLhs, 00047 int RhsStorageOrder, bool ConjugateRhs, 00048 int ResStorageOrder, int Version = Specialized> 00049 struct product_triangular_matrix_matrix; 00050 00051 template <typename Scalar, typename Index, 00052 int Mode, bool LhsIsTriangular, 00053 int LhsStorageOrder, bool ConjugateLhs, 00054 int RhsStorageOrder, bool ConjugateRhs, int Version> 00055 struct product_triangular_matrix_matrix<Scalar,Index,Mode,LhsIsTriangular, 00056 LhsStorageOrder,ConjugateLhs, 00057 RhsStorageOrder,ConjugateRhs,RowMajor,Version> 00058 { 00059 static EIGEN_STRONG_INLINE void run( 00060 Index rows, Index cols, Index depth, 00061 const Scalar* lhs, Index lhsStride, 00062 const Scalar* rhs, Index rhsStride, 00063 Scalar* res, Index resStride, 00064 const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking) 00065 { 00066 product_triangular_matrix_matrix<Scalar, Index, 00067 (Mode&(UnitDiag|ZeroDiag)) | ((Mode&Upper) ? Lower : Upper), 00068 (!LhsIsTriangular), 00069 RhsStorageOrder==RowMajor ? ColMajor : RowMajor, 00070 ConjugateRhs, 00071 LhsStorageOrder==RowMajor ? ColMajor : RowMajor, 00072 ConjugateLhs, 00073 ColMajor> 00074 ::run(cols, rows, depth, rhs, rhsStride, lhs, lhsStride, res, resStride, alpha, blocking); 00075 } 00076 }; 00077 00078 // implements col-major += alpha * op(triangular) * op(general) 00079 template <typename Scalar, typename Index, int Mode, 00080 int LhsStorageOrder, bool ConjugateLhs, 00081 int RhsStorageOrder, bool ConjugateRhs, int Version> 00082 struct product_triangular_matrix_matrix<Scalar,Index,Mode,true, 00083 LhsStorageOrder,ConjugateLhs, 00084 RhsStorageOrder,ConjugateRhs,ColMajor,Version> 00085 { 00086 00087 typedef gebp_traits<Scalar,Scalar> Traits; 00088 enum { 00089 SmallPanelWidth = 2 * EIGEN_PLAIN_ENUM_MAX(Traits::mr,Traits::nr), 00090 IsLower = (Mode&Lower) == Lower, 00091 SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1 00092 }; 00093 00094 static EIGEN_DONT_INLINE void run( 00095 Index _rows, Index _cols, Index _depth, 00096 const Scalar* _lhs, Index lhsStride, 00097 const Scalar* _rhs, Index rhsStride, 00098 Scalar* res, Index resStride, 00099 const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking); 00100 }; 00101 00102 template <typename Scalar, typename Index, int Mode, 00103 int LhsStorageOrder, bool ConjugateLhs, 00104 int RhsStorageOrder, bool ConjugateRhs, int Version> 00105 EIGEN_DONT_INLINE void product_triangular_matrix_matrix<Scalar,Index,Mode,true, 00106 LhsStorageOrder,ConjugateLhs, 00107 RhsStorageOrder,ConjugateRhs,ColMajor,Version>::run( 00108 Index _rows, Index _cols, Index _depth, 00109 const Scalar* _lhs, Index lhsStride, 00110 const Scalar* _rhs, Index rhsStride, 00111 Scalar* res, Index resStride, 00112 const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking) 00113 { 00114 // strip zeros 00115 Index diagSize = (std::min)(_rows,_depth); 00116 Index rows = IsLower ? _rows : diagSize; 00117 Index depth = IsLower ? diagSize : _depth; 00118 Index cols = _cols; 00119 00120 const_blas_data_mapper<Scalar, Index, LhsStorageOrder> lhs(_lhs,lhsStride); 00121 const_blas_data_mapper<Scalar, Index, RhsStorageOrder> rhs(_rhs,rhsStride); 00122 00123 Index kc = blocking.kc(); // cache block size along the K direction 00124 Index mc = (std::min)(rows,blocking.mc()); // cache block size along the M direction 00125 00126 std::size_t sizeA = kc*mc; 00127 std::size_t sizeB = kc*cols; 00128 std::size_t sizeW = kc*Traits::WorkSpaceFactor; 00129 00130 ei_declare_aligned_stack_constructed_variable(Scalar, blockA, sizeA, blocking.blockA()); 00131 ei_declare_aligned_stack_constructed_variable(Scalar, blockB, sizeB, blocking.blockB()); 00132 ei_declare_aligned_stack_constructed_variable(Scalar, blockW, sizeW, blocking.blockW()); 00133 00134 Matrix<Scalar,SmallPanelWidth,SmallPanelWidth,LhsStorageOrder> triangularBuffer; 00135 triangularBuffer.setZero(); 00136 if((Mode&ZeroDiag)==ZeroDiag) 00137 triangularBuffer.diagonal().setZero(); 00138 else 00139 triangularBuffer.diagonal().setOnes(); 00140 00141 gebp_kernel<Scalar, Scalar, Index, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp_kernel; 00142 gemm_pack_lhs<Scalar, Index, Traits::mr, Traits::LhsProgress, LhsStorageOrder> pack_lhs; 00143 gemm_pack_rhs<Scalar, Index, Traits::nr,RhsStorageOrder> pack_rhs; 00144 00145 for(Index k2=IsLower ? depth : 0; 00146 IsLower ? k2>0 : k2<depth; 00147 IsLower ? k2-=kc : k2+=kc) 00148 { 00149 Index actual_kc = (std::min)(IsLower ? k2 : depth-k2, kc); 00150 Index actual_k2 = IsLower ? k2-actual_kc : k2; 00151 00152 // align blocks with the end of the triangular part for trapezoidal lhs 00153 if((!IsLower)&&(k2<rows)&&(k2+actual_kc>rows)) 00154 { 00155 actual_kc = rows-k2; 00156 k2 = k2+actual_kc-kc; 00157 } 00158 00159 pack_rhs(blockB, &rhs(actual_k2,0), rhsStride, actual_kc, cols); 00160 00161 // the selected lhs's panel has to be split in three different parts: 00162 // 1 - the part which is zero => skip it 00163 // 2 - the diagonal block => special kernel 00164 // 3 - the dense panel below (lower case) or above (upper case) the diagonal block => GEPP 00165 00166 // the block diagonal, if any: 00167 if(IsLower || actual_k2<rows) 00168 { 00169 // for each small vertical panels of lhs 00170 for (Index k1=0; k1<actual_kc; k1+=SmallPanelWidth) 00171 { 00172 Index actualPanelWidth = std::min<Index>(actual_kc-k1, SmallPanelWidth); 00173 Index lengthTarget = IsLower ? actual_kc-k1-actualPanelWidth : k1; 00174 Index startBlock = actual_k2+k1; 00175 Index blockBOffset = k1; 00176 00177 // => GEBP with the micro triangular block 00178 // The trick is to pack this micro block while filling the opposite triangular part with zeros. 00179 // To this end we do an extra triangular copy to a small temporary buffer 00180 for (Index k=0;k<actualPanelWidth;++k) 00181 { 00182 if (SetDiag) 00183 triangularBuffer.coeffRef(k,k) = lhs(startBlock+k,startBlock+k); 00184 for (Index i=IsLower ? k+1 : 0; IsLower ? i<actualPanelWidth : i<k; ++i) 00185 triangularBuffer.coeffRef(i,k) = lhs(startBlock+i,startBlock+k); 00186 } 00187 pack_lhs(blockA, triangularBuffer.data(), triangularBuffer.outerStride(), actualPanelWidth, actualPanelWidth); 00188 00189 gebp_kernel(res+startBlock, resStride, blockA, blockB, actualPanelWidth, actualPanelWidth, cols, alpha, 00190 actualPanelWidth, actual_kc, 0, blockBOffset, blockW); 00191 00192 // GEBP with remaining micro panel 00193 if (lengthTarget>0) 00194 { 00195 Index startTarget = IsLower ? actual_k2+k1+actualPanelWidth : actual_k2; 00196 00197 pack_lhs(blockA, &lhs(startTarget,startBlock), lhsStride, actualPanelWidth, lengthTarget); 00198 00199 gebp_kernel(res+startTarget, resStride, blockA, blockB, lengthTarget, actualPanelWidth, cols, alpha, 00200 actualPanelWidth, actual_kc, 0, blockBOffset, blockW); 00201 } 00202 } 00203 } 00204 // the part below (lower case) or above (upper case) the diagonal => GEPP 00205 { 00206 Index start = IsLower ? k2 : 0; 00207 Index end = IsLower ? rows : (std::min)(actual_k2,rows); 00208 for(Index i2=start; i2<end; i2+=mc) 00209 { 00210 const Index actual_mc = (std::min)(i2+mc,end)-i2; 00211 gemm_pack_lhs<Scalar, Index, Traits::mr,Traits::LhsProgress, LhsStorageOrder,false>() 00212 (blockA, &lhs(i2, actual_k2), lhsStride, actual_kc, actual_mc); 00213 00214 gebp_kernel(res+i2, resStride, blockA, blockB, actual_mc, actual_kc, cols, alpha, -1, -1, 0, 0, blockW); 00215 } 00216 } 00217 } 00218 } 00219 00220 // implements col-major += alpha * op(general) * op(triangular) 00221 template <typename Scalar, typename Index, int Mode, 00222 int LhsStorageOrder, bool ConjugateLhs, 00223 int RhsStorageOrder, bool ConjugateRhs, int Version> 00224 struct product_triangular_matrix_matrix<Scalar,Index,Mode,false, 00225 LhsStorageOrder,ConjugateLhs, 00226 RhsStorageOrder,ConjugateRhs,ColMajor,Version> 00227 { 00228 typedef gebp_traits<Scalar,Scalar> Traits; 00229 enum { 00230 SmallPanelWidth = EIGEN_PLAIN_ENUM_MAX(Traits::mr,Traits::nr), 00231 IsLower = (Mode&Lower) == Lower, 00232 SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1 00233 }; 00234 00235 static EIGEN_DONT_INLINE void run( 00236 Index _rows, Index _cols, Index _depth, 00237 const Scalar* _lhs, Index lhsStride, 00238 const Scalar* _rhs, Index rhsStride, 00239 Scalar* res, Index resStride, 00240 const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking); 00241 }; 00242 00243 template <typename Scalar, typename Index, int Mode, 00244 int LhsStorageOrder, bool ConjugateLhs, 00245 int RhsStorageOrder, bool ConjugateRhs, int Version> 00246 EIGEN_DONT_INLINE void product_triangular_matrix_matrix<Scalar,Index,Mode,false, 00247 LhsStorageOrder,ConjugateLhs, 00248 RhsStorageOrder,ConjugateRhs,ColMajor,Version>::run( 00249 Index _rows, Index _cols, Index _depth, 00250 const Scalar* _lhs, Index lhsStride, 00251 const Scalar* _rhs, Index rhsStride, 00252 Scalar* res, Index resStride, 00253 const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking) 00254 { 00255 // strip zeros 00256 Index diagSize = (std::min)(_cols,_depth); 00257 Index rows = _rows; 00258 Index depth = IsLower ? _depth : diagSize; 00259 Index cols = IsLower ? diagSize : _cols; 00260 00261 const_blas_data_mapper<Scalar, Index, LhsStorageOrder> lhs(_lhs,lhsStride); 00262 const_blas_data_mapper<Scalar, Index, RhsStorageOrder> rhs(_rhs,rhsStride); 00263 00264 Index kc = blocking.kc(); // cache block size along the K direction 00265 Index mc = (std::min)(rows,blocking.mc()); // cache block size along the M direction 00266 00267 std::size_t sizeA = kc*mc; 00268 std::size_t sizeB = kc*cols; 00269 std::size_t sizeW = kc*Traits::WorkSpaceFactor; 00270 00271 ei_declare_aligned_stack_constructed_variable(Scalar, blockA, sizeA, blocking.blockA()); 00272 ei_declare_aligned_stack_constructed_variable(Scalar, blockB, sizeB, blocking.blockB()); 00273 ei_declare_aligned_stack_constructed_variable(Scalar, blockW, sizeW, blocking.blockW()); 00274 00275 Matrix<Scalar,SmallPanelWidth,SmallPanelWidth,RhsStorageOrder> triangularBuffer; 00276 triangularBuffer.setZero(); 00277 if((Mode&ZeroDiag)==ZeroDiag) 00278 triangularBuffer.diagonal().setZero(); 00279 else 00280 triangularBuffer.diagonal().setOnes(); 00281 00282 gebp_kernel<Scalar, Scalar, Index, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp_kernel; 00283 gemm_pack_lhs<Scalar, Index, Traits::mr, Traits::LhsProgress, LhsStorageOrder> pack_lhs; 00284 gemm_pack_rhs<Scalar, Index, Traits::nr,RhsStorageOrder> pack_rhs; 00285 gemm_pack_rhs<Scalar, Index, Traits::nr,RhsStorageOrder,false,true> pack_rhs_panel; 00286 00287 for(Index k2=IsLower ? 0 : depth; 00288 IsLower ? k2<depth : k2>0; 00289 IsLower ? k2+=kc : k2-=kc) 00290 { 00291 Index actual_kc = (std::min)(IsLower ? depth-k2 : k2, kc); 00292 Index actual_k2 = IsLower ? k2 : k2-actual_kc; 00293 00294 // align blocks with the end of the triangular part for trapezoidal rhs 00295 if(IsLower && (k2<cols) && (actual_k2+actual_kc>cols)) 00296 { 00297 actual_kc = cols-k2; 00298 k2 = actual_k2 + actual_kc - kc; 00299 } 00300 00301 // remaining size 00302 Index rs = IsLower ? (std::min)(cols,actual_k2) : cols - k2; 00303 // size of the triangular part 00304 Index ts = (IsLower && actual_k2>=cols) ? 0 : actual_kc; 00305 00306 Scalar* geb = blockB+ts*ts; 00307 00308 pack_rhs(geb, &rhs(actual_k2,IsLower ? 0 : k2), rhsStride, actual_kc, rs); 00309 00310 // pack the triangular part of the rhs padding the unrolled blocks with zeros 00311 if(ts>0) 00312 { 00313 for (Index j2=0; j2<actual_kc; j2+=SmallPanelWidth) 00314 { 00315 Index actualPanelWidth = std::min<Index>(actual_kc-j2, SmallPanelWidth); 00316 Index actual_j2 = actual_k2 + j2; 00317 Index panelOffset = IsLower ? j2+actualPanelWidth : 0; 00318 Index panelLength = IsLower ? actual_kc-j2-actualPanelWidth : j2; 00319 // general part 00320 pack_rhs_panel(blockB+j2*actual_kc, 00321 &rhs(actual_k2+panelOffset, actual_j2), rhsStride, 00322 panelLength, actualPanelWidth, 00323 actual_kc, panelOffset); 00324 00325 // append the triangular part via a temporary buffer 00326 for (Index j=0;j<actualPanelWidth;++j) 00327 { 00328 if (SetDiag) 00329 triangularBuffer.coeffRef(j,j) = rhs(actual_j2+j,actual_j2+j); 00330 for (Index k=IsLower ? j+1 : 0; IsLower ? k<actualPanelWidth : k<j; ++k) 00331 triangularBuffer.coeffRef(k,j) = rhs(actual_j2+k,actual_j2+j); 00332 } 00333 00334 pack_rhs_panel(blockB+j2*actual_kc, 00335 triangularBuffer.data(), triangularBuffer.outerStride(), 00336 actualPanelWidth, actualPanelWidth, 00337 actual_kc, j2); 00338 } 00339 } 00340 00341 for (Index i2=0; i2<rows; i2+=mc) 00342 { 00343 const Index actual_mc = (std::min)(mc,rows-i2); 00344 pack_lhs(blockA, &lhs(i2, actual_k2), lhsStride, actual_kc, actual_mc); 00345 00346 // triangular kernel 00347 if(ts>0) 00348 { 00349 for (Index j2=0; j2<actual_kc; j2+=SmallPanelWidth) 00350 { 00351 Index actualPanelWidth = std::min<Index>(actual_kc-j2, SmallPanelWidth); 00352 Index panelLength = IsLower ? actual_kc-j2 : j2+actualPanelWidth; 00353 Index blockOffset = IsLower ? j2 : 0; 00354 00355 gebp_kernel(res+i2+(actual_k2+j2)*resStride, resStride, 00356 blockA, blockB+j2*actual_kc, 00357 actual_mc, panelLength, actualPanelWidth, 00358 alpha, 00359 actual_kc, actual_kc, // strides 00360 blockOffset, blockOffset,// offsets 00361 blockW); // workspace 00362 } 00363 } 00364 gebp_kernel(res+i2+(IsLower ? 0 : k2)*resStride, resStride, 00365 blockA, geb, actual_mc, actual_kc, rs, 00366 alpha, 00367 -1, -1, 0, 0, blockW); 00368 } 00369 } 00370 } 00371 00372 /*************************************************************************** 00373 * Wrapper to product_triangular_matrix_matrix 00374 ***************************************************************************/ 00375 00376 template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs> 00377 struct traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,false> > 00378 : traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,false>, Lhs, Rhs> > 00379 {}; 00380 00381 } // end namespace internal 00382 00383 template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs> 00384 struct TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,false> 00385 : public ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,false>, Lhs, Rhs > 00386 { 00387 EIGEN_PRODUCT_PUBLIC_INTERFACE(TriangularProduct) 00388 00389 TriangularProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {} 00390 00391 template<typename Dest> void scaleAndAddTo(Dest& dst, const Scalar& alpha) const 00392 { 00393 typename internal::add_const_on_value_type<ActualLhsType>::type lhs = LhsBlasTraits::extract(m_lhs); 00394 typename internal::add_const_on_value_type<ActualRhsType>::type rhs = RhsBlasTraits::extract(m_rhs); 00395 00396 Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs) 00397 * RhsBlasTraits::extractScalarFactor(m_rhs); 00398 00399 typedef internal::gemm_blocking_space<(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor,Scalar,Scalar, 00400 Lhs::MaxRowsAtCompileTime, Rhs::MaxColsAtCompileTime, Lhs::MaxColsAtCompileTime,4> BlockingType; 00401 00402 enum { IsLower = (Mode&Lower) == Lower }; 00403 Index stripedRows = ((!LhsIsTriangular) || (IsLower)) ? lhs.rows() : (std::min)(lhs.rows(),lhs.cols()); 00404 Index stripedCols = ((LhsIsTriangular) || (!IsLower)) ? rhs.cols() : (std::min)(rhs.cols(),rhs.rows()); 00405 Index stripedDepth = LhsIsTriangular ? ((!IsLower) ? lhs.cols() : (std::min)(lhs.cols(),lhs.rows())) 00406 : ((IsLower) ? rhs.rows() : (std::min)(rhs.rows(),rhs.cols())); 00407 00408 BlockingType blocking(stripedRows, stripedCols, stripedDepth); 00409 00410 internal::product_triangular_matrix_matrix<Scalar, Index, 00411 Mode, LhsIsTriangular, 00412 (internal::traits<_ActualLhsType>::Flags&RowMajorBit) ? RowMajor : ColMajor, LhsBlasTraits::NeedToConjugate, 00413 (internal::traits<_ActualRhsType>::Flags&RowMajorBit) ? RowMajor : ColMajor, RhsBlasTraits::NeedToConjugate, 00414 (internal::traits<Dest >::Flags&RowMajorBit) ? RowMajor : ColMajor> 00415 ::run( 00416 stripedRows, stripedCols, stripedDepth, // sizes 00417 &lhs.coeffRef(0,0), lhs.outerStride(), // lhs info 00418 &rhs.coeffRef(0,0), rhs.outerStride(), // rhs info 00419 &dst.coeffRef(0,0), dst.outerStride(), // result info 00420 actualAlpha, blocking 00421 ); 00422 } 00423 }; 00424 00425 } // end namespace Eigen 00426 00427 #endif // EIGEN_TRIANGULAR_MATRIX_MATRIX_H
Generated on Thu Nov 17 2022 22:01:30 by
1.7.2