#include "mbed.h"
#include "rtos.h"
#include "EthernetInterface.h"
#include "ExperimentServer.h"
#include "QEI.h"
#include "BezierCurve.h"
#include "MotorShield.h" 
#include "HardwareSetup.h"

#define BEZIER_ORDER_JUMP    7
#define BEZIER_ORDER_FLIGHT  7
#define BEZIER_ORDER_LAND    7

#define NUM_INPUTS (12 + 2*(BEZIER_ORDER_JUMP+1) + 2*(BEZIER_ORDER_FLIGHT+1) + 2*(BEZIER_ORDER_LAND+1) + 2*(BEZIER_ORDER_FLIGHT+1))
#define NUM_OUTPUTS 19

#define PULSE_TO_RAD (2.0f*3.14159f / 1200.0f)

// Initializations
Serial pc(USBTX, USBRX);    // USB Serial Terminal
ExperimentServer server;    // Object that lets us communicate with MATLAB
Timer t;                    // Timer to measure elapsed time of experiment

QEI encoderA(PE_9,PE_11, NC, 1200, QEI::X4_ENCODING);  // MOTOR A ENCODER (no index, 1200 counts/rev, Quadrature encoding)
QEI encoderB(PA_5, PB_3, NC, 1200, QEI::X4_ENCODING);  // MOTOR B ENCODER (no index, 1200 counts/rev, Quadrature encoding)
QEI encoderC(PC_6, PC_7, NC, 1200, QEI::X4_ENCODING);  // MOTOR C ENCODER (no index, 1200 counts/rev, Quadrature encoding)
QEI encoderD(PD_12, PD_13, NC, 1200, QEI::X4_ENCODING);// MOTOR D ENCODER (no index, 1200 counts/rev, Quadrature encoding)

MotorShield motorShield(12000); //initialize the motor shield with a period of 12000 ticks or ~20kHZ
Ticker currentLoop;

// Variables for q1
float current1;
float current_des1 = 0;
float prev_current_des1 = 0;
float current_int1 = 0;
float angle1;
float velocity1;
float duty_cycle1;
float angle1_init;

// Variables for q2
float current2;
float current_des2 = 0;
float prev_current_des2 = 0;
float current_int2 = 0;
float angle2;
float velocity2;
float duty_cycle2;
float angle2_init;

// Fixed kinematic parameters
const float l_OA=.03673; 
const float l_OB=.06773; 
const float l_AC=.096; 
const float l_DE=.127-(l_OB-l_OA);

// Timing parameters
float current_control_period_us = 200.0f;     // 5kHz current control loop
float impedance_control_period_us = 1000.0f;  // 1kHz impedance control loop
float start_period, traj_period, end_period;
float jump_period, flight_period, land_period;

// Control parameters
float current_Kp = 4.0f;         
float current_Ki = 0.4f;           
float current_int_max = 3.0f;       
float duty_max;      
float K_xx;
float K_yy;
float K_xy;
float D_xx;
float D_xy;
float D_yy;

// Model parameters
float supply_voltage = 12;     // motor supply voltage
float R = 2.0f;                // motor resistance
float k_t = 0.18f;             // motor torque constant
float nu = 0.0005;             // motor viscous friction

// Foot Sensor Parameters
bool airborne = 0;
bool landCheck = false;
uint8_t gaitState = 0;
uint8_t hopCount = 0;
uint8_t hopCountLim = 4; // 0 means 1 hop

