#ifndef __AHRSMATHDSP_QUATERNION_
#define __AHRSMATHDSP_QUATERNION_

#include "Vector3.h"

const static float PI = 3.1415926;

class Quaternion
{
public:
    Quaternion() : w(0.0f), v(0.0f, 0.0f, 0.0f) {}

    Quaternion(float const _w, float const _x, float const _y, float const _z)
        : w(_w), v(_x, _y, _z) {}

    Quaternion(float const _w, const Vector3 &_v) : w(_w), v(_v) {}

    Quaternion(Vector3 const &row0, Vector3 const &row1, Vector3 const &row2) {
        // from rotation matrix
        float const m[3][3] = {{row0.x, row0.y, row0.z},
            {row1.x, row1.y, row1.z},
            {row2.x, row2.y, row2.z}
        };

        float const tr = m[0][0] + m[1][1] + m[2][2];

        if (tr > 0) {
            float const S = sqrt(tr + 1.0) * 2;
            w = 0.25 * S;
            v.x = (m[2][1] - m[1][2]) / S;
            v.y = (m[0][2] - m[2][0]) / S;
            v.z = (m[1][0] - m[0][1]) / S;
        } else if ((m[0][0] < m[1][1]) & (m[0][0] < m[2][2])) {
            float const S = sqrt(1.0 + m[0][0] - m[1][1] - m[2][2]) * 2;
            w = (m[2][1] - m[1][2]) / S;
            v.x = 0.25 * S;
            v.y = (m[0][1] + m[1][0]) / S;
            v.z = (m[0][2] + m[2][0]) / S;
        } else if (m[1][1] < m[2][2]) {
            float const S = sqrt(1.0 + m[1][1] - m[0][0] - m[2][2]) * 2;
            w = (m[0][2] - m[2][0]) / S;
            v.x = (m[0][1] + m[1][0]) / S;
            v.y = 0.25 * S;
            v.z = (m[1][2] + m[2][1]) / S;
        } else {
            float const S = sqrt(1.0 + m[2][2] - m[0][0] - m[1][1]) * 2;
            w = (m[1][0] - m[0][1]) / S;
            v.x = (m[0][2] + m[2][0]) / S;
            v.y = (m[1][2] + m[2][1]) / S;
            v.z = 0.25 * S;
        }
    }
    Quaternion(float const theta_x, float const theta_z, float const theta_y) {
        float const cos_z_2 = cosf(0.5f * theta_z);
        float const cos_y_2 = cosf(0.5f * theta_y);
        float const cos_x_2 = cosf(0.5f * theta_x);

        float const sin_z_2 = sinf(0.5f * theta_z);
        float const sin_y_2 = sinf(0.5f * theta_y);
        float const sin_x_2 = sinf(0.5f * theta_x);

        // and now compute quaternion
        w = cos_z_2 * cos_y_2 * cos_x_2 + sin_z_2 * sin_y_2 * sin_x_2;
        v.x = cos_z_2 * cos_y_2 * sin_x_2 - sin_z_2 * sin_y_2 * cos_x_2;
        v.y = cos_z_2 * sin_y_2 * cos_x_2 + sin_z_2 * cos_y_2 * sin_x_2;
        v.z = sin_z_2 * cos_y_2 * cos_x_2 - cos_z_2 * sin_y_2 * sin_x_2;
    }

    void encode(char *const buffer) const {
        int value = (w * (1 << 30));
        char const* bytes = (char const*) & value;
        for (int i = 0; i < 4; i++) {
            buffer[i] = bytes[3 - i];
        }

        value = v.x * (1 << 30);
        for (int i = 0; i < 4; i++) {
            buffer[i + 4] = bytes[3 - i];
        }

        value = v.y * (1 << 30);
        for (int i = 0; i < 4; i++) {
            buffer[i + 8] = bytes[3 - i];
        }

        value = v.z * (1 << 30);
        for (int i = 0; i < 4; i++) {
            buffer[i + 12] = bytes[3 - i];
        }
    }

    void decode(const char *buffer) {
        set((float)((((int32_t)buffer[0] << 24) + ((int32_t)buffer[1] << 16) +
                     ((int32_t)buffer[2] << 8) + buffer[3])) *
            (1.0 / (1 << 30)),
            (float)((((int32_t)buffer[4] << 24) + ((int32_t)buffer[5] << 16) +
                     ((int32_t)buffer[6] << 8) + buffer[7])) *
            (1.0 / (1 << 30)),
            (float)((((int32_t)buffer[8] << 24) + ((int32_t)buffer[9] << 16) +
                     ((int32_t)buffer[10] << 8) + buffer[11])) *
            (1.0 / (1 << 30)),
            (float)((((int32_t)buffer[12] << 24) + ((int32_t)buffer[13] << 16) +
                     ((int32_t)buffer[14] << 8) + buffer[15])) *
            (1.0 / (1 << 30)));
    }

    void set(float const _w, float const _x, float const _y, float const _z) {
        w = _w;
        v.set(_x, _y, _z);
    }

    float lengthSquared() const {
        return w * w + (v * v);
    }

