#include "solaESKF.hpp"

solaESKF::solaESKF()
//    :pihat(3,1),vihat(3,1),qhat(4,1),accBias(3,1),gyroBias(3,1),gravity(3,1),errState(18,1),Phat(18,18),Q(18,18)
    :errState(18,1),Phat(18,18),Q(18,18)
{
    pihat << 0.0f, 0.0f, 0.0f;
    vihat << 0.0f, 0.0f, 0.0f;
    qhat << 1.0f, 0.0f, 0.0f, 0.0f;
    accBias << 0.0f, 0.0f, 0.0f;
    gyroBias << 0.0f, 0.0f, 0.0f;
    gravity << 0.0f, 0.0, 9.8f;

    nState = 18;
    errState = VectorXf::Zero(nState);
    Phat = MatrixXf::Zero(nState, nState);
    Q = MatrixXf::Zero(nState, nState);

    setBlockDiag(Phat, 0.1f, 0, 2);//position
    setBlockDiag(Phat, 0.1f, 3, 5);//velocity
    setBlockDiag(Phat, 0.1f, 6, 8);//angle error
    setBlockDiag(Phat, 0.1f, 9, 11);//acc bias
    setBlockDiag(Phat, 0.1f, 12, 14);//gyro bias
    setBlockDiag(Phat, 0.00000001f, 15, 17);//gravity
    setBlockDiag(Q, 0.00025f, 3, 5);//velocity
    setBlockDiag(Q, 0.005f/57.0f, 6, 8);//angle error
    setBlockDiag(Q, 0.001f, 9, 11);//acc bias
    setBlockDiag(Q, 0.001f, 12, 14);//gyro bias//positionとgravityはQ項なし
}


void solaESKF::updateNominal(Vector3f acc, Vector3f gyro,float att_dt)
{
    Vector3f gyrom = gyro - gyroBias;
    Vector3f accm = acc - accBias;
    
    Vector4f qint;
    qint << 1.0f, 0.5f*gyrom(0)*att_dt, 0.5f*gyrom(1)*att_dt, 0.5f*gyrom(2)*att_dt; 
    qhat = quatmultiply(qhat, qint);
    qhat.normalize();
    
    Matrix3f dcm;
    computeDcm(dcm, qhat);
    Vector3f accned = dcm*accm + gravity;
    vihat += accned*att_dt;
    
    pihat += vihat*att_dt + 0.5f*accned*att_dt*att_dt;
}

void solaESKF::updateErrState(Vector3f acc, Vector3f gyro, float att_dt)
{
    Vector3f gyrom = gyro  - gyroBias;
    Vector3f accm = acc - accBias;

    Matrix3f dcm;
    computeDcm(dcm, qhat);
//    Matrix a2v = -dcm*MatrixMath::Matrixcross(accm(1,1),accm(2,1),accm(3,1))*att_dt;
    Matrix3f a2v = -dcm*solaESKF::Matrixcross(accm)*att_dt;
    Matrix3f a2v2 = 0.5f*a2v*att_dt;

    MatrixXf Fx = MatrixXf::Zero(nState, nState);
    //position
    Fx(0,0) =  1.0f;
    Fx(1,1) =  1.0f;
    Fx(2,2) =  1.0f;
    Fx(0,3) =  1.0f*att_dt;
    Fx(1,4) =  1.0f*att_dt;
    Fx(2,5) =  1.0f*att_dt;
    for (int i = 0; i < 3; i++){
        for (int j = 0; j < 3; j++){
            Fx(i,j+6) = a2v2(i,j);
            Fx(i,j+9) = -0.5f*dcm(i,j)*att_dt*att_dt;
        }
        Fx(i,i+15) = 0.5f*att_dt*att_dt;
    }


    //velocity
    Fx(3,3) =  1.0f;
    Fx(4,4) =  1.0f;
    Fx(5,5) =  1.0f;
    for (int i = 0; i < 3; i++){
        for (int j = 0; j < 3; j++){
            Fx(i+3,j+6) = a2v(i,j);
            Fx(i+3,j+9) = -dcm(i,j)*att_dt;
            Fx(i+3,j+12) = -a2v2(i,j);
        }
    }
    Fx(3,15) =  1.0f*att_dt;
    Fx(4,16) =  1.0f*att_dt;
    Fx(5,17) =  1.0f*att_dt;

    //angulat error
    Fx(6,6) =  1.0f;
    Fx(7,7) =  1.0f;
    Fx(8,8) =  1.0f;
    Fx(6,7) =  gyrom(2)*att_dt;
    Fx(6,8) = -gyrom(1)*att_dt;
    Fx(7,6) = -gyrom(2)*att_dt;
    Fx(7,8) =  gyrom(0)*att_dt;
    Fx(8,6) =  gyrom(1)*att_dt;
    Fx(8,7) = -gyrom(0)*att_dt;
    Fx(6,12) =  -1.0f*att_dt;
    Fx(7,13) =  -1.0f*att_dt;
    Fx(8,14) =  -1.0f*att_dt;

    //acc bias
    Fx(9,9) =  1.0f;
    Fx(10,10) =  1.0f;
    Fx(11,11) =  1.0f;

    //gyro bias
    Fx(12,12) =  1.0f;
    Fx(13,13) =  1.0f;
    Fx(14,14) =  1.0f;

    //gravity bias
    Fx(15,15) =  1.0f;
    Fx(16,16) =  1.0f;
    Fx(17,17) =  1.0f;

    //errState = Fx * errState;
    Phat = Fx*Phat*Fx.transpose();
    for (int i = 0; i < nState; i++){
        if(i>2 && i<9){
            Phat(i,i)  += Q(i,i)*att_dt;
        }else if(i>8 && i<15){
            Phat(i,i)  += Q(i,i)* att_dt*att_dt;
        }      
    }
}

