#include "HAPS_EKF.hpp"
#include "Matrix.h"
#include "MatrixMath.h"
#include <cmath>
#include "Vector3.hpp"


using namespace std;

HAPS_EKF::HAPS_EKF()
    :qhat(4,1), qhat_gyro(4,1), Phat(4,4), Qgyro(3,3), Racc(3,3), Rmag(3,3), D(3,3)
{ 
    qhat << 1 << 0 << 0 << 0;
    
    Phat.add(1,1,0.05);
    Phat.add(2,2,0.05);
    Phat.add(3,3,0.05);
    Phat.add(4,4,0.05);
    
    D.add(1,1,1.0);
    D.add(2,2,1.0);
    D.add(3,3,1.0);
    
    Qgyro.add(1,1,0.0224);
    Qgyro.add(2,2,0.0224); 
    Qgyro.add(3,3,0.0224);  
    
    Racc.add(1,1,0.0330*200);
    Racc.add(2,2,0.0330*200);
    Racc.add(3,3,0.0330*200);
    
    Rmag.add(1,1,1.0);
    Rmag.add(2,2,1.0);
    Rmag.add(3,3,1.0); 
}

void HAPS_EKF::updateBetweenMeasures(Vector3 gyro, float att_dt)
{
    float q0 = qhat.getNumber( 1, 1 );
    float q1 = qhat.getNumber( 2, 1 ); 
    float q2 = qhat.getNumber( 3, 1 ); 
    float q3 = qhat.getNumber( 4, 1 );
    
    Matrix B(4,3);
    B   << q1  << q2 << q3
        <<-q0  << q3 <<-q2
        <<-q3  <<-q0 << q1             
        << q2  <<-q1 <<-q0;
    B *= 0.5f;
     
    Matrix phi(4,4);
    phi <<  1.0               << -gyro.x*0.5*att_dt <<-gyro.y*0.5*att_dt <<-gyro.z*0.5*att_dt
        <<  gyro.x*0.5*att_dt << 1.0                << gyro.z*0.5*att_dt <<-gyro.y*0.5*att_dt
        <<  gyro.y*0.5*att_dt << -gyro.z*0.5*att_dt << 1.0               << gyro.x*0.5*att_dt
        <<  gyro.z*0.5*att_dt <<  gyro.y*0.5*att_dt <<-gyro.x*0.5*att_dt << 1.0;
    
    qhat = phi*qhat;
    float qnorm;
    qnorm = sqrt(MatrixMath::dot(MatrixMath::Transpose(qhat),qhat));
    qhat *= (1.0f/ qnorm);
    
    qhat_gyro = phi*qhat_gyro;
    qnorm = sqrt(MatrixMath::dot(MatrixMath::Transpose(qhat_gyro),qhat_gyro));
    qhat_gyro *= (1.0f/ qnorm);
    
    Phat = phi*Phat*MatrixMath::Transpose(phi)+B*Qgyro*MatrixMath::Transpose(B);
    
    q0 = qhat.getNumber( 1, 1 );
    q1 = qhat.getNumber( 2, 1 ); 
    q2 = qhat.getNumber( 3, 1 ); 
    q3 = qhat.getNumber( 4, 1 );
    
    D.add(1,1,q0*q0 + q1*q1 - q2*q2 - q3*q3);
    D.add(1,2,2*(q1*q2 + q0*q3));
    D.add(1,3,2*(q1*q3 - q0*q2));
    D.add(2,1,2*(q1*q2 - q0*q3));
    D.add(2,2,q0*q0 - q1*q1 + q2*q2 - q3*q3);
    D.add(2,3,2*(q2*q3 + q0*q1));
    D.add(3,1,2*(q1*q3 + q0*q2));
    D.add(3,2,2*(q2*q3 - q0*q1));
    D.add(3,3,q0*q0 - q1*q1 - q2*q2 + q3*q3);
}

