#include "SensorFusion.h"

#define DEBUG "SensorFusion"
#include "Logger.h"

#include "Utils.h"

SensorFusion6::SensorFusion6(I2C &i2c) :
    SensorFusion(),
    accel(i2c),
    gyro(i2c),
    deltat(0.010), // seconds
    beta(50),
    lowpassX(0.96),
    lowpassY(0.96),
    lowpassZ(0.96)
{
    gyro.setDelegate(*this);
}

bool SensorFusion6::start()
{
    lowpassX.reset();
    lowpassY.reset();
    lowpassZ.reset();

    accel.powerOn();
    accel.start();

    // Since everything is synced to gyro interrupt, start it last
    gyro.powerOn();
    gyro.start();

    return true;
}

void SensorFusion6::stop()
{
    gyro.stop();
    gyro.powerOff();

    accel.stop();
    accel.powerOff();
}

static float const deg_to_radian =  0.0174532925f;

void SensorFusion6::sensorUpdate(Vector3 gyro_degrees)
{
    Vector3 const gyro_reading = gyro_degrees * deg_to_radian;
    Vector3 const accel_reading = accel.read();
    Vector3 const filtered_accel = Vector3( lowpassX.filter(accel_reading.x),
                                            lowpassY.filter(accel_reading.y),
                                            lowpassZ.filter(accel_reading.z));

    updateFilter( filtered_accel.x,  filtered_accel.y,  filtered_accel.z,
                  gyro_reading.x,    gyro_reading.y,    gyro_reading.z);

    delegate->sensorTick(deltat, q.getEulerAngles(), filtered_accel, accel_reading, gyro_degrees, q);
}

void SensorFusion6::updateFilter(float ax, float ay, float az, float gx, float gy, float gz)
{
    float q0 = q.w, q1 = q.v.x, q2 = q.v.y, q3 = q.v.z;   // short name local variable for readability

    float recipNorm;
    float s0, s1, s2, s3;
    float qDot1, qDot2, qDot3, qDot4;
    float _2q0, _2q1, _2q2, _2q3, _4q0, _4q1, _4q2 ,_8q1, _8q2, q0q0, q1q1, q2q2, q3q3;

    // Rate of change of quaternion from gyroscope
    qDot1 = 0.5 * (-q1 * gx - q2 * gy - q3 * gz);
    qDot2 = 0.5 * (q0 * gx + q2 * gz - q3 * gy);
    qDot3 = 0.5 * (q0 * gy - q1 * gz + q3 * gx);
    qDot4 = 0.5 * (q0 * gz + q1 * gy - q2 * gx);

    // Compute feedback only if accelerometer measurement valid (avoids NaN in accelerometer normalisation)
    if(!((ax == 0.0) && (ay == 0.0) && (az == 0.0))) {

        // Normalise accelerometer measurement
        recipNorm = 1.0 / sqrt(ax * ax + ay * ay + az * az);
        ax *= recipNorm;
        ay *= recipNorm;
        az *= recipNorm;

        // Auxiliary variables to avoid repeated arithmetic
        _2q0 = 2.0 * q0;
        _2q1 = 2.0 * q1;
        _2q2 = 2.0 * q2;
        _2q3 = 2.0 * q3;
        _4q0 = 4.0 * q0;
        _4q1 = 4.0 * q1;
        _4q2 = 4.0 * q2;
        _8q1 = 8.0 * q1;
        _8q2 = 8.0 * q2;
        q0q0 = q0 * q0;
        q1q1 = q1 * q1;
        q2q2 = q2 * q2;
        q3q3 = q3 * q3;

        // Gradient decent algorithm corrective step
        s0 = _4q0 * q2q2 + _2q2 * ax + _4q0 * q1q1 - _2q1 * ay;
        s1 = _4q1 * q3q3 - _2q3 * ax + 4.0 * q0q0 * q1 - _2q0 * ay - _4q1 + _8q1 * q1q1 + _8q1 * q2q2 + _4q1 * az;
        s2 = 4.0 * q0q0 * q2 + _2q0 * ax + _4q2 * q3q3 - _2q3 * ay - _4q2 + _8q2 * q1q1 + _8q2 * q2q2 + _4q2 * az;
        s3 = 4.0 * q1q1 * q3 - _2q1 * ax + 4.0 * q2q2 * q3 - _2q2 * ay;
        recipNorm = 1.0 / sqrt(s0 * s0 + s1 * s1 + s2 * s2 + s3 * s3); // normalise step magnitude
        s0 *= recipNorm;
        s1 *= recipNorm;
        s2 *= recipNorm;
        s3 *= recipNorm;

        // Apply feedback step
        qDot1 -= beta * s0;
        qDot2 -= beta * s1;
        qDot3 -= beta * s2;
        qDot4 -= beta * s3;
    }

    // Integrate rate of change of quaternion to yield quaternion
    q0 += qDot1 * deltat;
    q1 += qDot2 * deltat;
    q2 += qDot3 * deltat;
    q3 += qDot4 * deltat;

    // Normalise quaternion
    recipNorm = 1.0 / sqrt(q0 * q0 + q1 * q1 + q2 * q2 + q3 * q3);
    q0 *= recipNorm;
    q1 *= recipNorm;
    q2 *= recipNorm;
    q3 *= recipNorm;

    // return
    q.w = q0;
    q.v.x = q1;
    q.v.y = q2;
    q.v.z = q3;
}

