start the wrapper burrito
tmatrix.h
- Committer:
- sepham
- Date:
- 2019-07-29
- Revision:
- 8:217f510db255
- Parent:
- 1:aac28ffd63ed
File content as of revision 8:217f510db255:
#ifndef TMATRIX_H #define TMATRIX_H /** * @file tmatrix.h * * @brief A dimension-templatized class for matrices of values. */ #include <cmath> #include <mbed.h> // Structures for static assert. http://www.boost.org template <bool x> struct STATIC_ASSERTION_FAILURE; template <> struct STATIC_ASSERTION_FAILURE<true> { enum { value = 1 }; }; template<int x> struct static_assert_test{}; #define STATIC_ASSERT( B ) \ typedef __attribute__((unused)) static_assert_test<sizeof(STATIC_ASSERTION_FAILURE<(bool)(B)>) > \ static_assert_typedef##__LINE__ #if DEBUG #define ERROR_CHECK(X) (X) #else #define ERROR_CHECK(X) #endif // Forward Decl. template <uint16_t, uint16_t, typename> class BasicMatrix; template <uint16_t, uint16_t, typename> class TMatrix; class TMatrixDummy { }; /** * @brief Class that layers on operator[] functionality for * typical matrices. */ template <uint16_t Rows, uint16_t Cols, typename value_type> class BasicIndexMatrix : public BasicMatrix<Rows,Cols,value_type> { typedef BasicMatrix<Rows,Cols,value_type> BaseType; protected: BasicIndexMatrix(TMatrixDummy d) : BaseType(d) { } public: BasicIndexMatrix() { } BasicIndexMatrix(const value_type* data) : BaseType(data) {} const value_type* operator[](uint16_t r) const { ERROR_CHECK(if (r >= Rows) { std::clog << "Invalid row index " << r << std::endl; return &BaseType::mData[0]; } ) return &BaseType::mData[r*Cols]; } value_type* operator[](uint16_t r) { ERROR_CHECK(if (r >= Rows) { std::clog << "Invalid row index " << r << std::endl; return &BaseType::mData[0]; } ) return &BaseType::mData[r*Cols]; } }; /** * @brief Specialization of BasicIndexMatrix that provides * single-indexing operator for column vectors. */ template <uint16_t Rows, typename value_type> class BasicIndexMatrix<Rows,1,value_type> : public BasicMatrix<Rows,1,value_type> { typedef BasicMatrix<Rows,1,value_type> BaseType; protected: BasicIndexMatrix(TMatrixDummy dummy) : BaseType(dummy) {} public: BasicIndexMatrix() { } BasicIndexMatrix(const value_type* data) : BaseType(data) {} value_type operator[](uint16_t r) const { ERROR_CHECK(if (r >= Rows) { std::clog << "Invalid vector index " << r << std::endl; return BaseType::mData[0]; }) return BaseType::mData[r]; } value_type& operator[](uint16_t r) { ERROR_CHECK(if (r >= Rows) { std::clog << "Invalid vector index " << r << std::endl; return BaseType::mData[0]; }) return BaseType::mData[r]; } value_type norm() const { return sqrt(norm2()); } value_type norm2() const { double normSum = 0; for (uint32_t i = 0; i < Rows; i++) { normSum += BaseType::mData[i]*BaseType::mData[i]; } return normSum; } /** @brief Returns matrix with vector elements on diagonal. */ TMatrix<Rows, Rows, value_type> diag(void) const { TMatrix<Rows, Rows, value_type> d; for (uint32_t i = 0; i < Rows; i++) d.element(i,i, BaseType::mData[i]); return d; } }; /** * @brief A dimension-templatized class for matrices of values. * * This template class generically supports any constant-sized * matrix of values. The @p Rows and @p Cols template parameters * define the size of the matrix at @e compile-time. Hence, the * size of the matrix cannot be chosen at runtime. However, the * dimensions are appropriately type-checked at compile-time where * possible. * * By default, the matrix contains values of type @p double. The @p * value_type template parameter may be selected to allow matrices * of integers or floats. * * @note At present, type cohersion between matrices with different * @p value_type parameters is not implemented. It is recommended * that matrices with value type @p double be used for all numerical * computation. * * Note that the special cases of row and column vectors are * subsumed by this class. */ template <uint16_t Rows, uint16_t Cols, typename value_type = double> class TMatrix : public BasicIndexMatrix<Rows, Cols, value_type> { typedef BasicIndexMatrix<Rows, Cols, value_type> BaseType; template <uint16_t R, uint16_t C, typename vt> friend class BasicMatrix; TMatrix(TMatrixDummy d) : BaseType(d) {} public: TMatrix() : BaseType() { } TMatrix(const value_type* data) : BaseType(data) {} }; /** * @brief Template specialization of TMatrix for Vector4. */ template <typename value_type> class TMatrix<4,1, value_type> : public BasicIndexMatrix<4,1, value_type> { typedef BasicIndexMatrix<4, 1, value_type> BaseType; template <uint16_t R, uint16_t C, typename vt> friend class BasicMatrix; TMatrix(TMatrixDummy d) : BaseType(d) {} public: TMatrix() { } TMatrix(const value_type* data) : BaseType(data) {} TMatrix(value_type a0, value_type a1, value_type a2, value_type a3) : BaseType(TMatrixDummy()) { BaseType::mData[0] = a0; BaseType::mData[1] = a1; BaseType::mData[2] = a2; BaseType::mData[3] = a3; } }; /** * @brief Template specialization of TMatrix for Vector3. */ template <typename value_type> class TMatrix<3,1, value_type> : public BasicIndexMatrix<3,1, value_type> { typedef BasicIndexMatrix<3, 1, value_type> BaseType; template <uint16_t R, uint16_t C, typename vt> friend class BasicMatrix; TMatrix(TMatrixDummy d) : BaseType(d) {} public: TMatrix() { } TMatrix(const value_type* data) : BaseType(data) {} TMatrix(value_type a0, value_type a1, value_type a2) : BaseType(TMatrixDummy()) { BaseType::mData[0] = a0; BaseType::mData[1] = a1; BaseType::mData[2] = a2; } TMatrix<3,1,value_type> cross(const TMatrix<3,1, value_type>& v) const { const TMatrix<3,1,value_type>& u = *this; return TMatrix<3,1,value_type>(u[1]*v[2]-u[2]*v[1], u[2]*v[0]-u[0]*v[2], u[0]*v[1]-u[1]*v[0]); } }; /** * @brief Template specialization of TMatrix for Vector2. */ template <typename value_type> class TMatrix<2,1, value_type> : public BasicIndexMatrix<2,1, value_type> { typedef BasicIndexMatrix<2, 1, value_type> BaseType; template <uint16_t R, uint16_t C, typename vt> friend class BasicMatrix; TMatrix(TMatrixDummy d) : BaseType(d) {} public: TMatrix() { } TMatrix(const value_type* data) : BaseType(data) {} TMatrix(value_type a0, value_type a1) : BaseType(TMatrixDummy()) { BaseType::mData[0] = a0; BaseType::mData[1] = a1; } }; /** * @brief Template specialization of TMatrix for a 1x1 vector. */ template <typename value_type> class TMatrix<1,1, value_type> : public BasicIndexMatrix<1,1, value_type> { typedef BasicIndexMatrix<1, 1, value_type> BaseType; template <uint16_t R, uint16_t C, typename vt> friend class BasicMatrix; TMatrix(TMatrixDummy dummy) : BaseType(dummy) {} public: TMatrix() { } TMatrix(const value_type* data) : BaseType(data) {} // explicit conversion from value_type explicit TMatrix(value_type a0) : BaseType(TMatrixDummy()) { // don't initialize BaseType::mData[0] = a0; } // implicit conversion to value_type operator value_type() const { return BaseType::mData[0]; } const TMatrix<1,1, value_type>& operator=(const TMatrix<1,1, value_type>& m) { BaseType::operator=(m); return *this; } double operator=(double a0) { BaseType::mData[0] = a0; return BaseType::mData[0]; } }; /** * @brief Base class implementing standard matrix functionality. */ template <uint16_t Rows, uint16_t Cols, typename value_type = double> class BasicMatrix { protected: value_type mData[Rows*Cols]; // Constructs uninitialized matrix. BasicMatrix(TMatrixDummy dummy) {} public: BasicMatrix() { // constructs zero matrix for (uint16_t i = 0; i < Rows*Cols; ++i) { mData[i] = 0; } } BasicMatrix(const value_type* data) { // constructs from array MBED_ASSERT(data); for (uint16_t i = 0; i < Rows*Cols; ++i) { mData[i] = data[i]; } } uint32_t rows() const { return Rows; } uint32_t columns() const { return Cols; } uint32_t elementCount() const { return Cols*Rows; } value_type element(uint16_t row, uint16_t col) const { ERROR_CHECK(if (row >= rows() || col >= columns()) { std::cerr << "Illegal read access: " << row << ", " << col << " in " << Rows << "x" << Cols << " matrix." << std::endl; return mData[0]; }) return mData[row*Cols+col]; } value_type& element(uint16_t row, uint16_t col) { ERROR_CHECK(if (row >= rows() || col >= columns()) { std::cerr << "Illegal read access: " << row << ", " << col << " in " << Rows << "x" << Cols << " matrix." << std::endl; return mData[0]; }) return mData[row*Cols+col]; } void element(uint16_t row, uint16_t col, value_type value) { ERROR_CHECK(if (row >= rows() || col >= columns()) { std::cerr << "Illegal write access: " << row << ", " << col << " in " << Rows << "x" << Cols << " matrix." << std::endl; return ; }) mData[row*Cols+col] = value; } TMatrix<Rows*Cols,1, value_type> vec() const { return TMatrix<Rows*Cols,1, value_type>(mData); } void vec(const TMatrix<Rows*Cols, 1, value_type>& vector) { for (uint32_t i = 0; i < Rows*Cols; i++) { mData[i] = vector.mData[i]; } } template <uint16_t R, uint16_t C, uint16_t RowRangeSize, uint16_t ColRangeSize> TMatrix<RowRangeSize, ColRangeSize, value_type> subMatrix(void) const { STATIC_ASSERT((R+RowRangeSize <= Rows) && (C+ColRangeSize <= Cols)); TMatrix<RowRangeSize, ColRangeSize, value_type> result; for (uint32_t i = 0; i < RowRangeSize; i++) { for (uint32_t j = 0; j < ColRangeSize; j++) { result.element(i,j, element(i+R, j+C)); } } return result; } template <uint16_t R, uint16_t C, uint16_t RowRangeSize, uint16_t ColRangeSize> void subMatrix(const TMatrix<RowRangeSize, ColRangeSize, value_type>& m) { STATIC_ASSERT((R+RowRangeSize <= Rows) && (C+ColRangeSize <= Cols)); for (uint32_t i = 0; i < RowRangeSize; i++) { for (uint32_t j = 0; j < ColRangeSize; j++) { element(i+R,j+C, m.element(i, j)); } } } /** * @brief Matrix multiplication operator. * @return A matrix where result is the matrix product * (*this) * rhs. */ template <uint16_t RhsCols> TMatrix<Rows, RhsCols, value_type> operator*(const TMatrix<Cols, RhsCols, value_type>& rhs) const { TMatrix<Rows, RhsCols, value_type> result; const value_type* rPtr = rhs.row(0); for (uint32_t i = 0; i < Rows; i++) { const value_type* rL = row(i); const value_type* cR = rPtr; value_type* resultRow = result.row(i); for (uint32_t j = 0; j < RhsCols; j++) { const value_type* rR = cR; // start at first element of right col const value_type* cL = rL; // start at first element of left row double r = 0; for (uint32_t k = 0; k < Cols; k++) { r += (*cL)*(*rR); cL++; // step to next col of left matrix rR += Cols; // step to next row of right matrix } resultRow[j] = r; cR++; // step to next column of right matrix } } return result; } /** * @brief Element-wise addition operator. * @return A matrix where result(i,j) = (*this)(i,j) + rhs(i,j). */ TMatrix<Rows, Cols, value_type> operator+(const TMatrix<Rows, Cols, value_type>& rhs) const { TMatrixDummy dummy; TMatrix<Rows, Cols, value_type> result(dummy); for (uint32_t i = 0; i < Rows*Cols; i++) { result.mData[i] = mData[i] + rhs.mData[i]; } return result; } /** * @brief Element-wise subtraction operator. * @return A matrix where result(i,j) = (*this)(i,j) - rhs(i,j). */ TMatrix<Rows, Cols, value_type> operator-(const TMatrix<Rows, Cols, value_type>& rhs) const { TMatrixDummy dummy; TMatrix<Rows, Cols, value_type> result(dummy); for (uint32_t i = 0; i < Rows*Cols; i++) { result.mData[i] = mData[i] - rhs.mData[i]; } return result; } /** * @brief Scalar multiplication operator. * @return A matrix where result(i,j) = (*this)(i,j) * s. */ TMatrix<Rows, Cols, value_type> operator*(value_type s) const { TMatrixDummy dummy; TMatrix<Rows, Cols, value_type> result(dummy); for (uint32_t i = 0; i < Rows*Cols; i++) { result.mData[i] = mData[i] * s; } return result; } /** * @brief Scalar division operator. * @return A matrix where result(i,j) = (*this)(i,j) / s. */ TMatrix<Rows, Cols, value_type> operator/(value_type s) const { TMatrixDummy dummy; TMatrix<Rows, Cols, value_type> result(dummy); for (uint32_t i = 0; i < Rows*Cols; i++) { result.mData[i] = mData[i] / s; } return result; } /** * @brief Unary negation operator. * @return A matrix where result(i,j) = -(*this)(i,j). */ TMatrix<Rows, Cols, value_type> operator-(void) const { TMatrixDummy dummy; TMatrix<Rows, Cols, value_type> result(dummy); for (uint32_t i = 0; i < Rows*Cols; i++) { result.mData[i] = -mData[i]; } return result; } /** * @brief Returns the matrix transpose of this matrix. * * @return A TMatrix of dimension @p Cols by @p Rows where * result(i,j) = (*this)(j,i) for each element. */ TMatrix<Cols, Rows, value_type> transpose(void) const { TMatrixDummy dummy; TMatrix<Cols, Rows, value_type> result(dummy); for (uint16_t i = 0; i < Rows; i++) { for (uint16_t j = 0; j < Cols; j++) { result.element(j,i, element(i,j)); } } return result; } /** * @brief Returns the diagonal elements of the matrix. * * @return A column vector @p v with dimension MIN(Rows,Cols) where * @p v[i] = (*this)[i][i]. */ TMatrix<(Rows>Cols)?Cols:Rows, 1, value_type> diag() const { TMatrixDummy dummy; TMatrix<(Rows>Cols)?Cols:Rows, 1> d(dummy); for (uint32_t i = 0; i < d.rows(); i++) { d[i] = mData[i*(Cols + 1)]; } return d; } /** @brief Returns the sum of the matrix entries. */ value_type sum(void) const { value_type s = 0; for (uint32_t i = 0; i < Rows*Cols; i++) { s += mData[i]; } return s; } /** @brief Returns the sum of the log of the matrix entries. */ value_type sumLog(void) const { value_type s = 0; for (uint32_t i = 0; i < Rows*Cols; i++) { s += log(mData[i]); } return s; } /** @brief Returns this vector with its elements replaced by their reciprocals. */ TMatrix<Rows,Cols, value_type> recip(void) const { TMatrixDummy dummy; TMatrix<Rows,Cols, value_type> result(dummy); for (uint32_t i = 0; i < Rows*Cols; i++) { result.mData[i] = 1.0/mData[i]; } return result; } /** @brief Returns this vector with its elements replaced by their reciprocals, * unless a value is less than epsilon, in which case it is left as zero. * * This is used mostly for pseudo-inverse computations. */ TMatrix<Rows,Cols, value_type> pseudoRecip(double epsilon = 1e-50) const { TMatrixDummy dummy; TMatrix<Rows,Cols, value_type> result(dummy); for (uint32_t i = 0; i < Rows*Cols; i++) { if (fabs(mData[i]) >= epsilon) { result.mData[i] = 1.0/mData[i]; } else { result.mData[i] = 0; } } return result; } /** * @brief Returns an "identity" matrix with dimensions given by the * class's template parameters. * * In the case that @p Rows != @p Cols, this matrix is simply the * one where the Aii elements for i < min(Rows, Cols) are 1, and all * other elements are 0. * * @return A TMatrix<Rows, Cols, value_type> with off-diagonal * elements set to 0, and diagonal elements set to 1. */ static TMatrix<Rows, Cols, value_type> identity() { TMatrix<Rows, Cols, value_type> id; for (uint16_t i = 0; i < Rows && i < Cols; i++) { id.element(i,i) = 1; } return id; } /** * @brief Returns a ones matrix with dimensions given by the * class's template parameters. * * @return A TMatrix<Rows, Cols, value_type> with all elements set * to 1. */ static TMatrix<Rows, Cols, value_type> one() { TMatrix<Rows, Cols, value_type> ones; for (uint16_t i = 0; i < Rows; i++) { for (uint16_t j = 0; j < Cols; j++) { ones.element(i,j, 1); } } return ones; } /** * @brief Returns a zero matrix with dimensions given by the * class's template parameters. * * @return A TMatrix<Rows, Cols, value_type> containing all 0. */ static TMatrix<Rows, Cols, value_type> zero() { return TMatrix<Rows, Cols, value_type>(); } value_type* row(uint32_t i) { return &mData[i*Cols]; } const value_type* row(uint32_t i) const { return &mData[i*Cols]; } /** * @brief Checks to see if any of this matrix's elements are NaN. */ bool hasNaN(void) const { for (uint32_t i = 0; i < Rows*Cols; i++) { if (isnan(mData[i])) { return true; } } return false; } void print(Stream & os, bool oneLine = false) const { for (uint16_t i = 0; i < Rows; i++) { for (uint16_t j = 0; j < Cols; j++) { os.printf("%.06f ", element(i, j)); } if(!oneLine) { os.printf("\n"); } } } private: template <uint16_t Rows2, uint16_t Cols2, typename value_type2> friend TMatrix<Rows2,Cols2,value_type2> operator*(double s, const TMatrix<Rows2, Cols2, value_type2>& m); }; typedef TMatrix<2,2, float> TMatrix2; typedef TMatrix<3,3, float> TMatrix3; typedef TMatrix<4,4, float> TMatrix4; typedef TMatrix<2,1, float> TVector2; typedef TMatrix<3,1, float> TVector3; typedef TMatrix<4,1, float> TVector4; // left-side scalar multiply template <uint16_t Rows, uint16_t Cols, typename value_type> TMatrix<Rows,Cols,value_type> operator*(double s, const TMatrix<Rows, Cols, value_type>& m) { return m * s; } #endif /* TMATRIX_H */