#include "mbed.h"
#include "General.hpp"
#include "ros.h"
#include "rtos.h"
#include "ROS_Handler.hpp"
#include <motordriver.h>
#include "math.h"
#include "Motors.hpp"
#include "Battery_Monitor.hpp"
#include "Pins.h"

//Defines used for the encoder counting
#define PREV_MASK 0x1 //Mask for the previous state in determining direction of rotation.
#define CURR_MASK 0x2 //Mask for the current state in determining direction of rotation.
#define INVALID   0x3 //XORing two states where both bits have changed.

Motor A(PB_13, PB_15, PC_6, 1); // pwm, fwd, rev, can brake
Motor B(PB_4, PC_7, PA_15, 1); // pwm, fwd, rev, can brake

Serial pc(USBTX, USBRX);

InterruptIn encoderAchannel1(PB_12);
InterruptIn encoderAchannel2(PA_4);
InterruptIn encoderBchannel1(PB_5);
InterruptIn encoderBchannel2(PB_3);

float radius = 3; //wheel radius in cm
float bodyradius = 9.15; //radius of wheel centre from centre of bot, measured in vertical axis
float benchmark = 3576; //number of encoder pulses per rotation of a wheel
float pwm = -0.5; //the pwm constant that the wheels will use when moving
float brakedist = 105; //the encoder target offset where the brakes will engage to prevent overshoot
//^^ found using trial and error for a pwm value of 0.5 - will almost certainly vary if that is changed

volatile long acurrent = 0; //wheel A's current position in encoder counts
int atarget = 0; //wheel A's current target position in encoder counts
int aprevtarget = 0; //wheel A's previous target position in encoder counts
float apwm = 0; //the pwm for wheel a, updated using the 'pwm' variable
volatile int channelAstate1 = encoderAchannel1.read(); //read the encoder channels for motor A and record their states
volatile int channelAstate2 = encoderAchannel2.read();
int currentStateA = (channelAstate1 << 1) | (channelAstate2); //store a 2-bit state to track the current encoder readings
int previousStateA = currentStateA;
int nextmoveflagA = 1;

volatile long bcurrent = 0;
int btarget = 0;
int bprevtarget = 0;
float bpwm = 0;
volatile int channelBstate1 = encoderBchannel1.read();
volatile int channelBstate2 = encoderBchannel2.read();
int currentStateB = (channelBstate1 << 1) | (channelBstate2);
int previousStateB = currentStateB;
int nextmoveflagB = 1;

void initialize()
{
    //every time a rising or falling edge is detected for the A or B channels, call the encoder counting interrupt for the respective wheel.
    //this corresponds to the counting program for the encoders for each wheel, given that quadrature encoding is used.
    encoderAchannel1.rise(&callbackA);
    encoderAchannel1.fall(&callbackA);
    encoderAchannel2.rise(&callbackA);
    encoderAchannel2.fall(&callbackA);
    encoderBchannel1.rise(&callbackB);
    encoderBchannel1.fall(&callbackB);
    encoderBchannel2.rise(&callbackB);
    encoderBchannel2.fall(&callbackB);
    //make sure the initial values for the encoders are all set correctly, and reflect their current real state.
    channelAstate1 = encoderAchannel1.read();
    channelAstate2 = encoderAchannel2.read();
    currentStateA = (channelAstate1 << 1) | (channelAstate2);
    previousStateA = currentStateA;
    channelBstate1 = encoderBchannel1.read();
    channelBstate2 = encoderBchannel2.read();
    currentStateB = (channelBstate1 << 1) | (channelBstate2);
    previousStateB = currentStateB;
}

void callbackA()
{
    ////////
    //insert the 'if X, carry out the rest of this code' debouncing check here, ensure the delay is minimal enough to avoid hindering 3576 measurements per rotation at max speed
    ////////

    volatile int changeA = 0;
    channelAstate1 = encoderAchannel1.read();
    channelAstate2 = encoderAchannel2.read();
    currentStateA = (channelAstate1 << 1) | (channelAstate2);

    if (((currentStateA ^ previousStateA) != INVALID) && (currentStateA != previousStateA)) {
        //2 bit state. Right hand bit of prev XOR left hand bit of current
        //gives 0 if clockwise rotation and 1 if counter clockwise rotation.
        changeA = (previousStateA & PREV_MASK) ^ ((currentStateA & CURR_MASK) >> 1);
        if (changeA == 0) {
            changeA = -1;
        }
        acurrent -= changeA;
    }
    previousStateA = currentStateA;
    //if the encoder signal is incremented or decremented to the target value, immediately hit the brakes for this wheel
    if(acurrent <= atarget + brakedist && acurrent >= atarget - brakedist) {
        A.stop(1);
        nextmoveflagA = 1;
        //amend the target of the wheels, to prevent buildup of error from stopping at target+-brakedist instead of exactly target
        atarget = acurrent;
    }
}