void solaESKF::updateAcc(Vector3f acc, Matrix3f R)
{
    Vector3f accm = acc - accBias;
    Matrix3f dcm;
    computeDcm(dcm, qhat);
    Matrix3f tdcm = dcm.transpose();
    Vector3f tdcm_g = tdcm*gravity;
    Matrix3f rotgrav = solaESKF::Matrixcross(tdcm_g);
    
    MatrixXf H = MatrixXf::Zero(3,nState);
    for (int i = 0; i < 3; i++){
        for (int j = 0; j < 3; j++){
            H(i,j+6) =  rotgrav(i,j);
            H(i,j+15) = tdcm(i,j);
        }
    }

    H(0,9) =  -1.0f;
    H(1,10) =  -1.0f;
    H(2,11) =  -1.0f;
    
//    Matrix K = (Phat*MatrixMath::Transpose(H))*MatrixMath::Inv(H*Phat*MatrixMath::Transpose(H)+R);
    MatrixXf K = (Phat*H.transpose())*(H*Phat*H.transpose()+R).inverse();
    Vector3f zacc = -accm-tdcm*gravity;
    Vector3f z;
    z = zacc;
    errState =  K * z;
    Phat = (MatrixXf::Identity(nState, nState)-K*H)*Phat;
    
    fuseErr2Nominal();
}

void solaESKF::updateHeading(float a, Matrix<float, 1, 1> R)
{
    float q0 = qhat(0);
    float q1 = qhat(1);
    float q2 = qhat(2);
    float q3 = qhat(3);

    bool canUseA = false;
    const float SA0 = 2.0f*q3;
    const float SA1 = 2.0f*q2;
    const float SA2 = SA0*q0 + SA1*q1;
    const float SA3 = q0*q0 + q1*q1 - q2*q2 - q3*q3;
    float SA4, SA5_inv;
    if ((SA3*SA3) > 1e-6f) {
        SA4 = 1.0f/(SA3*SA3);
        SA5_inv = SA2*SA2*SA4 + 1.0f;
        canUseA = std::abs(SA5_inv) > 1e-6f;
    }

    bool canUseB = false;
    const float SB0 = 2.0f*q0;
    const float SB1 = 2.0f*q1;
    const float SB2 = SB0*q3 + SB1*q2;
    const float SB4 = q0*q0 + q1*q1 - q2*q2 - q3*q3;
    float SB3, SB5_inv;
    if ((SB2*SB2) > 1e-6f) {
        SB3 = 1.0f/(SB2*SB2);
        SB5_inv = SB3*SB4*SB4 + 1;
        canUseB = std::abs(SB5_inv) > 1e-6f;
    }

    MatrixXf Hh = MatrixXf::Zero(1,4);

    if (canUseA && (!canUseB || std::abs(SA5_inv) >= std::abs(SB5_inv))) {
        const float SA5 = 1.0f/SA5_inv;
        const float SA6 = 1.0f/SA3;
        const float SA7 = SA2*SA4;
        const float SA8 = 2.0f*SA7;
        const float SA9 = 2.0f*SA6;

        Hh(0,0) = SA5*(SA0*SA6 - SA8*q0);
        Hh(0,1) = SA5*(SA1*SA6 - SA8*q1);
        Hh(0,2) = SA5*(SA1*SA7 + SA9*q1);
        Hh(0,3) = SA5*(SA0*SA7 + SA9*q0);
    } else if (canUseB && (!canUseA || std::abs(SB5_inv) > std::abs(SA5_inv))) {
        const float SB5 = 1.0f/SB5_inv;
        const float SB6 = 1.0f/SB2;
        const float SB7 = SB3*SB4;
        const float SB8 = 2.0f*SB7;
        const float SB9 = 2.0f*SB6;

        Hh(0,0) = -SB5*(SB0*SB6 - SB8*q3);
        Hh(0,1) = -SB5*(SB1*SB6 - SB8*q2);
        Hh(0,2) = -SB5*(-SB1*SB7 - SB9*q2);
        Hh(0,3) = -SB5*(-SB0*SB7 - SB9*q3);
    } else {
        return;
    }
    
    MatrixXf Hdq = MatrixXf::Zero(4,3);
    Hdq  << -0.5f*q1, -0.5f*q2, -0.5f*q3,
             0.5f*q0, -0.5f*q3,  0.5f*q2,
             0.5f*q3,  0.5f*q0, -0.5f*q1,
            -0.5f*q2,  0.5f*q1,  0.5f*q0;  
    
    MatrixXf  Hpart = Hh*Hdq;
    MatrixXf H=MatrixXf::Zero(1,nState);
    for(int j=0; j<3; j++){
        H(0,j+6) = Hpart(0,j);
    }
    
    const float psi = std::atan2(qhat(1)*qhat(2) + qhat(0)*qhat(3), 0.5f - qhat(2)*qhat(2) - qhat(3)*qhat(3));
    Matrix<float, 1, 1> z;
    z << std::atan2(std::sin(a-psi), std::cos(a-psi));
    MatrixXf K = (Phat*H.transpose())*(H*Phat*H.transpose()+R).inverse();
    errState =  K * z;
    Phat = (MatrixXf::Identity(nState, nState)-K*H)*Phat;
    
    fuseErr2Nominal();
}

