// THE FIGHTING BANANA SLUGS!!!
 
#include "mbed.h"       // mbed library
#include "QEI.h"        // quadrature encoder library to count encoder ticks

//Setup
//Motor 1
DigitalOut mDir1_A(p5);
DigitalOut mDir1_B(p6);
PwmOut pwmOut1(p21);
QEI encoder1(p23, p24, NC, 1200, QEI::X4_ENCODING);

//Motor 2
DigitalOut mDir2_A(p11);
DigitalOut mDir2_B(p12);
PwmOut pwmOut2(p22);
QEI encoder2(p9, p10, NC, 1200, QEI::X4_ENCODING);

DigitalIn button(p7);

// Declare other objects
Ticker ctrlTicker;                // creates an instance of the ticker class, which can be used for running functions at a specified frequency.
Ticker trajTicker;
Ticker buttonTicker;
Serial mySerial(USBTX, USBRX);  // create a serial connection to the computer over the tx/rx pins

DigitalOut dOut1(p17);

float a1_t0 = 0;    //motor angle 1 from previous time step
float a1_t1 = 0;    //motor angle 1 from current time step
float a2_t0 = 0;    //motor angle 2 from current time step
float a2_t1 = 0;    //motor angle 2 from current time step
float w1 = 0;
float w2 = 0;
float fPWM = 1000;
int nTraj = 0; //trajectory index
float aD[2] = {0.0, 0.0};
//knee
float traj2[] = {0.0};
//hip
float traj1[] = {0.0};
int numPoints = sizeof(traj1)/sizeof(traj1[0]);
float fTraj = numPoints/0.5;   //time frequency of trajectory commands
//starting motor positions
float a1_0 = traj1[0];
float a2_0 = traj2[0];
float e1 = 0;
float e2 = 0;
//controller gains
float kp1 = 1.7;
float kd1 = 0.0;
float kp2 = 1.5;
float kd2 = 0.0;

bool done = false;

void pdControl() { 
    float in1 = aD[0];
    float in2 = aD[1];
    //get motor position
    a1_t1 = encoder1.getPulses()*2*3.14/1200.0;
    a2_t1 = encoder2.getPulses()*2*3.14/1200.0;
    //calculate error
    e1 = a1_t1-in1;
    e2 = a2_t1-in2;
    //calculate motor speed
    w1 = abs(a1_t1-a1_t0)*fTraj;
    w2 = abs(a2_t1-a2_t0)*fTraj;
    //set motor direction
    mDir1_A = (e1<0);
    mDir1_B = !(e1<0);
    mDir2_A = (e2>0);
    mDir2_B = !(e2>0);
    //command motor speed
    pwmOut1.period(.0001); //set pwm frequency to 10kHz
    pwmOut2.period(.0001); //set pwm frequency to 10kHz
    pwmOut1.write(abs(kp1*e1)+abs(w1*kd1));
    pwmOut2.write(abs(kp2*e2)+abs(w2*kd2));
    
    a1_t0 = a1_t1; //save encoder position for next step (to find angular velocity)
    a2_t0 = a2_t1;
}

void setTraj() {
    if (nTraj >= numPoints){
        nTraj = 0;
        //done = true;
    } else {
        done = false;
        aD[0] = traj1[nTraj]-a1_0;
        aD[1] = traj2[nTraj]-a2_0;
        nTraj++;
    }
}

void checkBtn() {
    if (button.read() == 1){
        done = true;
    }
}

int main() {
    wait(5.0);
    mDir1_A = 1;
    mDir1_B = 0;
    mDir2_A = 1;
    mDir2_B = 0;
    
    //mySerial.printf("numPoints: %d, fTraj: %f, fPWM: %f\n\r", numPoints, fTraj, fPWM);
    
    //get initial position
    
    dOut1.write(1);
    
    if (sizeof(traj1) == sizeof(traj2)) {
        trajTicker.attach(setTraj, 1/fTraj);
        ctrlTicker.attach(pdControl, 1/fPWM);
        buttonTicker.attach(checkBtn, 0.1);
        while (!done) {
            //mySerial.printf("motor 2: %f \n\r", a2_t1);
        }
        //mySerial.printf("Done\n\r");
        dOut1.write(0);
        mDir1_A = 0;
        mDir1_B = 0;
        mDir2_A = 0;
        mDir2_B = 0;
    } else {
        //mySerial.printf("Input error\n\r");
    }
}