void callbackB()
{
    ////////
    //insert the 'if X, carry out the rest of this code' debouncing check here, ensure the delay is minimal enough to avoid hindering 3576 measurements per rotation at max speed
    ////////

    //every time this function is called, increment or decrement the encoder count depending on which direction the relevant wheel is moving

    volatile int changeB = 0;
    channelBstate1 = encoderBchannel1.read();
    channelBstate2 = encoderBchannel2.read();
    currentStateB = (channelBstate1 << 1) | (channelBstate2);

    if (((currentStateB ^ previousStateB) != INVALID) && (currentStateB != previousStateB)) {
        //2 bit state. Right hand bit of prev XOR left hand bit of current
        //gives 0 if clockwise rotation and 1 if counter clockwise rotation.
        changeB = (previousStateB & PREV_MASK) ^ ((currentStateB & CURR_MASK) >> 1);
        if (changeB == 0) {
            changeB = -1;
        }
        bcurrent -= changeB;
    }
    previousStateB = currentStateB;
    //if the encoder signal is incremented or decremented to the target value, immediately hit the brakes for this wheel
    if(bcurrent <= btarget + brakedist && bcurrent >= btarget - brakedist) {
        B.stop(1);
        nextmoveflagB = 1;
        //amend the target of the wheels, to prevent buildup of error from stopping at target+-brakedist instead of exactly target
        btarget = bcurrent;
    }
}

void move(const std_msgs::Int32& dinput)
{
    if(nextmoveflagA == 1 && nextmoveflagB == 1) {
        //import the relevant ROS message and convert it to a usable encoder target
        float distance = dinput.data;
        float mtarget = distance/(2*3.142*radius)*benchmark;
        //set the encoder target as the interval for the wheels to move - must be an integer to match encoder readings
        atarget = ceil(mtarget + aprevtarget);
        btarget = ceil(mtarget + bprevtarget);
        //move the motors to their respective targets
        nextmoveflagA = 0;
        nextmoveflagB = 0;
        driveMotors();
    }
}

void tempMove(float distance)
{
    if(nextmoveflagA == 1 && nextmoveflagB == 1) {
        float mtarget = distance/(2*3.142*radius)*benchmark;
        atarget = ceil(mtarget + aprevtarget);
        btarget = ceil(mtarget + bprevtarget);
        pc.printf("Moving: A target is: %d B target is: %d M target is %f\n\r", atarget, btarget, mtarget);
        nextmoveflagA = 0;
        nextmoveflagB = 0;
        driveMotors();
    }
}

void rotate(const std_msgs::Int32& rinput)
{
    if(nextmoveflagA == 1 && nextmoveflagB == 1) {
        //import the relevant ROS message and convert it to usable encoder targets
        float degrees = rinput.data;
        float rtarget = (degrees/(360*3))*9.15*benchmark;
        //set the encoder targets as the intervals for the wheels to move - must be integers to match encoder readings
        atarget = ceil(rtarget + aprevtarget);
        btarget = ceil(-rtarget + bprevtarget);
        //move the motors to their respective targets
        nextmoveflagA = 0;
        nextmoveflagB = 0;
        driveMotors();
    }
}

void tempRotate(float degrees)
{
    if(nextmoveflagA == 1 && nextmoveflagB == 1) {
        float rtarget = (degrees/(360*3))*9.15*benchmark;
        atarget = ceil(rtarget + aprevtarget);
        btarget = ceil(-rtarget + bprevtarget);
        pc.printf("Rotating: A target is: %d B target is: %d R target is: %f\n\r", atarget, btarget, rtarget);
        nextmoveflagA = 0;
        nextmoveflagB = 0;
        driveMotors();
    }
}

void driveMotors()
{
    //depending on the direction the wheels have to move, set their pwms to either positive or negative so each motor moves correctly
    if(atarget > aprevtarget) {
        apwm = pwm;
    } else if(atarget < aprevtarget) {
        apwm = -pwm;
    } else if(atarget == aprevtarget) {
        apwm = 0;
    }
    if(btarget > bprevtarget) {
        bpwm = pwm;
    } else if(btarget < bprevtarget) {
        bpwm = -pwm;
    } else if(btarget == bprevtarget) {
        bpwm = 0;
    }
    //apply the pwm output to the motors
    pc.printf("Driving: A pwm is: %f B pwm is: %f\n\r", apwm, bpwm);
    A.speed(apwm);
    B.speed(bpwm);
    pc.printf("Driving: A target is: %d B target is: %d\n\r", atarget, btarget);
    pc.printf("Driving: A target was: %d B target was: %d\n\r", aprevtarget, bprevtarget);
    /**************************************************************************/
    //Note that this code prevents consecutive movement commands from being executed until the
    //previous one is done, but has the downside of keeping the thread busy during movement.
    //Feel free to replace it with a different method of waiting, if you want.
    while(nextmoveflagA == 0 || nextmoveflagB == 0) {
        wait_ms(500);
    }
    /**************************************************************************/
    //update the recordings of the previous command's movement targets
    aprevtarget = atarget;
    bprevtarget = btarget;
}