void solaESKF::updateIMU(float palt,Vector3f acc,float heading, Matrix<float, 5, 5> R)
{
    MatrixXf H = MatrixXf::Zero(5,nState);
    
    
    H(0,2) = 1.0f;
    
    Vector3f accm = acc - accBias;
    Matrix3f dcm;
    computeDcm(dcm, qhat);
    Matrix3f tdcm = dcm.transpose();
    Vector3f tdcm_g = tdcm*gravity;
    Matrix3f rotgrav = solaESKF::Matrixcross(tdcm_g);
    
    for (int i = 0; i < 3; i++){
        for (int j = 0; j < 3; j++){
            H(i+1,j+6) =  rotgrav(i,j);
            H(i+1,j+15) = tdcm(i,j);
        }
    }

    H(1,9) =   -1.0f;
    H(2,10) =  -1.0f;
    H(3,11) =  -1.0f;

    float q0 = qhat(0);
    float q1 = qhat(1);
    float q2 = qhat(2);
    float q3 = qhat(3);

    bool canUseA = false;
    const float SA0 = 2.0f*q3;
    const float SA1 = 2.0f*q2;
    const float SA2 = SA0*q0 + SA1*q1;
    const float SA3 = q0*q0 + q1*q1 - q2*q2 - q3*q3;
    float SA4, SA5_inv;
    if ((SA3*SA3) > 1e-6f) {
        SA4 = 1.0f/(SA3*SA3);
        SA5_inv = SA2*SA2*SA4 + 1.0f;
        canUseA = std::abs(SA5_inv) > 1e-6f;
    }

    bool canUseB = false;
    const float SB0 = 2.0f*q0;
    const float SB1 = 2.0f*q1;
    const float SB2 = SB0*q3 + SB1*q2;
    const float SB4 = q0*q0 + q1*q1 - q2*q2 - q3*q3;
    float SB3, SB5_inv;
    if ((SB2*SB2) > 1e-6f) {
        SB3 = 1.0f/(SB2*SB2);
        SB5_inv = SB3*SB4*SB4 + 1;
        canUseB = std::abs(SB5_inv) > 1e-6f;
    }

    MatrixXf Hh = MatrixXf::Zero(1,4);

    if (canUseA && (!canUseB || std::abs(SA5_inv) >= std::abs(SB5_inv))) {
        const float SA5 = 1.0f/SA5_inv;
        const float SA6 = 1.0f/SA3;
        const float SA7 = SA2*SA4;
        const float SA8 = 2.0f*SA7;
        const float SA9 = 2.0f*SA6;

        Hh(0,0) = SA5*(SA0*SA6 - SA8*q0);
        Hh(0,1) = SA5*(SA1*SA6 - SA8*q1);
        Hh(0,2) = SA5*(SA1*SA7 + SA9*q1);
        Hh(0,3) = SA5*(SA0*SA7 + SA9*q0);
    } else if (canUseB && (!canUseA || std::abs(SB5_inv) > std::abs(SA5_inv))) {
        const float SB5 = 1.0f/SB5_inv;
        const float SB6 = 1.0f/SB2;
        const float SB7 = SB3*SB4;
        const float SB8 = 2.0f*SB7;
        const float SB9 = 2.0f*SB6;

        Hh(0,0) = -SB5*(SB0*SB6 - SB8*q3);
        Hh(0,1) = -SB5*(SB1*SB6 - SB8*q2);
        Hh(0,2) = -SB5*(-SB1*SB7 - SB9*q2);
        Hh(0,3) = -SB5*(-SB0*SB7 - SB9*q3);
    } else {
        return;
    }
    
    MatrixXf Hdq = MatrixXf::Zero(4,3);
    Hdq  << -0.5f*q1, -0.5f*q2, -0.5f*q3,
             0.5f*q0, -0.5f*q3,  0.5f*q2,
             0.5f*q3,  0.5f*q0, -0.5f*q1,
            -0.5f*q2,  0.5f*q1,  0.5f*q0;  
    
    MatrixXf  Hpart = Hh*Hdq;
    for(int j=0; j<3; j++){
        H(4,j+6) = Hpart(0,j);
    }
    

    
    const float psi = std::atan2(qhat(1)*qhat(2) + qhat(0)*qhat(3), 0.5f - qhat(2)*qhat(2) - qhat(3)*qhat(3));
    Vector3f zacc = -accm-tdcm*gravity;
    VectorXf z = VectorXf::Zero(5);
    z << palt-pihat(2),zacc(0),zacc(1),zacc(2),std::atan2(std::sin(heading-psi), std::cos(heading-psi));
    MatrixXf K = (Phat*H.transpose())*(H*Phat*H.transpose()+R).inverse();
    errState =  K * z;
    Phat = (MatrixXf::Identity(nState, nState)-K*H)*Phat;
    
    fuseErr2Nominal();
}
void solaESKF::updateGPSPosition(Vector3f posgps,float palt,Matrix3f R)
{
    MatrixXf H = MatrixXf::Zero(3,nState);
    H(0,0)  = 1.0f;
    H(1,1)  = 1.0f;
    H(2,2)  = 1.0f;
    
    //Matrix A = H*Phat*MatrixMath::Transpose(H)+R;
    //Matrix K = (Phat*MatrixMath::Transpose(H))*(MatrixMath::Inv(MatrixMath::Transpose(A)*A+1000.0f*MatrixMath::Eye(3))*MatrixMath::Transpose(A));
    MatrixXf K = (Phat*H.transpose())*(H*Phat*H.transpose()+R).inverse();
    Vector3f z;
    z << posgps(0)-pihat(0), posgps(1)-pihat(1), palt - pihat(2);
    errState =  K * z;
    Phat = (MatrixXf::Identity(nState, nState)-K*H)*Phat;
    //Phat = Phat-K*(H*Phat*MatrixMath::Transpose(H)+R)*MatrixMath::Transpose(K);
    //Phat = (MatrixMath::Eye(nState)-K*H)*Phat*MatrixMath::Transpose(MatrixMath::Eye(nState)-K*H)+K*R*MatrixMath::Transpose(K);
    fuseErr2Nominal();
}
void solaESKF::updateGPSVelocity(Vector3f velgps,Matrix3f R)
{
    MatrixXf H = MatrixXf::Zero(3,nState);
    H(0,3)  = 1.0f;
    H(1,4)  = 1.0f;
    H(2,5)  = 1.0f;
    
    //Matrix A = H*Phat*MatrixMath::Transpose(H)+R;
    //Matrix K = (Phat*MatrixMath::Transpose(H))*(MatrixMath::Inv(MatrixMath::Transpose(A)*A+1000.0f*MatrixMath::Eye(3))*MatrixMath::Transpose(A));
    MatrixXf K = (Phat*H.transpose())*(H*Phat*H.transpose()+R).inverse();
    Vector3f z;
    z << velgps(0)-vihat(0), velgps(1)-vihat(1), velgps(2)-vihat(2);
    errState =  K * z;
    Phat = (MatrixXf::Identity(nState, nState)-K*H)*Phat;
    //Phat = Phat-K*(H*Phat*MatrixMath::Transpose(H)+R)*MatrixMath::Transpose(K);
    //Phat = (MatrixMath::Eye(nState)-K*H)*Phat*MatrixMath::Transpose(MatrixMath::Eye(nState)-K*H)+K*R*MatrixMath::Transpose(K);
    fuseErr2Nominal();
}