void HAPS_EKF::updateAcrossMeasures(Vector3 _v, Vector3 _u, Matrix& R)
{        
    Matrix u(3,1);
    Matrix v(3,1);
    
    u << _u.x << _u.y << _u.z;
    v << _v.x << _v.y << _v.z;
    
    float q0 = qhat.getNumber( 1, 1 );
    float q1 = qhat.getNumber( 2, 1 ); 
    float q2 = qhat.getNumber( 3, 1 ); 
    float q3 = qhat.getNumber( 4, 1 ); 
    
    Matrix A1(3,3);
    A1 << q0 << q3 << -q2
       <<-q3 << q0 << q1
       <<q2  <<-q1 <<q0;
    A1 *= 2.0f;
    
    Matrix A2(3,3);   
    A2 << q1 << q2 << q3
       << q2 <<-q1 << q0
       << q3 <<-q0 <<-q1;
    A2 *= 2.0f;
    
    Matrix A3(3,3);
    A3 <<-q2 << q1 <<-q0
       << q1 << q2 << q3
       << q0 << q3 <<-q2;
    A3 *= 2.0f;
    
    Matrix A4(3,3);
    A4 <<-q3 << q0 << q1
       <<-q0 <<-q3 << q2
       << q1 << q2 << q3;
    A4 *= 2.0f;
    
    Matrix H(3,4);

    Matrix ab1(A1*u);
    Matrix ab2(A2*u);
    Matrix ab3(A3*u);
    Matrix ab4(A4*u);

    H << ab1.getNumber( 1, 1 ) << ab2.getNumber( 1, 1 ) << ab3.getNumber( 1, 1 ) << ab4.getNumber( 1, 1 )
      << ab1.getNumber( 2, 1 ) << ab2.getNumber( 2, 1 ) << ab3.getNumber( 2, 1 ) << ab4.getNumber( 2, 1 )
      << ab1.getNumber( 3, 1 ) << ab2.getNumber( 3, 1 ) << ab3.getNumber( 3, 1 ) << ab4.getNumber( 3, 1 );
    
    
    Matrix K(4,3);
    K = (Phat*MatrixMath::Transpose(H))*MatrixMath::Inv(H*Phat*MatrixMath::Transpose(H)+R);
    
    Matrix dq(4,1);
    dq = K*(v-D*u);
    qhat = qhat+dq;
    
    float qnorm = sqrt(MatrixMath::dot(MatrixMath::Transpose(qhat),qhat));
    qhat *= (1.0f/ qnorm);
    
    Matrix eye4(4,4);
    eye4 << 1 << 0 << 0 << 0
         << 0 << 1 << 0 << 0
         << 0 << 0 << 1 << 0
         << 0 << 0 << 0 << 1;
    Phat = (eye4-K*H)*Phat*MatrixMath::Transpose(eye4-K*H)+K*R*MatrixMath::Transpose(K);
}

void HAPS_EKF::computeAngles(Vector3& rpy, Vector3& rpy_g, Vector3 rpy_align)
{
    float q0 = qhat.getNumber( 1, 1 );
    float q1 = qhat.getNumber( 2, 1 ); 
    float q2 = qhat.getNumber( 3, 1 ); 
    float q3 = qhat.getNumber( 4, 1 ); 
    rpy.x = atan2f(q0*q1 + q2*q3, 0.5f - q1*q1 - q2*q2)-rpy_align.x;
    rpy.y = asinf(-2.0f * (q1*q3 - q0*q2))-rpy_align.y;
    rpy.z = atan2f(q1*q2 + q0*q3, 0.5f - q2*q2 - q3*q3);
    
    q0 = qhat_gyro.getNumber( 1, 1 );
    q1 = qhat_gyro.getNumber( 2, 1 ); 
    q2 = qhat_gyro.getNumber( 3, 1 ); 
    q3 = qhat_gyro.getNumber( 4, 1 ); 
    rpy_g.x = atan2f(q0*q1 + q2*q3, 0.5f - q1*q1 - q2*q2)-rpy_align.x;
    rpy_g.y = asinf(-2.0f * (q1*q3 - q0*q2))-rpy_align.y;
    rpy_g.z = atan2f(q1*q2 + q0*q3, 0.5f - q2*q2 - q3*q3);
    
}

