/*
JAGBot main program
USMA Jumping Agile Ground Robot
Fall 2019

Authors: 
Dr. Daniel J. Gonzalez - daniel.gonzalez@westpoint.edu
CDT Andy Rodriguez - andres.rodriguez@westpont.edu
CDT Josh Loyd - joshua.loyd@westpoint.edu 
*/
#include "main.h"

using namespace FastMath;

void setupServo() {
    float myRange = .000675;
    float myDegrees = 90;
    myservo0.calibrate(myRange, myDegrees);   
    myservo1.calibrate(myRange, myDegrees);  
    myservo2.calibrate(myRange, myDegrees);  
    myservo3.calibrate(myRange, myDegrees); 
}

void setup(){
    pcSerial.baud(115200);
    pcSerial.printf("----------- Start! -----------\n");
    //Create instance of PPM class
    //Pass in interrupt pin, minimum output value, maximum output value, minimum pulse time from transmitter, maximum pulse time from transmitter, number of channels, throttle channel (used for failsafe)
    ppmInputs = new PPM(PPMinterruptPin, 0, 1, 1000, 1900, 8, 3);
    
    setupServo();
    setupIMU(); //IMU Reading is handled through a Serial Interrupt
    
    //set up timer
    timer.start();
    tPrev = timer.read_us()/1000000.0;
    t = timer.read_us()/1000000.0;
    dt = t - tPrev;
    
    //------------------------------    Servo Out
    myservo0.write(0.5);
    myservo1.write(0.5);
    myservo2.write(0.5);
    myservo3.write(0.5);
    
    //------------------------------    ODrive Out
    odrv0axis0.write(0.5);
    odrv0axis1.write(0.5);
    odrv1axis0.write(0.5);
    odrv1axis1.write(0.5);
    
}

int main() {
    if      (CANTEST){
//        CANTest();
    }else if(SERVOTEST){
//        calibMain(); 
    }else if(RCTEST){
//        RCTest(); 
    }else if(IMUTEST){
//        IMUTest();
    }else{
        wait(0.5);
        setup();
        pcSerial.printf("----------- Loop Start! -----------\n");
        wait(0.5);
        pcSerial.printf("roll, pitch, rollDot, pitchDot, T1, T2, T3, T4, acc, state\n");
        while(isRunning){
            loop();
        }
    }
}

void getRCInputs(){
    //------------------------------    Get RC Input
    //Get channel data (mapped to between 0 and 1 as I passed these in as my max and min in the PPM constructor)
    //The throttle channel will return -1 if the signal from the transmitter is lost (the throttle pulse time goes below the minimum pulse time passed into the PPM constructor)
    ppmInputs->GetChannelData(rcCommandInputsRaw);
    
    // Channels: LUD, RRL, RUD, LRL, Pot1, Pot2, LeftSwitch, Right trigger
    for(int i = 0; i < 8; i++){
        rcCommandInputsRaw[i]*=100;
        rcCommandInputs[i]=rcCommandInputsRaw[i];
        if(i==1){  //If RRL
            rcCommandInputs[i] = map(rcCommandInputs[i], RRL_MIN, RRL_MAX,-100,100);
        }else if(i==0){ //if LUD
            rcCommandInputs[i] = map(rcCommandInputs[i], LUD_MIN, LUD_MAX,-100,100);
        }else if(i==3){ //if LRL
            rcCommandInputs[i] = map(rcCommandInputs[i], LRL_MIN, LRL_MAX,-100,100);
        }else if(i==4 || i==5){ //if Pot1 or Pot 2
            rcCommandInputs[i] = map(rcCommandInputs[i], LRL_MIN, LRL_MAX,0.0,1.0); //Do scaling factor for potentiometers
        }
        
        //Implement deadband
        if(i>=0 && i<=3){
            if(rcCommandInputs[i]<DEADBAND && rcCommandInputs[i]>-DEADBAND){
                rcCommandInputs[i] = 0;
            }else if(rcCommandInputs[i]>=DEADBAND){
                rcCommandInputs[i] = map(rcCommandInputs[i], DEADBAND, 100, 0, 100);
            }else if(rcCommandInputs[i]<=-DEADBAND){
                rcCommandInputs[i] = map(rcCommandInputs[i], -100, -DEADBAND, -100, 00);
            }
        }
    }
}