// Current control interrupt function
void CurrentLoop()
{
    // This loop sets the motor voltage commands using PI current controllers with feedforward terms.
    //use the motor shield as follows:
    //motorShield.motorAWrite(DUTY CYCLE, DIRECTION), DIRECTION = 0 is forward, DIRECTION =1 is backwards.
        
    current1 = -(((float(motorShield.readCurrentA())/65536.0f)*30.0f)-15.0f);           // measure current
    velocity1 = encoderA.getVelocity() * PULSE_TO_RAD;                                  // measure velocity        
    float err_c1 = current_des1 - current1;                                             // current errror
    current_int1 += err_c1;                                                             // integrate error
    current_int1 = fmaxf( fminf(current_int1, current_int_max), -current_int_max);      // anti-windup
    float ff1 = R*current_des1 + k_t*velocity1;                                         // feedforward terms
    duty_cycle1 = (ff1 + current_Kp*err_c1 + current_Ki*current_int1)/supply_voltage;   // PI current controller
    
    float absDuty1 = abs(duty_cycle1);
    if (absDuty1 > duty_max) {
        duty_cycle1 *= duty_max / absDuty1;
        absDuty1 = duty_max;
    }    
    if (duty_cycle1 < 0) { // backwards
        motorShield.motorAWrite(absDuty1, 1);
    } else { // forwards
        motorShield.motorAWrite(absDuty1, 0);
    }             
    prev_current_des1 = current_des1; 
    
    current2     = -(((float(motorShield.readCurrentB())/65536.0f)*30.0f)-15.0f);       // measure current
    velocity2 = encoderB.getVelocity() * PULSE_TO_RAD;                                  // measure velocity  
    float err_c2 = current_des2 - current2;                                             // current error
    current_int2 += err_c2;                                                             // integrate error
    current_int2 = fmaxf( fminf(current_int2, current_int_max), -current_int_max);      // anti-windup   
    float ff2 = R*current_des2 + k_t*velocity2;                                         // feedforward terms
    duty_cycle2 = (ff2 + current_Kp*err_c2 + current_Ki*current_int2)/supply_voltage;   // PI current controller
    
    float absDuty2 = abs(duty_cycle2);
    if (absDuty2 > duty_max) {
        duty_cycle2 *= duty_max / absDuty2;
        absDuty2 = duty_max;
    }    
    if (duty_cycle2 < 0) { // backwards
        motorShield.motorBWrite(absDuty2, 1);
    } else { // forwards
        motorShield.motorBWrite(absDuty2, 0);
    }             
    prev_current_des2 = current_des2; 
    
}

