#include "mbed.h"
#include "Motor.h"
#include "LSM9DS1.h"
#include "math.h"

#define PI 3.141592653589793
#define PERIOD 0.001 //NOTE: Max sample rate is 952 Hz

Serial pc(USBTX, USBRX);
LSM9DS1 imu(p28, p27, 0xD6, 0x3C);
Motor L(p22, p5, p6); // pwm, fwd, rev
Motor R(p21, p8, p7);
DigitalOut myled(LED1);

float angle;
float gOffset = 0;
float aOffset = 0;

int kp = 48; //Best performance so far acheived with kp and speedCorrection
int ki;
int kd;
int factor;
float speedCorrection = 2.0;
float speed;

void myCalibrate();
void control();
void drive();
void callback();

int main() {
    //LSM9DS1 imu(p9, p10, 0x6B, 0x1E);
    imu.begin();
    if (!imu.begin()) {
        pc.printf("Failed to communicate with LSM9DS1.\n");
    }
    imu.calibrate();
    imu.setAccelScale(2);
    imu.setGyroScale(2);
    pc.attach(&callback);
    myCalibrate();
    drive();
}

void control() {
}
    
void drive() {
    float accelAngle;
    float gyroAngle;
    float pastAccAngle = 0;
    float pastRawGyro = 0;
    float pastGyroAngle = 0;
    float temp;
    int counter = 0;
    float alpha = 0.50; // = T/(T+dt) where T is response time
    float beta = 1; // Decides how much the gyroscope data is trusted. b/t 0.9 & 1
    float integral = 0;
    float derivative = 0;
    float lastAngle = 0;
    factor = 1000;
    while(1) {
        counter++;
        accelAngle = 0;
        for (int i = 0; i < 10; i++) {
            imu.readAccel();
            imu.readGyro();
            accelAngle += ::atan2((float) imu.ax, (float) imu.az)*180/PI - 90;
            //The max we will see for this application is about 16380 (1 G), but max is 2 G's (32767 in 16 bits):
            gyroAngle += (1/2)*imu.gy + gOffset;
            wait(PERIOD);
        }
        accelAngle = accelAngle/10.0 + aOffset; //Averaging accel to get rid of some noise
        gyroAngle = (gyroAngle)*PERIOD - pastGyroAngle;         //Sum over a set time (10 ms) and multiply by differential time (1 ms) to integrate angular velocity

        temp = gyroAngle;
        accelAngle = (1 - alpha)*accelAngle + alpha*pastAccAngle; //Low Pass Filter
        //gyroAngle = (1 - alpha)*pastGyroAngle + (1 - alpha)*(gyroAngle - pastRawGyro); //High pass filter
        pastRawGyro = temp;
        angle = accelAngle*beta + gyroAngle*(1 - beta);
        angle -= 2; //Attempt to correct for weight distribution bias
        
        myled = 1;
        if (counter % (int) (0.5/(PERIOD*10)) == 0) {
            pc.printf("Angle by acc: %.3f\n\r", accelAngle);
            pc.printf("Angle by gyro: %.3f\n\r", gyroAngle);
            pc.printf("Overall %.3f\n\r", angle);
        }
        integral += angle*PERIOD;
        derivative = angle - lastAngle;
        speed = (kp*angle + ki*integral + kd*derivative)/factor;
        speed = abs(angle) > 80 ? 0 : speed;
        pastAccAngle = accelAngle;
        pastGyroAngle = gyroAngle;

        if (speed < 0) {
            speed *= speedCorrection; //Speed Correction attempts to correct weight distribution issue
        }                             //by driving faster when falling on the heavier side
        if (speed > 1) {
            speed = 1;
        } else if(speed < -1) {
            speed = -1;
        }
        if (counter % (int) (0.5/(PERIOD*10)) == 0) {
            pc.printf("speed: %.3f\n\n\r", speed);
        }
        speed *= -1;        //Had to correct direction after moving accelerometer to opposite side
        L.speed(speed);
        R.speed(speed);
        lastAngle = angle;
    }
}
        
void myCalibrate() {
    //Get linear offsets for accelerometer and gyroscope
    float gSum = 0;
    for (int i = 0; i < 100; i++) {
        imu.readAccel();
        imu.readGyro();
        gSum += imu.gy;
        aOffset += ::atan2((float) imu.ax, (float) imu.az)*180/PI;
        wait(PERIOD);
    }
    aOffset = - aOffset/100 ;
    gOffset = - gSum/100 ;
    pc.printf("Accelerometer offset: %.3f \n\r", aOffset);
    pc.printf("Gyroscope offset: %.3f \n\r", gOffset);
    wait(2);
}

//This function is a revised version of that from:
// https://os.mbed.com/teams/ECE-4180-Spring-15/code/balance2/
void callback()
{
    speed = 0;
    char val;                                                   // Needed for Serial communication (need to be global?)
    val = pc.getc();                                            // Reat the value from Serial
    pc.printf("%c\n", val);                                     // Print it back to the screen
    if( val =='p') {                                            // If character was a 'p'
        pc.printf("enter kp \n");                               // Adjust kp
        val = pc.getc();                                        // Wait for kp value
        if(val == 0x2b) {                                       // If character is a plus sign
            kp++;                                           // Increase kp
        } else if (val == 0x2d) {                               // If recieved character is the minus sign
            kp--;
        } else if (val == '(') {
            kp-= 10;
        } else if (val == ')') {
            kp += 10;                                           // Decrease kp
        } else {
            kp = val - 48;                                      // Cast char to float
        }
        pc.printf(" kp = %d \n",kp);                            // Print current kp value to screen
    } else if( val == 'd') {                                    // Adjust kd
        pc.printf("enter kd \n");                               // Wait for kd
        val= pc.getc();                                         // Read value from serial
        if(val == '+') {                                        // If given plus sign
            kd++;                                               // Increase kd
        } else if (val == '-') {                                // If given negative sign
            kd--;                                               // Decrease kd
        } else {                                                // If given some other ascii (a number?)
            kd = val - 48;                                      // Set derivative gain
        }
        pc.printf(" kd = %d \n",kd);                            // Print kd back to screen
    } else if( val == 'i') {                                    // If given i - integral gain
        pc.printf("enter ki \n");                               // Prompt on screen to ask for ascii
        val= pc.getc();                                         // Get the input
        if(val == '+') {                                        // If given the plus sign
            ki++;                                               // Increase ki
        } else if (val == '-') {                                // If given the minus sign
            ki--;                                               // Decrease ki
        } else {                                                // If given some other ascii
            ki = val - 48;                                      // Set ki to that number
        }
        pc.printf(" ki = %d \n",ki);
    } else if( val == 'o') {
        pc.printf("enter factor \n");
        val= pc.getc();
        if(val == '+') {
            factor=factor+1000;
        } else if (val == '-') {
            factor=factor-1000;
        } else {
            factor=(val-48)*1000;;
        }
        pc.printf(" factor = %d \n",factor);
    } else if (val == 'a') {
        pc.printf("enter speed correction \n");
        val= pc.getc();
        if(val == '+') {
            speedCorrection += 0.1;
        } else if (val == '-') {
            speedCorrection -= 0.1;
        }
         pc.printf("speedCorrect = %f \n",speedCorrection);
    }
    wait(1);
}