void solaESKF::updateWhole(Vector3f posgps, float palt, Vector3f velgps,Vector3f acc,float heading, MatrixXf R)
{
    MatrixXf H = MatrixXf::Zero(9,nState);
    H(0,0)  = 1.0f;
    H(1,1)  = 1.0f;
    H(2,2)  = 1.0f;
    H(3,3)  = 1.0f;
    H(4,4)  = 1.0f;
    
    Vector3f accm = acc - accBias;
    Matrix3f dcm;
    computeDcm(dcm, qhat);
    Matrix3f tdcm = dcm.transpose();
    Vector3f tdcm_g = tdcm*gravity;
    Matrix3f rotgrav = solaESKF::Matrixcross(tdcm_g);
    
    for (int i = 0; i < 3; i++){
        for (int j = 0; j < 3; j++){
            H(i+5,j+6) =  rotgrav(i,j);
            H(i+5,j+15) = tdcm(i,j);
        }
    }

    H(5,9) =   -1.0f;
    H(6,10) =  -1.0f;
    H(7,11) =  -1.0f;

    float q0 = qhat(0);
    float q1 = qhat(1);
    float q2 = qhat(2);
    float q3 = qhat(3);

    bool canUseA = false;
    const float SA0 = 2.0f*q3;
    const float SA1 = 2.0f*q2;
    const float SA2 = SA0*q0 + SA1*q1;
    const float SA3 = q0*q0 + q1*q1 - q2*q2 - q3*q3;
    float SA4, SA5_inv;
    if ((SA3*SA3) > 1e-6f) {
        SA4 = 1.0f/(SA3*SA3);
        SA5_inv = SA2*SA2*SA4 + 1.0f;
        canUseA = std::abs(SA5_inv) > 1e-6f;
    }

    bool canUseB = false;
    const float SB0 = 2.0f*q0;
    const float SB1 = 2.0f*q1;
    const float SB2 = SB0*q3 + SB1*q2;
    const float SB4 = q0*q0 + q1*q1 - q2*q2 - q3*q3;
    float SB3, SB5_inv;
    if ((SB2*SB2) > 1e-6f) {
        SB3 = 1.0f/(SB2*SB2);
        SB5_inv = SB3*SB4*SB4 + 1;
        canUseB = std::abs(SB5_inv) > 1e-6f;
    }

    MatrixXf Hh = MatrixXf::Zero(1,4);

    if (canUseA && (!canUseB || std::abs(SA5_inv) >= std::abs(SB5_inv))) {
        const float SA5 = 1.0f/SA5_inv;
        const float SA6 = 1.0f/SA3;
        const float SA7 = SA2*SA4;
        const float SA8 = 2.0f*SA7;
        const float SA9 = 2.0f*SA6;

        Hh(0,0) = SA5*(SA0*SA6 - SA8*q0);
        Hh(0,1) = SA5*(SA1*SA6 - SA8*q1);
        Hh(0,2) = SA5*(SA1*SA7 + SA9*q1);
        Hh(0,3) = SA5*(SA0*SA7 + SA9*q0);
    } else if (canUseB && (!canUseA || std::abs(SB5_inv) > std::abs(SA5_inv))) {
        const float SB5 = 1.0f/SB5_inv;
        const float SB6 = 1.0f/SB2;
        const float SB7 = SB3*SB4;
        const float SB8 = 2.0f*SB7;
        const float SB9 = 2.0f*SB6;

        Hh(0,0) = -SB5*(SB0*SB6 - SB8*q3);
        Hh(0,1) = -SB5*(SB1*SB6 - SB8*q2);
        Hh(0,2) = -SB5*(-SB1*SB7 - SB9*q2);
        Hh(0,3) = -SB5*(-SB0*SB7 - SB9*q3);
    } else {
        return;
    }
    
    MatrixXf Hdq = MatrixXf::Zero(4,3);
    Hdq  << -0.5f*q1, -0.5f*q2, -0.5f*q3,
             0.5f*q0, -0.5f*q3,  0.5f*q2,
             0.5f*q3,  0.5f*q0, -0.5f*q1,
            -0.5f*q2,  0.5f*q1,  0.5f*q0;  
    
    MatrixXf  Hpart = Hh*Hdq;
    for(int j=0; j<3; j++){
        H(8,j+6) = Hpart(0,j);
    }
    
    const float psi = std::atan2(qhat(1)*qhat(2) + qhat(0)*qhat(3), 0.5f - qhat(2)*qhat(2) - qhat(3)*qhat(3));
    Vector3f zacc = -accm-tdcm*gravity;
    VectorXf z = VectorXf::Zero(9);
    z << posgps(0)-pihat(0), posgps(1)-pihat(1), palt-pihat(2), velgps(0)-vihat(0), velgps(1)-vihat(1), zacc(0),zacc(1),zacc(2),std::atan2(std::sin(heading-psi), std::cos(heading-psi));
    MatrixXf K = (Phat*H.transpose())*(H*Phat*H.transpose()+R).inverse();
    errState =  K * z;
    Phat = (MatrixXf::Identity(nState, nState)-K*H)*Phat;
    
    fuseErr2Nominal();
    
}

void solaESKF::updateGPS(Vector3f posgps, float palt, Vector3f velgps, MatrixXf R)
{
    MatrixXf H = MatrixXf::Zero(5,nState);
    H(0,0)  = 1.0f;
    H(1,1)  = 1.0f;
    H(2,2)  = 1.0f;
    H(3,3)  = 1.0f;
    H(4,4)  = 1.0f;
    MatrixXf K = (Phat*H.transpose())*(H*Phat*H.transpose()+R).inverse();
    Matrix<float, 5, 1> z;
    z << posgps(0)-pihat(0), posgps(1)-pihat(1), palt-pihat(2), velgps(0)-vihat(0), velgps(1)-vihat(1);
    errState =  K * z;
    Phat = (MatrixXf::Identity(nState, nState)-K*H)*Phat;
    fuseErr2Nominal();
}


Vector3f solaESKF::computeAngles()
{
    
    Vector3f euler;
    euler(0) = std::atan2(qhat(0)*qhat(1) + qhat(2)*qhat(3), 0.5f - qhat(1)*qhat(1) - qhat(2)*qhat(2));
    euler(1) = std::asin(-2.0f * (qhat(1)*qhat(3) - qhat(0)*qhat(2)));
    euler(2) = std::atan2(qhat(1)*qhat(2) + qhat(0)*qhat(3), 0.5f - qhat(2)*qhat(2) - qhat(3)*qhat(3));
    return euler;
}


