Important changes to repositories hosted on mbed.com
Mbed hosted mercurial repositories are deprecated and are due to be permanently deleted in July 2026.
To keep a copy of this software download the repository Zip archive or clone locally using Mercurial.
It is also possible to export all your personal repositories from the account settings page.
SolveTriangular.h
00001 // This file is part of Eigen, a lightweight C++ template library 00002 // for linear algebra. 00003 // 00004 // Copyright (C) 2008-2009 Gael Guennebaud <gael.guennebaud@inria.fr> 00005 // 00006 // This Source Code Form is subject to the terms of the Mozilla 00007 // Public License v. 2.0. If a copy of the MPL was not distributed 00008 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 00009 00010 #ifndef EIGEN_SOLVETRIANGULAR_H 00011 #define EIGEN_SOLVETRIANGULAR_H 00012 00013 namespace Eigen { 00014 00015 namespace internal { 00016 00017 // Forward declarations: 00018 // The following two routines are implemented in the products/TriangularSolver*.h files 00019 template<typename LhsScalar, typename RhsScalar, typename Index, int Side, int Mode, bool Conjugate, int StorageOrder> 00020 struct triangular_solve_vector; 00021 00022 template <typename Scalar, typename Index, int Side, int Mode, bool Conjugate, int TriStorageOrder, int OtherStorageOrder> 00023 struct triangular_solve_matrix; 00024 00025 // small helper struct extracting some traits on the underlying solver operation 00026 template<typename Lhs, typename Rhs, int Side> 00027 class trsolve_traits 00028 { 00029 private: 00030 enum { 00031 RhsIsVectorAtCompileTime = (Side==OnTheLeft ? Rhs::ColsAtCompileTime : Rhs::RowsAtCompileTime)==1 00032 }; 00033 public: 00034 enum { 00035 Unrolling = (RhsIsVectorAtCompileTime && Rhs::SizeAtCompileTime != Dynamic && Rhs::SizeAtCompileTime <= 8) 00036 ? CompleteUnrolling : NoUnrolling, 00037 RhsVectors = RhsIsVectorAtCompileTime ? 1 : Dynamic 00038 }; 00039 }; 00040 00041 template<typename Lhs, typename Rhs, 00042 int Side, // can be OnTheLeft/OnTheRight 00043 int Mode, // can be Upper/Lower | UnitDiag 00044 int Unrolling = trsolve_traits<Lhs,Rhs,Side>::Unrolling, 00045 int RhsVectors = trsolve_traits<Lhs,Rhs,Side>::RhsVectors 00046 > 00047 struct triangular_solver_selector; 00048 00049 template<typename Lhs, typename Rhs, int Side, int Mode> 00050 struct triangular_solver_selector<Lhs,Rhs,Side,Mode,NoUnrolling,1> 00051 { 00052 typedef typename Lhs::Scalar LhsScalar; 00053 typedef typename Rhs::Scalar RhsScalar; 00054 typedef blas_traits<Lhs> LhsProductTraits; 00055 typedef typename LhsProductTraits::ExtractType ActualLhsType; 00056 typedef Map<Matrix<RhsScalar,Dynamic,1>, Aligned> MappedRhs; 00057 static void run(const Lhs& lhs, Rhs& rhs) 00058 { 00059 ActualLhsType actualLhs = LhsProductTraits::extract(lhs); 00060 00061 // FIXME find a way to allow an inner stride if packet_traits<Scalar>::size==1 00062 00063 bool useRhsDirectly = Rhs::InnerStrideAtCompileTime==1 || rhs.innerStride()==1; 00064 00065 ei_declare_aligned_stack_constructed_variable(RhsScalar,actualRhs,rhs.size(), 00066 (useRhsDirectly ? rhs.data() : 0)); 00067 00068 if(!useRhsDirectly) 00069 MappedRhs(actualRhs,rhs.size()) = rhs; 00070 00071 triangular_solve_vector<LhsScalar, RhsScalar, typename Lhs::Index, Side, Mode, LhsProductTraits::NeedToConjugate, 00072 (int(Lhs::Flags) & RowMajorBit) ? RowMajor : ColMajor> 00073 ::run(actualLhs.cols(), actualLhs.data(), actualLhs.outerStride(), actualRhs); 00074 00075 if(!useRhsDirectly) 00076 rhs = MappedRhs(actualRhs, rhs.size()); 00077 } 00078 }; 00079 00080 // the rhs is a matrix 00081 template<typename Lhs, typename Rhs, int Side, int Mode> 00082 struct triangular_solver_selector<Lhs,Rhs,Side,Mode,NoUnrolling,Dynamic> 00083 { 00084 typedef typename Rhs::Scalar Scalar; 00085 typedef typename Rhs::Index Index; 00086 typedef blas_traits<Lhs> LhsProductTraits; 00087 typedef typename LhsProductTraits::DirectLinearAccessType ActualLhsType; 00088 00089 static void run(const Lhs& lhs, Rhs& rhs) 00090 { 00091 typename internal::add_const_on_value_type<ActualLhsType>::type actualLhs = LhsProductTraits::extract(lhs); 00092 00093 const Index size = lhs.rows(); 00094 const Index othersize = Side==OnTheLeft? rhs.cols() : rhs.rows(); 00095 00096 typedef internal::gemm_blocking_space<(Rhs::Flags&RowMajorBit) ? RowMajor : ColMajor,Scalar,Scalar, 00097 Rhs::MaxRowsAtCompileTime, Rhs::MaxColsAtCompileTime, Lhs::MaxRowsAtCompileTime,4> BlockingType; 00098 00099 BlockingType blocking(rhs.rows(), rhs.cols(), size); 00100 00101 triangular_solve_matrix<Scalar,Index,Side,Mode,LhsProductTraits::NeedToConjugate,(int(Lhs::Flags) & RowMajorBit) ? RowMajor : ColMajor, 00102 (Rhs::Flags&RowMajorBit) ? RowMajor : ColMajor> 00103 ::run(size, othersize, &actualLhs.coeffRef(0,0), actualLhs.outerStride(), &rhs.coeffRef(0,0), rhs.outerStride(), blocking); 00104 } 00105 }; 00106 00107 /*************************************************************************** 00108 * meta-unrolling implementation 00109 ***************************************************************************/ 00110 00111 template<typename Lhs, typename Rhs, int Mode, int Index, int Size, 00112 bool Stop = Index==Size> 00113 struct triangular_solver_unroller; 00114 00115 template<typename Lhs, typename Rhs, int Mode, int Index, int Size> 00116 struct triangular_solver_unroller<Lhs,Rhs,Mode,Index,Size,false> { 00117 enum { 00118 IsLower = ((Mode&Lower)==Lower), 00119 RowIndex = IsLower ? Index : Size - Index - 1, 00120 S = IsLower ? 0 : RowIndex+1 00121 }; 00122 static void run(const Lhs& lhs, Rhs& rhs) 00123 { 00124 if (Index>0) 00125 rhs.coeffRef(RowIndex) -= lhs.row(RowIndex).template segment<Index>(S).transpose() 00126 .cwiseProduct(rhs.template segment<Index>(S)).sum(); 00127 00128 if(!(Mode & UnitDiag)) 00129 rhs.coeffRef(RowIndex) /= lhs.coeff(RowIndex,RowIndex); 00130 00131 triangular_solver_unroller<Lhs,Rhs,Mode,Index+1,Size>::run(lhs,rhs); 00132 } 00133 }; 00134 00135 template<typename Lhs, typename Rhs, int Mode, int Index, int Size> 00136 struct triangular_solver_unroller<Lhs,Rhs,Mode,Index,Size,true> { 00137 static void run(const Lhs&, Rhs&) {} 00138 }; 00139 00140 template<typename Lhs, typename Rhs, int Mode> 00141 struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,CompleteUnrolling,1> { 00142 static void run(const Lhs& lhs, Rhs& rhs) 00143 { triangular_solver_unroller<Lhs,Rhs,Mode,0,Rhs::SizeAtCompileTime>::run(lhs,rhs); } 00144 }; 00145 00146 template<typename Lhs, typename Rhs, int Mode> 00147 struct triangular_solver_selector<Lhs,Rhs,OnTheRight,Mode,CompleteUnrolling,1> { 00148 static void run(const Lhs& lhs, Rhs& rhs) 00149 { 00150 Transpose<const Lhs> trLhs(lhs); 00151 Transpose<Rhs> trRhs(rhs); 00152 00153 triangular_solver_unroller<Transpose<const Lhs>,Transpose<Rhs>, 00154 ((Mode&Upper)==Upper ? Lower : Upper) | (Mode&UnitDiag), 00155 0,Rhs::SizeAtCompileTime>::run(trLhs,trRhs); 00156 } 00157 }; 00158 00159 } // end namespace internal 00160 00161 /*************************************************************************** 00162 * TriangularView methods 00163 ***************************************************************************/ 00164 00165 /** "in-place" version of TriangularView::solve() where the result is written in \a other 00166 * 00167 * \warning The parameter is only marked 'const' to make the C++ compiler accept a temporary expression here. 00168 * This function will const_cast it, so constness isn't honored here. 00169 * 00170 * See TriangularView:solve() for the details. 00171 */ 00172 template<typename MatrixType, unsigned int Mode> 00173 template<int Side, typename OtherDerived> 00174 void TriangularView<MatrixType,Mode>::solveInPlace(const MatrixBase<OtherDerived>& _other) const 00175 { 00176 OtherDerived& other = _other.const_cast_derived(); 00177 eigen_assert( cols() == rows() && ((Side==OnTheLeft && cols() == other.rows()) || (Side==OnTheRight && cols() == other.cols())) ); 00178 eigen_assert((!(Mode & ZeroDiag)) && bool(Mode & (Upper|Lower))); 00179 00180 enum { copy = internal::traits<OtherDerived>::Flags & RowMajorBit && OtherDerived::IsVectorAtCompileTime }; 00181 typedef typename internal::conditional<copy, 00182 typename internal::plain_matrix_type_column_major<OtherDerived>::type, OtherDerived&>::type OtherCopy; 00183 OtherCopy otherCopy(other); 00184 00185 internal::triangular_solver_selector<MatrixType, typename internal::remove_reference<OtherCopy>::type, 00186 Side, Mode>::run(nestedExpression(), otherCopy); 00187 00188 if (copy) 00189 other = otherCopy; 00190 } 00191 00192 /** \returns the product of the inverse of \c *this with \a other, \a *this being triangular. 00193 * 00194 * This function computes the inverse-matrix matrix product inverse(\c *this) * \a other if 00195 * \a Side==OnTheLeft (the default), or the right-inverse-multiply \a other * inverse(\c *this) if 00196 * \a Side==OnTheRight. 00197 * 00198 * The matrix \c *this must be triangular and invertible (i.e., all the coefficients of the 00199 * diagonal must be non zero). It works as a forward (resp. backward) substitution if \c *this 00200 * is an upper (resp. lower) triangular matrix. 00201 * 00202 * Example: \include MatrixBase_marked.cpp 00203 * Output: \verbinclude MatrixBase_marked.out 00204 * 00205 * This function returns an expression of the inverse-multiply and can works in-place if it is assigned 00206 * to the same matrix or vector \a other. 00207 * 00208 * For users coming from BLAS, this function (and more specifically solveInPlace()) offer 00209 * all the operations supported by the \c *TRSV and \c *TRSM BLAS routines. 00210 * 00211 * \sa TriangularView::solveInPlace() 00212 */ 00213 template<typename Derived, unsigned int Mode> 00214 template<int Side, typename Other> 00215 const internal::triangular_solve_retval<Side,TriangularView<Derived,Mode>,Other> 00216 TriangularView<Derived,Mode>::solve(const MatrixBase<Other>& other) const 00217 { 00218 return internal::triangular_solve_retval<Side,TriangularView,Other>(*this, other.derived()); 00219 } 00220 00221 namespace internal { 00222 00223 00224 template<int Side, typename TriangularType, typename Rhs> 00225 struct traits<triangular_solve_retval<Side, TriangularType, Rhs> > 00226 { 00227 typedef typename internal::plain_matrix_type_column_major<Rhs>::type ReturnType; 00228 }; 00229 00230 template<int Side, typename TriangularType, typename Rhs> struct triangular_solve_retval 00231 : public ReturnByValue<triangular_solve_retval<Side, TriangularType, Rhs> > 00232 { 00233 typedef typename remove_all<typename Rhs::Nested>::type RhsNestedCleaned; 00234 typedef ReturnByValue<triangular_solve_retval> Base; 00235 typedef typename Base::Index Index; 00236 00237 triangular_solve_retval(const TriangularType& tri, const Rhs& rhs) 00238 : m_triangularMatrix(tri), m_rhs(rhs) 00239 {} 00240 00241 inline Index rows() const { return m_rhs.rows(); } 00242 inline Index cols() const { return m_rhs.cols(); } 00243 00244 template<typename Dest> inline void evalTo(Dest& dst) const 00245 { 00246 if(!(is_same<RhsNestedCleaned,Dest>::value && extract_data(dst) == extract_data(m_rhs))) 00247 dst = m_rhs; 00248 m_triangularMatrix.template solveInPlace<Side>(dst); 00249 } 00250 00251 protected: 00252 const TriangularType& m_triangularMatrix; 00253 typename Rhs::Nested m_rhs; 00254 }; 00255 00256 } // namespace internal 00257 00258 } // end namespace Eigen 00259 00260 #endif // EIGEN_SOLVETRIANGULAR_H
Generated on Thu Nov 17 2022 22:01:30 by
1.7.2