Important changes to repositories hosted on mbed.com
Mbed hosted mercurial repositories are deprecated and are due to be permanently deleted in July 2026.
To keep a copy of this software download the repository Zip archive or clone locally using Mercurial.
It is also possible to export all your personal repositories from the account settings page.
GeneralMatrixMatrix.h
00001 // This file is part of Eigen, a lightweight C++ template library 00002 // for linear algebra. 00003 // 00004 // Copyright (C) 2008-2009 Gael Guennebaud <gael.guennebaud@inria.fr> 00005 // 00006 // This Source Code Form is subject to the terms of the Mozilla 00007 // Public License v. 2.0. If a copy of the MPL was not distributed 00008 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 00009 00010 #ifndef EIGEN_GENERAL_MATRIX_MATRIX_H 00011 #define EIGEN_GENERAL_MATRIX_MATRIX_H 00012 00013 namespace Eigen { 00014 00015 namespace internal { 00016 00017 template<typename _LhsScalar, typename _RhsScalar> class level3_blocking; 00018 00019 /* Specialization for a row-major destination matrix => simple transposition of the product */ 00020 template< 00021 typename Index, 00022 typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs, 00023 typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs> 00024 struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,RowMajor> 00025 { 00026 typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar; 00027 static EIGEN_STRONG_INLINE void run( 00028 Index rows, Index cols, Index depth, 00029 const LhsScalar* lhs, Index lhsStride, 00030 const RhsScalar* rhs, Index rhsStride, 00031 ResScalar* res, Index resStride, 00032 ResScalar alpha, 00033 level3_blocking<RhsScalar,LhsScalar>& blocking, 00034 GemmParallelInfo<Index>* info = 0) 00035 { 00036 // transpose the product such that the result is column major 00037 general_matrix_matrix_product<Index, 00038 RhsScalar, RhsStorageOrder==RowMajor ? ColMajor : RowMajor, ConjugateRhs, 00039 LhsScalar, LhsStorageOrder==RowMajor ? ColMajor : RowMajor, ConjugateLhs, 00040 ColMajor> 00041 ::run(cols,rows,depth,rhs,rhsStride,lhs,lhsStride,res,resStride,alpha,blocking,info); 00042 } 00043 }; 00044 00045 /* Specialization for a col-major destination matrix 00046 * => Blocking algorithm following Goto's paper */ 00047 template< 00048 typename Index, 00049 typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs, 00050 typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs> 00051 struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,ColMajor> 00052 { 00053 00054 typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar; 00055 static void run(Index rows, Index cols, Index depth, 00056 const LhsScalar* _lhs, Index lhsStride, 00057 const RhsScalar* _rhs, Index rhsStride, 00058 ResScalar* res, Index resStride, 00059 ResScalar alpha, 00060 level3_blocking<LhsScalar,RhsScalar>& blocking, 00061 GemmParallelInfo<Index>* info = 0) 00062 { 00063 const_blas_data_mapper<LhsScalar, Index, LhsStorageOrder> lhs(_lhs,lhsStride); 00064 const_blas_data_mapper<RhsScalar, Index, RhsStorageOrder> rhs(_rhs,rhsStride); 00065 00066 typedef gebp_traits<LhsScalar,RhsScalar> Traits; 00067 00068 Index kc = blocking.kc(); // cache block size along the K direction 00069 Index mc = (std::min)(rows,blocking.mc()); // cache block size along the M direction 00070 //Index nc = blocking.nc(); // cache block size along the N direction 00071 00072 gemm_pack_lhs<LhsScalar, Index, Traits::mr, Traits::LhsProgress, LhsStorageOrder> pack_lhs; 00073 gemm_pack_rhs<RhsScalar, Index, Traits::nr, RhsStorageOrder> pack_rhs; 00074 gebp_kernel<LhsScalar, RhsScalar, Index, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp; 00075 00076 #ifdef EIGEN_HAS_OPENMP 00077 if(info) 00078 { 00079 // this is the parallel version! 00080 Index tid = omp_get_thread_num(); 00081 Index threads = omp_get_num_threads(); 00082 00083 std::size_t sizeA = kc*mc; 00084 std::size_t sizeW = kc*Traits::WorkSpaceFactor; 00085 ei_declare_aligned_stack_constructed_variable(LhsScalar, blockA, sizeA, 0); 00086 ei_declare_aligned_stack_constructed_variable(RhsScalar, w, sizeW, 0); 00087 00088 RhsScalar* blockB = blocking.blockB(); 00089 eigen_internal_assert(blockB!=0); 00090 00091 // For each horizontal panel of the rhs, and corresponding vertical panel of the lhs... 00092 for(Index k=0; k<depth; k+=kc) 00093 { 00094 const Index actual_kc = (std::min)(k+kc,depth)-k; // => rows of B', and cols of the A' 00095 00096 // In order to reduce the chance that a thread has to wait for the other, 00097 // let's start by packing A'. 00098 pack_lhs(blockA, &lhs(0,k), lhsStride, actual_kc, mc); 00099 00100 // Pack B_k to B' in a parallel fashion: 00101 // each thread packs the sub block B_k,j to B'_j where j is the thread id. 00102 00103 // However, before copying to B'_j, we have to make sure that no other thread is still using it, 00104 // i.e., we test that info[tid].users equals 0. 00105 // Then, we set info[tid].users to the number of threads to mark that all other threads are going to use it. 00106 while(info[tid].users!=0) {} 00107 info[tid].users += threads; 00108 00109 pack_rhs(blockB+info[tid].rhs_start*actual_kc, &rhs(k,info[tid].rhs_start), rhsStride, actual_kc, info[tid].rhs_length); 00110 00111 // Notify the other threads that the part B'_j is ready to go. 00112 info[tid].sync = k; 00113 00114 // Computes C_i += A' * B' per B'_j 00115 for(Index shift=0; shift<threads; ++shift) 00116 { 00117 Index j = (tid+shift)%threads; 00118 00119 // At this point we have to make sure that B'_j has been updated by the thread j, 00120 // we use testAndSetOrdered to mimic a volatile access. 00121 // However, no need to wait for the B' part which has been updated by the current thread! 00122 if(shift>0) 00123 while(info[j].sync!=k) {} 00124 00125 gebp(res+info[j].rhs_start*resStride, resStride, blockA, blockB+info[j].rhs_start*actual_kc, mc, actual_kc, info[j].rhs_length, alpha, -1,-1,0,0, w); 00126 } 00127 00128 // Then keep going as usual with the remaining A' 00129 for(Index i=mc; i<rows; i+=mc) 00130 { 00131 const Index actual_mc = (std::min)(i+mc,rows)-i; 00132 00133 // pack A_i,k to A' 00134 pack_lhs(blockA, &lhs(i,k), lhsStride, actual_kc, actual_mc); 00135 00136 // C_i += A' * B' 00137 gebp(res+i, resStride, blockA, blockB, actual_mc, actual_kc, cols, alpha, -1,-1,0,0, w); 00138 } 00139 00140 // Release all the sub blocks B'_j of B' for the current thread, 00141 // i.e., we simply decrement the number of users by 1 00142 for(Index j=0; j<threads; ++j) 00143 { 00144 #pragma omp atomic 00145 info[j].users -= 1; 00146 } 00147 } 00148 } 00149 else 00150 #endif // EIGEN_HAS_OPENMP 00151 { 00152 EIGEN_UNUSED_VARIABLE(info); 00153 00154 // this is the sequential version! 00155 std::size_t sizeA = kc*mc; 00156 std::size_t sizeB = kc*cols; 00157 std::size_t sizeW = kc*Traits::WorkSpaceFactor; 00158 00159 ei_declare_aligned_stack_constructed_variable(LhsScalar, blockA, sizeA, blocking.blockA()); 00160 ei_declare_aligned_stack_constructed_variable(RhsScalar, blockB, sizeB, blocking.blockB()); 00161 ei_declare_aligned_stack_constructed_variable(RhsScalar, blockW, sizeW, blocking.blockW()); 00162 00163 // For each horizontal panel of the rhs, and corresponding panel of the lhs... 00164 // (==GEMM_VAR1) 00165 for(Index k2=0; k2<depth; k2+=kc) 00166 { 00167 const Index actual_kc = (std::min)(k2+kc,depth)-k2; 00168 00169 // OK, here we have selected one horizontal panel of rhs and one vertical panel of lhs. 00170 // => Pack rhs's panel into a sequential chunk of memory (L2 caching) 00171 // Note that this panel will be read as many times as the number of blocks in the lhs's 00172 // vertical panel which is, in practice, a very low number. 00173 pack_rhs(blockB, &rhs(k2,0), rhsStride, actual_kc, cols); 00174 00175 // For each mc x kc block of the lhs's vertical panel... 00176 // (==GEPP_VAR1) 00177 for(Index i2=0; i2<rows; i2+=mc) 00178 { 00179 const Index actual_mc = (std::min)(i2+mc,rows)-i2; 00180 00181 // We pack the lhs's block into a sequential chunk of memory (L1 caching) 00182 // Note that this block will be read a very high number of times, which is equal to the number of 00183 // micro vertical panel of the large rhs's panel (e.g., cols/4 times). 00184 pack_lhs(blockA, &lhs(i2,k2), lhsStride, actual_kc, actual_mc); 00185 00186 // Everything is packed, we can now call the block * panel kernel: 00187 gebp(res+i2, resStride, blockA, blockB, actual_mc, actual_kc, cols, alpha, -1, -1, 0, 0, blockW); 00188 } 00189 } 00190 } 00191 } 00192 00193 }; 00194 00195 /********************************************************************************* 00196 * Specialization of GeneralProduct<> for "large" GEMM, i.e., 00197 * implementation of the high level wrapper to general_matrix_matrix_product 00198 **********************************************************************************/ 00199 00200 template<typename Lhs, typename Rhs> 00201 struct traits<GeneralProduct<Lhs,Rhs,GemmProduct> > 00202 : traits<ProductBase<GeneralProduct<Lhs,Rhs,GemmProduct>, Lhs, Rhs> > 00203 {}; 00204 00205 template<typename Scalar, typename Index, typename Gemm, typename Lhs, typename Rhs, typename Dest, typename BlockingType> 00206 struct gemm_functor 00207 { 00208 gemm_functor(const Lhs& lhs, const Rhs& rhs, Dest& dest, const Scalar& actualAlpha, 00209 BlockingType& blocking) 00210 : m_lhs(lhs), m_rhs(rhs), m_dest(dest), m_actualAlpha(actualAlpha), m_blocking(blocking) 00211 {} 00212 00213 void initParallelSession() const 00214 { 00215 m_blocking.allocateB(); 00216 } 00217 00218 void operator() (Index row, Index rows, Index col=0, Index cols=-1, GemmParallelInfo<Index>* info=0) const 00219 { 00220 if(cols==-1) 00221 cols = m_rhs.cols(); 00222 00223 Gemm::run(rows, cols, m_lhs.cols(), 00224 /*(const Scalar*)*/&m_lhs.coeffRef(row,0), m_lhs.outerStride(), 00225 /*(const Scalar*)*/&m_rhs.coeffRef(0,col), m_rhs.outerStride(), 00226 (Scalar*)&(m_dest.coeffRef(row,col)), m_dest.outerStride(), 00227 m_actualAlpha, m_blocking, info); 00228 } 00229 00230 protected: 00231 const Lhs& m_lhs; 00232 const Rhs& m_rhs; 00233 Dest& m_dest; 00234 Scalar m_actualAlpha; 00235 BlockingType& m_blocking; 00236 }; 00237 00238 template<int StorageOrder, typename LhsScalar, typename RhsScalar, int MaxRows, int MaxCols, int MaxDepth, int KcFactor=1, 00239 bool FiniteAtCompileTime = MaxRows!=Dynamic && MaxCols!=Dynamic && MaxDepth != Dynamic> class gemm_blocking_space; 00240 00241 template<typename _LhsScalar, typename _RhsScalar> 00242 class level3_blocking 00243 { 00244 typedef _LhsScalar LhsScalar; 00245 typedef _RhsScalar RhsScalar; 00246 00247 protected: 00248 LhsScalar* m_blockA; 00249 RhsScalar* m_blockB; 00250 RhsScalar* m_blockW; 00251 00252 DenseIndex m_mc; 00253 DenseIndex m_nc; 00254 DenseIndex m_kc; 00255 00256 public: 00257 00258 level3_blocking() 00259 : m_blockA(0), m_blockB(0), m_blockW(0), m_mc(0), m_nc(0), m_kc(0) 00260 {} 00261 00262 inline DenseIndex mc() const { return m_mc; } 00263 inline DenseIndex nc() const { return m_nc; } 00264 inline DenseIndex kc() const { return m_kc; } 00265 00266 inline LhsScalar* blockA() { return m_blockA; } 00267 inline RhsScalar* blockB() { return m_blockB; } 00268 inline RhsScalar* blockW() { return m_blockW; } 00269 }; 00270 00271 template<int StorageOrder, typename _LhsScalar, typename _RhsScalar, int MaxRows, int MaxCols, int MaxDepth, int KcFactor> 00272 class gemm_blocking_space<StorageOrder,_LhsScalar,_RhsScalar,MaxRows, MaxCols, MaxDepth, KcFactor, true> 00273 : public level3_blocking< 00274 typename conditional<StorageOrder==RowMajor,_RhsScalar,_LhsScalar>::type, 00275 typename conditional<StorageOrder==RowMajor,_LhsScalar,_RhsScalar>::type> 00276 { 00277 enum { 00278 Transpose = StorageOrder==RowMajor, 00279 ActualRows = Transpose ? MaxCols : MaxRows, 00280 ActualCols = Transpose ? MaxRows : MaxCols 00281 }; 00282 typedef typename conditional<Transpose,_RhsScalar,_LhsScalar>::type LhsScalar; 00283 typedef typename conditional<Transpose,_LhsScalar,_RhsScalar>::type RhsScalar; 00284 typedef gebp_traits<LhsScalar,RhsScalar> Traits; 00285 enum { 00286 SizeA = ActualRows * MaxDepth, 00287 SizeB = ActualCols * MaxDepth, 00288 SizeW = MaxDepth * Traits::WorkSpaceFactor 00289 }; 00290 00291 EIGEN_ALIGN16 LhsScalar m_staticA[SizeA]; 00292 EIGEN_ALIGN16 RhsScalar m_staticB[SizeB]; 00293 EIGEN_ALIGN16 RhsScalar m_staticW[SizeW]; 00294 00295 public: 00296 00297 gemm_blocking_space(DenseIndex /*rows*/, DenseIndex /*cols*/, DenseIndex /*depth*/) 00298 { 00299 this->m_mc = ActualRows; 00300 this->m_nc = ActualCols; 00301 this->m_kc = MaxDepth; 00302 this->m_blockA = m_staticA; 00303 this->m_blockB = m_staticB; 00304 this->m_blockW = m_staticW; 00305 } 00306 00307 inline void allocateA() {} 00308 inline void allocateB() {} 00309 inline void allocateW() {} 00310 inline void allocateAll() {} 00311 }; 00312 00313 template<int StorageOrder, typename _LhsScalar, typename _RhsScalar, int MaxRows, int MaxCols, int MaxDepth, int KcFactor> 00314 class gemm_blocking_space<StorageOrder,_LhsScalar,_RhsScalar,MaxRows, MaxCols, MaxDepth, KcFactor, false> 00315 : public level3_blocking< 00316 typename conditional<StorageOrder==RowMajor,_RhsScalar,_LhsScalar>::type, 00317 typename conditional<StorageOrder==RowMajor,_LhsScalar,_RhsScalar>::type> 00318 { 00319 enum { 00320 Transpose = StorageOrder==RowMajor 00321 }; 00322 typedef typename conditional<Transpose,_RhsScalar,_LhsScalar>::type LhsScalar; 00323 typedef typename conditional<Transpose,_LhsScalar,_RhsScalar>::type RhsScalar; 00324 typedef gebp_traits<LhsScalar,RhsScalar> Traits; 00325 00326 DenseIndex m_sizeA; 00327 DenseIndex m_sizeB; 00328 DenseIndex m_sizeW; 00329 00330 public: 00331 00332 gemm_blocking_space(DenseIndex rows, DenseIndex cols, DenseIndex depth) 00333 { 00334 this->m_mc = Transpose ? cols : rows; 00335 this->m_nc = Transpose ? rows : cols; 00336 this->m_kc = depth; 00337 00338 computeProductBlockingSizes<LhsScalar,RhsScalar,KcFactor>(this->m_kc, this->m_mc, this->m_nc); 00339 m_sizeA = this->m_mc * this->m_kc; 00340 m_sizeB = this->m_kc * this->m_nc; 00341 m_sizeW = this->m_kc*Traits::WorkSpaceFactor; 00342 } 00343 00344 void allocateA() 00345 { 00346 if(this->m_blockA==0) 00347 this->m_blockA = aligned_new<LhsScalar>(m_sizeA); 00348 } 00349 00350 void allocateB() 00351 { 00352 if(this->m_blockB==0) 00353 this->m_blockB = aligned_new<RhsScalar>(m_sizeB); 00354 } 00355 00356 void allocateW() 00357 { 00358 if(this->m_blockW==0) 00359 this->m_blockW = aligned_new<RhsScalar>(m_sizeW); 00360 } 00361 00362 void allocateAll() 00363 { 00364 allocateA(); 00365 allocateB(); 00366 allocateW(); 00367 } 00368 00369 ~gemm_blocking_space() 00370 { 00371 aligned_delete(this->m_blockA, m_sizeA); 00372 aligned_delete(this->m_blockB, m_sizeB); 00373 aligned_delete(this->m_blockW, m_sizeW); 00374 } 00375 }; 00376 00377 } // end namespace internal 00378 00379 template<typename Lhs, typename Rhs> 00380 class GeneralProduct<Lhs, Rhs, GemmProduct> 00381 : public ProductBase<GeneralProduct<Lhs,Rhs,GemmProduct>, Lhs, Rhs> 00382 { 00383 enum { 00384 MaxDepthAtCompileTime = EIGEN_SIZE_MIN_PREFER_FIXED(Lhs::MaxColsAtCompileTime,Rhs::MaxRowsAtCompileTime) 00385 }; 00386 public: 00387 EIGEN_PRODUCT_PUBLIC_INTERFACE(GeneralProduct) 00388 00389 typedef typename Lhs::Scalar LhsScalar; 00390 typedef typename Rhs::Scalar RhsScalar; 00391 typedef Scalar ResScalar; 00392 00393 GeneralProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) 00394 { 00395 #if !(defined(EIGEN_NO_STATIC_ASSERT) && defined(EIGEN_NO_DEBUG)) 00396 typedef internal::scalar_product_op<LhsScalar,RhsScalar> BinOp; 00397 EIGEN_CHECK_BINARY_COMPATIBILIY(BinOp,LhsScalar,RhsScalar); 00398 #endif 00399 } 00400 00401 template<typename Dest> void scaleAndAddTo(Dest& dst, const Scalar& alpha) const 00402 { 00403 eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols()); 00404 if(m_lhs.cols()==0 || m_lhs.rows()==0 || m_rhs.cols()==0) 00405 return; 00406 00407 typename internal::add_const_on_value_type<ActualLhsType>::type lhs = LhsBlasTraits::extract(m_lhs); 00408 typename internal::add_const_on_value_type<ActualRhsType>::type rhs = RhsBlasTraits::extract(m_rhs); 00409 00410 Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs) 00411 * RhsBlasTraits::extractScalarFactor(m_rhs); 00412 00413 typedef internal::gemm_blocking_space<(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor,LhsScalar,RhsScalar, 00414 Dest::MaxRowsAtCompileTime,Dest::MaxColsAtCompileTime,MaxDepthAtCompileTime> BlockingType; 00415 00416 typedef internal::gemm_functor< 00417 Scalar, Index, 00418 internal::general_matrix_matrix_product< 00419 Index, 00420 LhsScalar, (_ActualLhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(LhsBlasTraits::NeedToConjugate), 00421 RhsScalar, (_ActualRhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(RhsBlasTraits::NeedToConjugate), 00422 (Dest::Flags&RowMajorBit) ? RowMajor : ColMajor>, 00423 _ActualLhsType, _ActualRhsType, Dest, BlockingType> GemmFunctor; 00424 00425 BlockingType blocking(dst.rows(), dst.cols(), lhs.cols()); 00426 00427 internal::parallelize_gemm<(Dest::MaxRowsAtCompileTime>32 || Dest::MaxRowsAtCompileTime==Dynamic)>(GemmFunctor(lhs, rhs, dst, actualAlpha, blocking), this->rows(), this->cols(), Dest::Flags&RowMajorBit); 00428 } 00429 }; 00430 00431 } // end namespace Eigen 00432 00433 #endif // EIGEN_GENERAL_MATRIX_MATRIX_H
Generated on Thu Nov 17 2022 22:01:28 by
1.7.2