void solaESKF::fuseErr2Nominal()
{
    //position
    pihat(0) += errState(0);
    pihat(1) += errState(1);
    pihat(2) += errState(2);
    
    //velocity
    vihat(0) += errState(3);
    vihat(1) += errState(4);
    vihat(2) += errState(5);
    
    //angle error
    Vector4f qerr;
    qerr << 1.0f, 0.5f*errState(6), 0.5f*errState(7), 0.5f*errState(8);
    qhat = quatmultiply(qhat, qerr);
    qhat.normalize();
    
    //acc bias
    accBias(0) += errState(9);
    accBias(1) += errState(10);
    accBias(2) += errState(11);
    
    //gyro bias
    gyroBias(0) += errState(12);
    gyroBias(1) += errState(13);
    gyroBias(2) += errState(14);

    //gravity bias
    gravity(0) += errState(15);
    gravity(1) += errState(16);
    gravity(2) += errState(17);
    
    errState = VectorXf::Zero(nState);
}

void solaESKF::fuseCenter2Nominal(VectorXf errVal)
{
    //position
    pihat(0) += errVal(0);
    pihat(1) += errVal(1);
    pihat(2) += errVal(2);

    //velocity
    vihat(0) += errVal(3);
    vihat(1) += errVal(4);
    vihat(2) += errVal(5);

    //angle error
    Vector4f qerr;
    qerr << 1.0f, 0.5f*errVal(6), 0.5f*errVal(7), 0.5f*errVal(8);
    qhat = quatmultiply(qhat, qerr);
    qhat.normalize();

    //acc bias
    accBias(0) += errVal(9);
    accBias(1) += errVal(10);
    accBias(2) += errVal(11);

    //gyro bias
    gyroBias(0) += errVal(12);
    gyroBias(1) += errVal(13);
    gyroBias(2) += errVal(14);

    //gravity bias
    //gravity(0) += errVal(15);
    //gravity(1) += errVal(16);
    //gravity(2) += errVal(17);
}

Vector4f solaESKF::quatmultiply(Vector4f p, Vector4f q)
{
    Vector4f qout;
    qout(0) = p(0)*q(0) - p(1)*q(1) - p(2)*q(2) - p(3)*q(3);
    qout(1) = p(0)*q(1) + p(1)*q(0) + p(2)*q(3) - p(3)*q(2);
    qout(2) = p(0)*q(2) - p(1)*q(3) + p(2)*q(0) + p(3)*q(1);
    qout(3) = p(0)*q(3) + p(1)*q(2) - p(2)*q(1) + p(3)*q(0);
//    qout.normalize();
    return qout;
}

void solaESKF::computeDcm(Matrix3f& dcm, Vector4f quat)
{
//    dcm(1,1) = quat(1,1)*quat(1,1) + quat(2,1)*quat(2,1) - quat(3,1)*quat(3,1) - quat(4,1)*quat(4,1);
//    dcm(1,2) = 2.0f*(quat(2,1)*quat(3,1) - quat(1,1)*quat(4,1));
//    dcm(1,3) = 2.0f*(quat(2,1)*quat(4,1) + quat(1,1)*quat(3,1));
//    dcm(2,1) = 2.0f*(quat(2,1)*quat(3,1) + quat(1,1)*quat(4,1));
//    dcm(2,2) = quat(1,1)*quat(1,1) - quat(2,1)*quat(2,1) + quat(3,1)*quat(3,1) - quat(4,1)*quat(4,1);
//    dcm(2,3) = 2.0f*(quat(3,1)*quat(4,1) - quat(1,1)*quat(2,1));
//    dcm(3,1) = 2.0f*(quat(2,1)*quat(4,1) - quat(1,1)*quat(3,1));
//    dcm(3,2) = 2.0f*(quat(3,1)*quat(4,1) + quat(1,1)*quat(2,1));
//    dcm(3,3) = quat(1,1)*quat(1,1) - quat(2,1)*quat(2,1) - quat(3,1)*quat(3,1) + quat(4,1)*quat(4,1);

    dcm(0,0) = quat(0)*quat(0) + quat(1)*quat(1) - quat(2)*quat(2) - quat(3)*quat(3);
    dcm(0,1) = 2.0f*(quat(1)*quat(2) - quat(0)*quat(3));
    dcm(0,2) = 2.0f*(quat(1)*quat(3) + quat(0)*quat(2));
    dcm(1,0) = 2.0f*(quat(1)*quat(2) + quat(0)*quat(3));
    dcm(1,1) = quat(0)*quat(0) - quat(1)*quat(1) + quat(2)*quat(2) - quat(3)*quat(3);
    dcm(1,2) = 2.0f*(quat(2)*quat(3) - quat(0)*quat(1));
    dcm(2,0) = 2.0f*(quat(1)*quat(3) - quat(0)*quat(2));
    dcm(2,1) = 2.0f*(quat(2)*quat(3) + quat(0)*quat(1));
    dcm(2,2) = quat(0)*quat(0) - quat(1)*quat(1) - quat(2)*quat(2) + quat(3)*quat(3);
}

void solaESKF::setQhat(float ex,float ey,float ez)
{
        float cos_z_2 = std::cos(0.5f*ez);
        float cos_y_2 = std::cos(0.5f*ey);
        float cos_x_2 = std::cos(0.5f*ex);
 
        float sin_z_2 = std::sin(0.5f*ez);
        float sin_y_2 = std::sin(0.5f*ey);
        float sin_x_2 = std::sin(0.5f*ex);
 
        // and now compute quaternion
        qhat(0) = cos_z_2*cos_y_2*cos_x_2 + sin_z_2*sin_y_2*sin_x_2;
        qhat(1) = cos_z_2*cos_y_2*sin_x_2 - sin_z_2*sin_y_2*cos_x_2;
        qhat(2) = cos_z_2*sin_y_2*cos_x_2 + sin_z_2*cos_y_2*sin_x_2;
        qhat(3) = sin_z_2*cos_y_2*cos_x_2 - cos_z_2*sin_y_2*sin_x_2;    
}

Vector3f solaESKF::calcDynAcc(Vector3f acc)
{
    Vector3f accm = acc - accBias;
    Matrix3f dcm;
    computeDcm(dcm, qhat);
    Matrix3f tdcm = dcm.transpose();

    Vector3f dynAcc = accm+tdcm*gravity;
    return dynAcc;
}

Vector3f solaESKF::vector2Body(Vector3f veci)
{
    Matrix3f dcm;
    computeDcm(dcm, qhat);
    Matrix3f tdcm = dcm.transpose();

    return tdcm*veci ;
}

Vector3f solaESKF::vector2NED(Vector3f vecb)
{
    Matrix3f dcm;
    computeDcm(dcm, qhat);

    return dcm*vecb;
}