void HAPS_EKF::triad(Vector3 fb, Vector3 fn, Vector3 mb, Vector3 mn){
    Matrix W1(3,1);
    W1 << fb.x << fb.y << fb.z;
    Matrix W2(3,1);
    W2 << mb.x << mb.y << mb.z;
    
    Matrix V1(3,1);
    V1 << fn.x << fn.y << fn.z;
    Matrix V2(3,1);
    V2 << mn.x << mn.y << mn.z;
    

    Matrix Ou2(3,1);
    Ou2 << W1.getNumber( 2, 1 )*W2.getNumber( 3, 1 )-W1.getNumber( 3, 1 )*W2.getNumber( 2, 1 ) << W1.getNumber( 3, 1 )*W2.getNumber( 1, 1 )-W1.getNumber( 1, 1 )*W2.getNumber( 3, 1 ) << W1.getNumber( 1, 1 )*W2.getNumber( 2, 1 )-W1.getNumber( 2, 1 )*W2.getNumber( 1, 1 );
    Ou2 *= 1.0/sqrt(MatrixMath::dot(MatrixMath::Transpose(Ou2),Ou2));
    Matrix Ou3(3,1);
    Ou3 << W1.getNumber( 2, 1 )*Ou2.getNumber( 3, 1 )-W1.getNumber( 3, 1 )*Ou2.getNumber( 2, 1 ) << W1.getNumber( 3, 1 )*Ou2.getNumber( 1, 1 )-W1.getNumber( 1, 1 )*Ou2.getNumber( 3, 1 ) << W1.getNumber( 1, 1 )*Ou2.getNumber( 2, 1 )-W1.getNumber( 2, 1 )*Ou2.getNumber( 1, 1 );
    Ou3 *= 1.0/sqrt(MatrixMath::dot(MatrixMath::Transpose(Ou3),Ou3));
    Matrix R2(3,1);
    R2  << V1.getNumber( 2, 1 )*V2.getNumber( 3, 1 )-V1.getNumber( 3, 1 )*V2.getNumber( 2, 1 ) << V1.getNumber( 3, 1 )*V2.getNumber( 1, 1 )-V1.getNumber( 1, 1 )*V2.getNumber( 3, 1 ) << V1.getNumber( 1, 1 )*V2.getNumber( 2, 1 )-V1.getNumber( 2, 1 )*V2.getNumber( 1, 1 );
    R2 *= 1.0/sqrt(MatrixMath::dot(MatrixMath::Transpose(R2),R2));
    Matrix R3(3,1);
    R3  << V1.getNumber( 2, 1 )*R2.getNumber( 3, 1 )-V1.getNumber( 3, 1 )*R2.getNumber( 2, 1 ) << V1.getNumber( 3, 1 )*R2.getNumber( 1, 1 )-V1.getNumber( 1, 1 )*R2.getNumber( 3, 1 ) << V1.getNumber( 1, 1 )*R2.getNumber( 2, 1 )-V1.getNumber( 2, 1 )*R2.getNumber( 1, 1 );
    R3 *= 1.0/sqrt(MatrixMath::dot(MatrixMath::Transpose(R3),R3));

    Matrix Mou(3,3);
    Mou << W1.getNumber( 1, 1 ) << Ou2.getNumber( 1, 1 ) << Ou3.getNumber( 1, 1 )
        << W1.getNumber( 2, 1 ) << Ou2.getNumber( 2, 1 ) << Ou3.getNumber( 2, 1 )
        << W1.getNumber( 3, 1 ) << Ou2.getNumber( 3, 1 ) << Ou3.getNumber( 3, 1 );
    Matrix Mr(3,3);
    Mr << V1.getNumber( 1, 1 ) << R2.getNumber( 1, 1 ) << R3.getNumber( 1, 1 )
       << V1.getNumber( 2, 1 ) << R2.getNumber( 2, 1 ) << R3.getNumber( 2, 1 )
       << V1.getNumber( 3, 1 ) << R2.getNumber( 3, 1 ) << R3.getNumber( 3, 1 );
       
    Matrix Cbn = Mr*MatrixMath::Transpose(Mou);

    float sqtrp1 = sqrt(1.0+Cbn.getNumber( 1, 1 )+Cbn.getNumber( 2, 2 )+Cbn.getNumber( 3, 3 ));

    qhat.add(1,1,0.5*sqtrp1);
    qhat.add(2,1,-(Cbn.getNumber( 2, 3 )-Cbn.getNumber( 3, 2 ))/2.0/sqtrp1);
    qhat.add(3,1,-(Cbn.getNumber( 3, 1 )-Cbn.getNumber( 1, 3 ))/2.0/sqtrp1);
    qhat.add(4,1,-(Cbn.getNumber( 1, 2 )-Cbn.getNumber( 2, 1 ))/2.0/sqtrp1);
   
    float qnorm = sqrt(MatrixMath::dot(MatrixMath::Transpose(qhat),qhat));
    qhat *= (1.0f/ qnorm);
    
    qhat_gyro = qhat;
    
    float q0 = qhat.getNumber( 1, 1 );
    float q1 = qhat.getNumber( 2, 1 ); 
    float q2 = qhat.getNumber( 3, 1 ); 
    float q3 = qhat.getNumber( 4, 1 ); 
    
    D.add(1,1,q0*q0 + q1*q1 - q2*q2 - q3*q3);
    D.add(1,2,2*(q1*q2 + q0*q3));
    D.add(1,3,2*(q1*q3 - q0*q2));
    D.add(2,1,2*(q1*q2 - q0*q3));
    D.add(2,2,q0*q0 - q1*q1 + q2*q2 - q3*q3);
    D.add(2,3,2*(q2*q3 + q0*q1));
    D.add(3,1,2*(q1*q3 + q0*q2));
    D.add(3,2,2*(q2*q3 - q0*q1));
    D.add(3,3,q0*q0 - q1*q1 - q2*q2 + q3*q3);
}

Vector3 HAPS_EKF::calcMagRef(Vector3 m)
{
    float _x, _y, _z;
    Matrix magvec(3,1);
    magvec << m.x << m.y << m.z;
    Matrix magnedvec = MatrixMath::Transpose(D)*magvec;
    _x = sqrt(magnedvec(1,1)*magnedvec(1,1)+magnedvec(2,1)*magnedvec(2,1));
    _y = 0.0f;
    _z = magnedvec(3,1);
    return Vector3(_x, _y, _z);
}

Vector3 HAPS_EKF::calcDynAcc(Vector3 LPacc, Vector3 accref)
{
    float _x, _y, _z;
    _x = LPacc.x-(D.getNumber( 1, 1 )*accref.x+D.getNumber( 1, 2 )*accref.y+D.getNumber( 1, 3 )*accref.z);
    _y = LPacc.y-(D.getNumber( 2, 1 )*accref.x+D.getNumber( 2, 2 )*accref.y+D.getNumber( 2, 3 )*accref.z);
    _z = LPacc.z-(D.getNumber( 3, 1 )*accref.x+D.getNumber( 3, 2 )*accref.y+D.getNumber( 3, 3 )*accref.z);
    return Vector3(_x, _y, _z);
}
