#include "mbed.h"
#include "AttitudeEstimator.h"

// Class constructor
AttitudeEstimator::AttitudeEstimator() : imu(A4,A5)
{
}

// Initialize class
void AttitudeEstimator::init()
{
    // Initialize IMU sensor object
    imu.init();

    dt = 0.005;
    dt_half = dt/2.0;
    
    x = eye(4,1);
    A = eye(4);
    P = eye(4);
    Q = 0.001*eye(4);
    R = 10.0*eye(4);
    H = eye(4);
    I = eye(4);
    
    g = zeros(3,1);
    a = zeros(3,1);
    m = zeros(3,1);
    
    q = zeros(4,1);
    
    BN = eye(3);

}

// Estimate euler angles (rad) and angular velocity (rad/s)
void AttitudeEstimator::estimate()
{
    // Read IMU sensor data
    read();
    get_output();
    
    float omega_x = g(1,1)*dt_half;
    float omega_y = g(2,1)*dt_half;
    float omega_z = g(3,1)*dt_half;
    
    /*A(1,1) = 1.0;
    A(1,2) = -g(1,1)*dt/2.0;
    A(1,3) = -g(2,1)*dt/2.0;
    A(1,4) = -g(3,1)*dt/2.0;
    A(2,1) = g(1,1)*dt/2.0;
    A(2,2) = 1.0;
    A(2,3) = g(3,1)*dt/2.0;
    A(2,4) = -g(2,1)*dt/2.0;
    A(3,1) = g(2,1)*dt/2.0;
    A(3,2) = -g(3,1)*dt/2.0;
    A(3,3) = 1.0;
    A(3,4) = g(1,1)*dt/2.0;
    A(4,1) = g(3,1)*dt/2.0;
    A(4,2) = g(2,1)*dt/2.0;
    A(4,3) = -g(1,1)*dt/2.0;
    A(4,4) = 1.0;*/
    
    A(1,2) = -omega_x;
    A(1,3) = -omega_y;
    A(1,4) = -omega_z;
    A(2,1) = omega_x;
    A(2,3) = omega_z;
    A(2,4) = -omega_y;
    A(3,1) = omega_y;
    A(3,2) = -omega_z;
    A(3,4) = omega_x;
    A(4,1) = omega_z;
    A(4,2) = omega_y;
    A(4,3) = -omega_x;

    x = A*x;
    x = x/norm(x);

    P = A*P*transpose(A)+Q;

    //K = P*transpose(H)*inverse(H*P*transpose(H)+R);
    K = P*inverse(P+R);

    //x = x+K*(q-H*x);
    x = x+K*(q-x);
    x = x/norm(x);
    //P = P-K*H*P;
    P = P-K*P;
    //P = P*(I-K*H);
    //P = P-K*(H*P*transpose(H)+R)*transpose(K);
    //P = (eye(4)-K*H)*P*transpose(eye(4)-K*H)+K*R*transpose(K);
}

void AttitudeEstimator::read()
{
    imu.read();
    g(1,1) = imu.gx-0.0099;
    g(2,1) = imu.gy+0.0693;
    g(3,1) = imu.gz-0.0339;
    a(1,1) = 1.0077*(imu.ax-0.0975);
    a(2,1) = 1.0061*(imu.ay-0.0600);
    a(3,1) = 0.9926*(imu.az-0.5156);
    m(1,1) = 0.8838*(imu.mx+21.0631);
    m(2,1) = 1.1537*(imu.my+8.9233);
    m(3,1) = 0.9982*(imu.mz-11.8958);
}