void scaleSkidSteer(){
    bool flag = 0;
    double maxVal = 0;
    for(int i=0; i<4; i++){
        if(fabs(odrvCmds[i])>100 && fabs(odrvCmds[i])>maxVal){
            flag = 1;
            maxVal = fabs(odrvCmds[i]);
        }
    }
    if(flag){
        for(int i=0; i<4; i++){
            odrvCmds[i] = odrvCmds[i]*100.0/maxVal;
        }
    }
}

void loop(){
    t = timer.read_us()/1000000.0;
    loopCounter++;
    
    //Take Average of IMU Readings (Handled by interrupts)
    pitchAvg+=pitch;
    rollAvg+=roll;
    yawAvg+=yaw;
    rollDotAvg+=rollDot;
    pitchDotAvg+=pitchDot;
    yawDotAvg+=yawDot;
    accXAvg+=accXAvg;
    accYAvg+=accYAvg;
    accZAvg+=accZAvg;
    
    if((t-tPrev)>=PERIOD){
        dt = t - tPrev;
        tPrev = t;
        badTime = (dt>1.20*PERIOD); //If more than 20% off
        getRCInputs();
        
        pitchAvg/=loopCounter;
        rollAvg/=loopCounter;
        yawAvg/=loopCounter;
        rollDotAvg/=loopCounter;
        pitchDotAvg/=loopCounter;
        yawDotAvg/=loopCounter;
        accXAvg/=loopCounter;
        accYAvg/=loopCounter;
        accZAvg/=loopCounter;
        
        if(CONTROL_MODE == 0){ //Do nothing
            //----    Servo Out
            myservo0.write(0.5);
            myservo1.write(0.5);
            myservo2.write(0.5);
            myservo3.write(0.5);
            //----    ODrive Out
            odrv0axis0.write(0.5);
            odrv0axis1.write(0.5);
            odrv1axis0.write(0.5);
            odrv1axis1.write(0.5);
        }else if(CONTROL_MODE==1){ //Pure Skid Steer
            //----    Servo Out
            myservo0.write(0.5);
            myservo1.write(0.5);
            myservo2.write(0.5);
            myservo3.write(0.5);
            //----    ODrive Out
            odrvCmds[0] = rcCommandInputs[3] + rcCommandInputs[0]; // LRL+LUD
            odrvCmds[1] = -rcCommandInputs[3] + rcCommandInputs[0];
            odrvCmds[2] = rcCommandInputs[3] + rcCommandInputs[0];
            odrvCmds[3] = -rcCommandInputs[3] + rcCommandInputs[0];
            
            //Skid Steer Scaling
            scaleSkidSteer();
            
            //Final Scaling with Potentiometer
            for(int i=0; i<4; i++){
                odrvCmds[i]*=rcCommandInputs[4];
            }
            
            odrv0axis0.write(map(odrvCmds[0],-100,100,0,1));
            odrv0axis1.write(map(-odrvCmds[1],-100,100,0,1));
            odrv1axis0.write(map(-odrvCmds[2],-100,100,0,1));
            odrv1axis1.write(map(odrvCmds[3],-100,100,0,1));
        }else if(CONTROL_MODE==2){ //OmniSteer
            // Channels: LUD, RRL, RUD, LRL, Pot1, Pot2, LeftSwitch, Right trigger
                // Channels: LUD, RRL, RUD, LRL, Pot1, Pot2, LeftSwitch, Right trigger
                if(rcCommandInputs[0]==0 && rcCommandInputs[3]==0 && rcCommandInputs[1]!=0 ){  //Turn on Dime
                    myservo0.write(0.75);
                    myservo1.write(0.25);
                    myservo2.write(0.75);
                    myservo3.write(0.25);            
                    odrvCmds[0] = -rcCommandInputs[1]; 
                    odrvCmds[1] = rcCommandInputs[1];
                    odrvCmds[2] = rcCommandInputs[1];
                    odrvCmds[3] = -rcCommandInputs[1];
                }else if(rcCommandInputs[1]==0 && rcCommandInputs[3]!=0 && rcCommandInputs[0]==0){  //Strafe Left-Right
                    myservo0.write(1);
                    myservo1.write(0);
                    myservo2.write(1);
                    myservo3.write(0);
                    odrvCmds[0] = -rcCommandInputs[3]; 
                    odrvCmds[1] = rcCommandInputs[3];
                    odrvCmds[2] = -rcCommandInputs[3];
                    odrvCmds[3] = rcCommandInputs[3];
                }else{  // Steer while driving
                    myservo0.write(0.5 - map(rcCommandInputs[1],-100,100,-0.5,0.5));
                    myservo1.write(0.5 - map(rcCommandInputs[1],-100,100,-0.5,0.5));
                    myservo2.write(0.5 + map(rcCommandInputs[1],-100,100,-0.5,0.5));
                    myservo3.write(0.5 + map(rcCommandInputs[1],-100,100,-0.5,0.5));
                    odrvCmds[0] = -0.1*rcCommandInputs[1] + rcCommandInputs[0]; //LRL+LUD, 3 and 0
                    odrvCmds[1] = 0.1*rcCommandInputs[1] + rcCommandInputs[0];
                    odrvCmds[2] = 0.1*rcCommandInputs[1] + rcCommandInputs[0];
                    odrvCmds[3] = -0.1*rcCommandInputs[1] + rcCommandInputs[0];
                }
                
                //Skid Steer Scaling
                scaleSkidSteer();
                
                //Final Scaling with Potentiometer
                for(int i=0; i<4; i++){
                    odrvCmds[i]*=rcCommandInputs[4];
                    map(odrvCmds[i],-25,25,-25,25);   
                }
                
                odrv0axis0.write(map(odrvCmds[0],-100,100,0,1));
                odrv0axis1.write(map(-odrvCmds[1],-100,100,0,1));
                odrv1axis0.write(map(-odrvCmds[2],-100,100,0,1));
                odrv1axis1.write(map(odrvCmds[3],-100,100,0,1));
        }else if(CONTROL_MODE==3){ //Only Balance
            double kP1 = 10;
            double kD1 = 25;
            double kP2 = 10;
            double kD2 = 25;
            Tx = (kP1*(0 - roll) + kD1*(0 - rollDot))*DEG2RAD;
            Ty = (kP2*(0 - pitch) + kD2*(0 - pitchDot))*DEG2RAD;
        
            myservo0.write(0.75);
            myservo1.write(0.25);
            myservo2.write(0.75);
            myservo3.write(0.25);            
            odrvCmds[0] = KV*HSIN*(Ty + Tx)/2;
            odrvCmds[1] = KV*HSIN*(Ty - Tx)/2;
            odrvCmds[2] = KV*HSIN*(Ty + Tx)/2;
            odrvCmds[3] = KV*HSIN*(Ty - Tx)/2;
            
            //Final Scaling with Potentiometer
            for(int i=0; i<4; i++){
                odrvCmds[i]*=rcCommandInputs[4];
                map(odrvCmds[i],-25,25,-25,25);
            }
            
            odrv0axis0.write(map(odrvCmds[0],-25,25,0,1));
            odrv0axis1.write(map(-odrvCmds[1],-25,25,0,1));
            odrv1axis0.write(map(-odrvCmds[2],-25,25,0,1));
            odrv1axis1.write(map(odrvCmds[3],-25,25,0,1));
        }else if(CONTROL_MODE==4){ //State Machine
            if(state==0){ //Omnisteer
                // Channels: LUD, RRL, RUD, LRL, Pot1, Pot2, LeftSwitch, Right trigger
                if(rcCommandInputs[0]==0 && rcCommandInputs[3]==0 && rcCommandInputs[1]!=0 ){  //Turn on Dime
                    myservo0.write(0.75);
                    myservo1.write(0.25);
                    myservo2.write(0.75);
                    myservo3.write(0.25);            
                    odrvCmds[0] = -rcCommandInputs[1]; 
                    odrvCmds[1] = rcCommandInputs[1];
                    odrvCmds[2] = rcCommandInputs[1];
                    odrvCmds[3] = -rcCommandInputs[1];
                }else if(rcCommandInputs[1]==0 && rcCommandInputs[3]!=0 && rcCommandInputs[0]==0){  //Strafe Left-Right
                    myservo0.write(1);
                    myservo1.write(0);
                    myservo2.write(1);
                    myservo3.write(0);
                    odrvCmds[0] = -rcCommandInputs[3]; 
                    odrvCmds[1] = rcCommandInputs[3];
                    odrvCmds[2] = -rcCommandInputs[3];
                    odrvCmds[3] = rcCommandInputs[3];
                }else{  // Steer while driving
                    myservo0.write(0.5 - map(rcCommandInputs[1],-100,100,-0.5,0.5));
                    myservo1.write(0.5 - map(rcCommandInputs[1],-100,100,-0.5,0.5));
                    myservo2.write(0.5 + map(rcCommandInputs[1],-100,100,-0.5,0.5));
                    myservo3.write(0.5 + map(rcCommandInputs[1],-100,100,-0.5,0.5));
                    odrvCmds[0] = -0.1*rcCommandInputs[1] + rcCommandInputs[0]; //LRL+LUD, 3 and 0
                    odrvCmds[1] = 0.1*rcCommandInputs[1] + rcCommandInputs[0];
                    odrvCmds[2] = 0.1*rcCommandInputs[1] + rcCommandInputs[0];
                    odrvCmds[3] = -0.1*rcCommandInputs[1] + rcCommandInputs[0];
                }
                
                //Skid Steer Scaling
                scaleSkidSteer();
                
                //Final Scaling with Potentiometer
                for(int i=0; i<4; i++){
                    odrvCmds[i]*=rcCommandInputs[4];
                    map(odrvCmds[i],-25,25,-25,25);   
                }
                
                odrv0axis0.write(map(odrvCmds[0],-100,100,0,1));
                odrv0axis1.write(map(-odrvCmds[1],-100,100,0,1));
                odrv1axis0.write(map(-odrvCmds[2],-100,100,0,1));
                odrv1axis1.write(map(odrvCmds[3],-100,100,0,1));
                
                if(rcCommandInputs[5]*sqrt((accX*accX + accY*accY + accZ*accZ))<0.5){ //units are gs
                    state = 1; //If we are in freefall, transition to aerial balance state.
                    tAirStart = t;
                }
            }else if(state==1){ //Aerial Balance
                double kP1 = 75;
                double kD1 = 12;
                double kP2 = 75;
                double kD2 = 12;
                Tx = (kP1*(-1.3 - roll) + kD1*(0 - rollDot))*DEG2RAD;
                Ty = (kP2*(1.5 - pitch) + kD2*(0 - pitchDot))*DEG2RAD;
            
                myservo0.write(0.75);
                myservo1.write(0.25);
                myservo2.write(0.75);
                myservo3.write(0.25);            
                odrvCmds[0] = KV*HSIN*(Ty + Tx)/2;
                odrvCmds[1] = KV*HSIN*(Ty - Tx)/2;
                odrvCmds[2] = KV*HSIN*(Ty + Tx)/2;
                odrvCmds[3] = KV*HSIN*(Ty - Tx)/2;
                
                //Final Scaling with Potentiometer
                for(int i=0; i<4; i++){
                    odrvCmds[i]*=rcCommandInputs[4];
                    odrvCmds[i] = map(odrvCmds[i],-25,25,-25,25);
                }
                
                odrv0axis0.write(map(odrvCmds[0],-25,25,0,1));
                odrv0axis1.write(map(-odrvCmds[1],-25,25,0,1));
                odrv1axis0.write(map(-odrvCmds[2],-25,25,0,1));
                odrv1axis1.write(map(odrvCmds[3],-25,25,0,1));
                 
//                if(rcCommandInputs[5]*sqrt((accX*accX + accY*accY + accZ*accZ))>4){ //units are m/(s^2)   
                if((rcCommandInputs[5]*sqrt((accX*accX + accY*accY + accZ*accZ))>2)||(t-tAirStart>1.5)){ //units are m/(s^2)
                    state = 2; //If we land, transition back to omnidirectional control
                }
            }else if(state==2){ //Land
                myservo0.write(0.75);
                myservo1.write(0.25);
                myservo2.write(0.75);
                myservo3.write(0.25);
                odrvCmds[0] = 0;
                odrvCmds[1] = 0;
                odrvCmds[2] = 0;
                odrvCmds[3] = 0;
                odrv0axis0.write(map(0,-100,100,0,1));
                odrv0axis1.write(map(0,-100,100,0,1));
                odrv1axis0.write(map(0,-100,100,0,1));
                odrv1axis1.write(map(0,-100,100,0,1));
            }
        }
        
        //Telemetry
        teleCounter++;
        if(teleCounter > SERIAL_RATIO and USESERIAL and !inInt){ //and pcSerial.writable()
            //inWrite = 1;
            if(badTime){
                pcSerial.printf("b%f\n",dt-1.25*PERIOD); //(dt-PERIOD)
                badTime = 0;
            }
            teleCounter = 0;            
//            pcSerial.printf("%f, %f, %f, %f, %f, %f, %f, %f\r\n", rcCommandInputsRaw[0], rcCommandInputsRaw[1], rcCommandInputsRaw[2], rcCommandInputsRaw[3], rcCommandInputsRaw[4], rcCommandInputsRaw[5], rcCommandInputsRaw[6], rcCommandInputsRaw[7]);
//            pcSerial.printf("%f, %f, %f, %f, %f, %f, %f, %f, %f, \n", rcCommandInputs[0], rcCommandInputs[1], rcCommandInputs[3], rcCommandInputs[4], rcCommandInputs[5], odrvCmds[0], odrvCmds[1], odrvCmds[2], odrvCmds[3]);
//            pcSerial.printf("%f, %f, %f, \n", roll, pitch, yaw);
//            pcSerial.printf("%f, %f, %f, ", rollDot, pitchDot, yawDot);
//            pcSerial.printf("%f, %f, %f, %f, %f, %f, %f, %f, %i\n",
//                            accX, accY, accZ, accXAvg, accYAvg, accZAvg,
//                            sqrt(accX*accX + accY*accY + accZ*accZ), sqrt(accXAvg*accXAvg + accYAvg*accYAvg + accZAvg*accZAvg), state);
            //            pcSerial.printf("%f, %f, %f, \n", roll, pitch, yaw);
            pcSerial.printf("%f, ",t);
            pcSerial.printf("%f, %f, ", roll, pitch);
            pcSerial.printf("%f, %f, ", rollDot, pitchDot);
            pcSerial.printf("%f, %f, %f, %f ", odrvCmds[0], odrvCmds[1], odrvCmds[2], odrvCmds[3]);
            pcSerial.printf("%f, %i \n", sqrt(accX*accX + accY*accY + accZ*accZ), state);
            
//            pcSerial.printf("%f, %f, %f, %f, %f, %f, %f, %f, %i\n",
//                            accX, accY, accZ, accXAvg, accYAvg, accZAvg,
//                            sqrt(accX*accX + accY*accY + accZ*accZ), sqrt(accXAvg*accXAvg + accYAvg*accYAvg + accZAvg*accZAvg), state);
        }
        pitchAvg=0;
        rollAvg=0;
        yawAvg=0;
        rollDotAvg=0;
        pitchDotAvg=0;
        yawDotAvg=0;
        accXAvg=0;
        accYAvg=0;
        accZAvg=0;
        loopCounter = 0;
    }
}