int main (void)
{
    // Object for 7th order Cartesian foot trajectory
    BezierCurve rDesJump_bez(2,BEZIER_ORDER_JUMP);
    BezierCurve rDesFlight_bez(2,BEZIER_ORDER_FLIGHT);
    BezierCurve rDesLand_bez(2,BEZIER_ORDER_LAND);
    
    // Link the terminal with our server and start it up
    server.attachTerminal(pc);
    server.init();
    
    // Continually get input from MATLAB and run experiments
    float input_params[NUM_INPUTS];
    pc.printf("%f",input_params[0]);
    
    // Setup our foot sensor input
    DigitalIn  footContactPin(PC_6);
    
    // NOW WE STAY IN THIS WHILE STATEMENT FOREVER -------------------------------------------------------------------
    while(1) {
        // EVERYTHING IS NOW WAITING FOR SERIAL COMMUNICATION TO RUN BELOW, OTHERWISE WE ARE LOOPING THIS IF STATEMENT, this still only runs once per jump
        if (server.getParams(input_params,NUM_INPUTS)) {
            
            // Get inputs from MATLAB, note that these are formatted in Experiment_trajectory        
            gaitState = 0;
            landCheck = 0;
            start_period                = 1.5; // removed it as a parameter but we have a 1 second delay before starting the jump
            traj_period                 = 5.0;
            end_period                  = 1.0;
            jump_period                 = input_params[0];    // First buffer time, before trajectory
            flight_period               = input_params[1];    // Trajectory time/length
            land_period                 = input_params[2];    // Second buffer time, after trajectory
    
            angle1_init                 = input_params[3];    // Initial angle for q1 (rad)
            angle2_init                 = input_params[4];    // Initial angle for q2 (rad)

            K_xx                        = input_params[5];    // Foot stiffness N/m
            K_yy                        = input_params[6];    // Foot stiffness N/m
            K_xy                        = input_params[7];    // Foot stiffness N/m
            D_xx                        = input_params[8];    // Foot damping N/(m/s)
            D_yy                        = input_params[9];    // Foot damping N/(m/s)
            D_xy                        = input_params[10];   // Foot damping N/(m/s)
            duty_max                    = input_params[11];   // Maximum duty factor
          
            // Get foot trajectory points - start by initializing an empty array for each curve
            float jump_pts[2*(BEZIER_ORDER_JUMP+1)]; // foot_pts is a matrix sized for the bezier order
            float flight_pts[2*(BEZIER_ORDER_FLIGHT+1)]; 
            float land_pts[2*(BEZIER_ORDER_LAND+1)]; 
            
            // Now fill each array up with its respective points
            for(int i = 0; i<2*(BEZIER_ORDER_JUMP+1);i++) { 
                jump_pts[i] = input_params[12+i];    // assign the proper input parameters to each foot position
                flight_pts[i] = input_params[12+BEZIER_ORDER_JUMP+1 +i];
                land_pts[i] = input_params[12+BEZIER_ORDER_JUMP+1 + BEZIER_ORDER_FLIGHT+1 +i];
                // Adding this line to try and get the proper flight in
                //flight_pts[i] = input_params[12+BEZIER_ORDER_JUMP+1 + BEZIER_ORDER_FLIGHT+1 + BEZIER_ORDER_FLIGHT+1 +i];
            }
            //pc.printf("%f" " = flight points \n\r",flight_pts[1]);
            //float flight_pts[2][8] = {{0.0651000,0.0650000,0.0650000,0.0417000,-0.0169000,-0.0700000,-0.0700000,-0.0700000},
            //{-0.215000,-0.215000,-0.215000,-0.159700,-0.147000,-0.175000,-0.175000,-0.175100}};
            
            flight_pts[0] = .065000;
            flight_pts[1] = -.215000;
            flight_pts[2] = .065000;
            flight_pts[3] = -.215000;
            flight_pts[4] = .065000;
            flight_pts[5] = -.215000;
            flight_pts[6] = .0417;
            flight_pts[7] = -.1597;
            flight_pts[8] = -.0169;
            flight_pts[9] = -.147;
            flight_pts[10] = -.0700000;
            flight_pts[11] = -.175;
            flight_pts[12] = -.0700000;
            flight_pts[13] = -.175;
            flight_pts[14] = -.0700000;
            flight_pts[15] = -.175;
            
            
            //pc.printf("%f" " = flight points \n\r",flight_pts[1]);
            
            rDesJump_bez.setPoints(jump_pts);
            rDesFlight_bez.setPoints(flight_pts);
            rDesLand_bez.setPoints(land_pts);
            
            // Attach current loop interrupt
            currentLoop.attach_us(CurrentLoop,current_control_period_us);
                        
            // Setup experiment
            t.reset();
            t.start();
            encoderA.reset();
            encoderB.reset();
            encoderC.reset();
            encoderD.reset();

            motorShield.motorAWrite(0, 0); //turn motor A off
            motorShield.motorBWrite(0, 0); //turn motor B off
            
            hopCount = 0;
                         
            // Run experiment HERE's THE REAL MAIN LOOP ---------------------------------------------------------------------------------------------------------------------------
            while(gaitState < 5) { 
            //while( t.read() < 8.0) { 
                // Check to see if we are airborne or not, where airborne = 1 means we are not touching the ground
                if(footContactPin == true) {
                    airborne = true;}
                else {airborne = false;}
                // pc.printf("%d" " = Airborne? \n\r",airborne); // confirmed to be working
                
                // Read encoders to get motor states
                angle1 = encoderA.getPulses() *PULSE_TO_RAD + angle1_init;       
                velocity1 = encoderA.getVelocity() * PULSE_TO_RAD;
                angle2 = encoderB.getPulses() * PULSE_TO_RAD + angle2_init;       
                velocity2 = encoderB.getVelocity() * PULSE_TO_RAD;           
                
                const float th1 = angle1;
                const float th2 = angle2;
                const float dth1= velocity1;
                const float dth2= velocity2;
 
                // Calculate the Jacobian
                float Jx_th1 = l_AC*cos(th1 + th2) + l_DE*cos(th1) + l_OB*cos(th1);
                float Jx_th2 = l_AC*cos(th1 + th2);
                float Jy_th1 = l_AC*sin(th1 + th2) + l_DE*sin(th1) + l_OB*sin(th1);
                float Jy_th2 = l_AC*sin(th1 + th2);
                                
                // Calculate the forward kinematics (position and velocity)
                float xFoot = l_AC*sin(th1 + th2) + l_DE*sin(th1) + l_OB*sin(th1);
                float yFoot = - l_AC*cos(th1 + th2) - l_DE*cos(th1) - l_OB*cos(th1);
                float dxFoot = dth1*(l_AC*cos(th1 + th2) + l_DE*cos(th1) + l_OB*cos(th1)) + dth2*l_AC*cos(th1 + th2);
                float dyFoot = dth1*(l_AC*sin(th1 + th2) + l_DE*sin(th1) + l_OB*sin(th1)) + dth2*l_AC*sin(th1 + th2);       

                // Set gains based on buffer and traj times, then calculate desired x,y from Bezier trajectory at current time if necessary
                float teff  = 0;
                float vMult = 0;
                
                
                // HERE IS WHERE WE CHECK WHERE WE ARE IN THE GAIT ----------------------------------------------------------------------------------------------
                // gaitState 0 is just a buffer time before the jump holding the intial position
                if( t < start_period) {
                    if (K_xx > 0 || K_yy > 0) {
                        K_xx = input_params[5];  // Foot stiffness N/m
                        K_yy = input_params[6];  // Foot stiffness N/m
                        K_xy = input_params[7];  // Foot stiffness N/m
                        D_xx = input_params[8];  // Foot damping N/(m/s)
                        D_yy = input_params[9];  // Foot damping N/(m/s)
                        D_xy = input_params[10]; // Foot damping N/(m/s)
                        gaitState = 0;
                    }
                    teff = 0;
                    //pc.printf("%f" " Start Phase \n\r",teff);
                }
                // Jump Phase = gait 1
                else if (t < start_period + jump_period) {
                    // Lets only update the gains if we need to, this will be the last one for now
                    K_xx = input_params[5];  // Foot stiffness N/m
                    K_yy = input_params[6];  // Foot stiffness N/m
                    K_xy = input_params[7];  // Foot stiffness N/m
                    D_xx = input_params[8];  // Foot damping N/(m/s)
                    D_yy = input_params[9];  // Foot damping N/(m/s)
                    D_xy = input_params[10]; // Foot damping N/(m/s)
                    vMult = 1;
                    teff = (t - start_period);
                    gaitState = 1;
//                    pc.printf("%f" " Jump Phase \n\r",teff);
                }
                // Flight Phase = gait 2, we should not be touching the ground here
                // If t is lower than expected 
                else if (t < start_period + jump_period + land_period + 1.6 && landCheck == false) {
                    teff = (t - start_period - jump_period);
                    gaitState = 2;
//                    pc.printf("%f" " Flight Phase \n\r",teff);
                    
                    // If we land, or if we are 1 second past the expected flight phase, continue to the land gait
                    if (airborne == 0 || t > start_period + jump_period + flight_period + 1.5) { 
//                        pc.printf("%f" " < FLIGHT PHASE OVER, AIRBORNE FOR THIS LONG \n\r",teff);
                        landCheck = true;
                        flight_period = (t - start_period - jump_period);
                        gaitState = 3;
                        teff = (t - start_period - jump_period - flight_period);
                    }
                }
                // Land Phase = gait 3
                else if (gaitState == 3) {
                    teff = (t - start_period - jump_period - flight_period);
                    gaitState = 3;
//                    pc.printf("%f" " Land Phase \n\r",teff);
                    if (t > start_period + jump_period + flight_period + land_period) {
                        gaitState = 4;
                        vMult = 0;
                    }
                }
                else {
                    if (hopCount <= hopCountLim) {
                        hopCount = hopCount + 1;
                        t.reset();
                        t.start();
                    }
                    else {
                        gaitState = 5;
                        vMult = 0;
                    }
                    hopCount = hopCount + 1;
                }
                
                // Get desired foot positions and velocities                
                float rDesFoot[2] , vDesFoot[2];
                // Evaluate the correct gait (bezier curve) at the % detemined by how far into teff we are. There are 4 gaitStates...
                if (gaitState == 0){
                    rDesFoot[0] = jump_pts[0];
                    rDesFoot[1] = jump_pts[1];
                    vDesFoot[0] = 0.0;
                    vDesFoot[1] = 0.0;
                }
                else if (gaitState == 1){
//                    pc.printf("%d" " 1=JUMP \n\r",gaitState);
                    rDesJump_bez.evaluate(teff/jump_period, rDesFoot);
                    rDesJump_bez.evaluateDerivative(teff/jump_period, vDesFoot);
                    vDesFoot[0] /= jump_period;
                    vDesFoot[1] /= jump_period;
                }
                else if (gaitState == 2){
//                   pc.printf("%d" " 2=FLIGHT \n\r",gaitState);
                    rDesFlight_bez.evaluate(teff/flight_period, rDesFoot);
                    rDesFlight_bez.evaluateDerivative(teff/flight_period, vDesFoot);
                    vDesFoot[0] /= flight_period;
                    vDesFoot[1] /= flight_period;
                    if (teff > jump_period) {
                        rDesFoot[0] = flight_pts[14];
                        rDesFoot[1] = flight_pts[15];
                        vDesFoot[0] = 0.0;
                        vDesFoot[1] = 0.0;
                    }
                }
                else if (gaitState == 3){
//                    pc.printf("%d" " 3=LAND \n\r",gaitState);
                    rDesLand_bez.evaluate(teff/land_period, rDesFoot);
                    rDesLand_bez.evaluateDerivative(teff/land_period, vDesFoot);
                    vDesFoot[0] /= land_period;
                    vDesFoot[1] /= land_period;
                }
                // if we are past the trajectory period multiply by 0
                vDesFoot[0]*=vMult; 
                vDesFoot[1]*=vMult;
                
                // Calculate the inverse kinematics (joint positions and velocities) for desired joint angles              
                float xFoot_inv = -rDesFoot[0];
                float yFoot_inv = rDesFoot[1];                
                float l_OE = sqrt( (pow(xFoot_inv,2) + pow(yFoot_inv,2)) );
                float alpha = abs(acos( (pow(l_OE,2) - pow(l_AC,2) - pow((l_OB+l_DE),2))/(-2.0f*l_AC*(l_OB+l_DE)) ));
                float th2_des = -(3.14159f - alpha); 
                float th1_des = -((3.14159f/2.0f) + atan2(yFoot_inv,xFoot_inv) - abs(asin( (l_AC/l_OE)*sin(alpha) )));
                
                float dd = (Jx_th1*Jy_th2 - Jx_th2*Jy_th1);
                float dth1_des = (1.0f/dd) * (  Jy_th2*vDesFoot[0] - Jx_th2*vDesFoot[1] );
                float dth2_des = (1.0f/dd) * ( -Jy_th1*vDesFoot[0] + Jx_th1*vDesFoot[1] );
        
                // Calculate error variables
                float e_x = xFoot_inv - xFoot;
                float e_y = yFoot_inv - yFoot;
                float de_x = vDesFoot[0] - dxFoot;
                float de_y = vDesFoot[1] - dyFoot;
        
                // Calculate virtual force on foot
                float fx = K_xx*(e_x) + K_xy*(e_y) + D_xx*(de_x) + D_xy*(de_y);
                float fy = K_xy*(e_x) + K_yy*(e_y) + D_xy*(de_x) + D_yy*(de_y);
                
                float tau1 = Jx_th1*fx + Jy_th1*fy;
                float tau2 = Jx_th2*fx + Jy_th2*fy;
                
                current_des1 = tau1/k_t;
                current_des2 = tau2/k_t;
                
                // All of the above was to set these current_des values which will get fed into the CurrentLoop function running on an interrupt
                                                
                // Joint impedance
                // sub Kxx for K1, Dxx for D1, Kyy for K2, Dyy for D2
                // Note: Be careful with signs now that you have non-zero desired angles!
                // Your equations should be of the form i_d = K1*(q1_d - q1) + D1*(dq1_d - dq1)
                //current_des1 = (K_xx*(angle1_init - angle1) + D_xx*(0.0 - velocity1))/k_t;           
                //current_des2 = (K_yy*(angle2_init - angle2) + D_yy*(0.0 - velocity2))/k_t;   
                //current_des1 = 0;
                //current_des2 = 0;
                     
                // Cartesian impedance  
                // Note: As with the joint space laws, be careful with signs!              
                //current_des1 = (K_xx*(th1_des - angle1) + D_xx*(th1_des - velocity1))/k_t;           
                //current_des2 = (K_yy*(th2_des - angle2) + D_yy*(th2_des - velocity2))/k_t;    
                
                
                // Form output to send to MATLAB     
                float output_data[NUM_OUTPUTS];
                
                // current time
                output_data[0] = t.read();
                // motor 1 state
                output_data[1] = angle1;
                output_data[2] = velocity1;  
                output_data[3] = current1;
                output_data[4] = current_des1;
                output_data[5] = duty_cycle1;
                // motor 2 state
                output_data[6] = angle2;
                output_data[7] = velocity2;
                output_data[8] = current2;
                output_data[9] = current_des2;
                output_data[10]= duty_cycle2;
                // foot state
                output_data[11] = xFoot;
                output_data[12] = yFoot;
                output_data[13] = dxFoot;
                output_data[14] = dyFoot;
                output_data[15] = rDesFoot[0];
                output_data[16] = rDesFoot[1];
                output_data[17] = vDesFoot[0];
                output_data[18] = vDesFoot[1];
                
                // Send data to MATLAB
                server.sendData(output_data,NUM_OUTPUTS);

                wait_us(impedance_control_period_us);   
            }
            
            // Cleanup after experiment
            server.setExperimentComplete();
            currentLoop.detach();
            motorShield.motorAWrite(0, 0); //turn motor A off
            motorShield.motorBWrite(0, 0); //turn motor B off
        
        } // end if
        
    } // end while
    
} // end main