void solaESKF::setGravity(float gx,float gy,float gz)
{
    gravity(0) = gx;
    gravity(1) = gy;
    gravity(2) = gz;
}

Vector3f solaESKF::getPihat()
{
    return pihat;
}
Vector3f solaESKF::getVihat()
{
    return vihat;
}
Vector4f solaESKF::getQhat()
{
    return qhat;
}
Vector3f solaESKF::getAccBias()
{
    return accBias;
}
Vector3f solaESKF::getGyroBias()
{
    return gyroBias;
}
Vector3f solaESKF::getGravity()
{
    return gravity;
}

VectorXf solaESKF::getErrState()
{
    return errState;
}

VectorXf solaESKF::getState()
{
    VectorXf state = VectorXf::Zero(nState+1);
    for (int i = 0; i < 3; i++){
        state(i) = pihat(i);
        state(i+3) = vihat(i);
    }
    for (int i = 0; i < 4; i++){
        state(i+6) = qhat(i);
    }
    for (int i = 0; i < 3; i++){
        state(i+10) = accBias(i);
        state(i+13) = gyroBias(i);
        state(i+16) = gravity(i);
    }
    return state;
}

VectorXf solaESKF::getVariance()
{
    VectorXf variance = VectorXf::Zero(nState);
    for (int i = 0; i < nState; i++){
        variance(i) = Phat(i,i);
    }
    return variance;
}

void solaESKF::setPhatPosition(float valNE,float valD)
{
    setBlockDiag(Phat, valNE, 0,2);
    Phat(2,2) = valD;
}
void solaESKF::setPhatVelocity(float valNE,float valD)
{
    setBlockDiag(Phat, valNE, 3,5);
    Phat(5,5) = valD;
}
void solaESKF::setPhatAngleError(float val)
{
    setBlockDiag(Phat, val, 6,8);
}
void solaESKF::setPhatAccBias(float val)
{
    setBlockDiag(Phat, val, 9,11);
}
void solaESKF::setPhatGyroBias(float val)
{
    setBlockDiag(Phat, val, 12,14);
}
void solaESKF::setPhatGravity(float val)
{
    setBlockDiag(Phat, val, 15,17);
}


void solaESKF::setQVelocity(float valNE,float valD)
{
    setBlockDiag(Q, valNE, 3, 5);
    Q(5,5) = valD;
}
void solaESKF::setQAngleError(float val)
{
    setBlockDiag(Q, val, 6, 8);
}
void solaESKF::setQAccBias(float val)
{
    setBlockDiag(Q, val, 9, 11);
}
void solaESKF::setQGyroBias(float val)
{
    setBlockDiag(Q, val, 12, 14);
}

void solaESKF::setDiag(Matrix3f& mat, float val){
    for (int i = 0; i < mat.cols(); i++){
            for (int j = 0; j < mat.rows(); j++){
                mat(i,j) = 0.0f;
            }
    }
    for (int i = 0; i < mat.cols(); i++){
            mat(i,i) = val;
    }
}

void solaESKF::setDiag(MatrixXf& mat, float val){
    for (int i = 0; i < mat.cols(); i++){
        for (int j = 0; j < mat.rows(); j++){
            mat(i,j) = 0.0f;
        }
    }
    for (int i = 0; i < mat.cols(); i++)
    {
        mat(i, i) = val;
    }
}

void solaESKF::setBlockDiag(MatrixXf& mat, float val, int startIndex, int endIndex){

    for (int i = startIndex; i < endIndex+1; i++){
            mat(i,i) = val;
    }
}

void solaESKF::setPihat(float pi_x, float pi_y, float pi_z)
{
    pihat(0) = pi_x;
    pihat(1) = pi_y;
    pihat(2) = pi_z;
}