SensorFusion9::SensorFusion9(I2C &i2c) : SensorFusion6(i2c), magneto(i2c)
{
    gyro.setDelegate(*this);
}

bool SensorFusion9::start()
{
    magneto.powerOn();
    magneto.start();

    return SensorFusion6::start();
}

void SensorFusion9::stop()
{
    SensorFusion6::stop();
    magneto.stop();
    magneto.powerOff();
}

void SensorFusion9::sensorUpdate(Vector3 gyro_degrees)
{
    Vector3 const gyro_reading = gyro_degrees * deg_to_radian;
    Vector3 const accel_reading = accel.read();
    Vector3 const magneto_reading = magneto.read();

    Vector3 const filtered_accel = Vector3( lowpassX.filter(accel_reading.x),
                                            lowpassY.filter(accel_reading.y),
                                            lowpassZ.filter(accel_reading.z));

    updateFilter( filtered_accel.x,  filtered_accel.y,  filtered_accel.z,
                  gyro_reading.x,    gyro_reading.y,    gyro_reading.z,
                  magneto_reading.x, magneto_reading.y, magneto_reading.z);

    delegate->sensorTick(deltat, q.getEulerAngles(), filtered_accel, magneto_reading, gyro_degrees, q);
}

void SensorFusion9::updateFilter(float ax, float ay, float az, float gx, float gy, float gz, float mx, float my, float mz)
{
    float q1 = q.w, q2 = q.v.x, q3 = q.v.y, q4 = q.v.z;   // short name local variable for readability
    float norm;
    float s1, s2, s3, s4;

    // Auxiliary variables to avoid repeated arithmetic
    const float _2q1 = 2.0f * q1;
    const float _2q2 = 2.0f * q2;
    const float _2q3 = 2.0f * q3;
    const float _2q4 = 2.0f * q4;
    const float _2q1q3 = 2.0f * q1 * q3;
    const float _2q3q4 = 2.0f * q3 * q4;
    const float q1q1 = q1 * q1;
    const float q1q2 = q1 * q2;
    const float q1q3 = q1 * q3;
    const float q1q4 = q1 * q4;
    const float q2q2 = q2 * q2;
    const float q2q3 = q2 * q3;
    const float q2q4 = q2 * q4;
    const float q3q3 = q3 * q3;
    const float q3q4 = q3 * q4;
    const float q4q4 = q4 * q4;

    // Normalise accelerometer measurement
    norm = sqrt(ax * ax + ay * ay + az * az);
    if (norm == 0.0f) return; // handle NaN
    norm = 1.0f/norm;
    ax *= norm;
    ay *= norm;
    az *= norm;

    // Normalise magnetometer measurement
    norm = sqrt(mx * mx + my * my + mz * mz);
    if (norm == 0.0f) return; // handle NaN
    norm = 1.0f/norm;
    mx *= norm;
    my *= norm;
    mz *= norm;

    // Reference direction of Earth's magnetic field
    const float _2q1mx = 2.0f * q1 * mx;
    const float _2q1my = 2.0f * q1 * my;
    const float _2q1mz = 2.0f * q1 * mz;
    const float _2q2mx = 2.0f * q2 * mx;
    const float hx = mx * q1q1 - _2q1my * q4 + _2q1mz * q3 + mx * q2q2 + _2q2 * my * q3 + _2q2 * mz * q4 - mx * q3q3 - mx * q4q4;
    const float hy = _2q1mx * q4 + my * q1q1 - _2q1mz * q2 + _2q2mx * q3 - my * q2q2 + my * q3q3 + _2q3 * mz * q4 - my * q4q4;
    const float _2bx = sqrt(hx * hx + hy * hy);
    const float _2bz = -_2q1mx * q3 + _2q1my * q2 + mz * q1q1 + _2q2mx * q4 - mz * q2q2 + _2q3 * my * q4 - mz * q3q3 + mz * q4q4;
    const float _4bx = 2.0f * _2bx;
    const float _4bz = 2.0f * _2bz;

    // Gradient decent algorithm corrective step
    s1 = -_2q3 * (2.0f * q2q4 - _2q1q3 - ax) + _2q2 * (2.0f * q1q2 + _2q3q4 - ay) - _2bz * q3 * (_2bx * (0.5f - q3q3 - q4q4) + _2bz * (q2q4 - q1q3) - mx) + (-_2bx * q4 + _2bz * q2) * (_2bx * (q2q3 - q1q4) + _2bz * (q1q2 + q3q4) - my) + _2bx * q3 * (_2bx * (q1q3 + q2q4) + _2bz * (0.5f - q2q2 - q3q3) - mz);
    s2 = _2q4 * (2.0f * q2q4 - _2q1q3 - ax) + _2q1 * (2.0f * q1q2 + _2q3q4 - ay) - 4.0f * q2 * (1.0f - 2.0f * q2q2 - 2.0f * q3q3 - az) + _2bz * q4 * (_2bx * (0.5f - q3q3 - q4q4) + _2bz * (q2q4 - q1q3) - mx) + (_2bx * q3 + _2bz * q1) * (_2bx * (q2q3 - q1q4) + _2bz * (q1q2 + q3q4) - my) + (_2bx * q4 - _4bz * q2) * (_2bx * (q1q3 + q2q4) + _2bz * (0.5f - q2q2 - q3q3) - mz);
    s3 = -_2q1 * (2.0f * q2q4 - _2q1q3 - ax) + _2q4 * (2.0f * q1q2 + _2q3q4 - ay) - 4.0f * q3 * (1.0f - 2.0f * q2q2 - 2.0f * q3q3 - az) + (-_4bx * q3 - _2bz * q1) * (_2bx * (0.5f - q3q3 - q4q4) + _2bz * (q2q4 - q1q3) - mx) + (_2bx * q2 + _2bz * q4) * (_2bx * (q2q3 - q1q4) + _2bz * (q1q2 + q3q4) - my) + (_2bx * q1 - _4bz * q3) * (_2bx * (q1q3 + q2q4) + _2bz * (0.5f - q2q2 - q3q3) - mz);
    s4 = _2q2 * (2.0f * q2q4 - _2q1q3 - ax) + _2q3 * (2.0f * q1q2 + _2q3q4 - ay) + (-_4bx * q4 + _2bz * q2) * (_2bx * (0.5f - q3q3 - q4q4) + _2bz * (q2q4 - q1q3) - mx) + (-_2bx * q1 + _2bz * q3) * (_2bx * (q2q3 - q1q4) + _2bz * (q1q2 + q3q4) - my) + _2bx * q2 * (_2bx * (q1q3 + q2q4) + _2bz * (0.5f - q2q2 - q3q3) - mz);
    norm = sqrt(s1 * s1 + s2 * s2 + s3 * s3 + s4 * s4);    // normalise step magnitude
    norm = 1.0f/norm;
    s1 *= norm;
    s2 *= norm;
    s3 *= norm;
    s4 *= norm;

    // Compute rate of change of quaternion
    const float qDot1 = 0.5f * (-q2 * gx - q3 * gy - q4 * gz) - beta * s1;
    const float qDot2 = 0.5f * (q1 * gx + q3 * gz - q4 * gy) - beta * s2;
    const float qDot3 = 0.5f * (q1 * gy - q2 * gz + q4 * gx) - beta * s3;
    const float qDot4 = 0.5f * (q1 * gz + q2 * gy - q3 * gx) - beta * s4;

    // Integrate to yield quaternion
    q1 += qDot1 * deltat;
    q2 += qDot2 * deltat;
    q3 += qDot3 * deltat;
    q4 += qDot4 * deltat;
    norm = sqrt(q1 * q1 + q2 * q2 + q3 * q3 + q4 * q4);    // normalise quaternion
    norm = 1.0f/norm;
    q.w = q1 * norm;
    q.v.x = q2 * norm;
    q.v.y = q3 * norm;
    q.v.z = q4 * norm;
}