    float length() const {
        return sqrt(lengthSquared());
    }

    void normalise() {
        float const magnitude = length();
        w /= magnitude; // pop pop
        v.x /= magnitude;
        v.y /= magnitude;
        v.z /= magnitude;
    }

    Quaternion normalised() const {
        return (*this) / length();
    }

    Quaternion conjugate() const {
        return Quaternion(w, -v);
    }

    Quaternion inverse() const {
        return conjugate() / lengthSquared();
    }

    float dot_product(Quaternion const &q) const {
        return q.v * v + q.w * w;
    }

    Vector3 rotate(Vector3 const &v) const {
        return ((*this) * Quaternion(0, v) * conjugate()).v;
    }

    Quaternion lerp(Quaternion const &q2, float t) const {
        if (t > 1.0f) {
            t = 1.0f;
        } else if (t < 0.0f) {
            t = 0.0f;
        }
        return ((*this) * (1 - t) + q2 * t).normalised();
    }

    Quaternion slerp(Quaternion const &q2, float t) const {
        if (t > 1.0f) {
            t = 1.0f;
        } else if (t < 0.0f) {
            t = 0.0f;
        }

        Quaternion q3;
        float dot = dot_product(q2);

        if (dot < 0) {
            dot = -dot;
            q3 = -q2;
        } else
            q3 = q2;

        if (dot < 0.95f) {
            float const angle = acosf(dot);
            return ((*this) * sinf(angle * (1 - t)) + q3 * sinf(angle * t)) /
                   sinf(angle);
        } else {
            // if the angle is small, use linear interpolation
            return lerp(q3, t);
        }
    }

    void getRotationMatrix(Vector3 &row0, Vector3 &row1, Vector3 &row2) const {
        Quaternion q(normalised());
        const double _w = q.w;
        const double _x = q.v.x;
        const double _y = q.v.y;
        const double _z = q.v.z;
        row0.x = 1 - (2 * (_y * _y)) - (2 * (_z * _z));
        row0.y = (2 * _x * _y) - (2 * _w * _z);
        row0.z = (2 * _x * _z) + (2 * _w * _y);

        row1.x = (2 * _x * _y) + (2 * _w * _z);
        row1.y = 1 - (2 * (_x * _x)) - (2 * (_z * _z));
        row1.z = (2 * (_y * _z)) - (2 * (_w * _x));

        row2.x = (2 * (_x * _z)) - (2 * _w * _y);
        row2.y = (2 * _y * _z) + (2 * _w * _x);
        row2.z = 1 - (2 * (_x * _x)) - (2 * (_y * _y));
    }

    Quaternion getAxisAngle() const {
        Quaternion q1(normalised()); // get normalised version

        float const angle = 2 * acos(q1.w);
        double const s = sqrt(1 - q1.w * q1.w); // assuming quaternion normalised
        // then w is less than 1, so term
        // always positive.
        if (s < 0.001) { // test to avoid divide by zero, s is always positive due
            // to sqrt
            // if s close to zero then direction of axis not important
            q1.v = Vector3(1, 0, 0);
        } else {
            q1.v = q1.v / s; // normalise axis
        }
        return q1;
    }

    const Vector3 getEulerAngles() const {
        float const q0 = w;
        float const q1 = v.x;
        float const q2 = v.y;
        float const q3 = v.z;

        float const roll = asin(2.0 * (q0 * q2 - q3 * q1));
        float const pitch =
            atan2(2.0 * (q0 * q1 + q2 * q3), 1.0 - 2.0 * (q1 * q1 + q2 * q2));
        float const yaw =
            atan2(2.0 * (q0 * q3 + q1 * q2), 1.0 - 2.0 * (q2 * q2 + q3 * q3));

        return Vector3(pitch, roll, yaw);
    }

    Quaternion difference(const Quaternion &q2) const {
        return Quaternion(q2 * (*this).inverse());
    }

    // operators
    Quaternion &operator=(const Quaternion &q) {
        w = q.w;
        v = q.v;
        return *this;
    }

    const Quaternion operator+(const Quaternion &q) const {
        return Quaternion(w + q.w, v + q.v);
    }

    const Quaternion operator-(const Quaternion &q) const {
        return Quaternion(w - q.w, v - q.v);
    }

    const Quaternion operator*(const Quaternion &q) const {
        return Quaternion(w * q.w - v * q.v,
                          v.y * q.v.z - v.z * q.v.y + w * q.v.x + v.x * q.w,
                          v.z * q.v.x - v.x * q.v.z + w * q.v.y + v.y * q.w,
                          v.x * q.v.y - v.y * q.v.x + w * q.v.z + v.z * q.w);
    }

    const Quaternion operator/(const Quaternion &q) const {
        Quaternion p = q.inverse();
        return p;
    }

    const Quaternion operator-() const {
        return Quaternion(-w, -v);
    }

    // scaler operators
    const Quaternion operator*(float scaler) const {
        return Quaternion(w * scaler, v * scaler);
    }

    const Quaternion operator/(float scaler) const {
        return Quaternion(w / scaler, v / scaler);
    }

    float w;
    Vector3 v;
};

#endif