Contains added code for stm32-L432KC compatibility

Dependents:   BNO080_stm32_compatible

Embed: (wiki syntax)

« Back to documentation index

Show/hide line numbers tmatrix.h Source File

tmatrix.h

Go to the documentation of this file.
00001 #ifndef TMATRIX_H
00002 #define TMATRIX_H
00003 
00004 /** 
00005  * @file tmatrix.h
00006  *
00007  * @brief A dimension-templatized class for matrices of values.
00008  */
00009 #include <cmath>
00010 #include <mbed.h>
00011 
00012 // Structures for static assert.  http://www.boost.org
00013 template <bool x> struct STATIC_ASSERTION_FAILURE;
00014 template <> struct STATIC_ASSERTION_FAILURE<true> { enum { value = 1 }; };
00015 template<int x> struct static_assert_test{};
00016 #define STATIC_ASSERT( B ) \
00017   typedef __attribute__((unused)) static_assert_test<sizeof(STATIC_ASSERTION_FAILURE<(bool)(B)>) > \
00018   static_assert_typedef##__LINE__
00019 
00020 #if DEBUG
00021 #define ERROR_CHECK(X) (X) 
00022 #else
00023 #define ERROR_CHECK(X)
00024 #endif
00025 
00026 // Forward Decl.
00027 template <uint16_t, uint16_t, typename> class BasicMatrix;
00028 template <uint16_t, uint16_t, typename> class TMatrix;
00029 class TMatrixDummy { };
00030 
00031 /**
00032  * @brief Class that layers on operator[] functionality for
00033  * typical matrices.
00034  */
00035 template <uint16_t Rows, uint16_t Cols, typename value_type>
00036 class BasicIndexMatrix : public BasicMatrix<Rows,Cols,value_type> {
00037     typedef BasicMatrix<Rows,Cols,value_type> BaseType;
00038 protected:
00039     BasicIndexMatrix(TMatrixDummy d) : BaseType(d) { }
00040 public:
00041     BasicIndexMatrix() { }
00042     BasicIndexMatrix(const value_type* data) : BaseType(data) {}
00043 
00044     const value_type* operator[](uint16_t r) const {
00045         ERROR_CHECK(if (r >= Rows) {
00046             std::clog << "Invalid row index " << r << std::endl;
00047             return &BaseType::mData[0];
00048         }
00049         )
00050         return &BaseType::mData[r*Cols];
00051     }
00052 
00053     value_type* operator[](uint16_t r) {
00054         ERROR_CHECK(if (r >= Rows) {
00055             std::clog << "Invalid row index " << r << std::endl;
00056             return &BaseType::mData[0];
00057         }
00058         )
00059         return &BaseType::mData[r*Cols];
00060     }
00061 };
00062 
00063 /**
00064  * @brief Specialization of BasicIndexMatrix that provides
00065  * single-indexing operator for column vectors.
00066  */
00067 template <uint16_t Rows, typename value_type>
00068 class BasicIndexMatrix<Rows,1,value_type> :
00069         public BasicMatrix<Rows,1,value_type> {
00070     typedef BasicMatrix<Rows,1,value_type>  BaseType ;
00071 protected:
00072     BasicIndexMatrix(TMatrixDummy dummy) : BaseType (dummy) {}
00073 public:
00074     BasicIndexMatrix() { }
00075     BasicIndexMatrix(const value_type* data) : BaseType (data) {}
00076 
00077     value_type operator[](uint16_t r) const {
00078         ERROR_CHECK(if (r >= Rows) {
00079             std::clog << "Invalid vector index " << r << std::endl;
00080             return BaseType::mData[0];
00081         })
00082         return BaseType::mData[r];
00083     }
00084     value_type& operator[](uint16_t r) {
00085         ERROR_CHECK(if (r >= Rows) {
00086             std::clog << "Invalid vector index " << r << std::endl;
00087             return BaseType::mData[0];
00088         })
00089         return BaseType::mData[r];
00090     }
00091 
00092     value_type norm() const {
00093         return sqrt(norm2());
00094     }
00095 
00096     value_type norm2() const {
00097         double normSum = 0;
00098         for (uint32_t i = 0; i < Rows; i++) {
00099             normSum += BaseType::mData[i]*BaseType::mData[i];
00100         }
00101         return normSum;
00102     }
00103 
00104     /** @brief Returns matrix with vector elements on diagonal. */
00105     TMatrix<Rows, Rows, value_type> diag(void) const {
00106         TMatrix<Rows, Rows, value_type> d;
00107         for (uint32_t i = 0; i < Rows; i++) d.element(i,i, BaseType::mData[i]);
00108         return d;
00109     }
00110 
00111 };
00112 
00113 /**
00114  * @brief A dimension-templatized class for matrices of values.
00115  *
00116  * This template class generically supports any constant-sized
00117  * matrix of values.  The @p Rows and @p Cols template parameters
00118  * define the size of the matrix at @e compile-time.  Hence, the
00119  * size of the matrix cannot be chosen at runtime.  However, the
00120  * dimensions are appropriately type-checked at compile-time where
00121  * possible.
00122  *
00123  * By default, the matrix contains values of type @p double.  The @p
00124  * value_type template parameter may be selected to allow matrices
00125  * of integers or floats.
00126  *
00127  * @note At present, type cohersion between matrices with different
00128  * @p value_type parameters is not implemented.  It is recommended
00129  * that matrices with value type @p double be used for all numerical
00130  * computation.
00131  *
00132  * Note that the special cases of row and column vectors are
00133  * subsumed by this class.
00134  */
00135 template <uint16_t Rows, uint16_t Cols, typename value_type = double>
00136 class TMatrix : public BasicIndexMatrix<Rows, Cols, value_type> {
00137     typedef BasicIndexMatrix<Rows, Cols, value_type>  BaseType;
00138     template <uint16_t R, uint16_t C, typename vt> friend class BasicMatrix;
00139     TMatrix(TMatrixDummy d) : BaseType(d) {}
00140 
00141 public:
00142     TMatrix() : BaseType() { }
00143     TMatrix(const value_type* data) : BaseType(data) {}
00144 };
00145 
00146 /**
00147  * @brief Template specialization of TMatrix for Vector4.
00148  */
00149 template <typename value_type>
00150 class TMatrix<4,1, value_type> : public BasicIndexMatrix<4,1, value_type> {
00151     typedef BasicIndexMatrix<4, 1, value_type>   BaseType ;
00152     template <uint16_t R, uint16_t C, typename vt> friend class BasicMatrix;
00153     TMatrix(TMatrixDummy d) : BaseType (d) {}
00154 
00155 public:
00156     TMatrix() { }
00157     TMatrix(const value_type* data) : BaseType (data) {}
00158     TMatrix(value_type a0, value_type a1, value_type a2, value_type a3) : BaseType (TMatrixDummy()) {
00159         BaseType::mData[0] = a0;
00160         BaseType::mData[1] = a1;
00161         BaseType::mData[2] = a2;
00162         BaseType::mData[3] = a3;
00163     }
00164 };
00165 
00166 /**
00167  * @brief Template specialization of TMatrix for Vector3.
00168  */
00169 template <typename value_type>
00170 class TMatrix<3,1, value_type> : public BasicIndexMatrix<3,1, value_type> {
00171     typedef BasicIndexMatrix<3, 1, value_type>   BaseType ;
00172     template <uint16_t R, uint16_t C, typename vt> friend class BasicMatrix;
00173     TMatrix(TMatrixDummy d) : BaseType (d) {}
00174 
00175 public:
00176     TMatrix() { }
00177     TMatrix(const value_type* data) : BaseType (data) {}
00178     TMatrix(value_type a0, value_type a1, value_type a2) : BaseType (TMatrixDummy()) {
00179         BaseType::mData[0] = a0;
00180         BaseType::mData[1] = a1;
00181         BaseType::mData[2] = a2;
00182     }
00183 
00184     TMatrix<3,1,value_type> cross(const TMatrix<3,1, value_type>& v) const {
00185         const TMatrix<3,1,value_type>& u = *this;
00186         return TMatrix<3,1,value_type>(u[1]*v[2]-u[2]*v[1],
00187                                        u[2]*v[0]-u[0]*v[2],
00188                                        u[0]*v[1]-u[1]*v[0]);
00189     }
00190 };
00191 
00192 /**
00193  * @brief Template specialization of TMatrix for Vector2.
00194  */
00195 template <typename value_type>
00196 class TMatrix<2,1, value_type> : public BasicIndexMatrix<2,1, value_type> {
00197     typedef BasicIndexMatrix<2, 1, value_type>   BaseType ;
00198     template <uint16_t R, uint16_t C, typename vt> friend class BasicMatrix;
00199     TMatrix(TMatrixDummy d) : BaseType (d) {}
00200 
00201 public:
00202     TMatrix() { }
00203     TMatrix(const value_type* data) : BaseType (data) {}
00204     TMatrix(value_type a0, value_type a1) : BaseType (TMatrixDummy()) {
00205         BaseType::mData[0] = a0;
00206         BaseType::mData[1] = a1;
00207     }
00208 };
00209 
00210 /**
00211  * @brief Template specialization of TMatrix for a 1x1 vector.
00212  */
00213 template <typename value_type>
00214 class TMatrix<1,1, value_type> : public BasicIndexMatrix<1,1, value_type> {
00215     typedef BasicIndexMatrix<1, 1, value_type>   BaseType ;
00216     template <uint16_t R, uint16_t C, typename vt> friend class BasicMatrix;
00217     TMatrix(TMatrixDummy dummy) : BaseType (dummy) {}
00218 
00219 public:
00220     TMatrix() { }
00221     TMatrix(const value_type* data) : BaseType (data) {}
00222 
00223     // explicit conversion from value_type
00224     explicit TMatrix(value_type a0) : BaseType (TMatrixDummy()) { // don't initialize
00225         BaseType::mData[0] = a0;
00226     }
00227 
00228     // implicit conversion to value_type
00229     operator value_type() const {
00230         return BaseType::mData[0];
00231     }
00232 
00233     const TMatrix<1,1, value_type>&
00234     operator=(const TMatrix<1,1, value_type>& m) {
00235         BaseType::operator=(m);
00236         return *this;
00237     }
00238 
00239     double operator=(double a0) {
00240         BaseType::mData[0] = a0;
00241         return BaseType::mData[0];
00242     }
00243 };
00244 
00245 
00246 /**
00247  * @brief Base class implementing standard matrix functionality.
00248  */
00249 template <uint16_t Rows, uint16_t Cols, typename value_type = double>
00250 class BasicMatrix {
00251 protected:
00252     value_type mData[Rows*Cols];
00253 
00254     // Constructs uninitialized matrix.
00255     BasicMatrix(TMatrixDummy dummy) {}
00256 public:
00257     BasicMatrix() { // constructs zero matrix
00258         for (uint16_t i = 0; i < Rows*Cols; ++i) {
00259             mData[i] = 0;
00260         }
00261     }
00262     BasicMatrix(const value_type* data) { // constructs from array
00263         MBED_ASSERT(data);
00264 
00265         for (uint16_t i = 0; i < Rows*Cols; ++i) {
00266             mData[i] = data[i];
00267         }
00268     }
00269 
00270     uint32_t rows() const { return Rows; }
00271     uint32_t columns() const { return Cols; }
00272     uint32_t elementCount() const { return Cols*Rows; }
00273 
00274     value_type element(uint16_t row, uint16_t col) const {
00275         ERROR_CHECK(if (row >= rows() || col >= columns()) {
00276             std::cerr << "Illegal read access: " << row << ", " << col
00277                       << " in " << Rows << "x" << Cols << " matrix." << std::endl;
00278             return mData[0];
00279         })
00280         return mData[row*Cols+col];
00281     }
00282 
00283     value_type& element(uint16_t row, uint16_t col) {
00284         ERROR_CHECK(if (row >= rows() || col >= columns()) {
00285             std::cerr << "Illegal read access: " << row << ", " << col
00286                       << " in " << Rows << "x" << Cols << " matrix." << std::endl;
00287             return mData[0];
00288         })
00289         return mData[row*Cols+col];
00290     }
00291 
00292     void element(uint16_t row, uint16_t col, value_type value) {
00293         ERROR_CHECK(if (row >= rows() || col >= columns()) {
00294             std::cerr << "Illegal write access: " << row << ", " << col
00295                       << " in " << Rows << "x" << Cols << " matrix." << std::endl;
00296             return ;
00297         })
00298         mData[row*Cols+col] = value;
00299     }
00300 
00301     TMatrix<Rows*Cols,1, value_type> vec() const {
00302         return TMatrix<Rows*Cols,1, value_type>(mData);
00303     }
00304 
00305     void vec(const TMatrix<Rows*Cols, 1, value_type>& vector) {
00306         for (uint32_t i = 0; i < Rows*Cols; i++) {
00307             mData[i] = vector.mData[i];
00308         }
00309     }
00310 
00311     template <uint16_t R, uint16_t C, uint16_t RowRangeSize, uint16_t ColRangeSize>
00312     TMatrix<RowRangeSize, ColRangeSize, value_type> subMatrix(void) const {
00313         STATIC_ASSERT((R+RowRangeSize <= Rows) &&
00314                       (C+ColRangeSize <= Cols));
00315         TMatrix<RowRangeSize, ColRangeSize, value_type> result;
00316         for (uint32_t i = 0; i < RowRangeSize; i++) {
00317             for (uint32_t j = 0; j < ColRangeSize; j++) {
00318                 result.element(i,j, element(i+R, j+C));
00319             }
00320         }
00321         return result;
00322     }
00323 
00324 
00325     template <uint16_t R, uint16_t C, uint16_t RowRangeSize, uint16_t ColRangeSize>
00326     void subMatrix(const TMatrix<RowRangeSize, ColRangeSize, value_type>& m) {
00327         STATIC_ASSERT((R+RowRangeSize <= Rows) &&
00328                       (C+ColRangeSize <= Cols));
00329         for (uint32_t i = 0; i < RowRangeSize; i++) {
00330             for (uint32_t j = 0; j < ColRangeSize; j++) {
00331                 element(i+R,j+C, m.element(i, j));
00332             }
00333         }
00334     }
00335 
00336     /**
00337      * @brief Matrix multiplication operator.
00338      * @return A matrix where result is the matrix product
00339      * (*this) * rhs.
00340      */
00341     template <uint16_t RhsCols>
00342     TMatrix<Rows, RhsCols, value_type> operator*(const TMatrix<Cols, RhsCols, value_type>& rhs) const {
00343 
00344         TMatrix<Rows, RhsCols, value_type> result;
00345         const value_type* rPtr = rhs.row(0);
00346         for (uint32_t i = 0; i < Rows; i++)
00347         {
00348             const value_type* rL = row(i);
00349             const value_type* cR = rPtr;
00350             value_type* resultRow = result.row(i);
00351             for (uint32_t j = 0; j < RhsCols; j++)
00352             {
00353                 const value_type* rR = cR; // start at first element of right col
00354                 const value_type* cL = rL; // start at first element of left row
00355                 double r = 0;
00356                 for (uint32_t k = 0; k < Cols; k++)
00357                 {
00358                     r += (*cL)*(*rR);
00359                     cL++; // step to next col of left matrix
00360                     rR += Cols; // step to next row of right matrix
00361                 }
00362                 resultRow[j] = r;
00363                 cR++; // step to next column of right matrix
00364             }
00365         }
00366         return result;
00367     }
00368 
00369     /**
00370      * @brief Element-wise addition operator.
00371      * @return A matrix where result(i,j) = (*this)(i,j) + rhs(i,j).
00372      */
00373     TMatrix<Rows, Cols, value_type> operator+(const TMatrix<Rows, Cols, value_type>& rhs) const {
00374         TMatrixDummy dummy;
00375         TMatrix<Rows, Cols, value_type> result(dummy);
00376         for (uint32_t i = 0;  i < Rows*Cols;  i++) {
00377             result.mData[i] = mData[i] + rhs.mData[i];
00378         }
00379         return result;
00380     }
00381 
00382     /**
00383      * @brief Element-wise subtraction operator.
00384      * @return A matrix where result(i,j) = (*this)(i,j) - rhs(i,j).
00385      */
00386     TMatrix<Rows, Cols, value_type> operator-(const TMatrix<Rows, Cols, value_type>& rhs) const {
00387         TMatrixDummy dummy;
00388         TMatrix<Rows, Cols, value_type> result(dummy);
00389         for (uint32_t i = 0;  i < Rows*Cols;  i++) {
00390             result.mData[i] = mData[i] - rhs.mData[i];
00391         }
00392         return result;
00393     }
00394 
00395     /**
00396      * @brief Scalar multiplication operator.
00397      * @return A matrix where result(i,j) = (*this)(i,j) * s.
00398      */
00399     TMatrix<Rows, Cols, value_type> operator*(value_type s) const {
00400         TMatrixDummy dummy;
00401         TMatrix<Rows, Cols, value_type> result(dummy);
00402         for (uint32_t i = 0;  i < Rows*Cols;  i++) {
00403             result.mData[i] = mData[i] * s;
00404         }
00405         return result;
00406     }
00407 
00408     /**
00409      * @brief Scalar division operator.
00410      * @return A matrix where result(i,j) = (*this)(i,j) / s.
00411      */
00412     TMatrix<Rows, Cols, value_type> operator/(value_type s) const {
00413         TMatrixDummy dummy;
00414         TMatrix<Rows, Cols, value_type> result(dummy);
00415         for (uint32_t i = 0;  i < Rows*Cols;  i++) {
00416             result.mData[i] = mData[i] / s;
00417         }
00418         return result;
00419     }
00420 
00421     /**
00422      * @brief Unary negation operator.
00423      * @return A matrix where result(i,j) = -(*this)(i,j).
00424      */
00425     TMatrix<Rows, Cols, value_type> operator-(void) const {
00426         TMatrixDummy dummy;
00427         TMatrix<Rows, Cols, value_type> result(dummy);
00428         for (uint32_t i = 0;  i < Rows*Cols;  i++) {
00429             result.mData[i] = -mData[i];
00430         }
00431         return result;
00432     }
00433 
00434     /**
00435      * @brief Returns the matrix transpose of this matrix.
00436      *
00437      * @return A TMatrix of dimension @p Cols by @p Rows where
00438      * result(i,j) = (*this)(j,i) for each element.
00439      */
00440     TMatrix<Cols, Rows, value_type> transpose(void) const {
00441         TMatrixDummy dummy;
00442         TMatrix<Cols, Rows, value_type> result(dummy);
00443         for (uint16_t i = 0;  i < Rows;  i++) {
00444             for (uint16_t j = 0;  j < Cols;  j++) {
00445                 result.element(j,i, element(i,j));
00446             }
00447         }
00448         return result;
00449     }
00450 
00451     /**
00452      * @brief Returns the diagonal elements of the matrix.
00453      *
00454      * @return A column vector @p v with dimension MIN(Rows,Cols) where
00455      * @p v[i] = (*this)[i][i].
00456      */
00457     TMatrix<(Rows>Cols)?Cols:Rows, 1, value_type> diag() const {
00458         TMatrixDummy dummy;
00459         TMatrix<(Rows>Cols)?Cols:Rows, 1> d(dummy);
00460         for (uint32_t i = 0; i < d.rows(); i++) {
00461             d[i] = mData[i*(Cols + 1)];
00462         }
00463         return d;
00464     }
00465 
00466     /** @brief Returns the sum of the matrix entries. */
00467     value_type sum(void) const {
00468         value_type s = 0;
00469         for (uint32_t i = 0; i < Rows*Cols; i++) { s += mData[i]; }
00470         return s;
00471     }
00472     /** @brief Returns the sum of the log of the matrix entries.
00473      */
00474     value_type sumLog(void) const {
00475         value_type s = 0;
00476         for (uint32_t i = 0; i < Rows*Cols; i++) { s += log(mData[i]); }
00477         return s;
00478     }
00479 
00480     /** @brief Returns this vector with its elements replaced by their reciprocals. */
00481     TMatrix<Rows,Cols, value_type> recip(void) const {
00482         TMatrixDummy dummy;
00483         TMatrix<Rows,Cols, value_type> result(dummy);
00484         for (uint32_t i = 0; i < Rows*Cols; i++) {
00485             result.mData[i] = 1.0/mData[i];
00486         }
00487         return result;
00488     }
00489 
00490     /** @brief Returns this vector with its elements replaced by their reciprocals,
00491      * unless a value is less than epsilon, in which case it is left as zero.
00492      *
00493      * This is used mostly for pseudo-inverse computations.
00494      */
00495     TMatrix<Rows,Cols, value_type> pseudoRecip(double epsilon = 1e-50) const {
00496         TMatrixDummy dummy;
00497         TMatrix<Rows,Cols, value_type> result(dummy);
00498         for (uint32_t i = 0; i < Rows*Cols; i++) {
00499             if (fabs(mData[i]) >= epsilon) {
00500                 result.mData[i] = 1.0/mData[i];
00501             } else {
00502                 result.mData[i] = 0;
00503             }
00504         }
00505         return result;
00506     }
00507 
00508     /**
00509      * @brief Returns an "identity" matrix with dimensions given by the
00510      * class's template parameters.
00511      *
00512      * In the case that @p Rows != @p Cols, this matrix is simply the
00513      * one where the Aii elements for i < min(Rows, Cols) are 1, and all
00514      * other elements are 0.
00515      *
00516      * @return A TMatrix<Rows, Cols, value_type> with off-diagonal
00517      * elements set to 0, and diagonal elements set to 1.
00518      */
00519     static TMatrix<Rows, Cols, value_type> identity() {
00520         TMatrix<Rows, Cols, value_type> id;
00521         for (uint16_t i = 0; i < Rows && i < Cols; i++) {
00522             id.element(i,i) = 1;
00523         }
00524         return id;
00525     }
00526 
00527     /**
00528      * @brief Returns a ones matrix with dimensions given by the
00529      * class's template parameters.
00530      *
00531      * @return A TMatrix<Rows, Cols, value_type> with all elements set
00532      * to 1.
00533      */
00534     static TMatrix<Rows, Cols, value_type> one() {
00535         TMatrix<Rows, Cols, value_type> ones;
00536         for (uint16_t i = 0; i < Rows; i++) {
00537             for (uint16_t j = 0; j < Cols; j++) {
00538                 ones.element(i,j, 1);
00539             }
00540         }
00541         return ones;
00542     }
00543 
00544     /**
00545      * @brief Returns a zero matrix with dimensions given by the
00546      * class's template parameters.
00547      *
00548      * @return A TMatrix<Rows, Cols, value_type> containing all 0.
00549      */
00550     static TMatrix<Rows, Cols, value_type> zero() {
00551         return TMatrix<Rows, Cols, value_type>();
00552     }
00553 
00554     value_type* row(uint32_t i) { return &mData[i*Cols]; }
00555     const value_type* row(uint32_t i) const { return &mData[i*Cols]; }
00556 
00557     /**
00558      * @brief Checks to see if any of this matrix's elements are NaN.
00559      */
00560     bool hasNaN(void) const {
00561         for (uint32_t i = 0; i < Rows*Cols; i++) {
00562             if (isnan(mData[i])) {
00563                 return true;
00564             }
00565         }
00566         return false;
00567     }
00568 
00569     void print(Stream & os, bool oneLine = false) const {
00570         for (uint16_t i = 0; i < Rows; i++) {
00571             for (uint16_t j = 0; j < Cols; j++) {
00572                 os.printf("%.06f ", element(i, j));
00573             }
00574 
00575             if(!oneLine)
00576             {
00577                 os.printf("\n");
00578             }
00579         }
00580     }
00581 
00582 private:
00583 
00584     template <uint16_t Rows2, uint16_t Cols2, typename value_type2>
00585     friend TMatrix<Rows2,Cols2,value_type2> operator*(double s, const TMatrix<Rows2, Cols2, value_type2>& m);
00586 };
00587 
00588 typedef TMatrix<2,2, float>  TMatrix2;
00589 typedef TMatrix<3,3, float>  TMatrix3;
00590 typedef TMatrix<4,4, float>  TMatrix4;
00591 typedef TMatrix<2,1, float>  TVector2;
00592 typedef TMatrix<3,1, float>   TVector3 ;
00593 typedef TMatrix<4,1, float>  TVector4;
00594 
00595 // left-side scalar multiply
00596 template <uint16_t Rows, uint16_t Cols, typename value_type>
00597 TMatrix<Rows,Cols,value_type> operator*(double s, const TMatrix<Rows, Cols, value_type>& m) {
00598     return m * s;
00599 }
00600 
00601 
00602 #endif /* TMATRIX_H */