#include "IMUCalc.h"

IMUCalc::IMUCalc(void)
{
    heading.x=1;
    heading.y=0;
    heading.z=0;

    top.x=0;
    top.y=0;
    top.z=1;

    gyroOffset.x=0;
    gyroOffset.y=0;
    gyroOffset.z=0;

    absGain=0.01;
    initialRun=true;


}


void IMUCalc::runCalc(float *accdat, float *gyrodat, float *magdat, float timestep)
{


    //Variables
    Vector3 accdata;
    Vector3 gyrodata;
    Vector3 magdata;

    Vector3 heading_abs;
    Vector3 top_abs;

    //Change data to vector3
    for (int i = 0; i<3; i++) {
        accdata.af[i]=accdat[i];
        gyrodata.af[i]=gyrodat[i];
        magdata.af[i]=magdat[i];
    }

    gyrodata = gyrodata-gyroOffset; 

    heading = rotateVector(heading, gyrodata, -gyrodata.Length()*timestep);
    top = rotateVector(top, gyrodata, -gyrodata.Length()*timestep);



    

    //Rotate the magnetic data to be in plain with the earth
    heading_abs  = -1 * rotateMag(magdata, accdata);
    top_abs = -1 * accdata;

    heading_abs = heading_abs.Normalize();
    top_abs = top_abs.Normalize();

    //Calculate offset
    Vector3 currentGyroOffset, weightTop, weightHeading, tempVector;



    //make tempvector in X direction, do crossproduct, calculate length of result
    tempVector.x = 1;
    tempVector.y=0;
    tempVector.z=0;
    weightTop.x = top.CrossP(tempVector).Length();
    weightHeading.x = heading.CrossP(tempVector).Length();

    tempVector.x = 0;
    tempVector.y=1;
    tempVector.z=0;
    weightTop.y = top.CrossP(tempVector).Length();
    weightHeading.y = heading.CrossP(tempVector).Length();

    tempVector.x = 0;
    tempVector.y=0;
    tempVector.z=1;
    weightTop.z = top.CrossP(tempVector).Length();
    weightHeading.z = heading.CrossP(tempVector).Length();


    //Use weightfactors, then divide by their sum
    currentGyroOffset = weightTop * angleBetween(top_abs, top) + weightHeading * angleBetween(heading_abs, heading);
    currentGyroOffset = currentGyroOffset / (weightTop + weightHeading);


    if (currentGyroOffset.x > timestep * 0.1)
        currentGyroOffset.x = timestep * 0.1;
    if (currentGyroOffset.x < -timestep * 0.1)
        currentGyroOffset.x = -timestep * 0.1;

    if (currentGyroOffset.y > timestep * 0.1)
        currentGyroOffset.y = timestep * 0.1;
    if (currentGyroOffset.y < -timestep * 0.1)
        currentGyroOffset.y = -timestep * 0.1;

    if (currentGyroOffset.z > timestep * 0.1)
        currentGyroOffset.z = timestep * 0.1;
    if (currentGyroOffset.z < -timestep * 0.1)
        currentGyroOffset.z = -timestep * 0.1;

    gyroOffset -= 0.01 * currentGyroOffset;


    //Take average value of heading/heading_abs with different gains to get current estimate
    if (initialRun) {
        heading = heading_abs;
        top = top_abs;
        gyroOffset *= 0;
        initialRun=false;
    } else {
        heading = heading*(1-absGain) + heading_abs*absGain;
        top = top * (1-absGain) + top_abs * absGain;
    }
}

//Calculates the yaw
float IMUCalc::getYaw( void )
{
    //First normalize yaw vector, then calculate the heading
    Vector2 yawVector(heading.x, heading.y);

    if (yawVector.Length()>0) {
        yawVector = yawVector.Normalize();

        //check Quadrant
        if (yawVector.y<0) {
            if (yawVector.x < 0)
                return -M_PI - asin(yawVector.y);
            else
                return asin(yawVector.y);
        } else {
            if (yawVector.x < 0)
                return M_PI - asin(yawVector.y);
            else
                return asin(yawVector.y);
        }
    } else
        return 0;
}

//Calculates the pitch
float IMUCalc::getPitch( void )
{
    //First normalize pitch vector, then calculate the pitch
    Vector2 pitchVector(top.x, top.z);

    if (pitchVector.Length()>0) {
        pitchVector = pitchVector.Normalize();

        //if the top is at the bottom, invert the vector
        if (pitchVector.y<0)
            pitchVector = -pitchVector;
        return asin(pitchVector.x);

    } else
        return 0;
}

//Calculates the roll
float IMUCalc::getRoll( void )
{
    //First normalize yaw vector, then calculate the heading
    Vector2 rollVector(top.y, top.z);

    if (rollVector.Length()>0) {
        rollVector = rollVector.Normalize();

        //check Quadrant
        if (rollVector.y<0) {
            if (rollVector.x < 0)
                return -M_PI - asin(rollVector.x);
            else
                return M_PI - asin(rollVector.x);
        } else {
            return asin(rollVector.x);
        }
    } else
        return 0;
}

Vector3 IMUCalc::getGyroOffset( void )
{

    return gyroOffset;
}

//The angle between the magnetic vector and the ground vector should be 90 degrees (0.5 pi). We calculate the angle, and rotate the magnetic vector while not changing the angle of
//the original rotations vector, only we rotate far enough to make it 90 degrees.
Vector3 IMUCalc::rotateMag(Vector3 magdat, Vector3 ground)
{
    //Variables
    Vector3 retval;
    Vector3 rotVector;
    Matrix3x3 rotMatrix;
    float angle;

    //Calculate the angle between magnetic and acceleration vector
    rotVector = angleBetween(magdat, ground);
    angle = rotVector.Length();

    //Calculate how far we have to rotate magnetic vector
    angle = 0.5 * M_PI - angle;

    //And do that
    retval = rotateVector(magdat, rotVector, -angle);

    return retval;
}


// Vector calculations not included in GTMath

Vector3 IMUCalc::angleBetween(Vector3 vectorA, Vector3 vectorB)
{

    
    
    float angle;
    if ((vectorA.Length()==0)||(vectorB.Length()==0))
        angle=0;
    else
        angle = vectorA.Angle(vectorB);
    // if no noticable rotation is available return zero rotation
    // this way we avoid Cross product artifacts
    if( abs(angle) < 0.0001 ) return Vector3( 0, 0, 0);
    // in this case there are 2 lines on the same axis
    if(abs(angle-M_PI) < 0.0001) {
        //They are in opposite directions, rotate one by 90 degrees, that picks one of the infinite amount of rotation angles you get
        float temp = vectorB.z;
        vectorB.z=vectorB.y;
        vectorB.y=vectorB.x;
        vectorB.x=temp;
    }
    Vector3 axis = (vectorA.CrossP(vectorB));
    axis=axis.Normalize();
    axis *= (angle);

    return axis;
}


Vector3 IMUCalc::rotateVector(Vector3 vector, Vector3 axis, float angle)
{
    if (axis.Length()>0.0001) {
        Matrix3x3 rotMatrix = Matrix3x3::RotateAxis(axis, angle);
        return rotMatrix.Transform(vector);
    }
    return vector;
}