void AttitudeEstimator::get_output()
{
    a = a/norm(a);
    m = m/norm(m);

    Matrix t1(3,1), t2(3,1), t3(3,1);
    t1 = a;
    //t2 = cross(a,m)/norm(cross(a,m));
    t2 = cross(a,m);
    t2 = t2/norm(t2);
    t3 = cross(t1,t2);

    /*Matrix BT(3,3), NT(3,3);
    BT(1,1) = t1(1,1);
    BT(2,1) = t1(2,1);
    BT(3,1) = t1(3,1);
    BT(1,2) = t2(1,1);
    BT(2,2) = t2(2,1);
    BT(3,2) = t2(3,1);
    BT(1,3) = t3(1,1);
    BT(2,3) = t3(2,1);
    BT(3,3) = t3(3,1);

    NT(1,2) = -0.3666;
    NT(1,3) = -0.9304;
    NT(2,2) = -0.9304;
    NT(2,3) = 0.3666;
    NT(3,1) = -1.0;

    BN = BT*transpose(NT);*/
    
    
    BN(1,1) = -t3(1,1);
    BN(2,1) = -t3(2,1);
    BN(3,1) = -t3(3,1);
    BN(1,2) = -t2(1,1);
    BN(2,2) = -t2(2,1);
    BN(3,2) = -t2(3,1);
    BN(1,3) = -t1(1,1);
    BN(2,3) = -t1(2,1);
    BN(3,3) = -t1(3,1);

    /*q(4,1) =  0.0;

    float tr = trace(BN);
    if (tr > 0.0) {
        float sqtrp1 = sqrt( tr + 1.0);
        q(1,1) = 0.5*sqtrp1;
        q(2,1) = (BN(2,3) - BN(3,2))/(2.0*sqtrp1);
        q(3,1) = (BN(3,1) - BN(1,3))/(2.0*sqtrp1);
        q(4,1) = (BN(1,2) - BN(2,1))/(2.0*sqtrp1);
    } else {
        if ((BN(2,2) > BN(1,1)) && (BN(2,2) > BN(3,3))) {
            float sqdip1 = sqrt(BN(2,2) - BN(1,1) - BN(3,3) + 1.0 );
            q(3,1) = 0.5*sqdip1;
            if ( sqdip1 != 0 ) {
                sqdip1 = 0.5/sqdip1;
            }
            q(1,1) = (BN(3,1) - BN(1,3))*sqdip1;
            q(2,1) = (BN(1,2) + BN(2,1))*sqdip1;
            q(4,1) = (BN(2,3) + BN(3,2))*sqdip1;
        } else if (BN(3,3) > BN(1,1)) {
            float sqdip1 = sqrt(BN(3,3) - BN(1,1) - BN(2,2) + 1.0 );
            q(4,1) = 0.5*sqdip1;
            if ( sqdip1 != 0 ) {
                sqdip1 = 0.5/sqdip1;
            }
            q(1,1) = (BN(1,2) - BN(2,1))*sqdip1;
            q(2,1) = (BN(3,1) + BN(1,3))*sqdip1;
            q(3,1) = (BN(2,3) + BN(3,2))*sqdip1;
        } else {
            float sqdip1 = sqrt(BN(1,1) - BN(2,2) - BN(3,3) + 1.0 );
            q(2,1) = 0.5*sqdip1;
            if ( sqdip1 != 0 ) {
                sqdip1 = 0.5/sqdip1;
            }
            q(1,1) = (BN(2,3) - BN(3,2))*sqdip1;
            q(3,1) = (BN(1,2) + BN(2,1))*sqdip1;
            q(4,1) = (BN(3,1) + BN(1,3))*sqdip1;
        }
    }*/
    
    q = dcm2quat(BN);

    if((abs(x(1,1)) > abs(x(2,1))) && (abs(x(1,1)) > abs(x(3,1))) && (abs(x(1,1)) > abs(x(4,1)))) {
        if (((x(1,1) > 0) && (q(1,1) < 0)) || ((x(1,1) < 0) && (q(1,1) > 0))) {
            q = -q;
        }
    } else if ((abs(x(2,1)) > abs(x(3,1))) && (abs(x(2,1)) > abs(x(4,1)))) {
        if (((x(2,1) > 0) && (q(2,1) < 0)) || ((x(2,1) < 0) && (q(2,1) > 0))) {
            q = -q;
        }
    } else if ((abs(x(3,1)) > abs(x(4,1)))) {
        if (((x(3,1) > 0) && (q(3,1) < 0)) || ((x(3,1) < 0) && (q(3,1) > 0))) {
            q = -q;
        }
    } else {
        if (((x(4,1) > 0) && (q(4,1) < 0)) || ((x(4,1) < 0) && (q(4,1) > 0))) {
            q = -q;
        }
    }
}