Matrix3f solaESKF::Matrixcross(Vector3f v)
{
    Matrix3f m = Matrix3f::Zero();
    m << 0.0f, -v(2),  v(1),
         v(2),  0.0f, -v(0),
        -v(1),  v(0),  0.0f;
    return m;
}
/*
void solaESKF::updateAccConstraints(Matrix acc,float palt,Matrix R)
{
    Matrix accm = acc - accBias;
    Matrix tdcm(3,3);
    computeDcm(tdcm, qhat);
    tdcm = MatrixMath::Transpose(tdcm);
    Matrix tdcm_g = tdcm*gravity;
    Matrix a2v = MatrixMath::Matrixcross(tdcm_g(1,1),tdcm_g(2,1),tdcm_g(3,1));
    
    Matrix H(4,nState);
    for (int i = 1; i < 4; i++){
        for (int j = 1; j < 4; j++){
            H(i,j+6) =  a2v(i,j);
            H(i,j+15) = tdcm(i,j);
        }
    }
    H(1,10) =  -1.0f;
    H(2,11) =  -1.0f;
    H(3,12) =  -1.0f;
    H(4,3) = 1.0f;
    
    //Matrix A = H*Phat*MatrixMath::Transpose(H)+R;
    //Matrix K = (Phat*MatrixMath::Transpose(H))*(MatrixMath::Inv(MatrixMath::Transpose(A)*A+10.0f*MatrixMath::Eye(3))*MatrixMath::Transpose(A));
    Matrix K = (Phat*MatrixMath::Transpose(H))*MatrixMath::Inv(H*Phat*MatrixMath::Transpose(H)+R);
    Matrix zacc = -accm-tdcm*gravity;
    Matrix z(4,1);
    z << zacc(1,1) << zacc(2,1) << zacc(3,1) << palt - pihat(3,1);
    errState =  K * z;
    Phat = (MatrixMath::Eye(nState)-K*H)*Phat;
    //Phat = Phat-K*(H*Phat*MatrixMath::Transpose(H)+R)*MatrixMath::Transpose(K);
    //Phat = (MatrixMath::Eye(nState)-K*H)*Phat*MatrixMath::Transpose(MatrixMath::Eye(nState)-K*H)+K*R*MatrixMath::Transpose(K);
    fuseErr2Nominal();
}

void solaESKF::updateGyroConstraints(Matrix gyro,Matrix R)
{
    Matrix gyrom = gyro - gyroBias;
    Matrix dcm(3,3);
    computeDcm(dcm, qhat);
    Matrix a2v = dcm*MatrixMath::Matrixcross(gyrom(1,1),gyrom(2,1),gyrom(3,1));
    
    Matrix H(2,nState);
    for (int i = 1; i < 3; i++){
        for (int j = 1; j < 4; j++){
            H(i,j+6) =  a2v(i,j);
            H(i,j+12) = -dcm(i,j);
        }
    }
    //Matrix A = H*Phat*MatrixMath::Transpose(H)+R;
    //Matrix K = (Phat*MatrixMath::Transpose(H))*(MatrixMath::Inv(MatrixMath::Transpose(A)*A+10.0f*MatrixMath::Eye(2))*MatrixMath::Transpose(A));
    Matrix K = (Phat*MatrixMath::Transpose(H))*MatrixMath::Inv(H*Phat*MatrixMath::Transpose(H)+R);
    
    Matrix z3 = -dcm*gyrom;
    Matrix z(2,1);
    z << z3(1,1) << z3(2,1);
    errState =  K * z;
    //Phat = (MatrixMath::Eye(nState)-K*H)*Phat;
    //Phat = Phat-K*(H*Phat*MatrixMath::Transpose(H)+R)*MatrixMath::Transpose(K);
    Phat = (MatrixMath::Eye(nState)-K*H)*Phat*MatrixMath::Transpose(MatrixMath::Eye(nState)-K*H)+K*R*MatrixMath::Transpose(K);
    fuseErr2Nominal();
}
void solaESKF::updateMag(Matrix mag, Matrix R)
{
    Matrix dcm(3,3);
    computeDcm(dcm, qhat);
    Matrix a2v = -dcm*MatrixMath::Matrixcross(mag(1,1),mag(2,1),mag(3,1));
    
    Matrix H(2,nState);
    for (int i = 1; i < 3; i++){
        for (int j = 1; j < 4; j++){
            H(i,j+6) =  a2v(i,j);
        }
    }
    H(1,19) = 1.0f;
    //H(3,20) = 1.0f;
    
    //Matrix A = H*Phat*MatrixMath::Transpose(H)+R;
    //Matrix K = (Phat*MatrixMath::Transpose(H))*(MatrixMath::Inv(MatrixMath::Transpose(A)*A+10.0f*MatrixMath::Eye(3))*MatrixMath::Transpose(A));
    Matrix K = (Phat*MatrixMath::Transpose(H))*MatrixMath::Inv(H*Phat*MatrixMath::Transpose(H)+R);
    Matrix zmag = -dcm*mag-magField;
    Matrix z(2,1);
    z << zmag(1,1) << zmag(2,1);
    errState =  K * z;
    Phat = (MatrixMath::Eye(nState)-K*H)*Phat;
    //Phat = Phat-K*(H*Phat*MatrixMath::Transpose(H)+R)*MatrixMath::Transpose(K);
    //Phat = (MatrixMath::Eye(nState)-K*H)*Phat*MatrixMath::Transpose(MatrixMath::Eye(nState)-K*H)+K*R*MatrixMath::Transpose(K);
    fuseErr2Nominal();
}
void solaESKF::updateMag(Matrix mag,float palt, Matrix R)
{
    Matrix dcm(3,3);
    computeDcm(dcm, qhat);
    Matrix a2v = -dcm*MatrixMath::Matrixcross(mag(1,1),mag(2,1),mag(3,1));
    
    Matrix H(3,nState);
    for (int i = 1; i < 3; i++){
        for (int j = 1; j < 4; j++){
            H(i,j+6) =  a2v(i,j);
        }
    }
    H(1,19) = 1.0f;
    H(3,3)  = 1.0f;
    //Matrix A = H*Phat*MatrixMath::Transpose(H)+R;
    //Matrix K = (Phat*MatrixMath::Transpose(H))*(MatrixMath::Inv(MatrixMath::Transpose(A)*A+10.0f*MatrixMath::Eye(3))*MatrixMath::Transpose(A));
    Matrix K = (Phat*MatrixMath::Transpose(H))*MatrixMath::Inv(H*Phat*MatrixMath::Transpose(H)+R);
    Matrix zmag = -dcm*mag-magField;
    Matrix z(3,1);
    z << zmag(1,1) << zmag(2,1) <<  palt - pihat(3,1);
    errState =  K * z;
    Phat = (MatrixMath::Eye(nState)-K*H)*Phat;
    //Phat = Phat-K*(H*Phat*MatrixMath::Transpose(H)+R)*MatrixMath::Transpose(K);
    //Phat = (MatrixMath::Eye(nState)-K*H)*Phat*MatrixMath::Transpose(MatrixMath::Eye(nState)-K*H)+K*R*MatrixMath::Transpose(K);
    
    fuseErr2Nominal();
}
void solaESKF::updateGPSVelocity(Matrix velgps,Matrix mag,Matrix R)
{
    
    Matrix dcm(3,3);
    computeDcm(dcm, qhat);
    Matrix a2v = -dcm*MatrixMath::Matrixcross(mag(1,1),mag(2,1),mag(3,1));
    
    
    Matrix H(3,nState);
    H(1,4)  = 1.0f;
    H(2,5)  = 1.0f;
    
    for (int j = 1; j < 4; j++){
        H(3,j+6) =  a2v(2,j);
    }
    //H(3,19) = 1.0f;
    
    Matrix zmag = -dcm*mag-magField;
    
    //Matrix A = H*Phat*MatrixMath::Transpose(H)+R;
    //Matrix K = (Phat*MatrixMath::Transpose(H))*(MatrixMath::Inv(MatrixMath::Transpose(A)*A+10.0f*MatrixMath::Eye(3))*MatrixMath::Transpose(A));
    Matrix K = (Phat*MatrixMath::Transpose(H))*MatrixMath::Inv(H*Phat*MatrixMath::Transpose(H)+R);
    Matrix z(3,1);
    z << velgps(1,1) - vihat(1,1) << velgps(2,1)-vihat(2,1) << zmag(2,1);
    errState =  K * z;
    Phat = (MatrixMath::Eye(nState)-K*H)*Phat;
    //Phat = Phat-K*(H*Phat*MatrixMath::Transpose(H)+R)*MatrixMath::Transpose(K);
    //Phat = (MatrixMath::Eye(nState)-K*H)*Phat*MatrixMath::Transpose(MatrixMath::Eye(nState)-K*H)+K*R*MatrixMath::Transpose(K);
    fuseErr2Nominal();
}

void solaESKF::updateGPSPosition(Matrix posgps,float palt,Matrix R)
{
    Matrix H(3,nState);
    H(1,1)  = 1.0f;
    H(2,2)  = 1.0f;
    H(3,3)  = 1.0f;
    
    //Matrix A = H*Phat*MatrixMath::Transpose(H)+R;
    //Matrix K = (Phat*MatrixMath::Transpose(H))*(MatrixMath::Inv(MatrixMath::Transpose(A)*A+1000.0f*MatrixMath::Eye(3))*MatrixMath::Transpose(A));
    Matrix K = (Phat*MatrixMath::Transpose(H))*MatrixMath::Inv(H*Phat*MatrixMath::Transpose(H)+R);
    Matrix z(3,1);
    z << posgps(1,1) - pihat(1,1) << posgps(2,1)-pihat(2,1) << palt - pihat(3,1);
    errState =  K * z;
    Phat = (MatrixMath::Eye(nState)-K*H)*Phat;
    //Phat = Phat-K*(H*Phat*MatrixMath::Transpose(H)+R)*MatrixMath::Transpose(K);
    //Phat = (MatrixMath::Eye(nState)-K*H)*Phat*MatrixMath::Transpose(MatrixMath::Eye(nState)-K*H)+K*R*MatrixMath::Transpose(K);
    fuseErr2Nominal();
}

void solaESKF::updateImuConstraints(Matrix acc,Matrix mag,Matrix R)
{
    Matrix accm = acc - accBias;
    Matrix magm = mag - magBias;
    Matrix dcm(3,3);
    computeDcm(dcm, qhat);
    Matrix tdcm = MatrixMath::Transpose(dcm);
    Matrix tdcm_g = tdcm*gravity;
    Matrix rotgrav = MatrixMath::Matrixcross(tdcm_g(1,1),tdcm_g(2,1),tdcm_g(3,1));
    Matrix rotmag = -dcm*MatrixMath::Matrixcross(magm(1,1),magm(2,1),magm(3,1));
    
    Matrix H(5,nState);
    for (int i = 1; i < 4; i++){
        for (int j = 1; j < 4; j++){
            H(i,j+6) =  rotgrav(i,j);
        }
        H(i,16) = tdcm(i,3);
    }

    H(1,10) =  -1.0f;
    H(2,11) =  -1.0f;
    H(3,12) =  -1.0f;
    
    Matrix magned = dcm*magm;
    float hx = sqrt(magned(1,1)*magned(1,1)+magned(2,1)*magned(2,1));
    
    for(int j = 1; j < 4; j++){
        H(4,j+6) =  rotmag(1,j)-(rotmag(1,j)+rotmag(2,j))/hx;
        H(4,j+16) =  -dcm(1,j)+(dcm(1,j)+dcm(2,j))/hx;
        H(5,j+6) =  rotmag(2,j);
        H(5,j+16) =  -dcm(2,j);
    }
    
    Matrix K = (Phat*MatrixMath::Transpose(H))*MatrixMath::Inv(H*Phat*MatrixMath::Transpose(H)+R);
    Matrix zacc = -accm-tdcm*gravity;
    Matrix zmag = dcm*magm;
    Matrix z(5,1);
    z << zacc(1,1) << zacc(2,1) << zacc(3,1) << -(zmag(1,1) - hx) << -zmag(2,1);
    twelite.printf("%f %f\r\n",hx,(zmag(1,1) - hx));
    errState =  K * z;
    Phat = (MatrixMath::Eye(nState)-K*H)*Phat;
    
    fuseErr2Nominal();
}
    float q0 = qhat(1,1);
    float q1 = qhat(2,1);
    float q2 = qhat(3,1);
    float q3 = qhat(4,1);
    
    float d0 = (-q3*q3-q2*q2+q1*q1+q0*q0);
    float q0q3q1q2 = (q0*q3+q1*q2);
    float h1lower = 2.0f*(4.0f*q0q3q1q2*q0q3q1q2/d0/d0+1.0f)*sqrt(4.0f*q0q3q1q2*q0q3q1q2/d0/d0+1.0f);
    
    float d1 = d0*sqrt(4.0f*q0q3q1q2*q0q3q1q2/d0/d0+1.0f);
    float d2 = d0*d0*sqrt(4.0f*q0q3q1q2*q0q3q1q2/d0/d0+1.0f);
    float d3 = d0*sqrt(4.0f*q0q3q1q2*q0q3q1q2/d0/d0+1.0f)*(4.0f*q0q3q1q2*q0q3q1q2/d0/d0+1.0f);
    
    
    
    Matrix Hh(2,4);
    Hh(1,1) = -(8.0f*q3*q0q3q1q2/d0/d0-16.0f*q0*q0q3q1q2*q0q3q1q2/d0/d0/d0)/h1lower;
    Hh(1,2) = -(8.0f*q2*q0q3q1q2/d0/d0-16.0f*q1*q0q3q1q2*q0q3q1q2/d0/d0/d0)/h1lower;
    Hh(1,3) = -(8.0f*q1*q0q3q1q2/d0/d0-16.0f*q2*q0q3q1q2*q0q3q1q2/d0/d0/d0)/h1lower;
    Hh(1,4) = -(8.0f*q0*q0q3q1q2/d0/d0-16.0f*q3*q0q3q1q2*q0q3q1q2/d0/d0/d0)/h1lower;
    
    Hh(2,1) = 2.0f*q3/d1-4.0f*q0*q0q3q1q2/d2-q0q3q1q2*(8.0f*q3*q0q3q1q2/d0/d0-16.0f*q0*q0q3q1q2*q0q3q1q2/d0/d0/d0)/d3;
    Hh(2,2) = 2.0f*q2/d1-4.0f*q1*q0q3q1q2/d2-q0q3q1q2*(8.0f*q2*q0q3q1q2/d0/d0-16.0f*q1*q0q3q1q2*q0q3q1q2/d0/d0/d0)/d3;
    Hh(2,3) = 2.0f*q1/d1+4.0f*q2*q0q3q1q2/d2-q0q3q1q2*(8.0f*q1*q0q3q1q2/d0/d0-16.0f*q2*q0q3q1q2*q0q3q1q2/d0/d0/d0)/d3;
    Hh(2,4) = 2.0f*q0/d1+4.0f*q3*q0q3q1q2/d2-q0q3q1q2*(8.0f*q0*q0q3q1q2/d0/d0-16.0f*q3*q0q3q1q2*q0q3q1q2/d0/d0/d0)/d3;
*/