/*
pc.printf("%f" " OUTPUT DATA 0 \n\r",output_data[0]);
pc.printf("%f" " OUTPUT DATA 0 \n\r",output_data[1]);
pc.printf("%f" " OUTPUT DATA 0 \n\r",output_data[2]);
pc.printf("%f" " OUTPUT DATA 0 \n\r",output_data[3]);
pc.printf("%f" " OUTPUT DATA 0 \n\r",output_data[4]);
pc.printf("%f" " OUTPUT DATA 0 \n\r",output_data[5]);
pc.printf("%f" " OUTPUT DATA 0 \n\r",output_data[6]);
pc.printf("%f" " OUTPUT DATA 0 \n\r",output_data[7]);
pc.printf("%f" " OUTPUT DATA 0 \n\r",output_data[8]);
pc.printf("%f" " OUTPUT DATA 0 \n\r",output_data[9]);
pc.printf("%f" " OUTPUT DATA 0 \n\r",output_data[10]);
pc.printf("%f" " OUTPUT DATA 0 \n\r",output_data[11]);
pc.printf("%f" " OUTPUT DATA 0 \n\r",output_data[12]);
pc.printf("%f" " OUTPUT DATA 0 \n\r",output_data[13]);
pc.printf("%f" " OUTPUT DATA 0 \n\r",output_data[14]);
pc.printf("%f" " OUTPUT DATA 0 \n\r",output_data[15]);
pc.printf("%f" " OUTPUT DATA 0 \n\r",output_data[16]);
pc.printf("%f" " OUTPUT DATA 0 \n\r",output_data[17]);
pc.printf("%f" " OUTPUT DATA 0 \n\r",output_data[18]);
*/