//c++ script for filtering of measured EMG signals
#include "mbed.h" //Base library
#include "HIDScope.h" // to see if program is working and EMG is filtered properly
// #include "QEI.h"// is needed for the encoder
#include "MODSERIAL.h"// in order for connection with the pc
#include "BiQuad.h"
// #include "FastPWM.h"
// #include "Arduino.h" //misschien handig omdat we het EMG arduino board gebruiken (?)
// #include "EMGFilters.h"
#include <vector> // For easy array management

/*
------ DEFINE MBED CONNECTIONS ------
*/

// PC serial connection
HIDScope        scope( 4 );
MODSERIAL pc(USBTX, USBRX);

// LED
DigitalOut      led_g(LED_GREEN);
DigitalOut      led_r(LED_RED);
DigitalOut      led_b(LED_BLUE);

// Buttons
InterruptIn button1(D11);
InterruptIn button2(D10);
InterruptIn button3(SW3);

// EMG Substates
enum EMG_States { emg_wait, emg_cal_MVC, emg_cal_rest, emg_operation }; // Define EMG substates
EMG_States emg_curr_state; // Initialize EMG substate variable
bool emg_state_changed = true;

bool sampleNow = false;
bool calibrateNow = false;
bool emg_MVC_cal_done = false;
bool emg_rest_cal_done = false;

bool button1_pressed = false;
bool button2_pressed = false;

// Global variables for EMG reading
AnalogIn emg1_in (A1); // Right biceps, x axis
AnalogIn emg2_in (A2); // Left biceps, y axis
AnalogIn emg3_in (A3); // Third muscle, TBD

double emg1;
double emg1_env;
double emg1_MVC;
double emg1_rest;
double emg1_factor;
double emg1_th;
double emg1_out;
double emg1_out_prev;
double emg1_dt;
double emg1_dt_prev;
double emg1_dtdt;
double emg1_norm;
vector<double> emg1_cal;
int emg1_cal_size;
int emg1_dir = 1;

double emg2;
double emg2_env;
double emg2_MVC;
double emg2_rest;
double emg2_factor;
double emg2_th;
double emg2_out;
double emg2_norm;
vector<double> emg2_cal;
int emg2_cal_size;
int emg2_dir = 1;

double emg3;
double emg3_env;
double emg3_MVC;
double emg3_rest;
double emg3_factor;
double emg3_th;
double emg3_out;
double emg3_norm;
vector<double> emg3_cal;
int emg3_cal_size;
int emg3_dir = 1;

// Initialize tickers and timeouts
Ticker tickGlobal; // Set global ticker
Timer timerCalibration;

/*
------ GLOBAL VARIABLES ------
*/
const double Fs = 500; // Sampling frequency (s)
const double Tcal = 10.0f; // Calibration duration (s)

// Calculate global variables
const double Ts = 1/Fs; // Sampling time (s)

// Notch biquad filter coefficients (iirnotch Q factor 35 @50Hz) from MATLAB:
BiQuad bq1_notch( 0.995636295063941,  -1.89829218816065,   0.995636295063941,  1, -1.89829218816065,   0.991272590127882); // b01 b11 b21 a01 a11 a21
BiQuad bq2_notch = bq1_notch;
BiQuad bq3_notch = bq1_notch;
BiQuadChain bqc1_notch;
BiQuadChain bqc2_notch;
BiQuadChain bqc3_notch;

// Highpass biquad filter coefficients (butter 4th order @10Hz cutoff) from MATLAB
BiQuad bq1_H1(0.922946103200875, -1.84589220640175,  0.922946103200875,  1,  -1.88920703055163,  0.892769008131025); // b01 b11 b21 a01 a11 a21
BiQuad bq1_H2(1,                 -2,                 1,                  1,  -1.95046575793011,  0.954143234875078); // b02 b12 b22 a02 a12 a22
BiQuad bq2_H1 = bq1_H1;
BiQuad bq2_H2 = bq1_H2;
BiQuad bq3_H1 = bq1_H1;
BiQuad bq3_H2 = bq1_H2;
BiQuadChain bqc1_high;
BiQuadChain bqc2_high;
BiQuadChain bqc3_high;

// Lowpass biquad filter coefficients (butter 4th order @5Hz cutoff) from MATLAB:
BiQuad bq1_L1(5.32116245737504e-08,  1.06423249147501e-07,   5.32116245737504e-08,   1,  -1.94396715039462,  0.944882378004138); // b01 b11 b21 a01 a11 a21
BiQuad bq1_L2(1,                     2,                      1,                      1,  -1.97586467534468,  0.976794920438162); // b02 b12 b22 a02 a12 a22
BiQuad bq2_L1 = bq1_L1;
BiQuad bq2_L2 = bq1_L2;
BiQuad bq3_L1 = bq1_L1;
BiQuad bq3_L2 = bq1_L2;
BiQuadChain bqc1_low;
BiQuadChain bqc2_low;
BiQuadChain bqc3_low;

/*
------ HELPER FUNCTIONS ------
*/

// Return max value of vector
double getMax(const vector<double> &vect)
{
    double curr_max = 0.0;
    int vect_n = vect.size();

    for (int i = 0; i < vect_n; i++) {
        if (vect[i] > curr_max) {
            curr_max = vect[i];
        };
    }
    return curr_max;
}

// Return mean of vector
double getMean(const vector<double> &vect)
{
    double sum = 0.0;
    int vect_n = vect.size();

    for ( int i = 0; i < vect_n; i++ ) {
        sum += vect[i];
    }
    return sum/vect_n;
}

// Return standard deviation of vector
double getStdev(const vector<double> &vect, const double vect_mean)
{
    double sum2 = 0.0;
    int vect_n = vect.size();

    for ( int i = 0; i < vect_n; i++ ) {
        sum2 += pow( vect[i] - vect_mean, 2 );
    }
    double output = sqrt( sum2 / vect_n );
    return output;
}

// Rescale values to certain range
double rescale(double input, double out_min, double out_max, double in_min, double in_max)
{
    double output = out_min + ((input-in_min)/(in_max-in_min))*(out_max-out_min); // Based on MATLAB rescale function
    return output;
}

// Check filter stability
bool checkBQChainStable()
{
    bool n_stable = bqc1_notch.stable(); // Check stability of all BQ Chains
    bool hp_stable =  bqc1_high.stable();
    bool l_stable = bqc1_low.stable();

    if (n_stable && hp_stable && l_stable) {
        return true;
    } else {
        return false;
    }
}

/*
------ BUTTON FUNCTIONS ------
*/

// Handle button press
void button1Press()
{
    button1_pressed = true;
}

// Handle button press
void button2Press()
{
    button2_pressed = true;
}

// Toggle EMG direction
void toggleEMG1Dir()
{
    switch( emg1_dir ) {
        case -1:
            emg1_dir = 1;
            break;
        case 1:
            emg1_dir = -1;
            break;
    }
}

// Toggle EMG direction
void toggleEMG2Dir()
{
    switch( emg1_dir ) {
        case -1:
            emg1_dir = 1;
            break;
        case 1:
            emg1_dir = -1;
            break;
    }
}

/*
------ TICKER FUNCTIONS ------
*/
void sampleSignal()
{
    if (sampleNow == true) { // This ticker only samples if the sample flag is true, to prevent unnecessary computations
        // Read EMG inputs
        emg1 = emg1_in.read();
        emg2 = emg2_in.read();
        emg3 = emg3_in.read();


        double emg1_n = bqc1_notch.step( emg1 );         // Filter notch
        double emg1_hp = bqc1_high.step( emg1_n );       // Filter highpass
        double emg1_rectify = fabs( emg1_hp );           // Rectify
        emg1_env = bqc1_low.step( emg1_rectify ); // Filter lowpass (completes envelope)

        double emg2_n = bqc2_notch.step( emg2 );         // Filter notch
        double emg2_hp = bqc2_high.step( emg2_n );       // Filter highpass
        double emg2_rectify = fabs( emg2_hp );           // Rectify
        emg2_env = bqc2_low.step( emg2_rectify ); // Filter lowpass (completes envelope)

        double emg3_n = bqc3_notch.step( emg3 );         // Filter notch
        double emg3_hp = bqc3_high.step( emg3_n );       // Filter highpass
        double emg3_rectify = fabs( emg3_hp );           // Rectify
        emg3_env = bqc3_low.step( emg3_rectify ); // Filter lowpass (completes envelope)

        if (calibrateNow == true) { // Only add values to EMG vectors if calibration flag is true
            emg1_cal.push_back(emg1_env); // Add values to calibration vector
            // emg1_cal_size = emg1_cal.size(); // Used for debugging
            emg2_cal.push_back(emg2_env); // Add values to calibration vector
            // emg2_cal_size = emg1_cal.size(); // Used for debugging
            emg3_cal.push_back(emg3_env); // Add values to calibration vector
            // emg3_cal_size = emg1_cal.size(); // Used for debugging
        }
    }
}

/*
------ EMG CALIBRATION STATES ------
*/

/* ALL STATES HAVE THE FOLLOWING FORM:
void do_state_function() {
    // Entry function
    if ( emg_state_changed == true ) {
        emg_state_changed == false;
        // More functions
    }

    // Do stuff until end condition is met
    doStuff();

    // State transition guard
    if ( endCondition == true ) {
        emg_curr_state == next_state;
        emg_state_changed == true;
        // More functions
    }
}
*/
// EMG Waiting state
void do_emg_wait()
{
    // Entry function
    if ( emg_state_changed == true ) {
        emg_state_changed = false; // Disable entry functions

        button1.fall( &button1Press ); // Change to state MVC calibration on button1 press
        button2.fall( &button2Press ); // Change to state rest calibration on button2 press
    }

    // Do nothing until end condition is met

    // State transition guard. Possible next states:
    // 1. emg_cal_MVC   (button1 pressed)
    // 2. emg_cal_rest  (button2 pressed)
    // 3. emg_operation (both calibrations have run)
    if ( button1_pressed ) {
        button1_pressed = false; // Disable button pressed function until next button press
        button1.fall( NULL ); // Disable interrupt during calibration
        button2.fall( NULL ); // Disable interrupt during calibration
        emg_curr_state = emg_cal_MVC; // Set next state
        emg_state_changed = true; // Enable entry functions

    } else if ( button2_pressed ) {
        button2_pressed = false; // Disable button pressed function until next button press
        button1.fall( NULL ); // Disable interrupt during calibration
        button2.fall( NULL ); // Disable interrupt during calibration
        emg_curr_state = emg_cal_rest; // Set next state
        emg_state_changed = true; // Enable entry functions

    } else if ( emg_MVC_cal_done && emg_rest_cal_done ) {
        button1.fall( NULL ); // Disable interrupt during operation
        button2.fall( NULL ); // Disable interrupt during operation
        emg_curr_state = emg_operation; // Set next state
        emg_state_changed = true; // Enable entry functions
    }
}

// Run calibration of EMG
void do_emg_cal()
{
    // Entry functions
    if ( emg_state_changed == true ) {
        emg_state_changed = false; // Disable entry functions
        led_b = 0; // Turn on calibration led

        timerCalibration.reset();
        timerCalibration.start(); // Sets up timer to stop calibration after Tcal seconds
        sampleNow = true; // Enable signal sampling in sampleSignal()
        calibrateNow = true; // Enable calibration vector functionality in sampleSignal()

        emg1_cal.reserve(Fs * Tcal); // Initialize vector lengths to prevent memory overflow
        emg2_cal.reserve(Fs * Tcal); // Idem
        emg3_cal.reserve(Fs * Tcal); // Idem
    }

    // Do stuff until end condition is met
    // Set HIDScope outputs
    scope.set(0, emg1 );
    scope.set(1, emg1_env );
    //scope.set(2, emg2_env );
    //scope.set(3, emg3_env );
    scope.send();

    // State transition guard
    if ( timerCalibration.read() >= Tcal ) { // After interval Tcal the calibration step is finished
        sampleNow = false; // Disable signal sampling in sampleSignal()
        calibrateNow = false; // Disable calibration sampling

        switch( emg_curr_state ) {
            case emg_cal_MVC:
                emg1_MVC = getMax(emg1_cal); // Store max value of MVC globally
                emg2_MVC = getMax(emg2_cal); // Store max value of MVC globally
                emg3_MVC = getMax(emg3_cal); // Store max value of MVC globally

                emg_MVC_cal_done = true; // To set up transition guard to operation mode
                break;
            case emg_cal_rest:
                emg1_rest = getMean(emg1_cal); // Store rest EMG globally
                emg2_rest = getMean(emg2_cal); // Store rest EMG globally
                emg3_rest = getMean(emg3_cal); // Store rest EMG globally
                emg_rest_cal_done = true; // To set up transition guard to operation mode
                break;
        }
        vector<double>().swap(emg1_cal); // Empty vector to prevent memory overflow
        vector<double>().swap(emg2_cal); // Empty vector to prevent memory overflow
        vector<double>().swap(emg3_cal); // Empty vector to prevent memory overflow

        led_b = 1; // Turn off calibration led

        emg_curr_state = emg_wait; // Set next state
        emg_state_changed = true; // State has changed (to run
    }
}

void do_emg_operation()
{
    // Entry function
    if ( emg_state_changed == true ) {
        emg_state_changed = false; // Disable entry functions
        double margin_percentage = 5; // Set up % margin for rest

        emg1_factor = 1 / emg1_MVC; // Factor to normalize MVC
        emg1_th = emg1_rest * emg1_factor + margin_percentage/100; // Set normalized rest threshold
        emg2_factor = 1 / emg2_MVC; // Factor to normalize MVC
        emg2_th = emg2_rest * emg2_factor + margin_percentage/100; // Set normalized rest threshold
        emg3_factor = 1 / emg3_MVC; // Factor to normalize MVC
        emg3_th = emg3_rest * emg3_factor + margin_percentage/100; // Set normalized rest threshold


        // ------- TO DO: MAKE SURE THESE BUTTONS DO NOT BOUNCE (e.g. with button1.rise() ) ------
        //button1.fall( &toggleEMG1Dir ); // Change to state MVC calibration on button1 press
        //button2.fall( &toggleEMG2Dir ); // Change to state rest calibration on button2 press

        sampleNow = true; // Enable signal sampling in sampleSignal()
        calibrateNow = false; // Disable calibration vector functionality in sampleSignal()
    }

    // Do stuff until end condition is met
    emg1_norm = emg1_env * emg1_factor; // Normalize EMG signal with calibrated factor
    emg2_norm = emg2_env * emg2_factor; // Idem
    emg3_norm = emg3_env * emg3_factor; // Idem

    emg1_out_prev = emg1_out; // Set previous emg_out signal
    emg1_dt_prev = emg1_dt; // Set previous emg_out_dt signal
    // Set normalized EMG output signal (CAN BE MOVED TO EXTERNAL FUNCTION BECAUSE IT IS REPEATED 3 TIMES)
    if ( emg1_norm < emg1_th ) { // If below threshold, emg_out = 0 (ignored)
        emg1_out = 0.0;
    } else if ( emg1_norm > 1.0f ) { // If above MVC (e.g. due to filtering), emg_out = 1 (max value)
        emg1_out = 1.0;
    } else { // If in between threshold and MVC, scale EMG signal accordingly
        // Inputs may be in range       [emg_th, 1]
        // Outputs are scaled to range  [0,      1]
        emg1_out = rescale(emg1_norm, 0, 1, emg1_th, 1);
    }
    emg1_dt = (emg1_out - emg1_out_prev) / Ts; // Calculate derivative of filtered normalized output signal
    emg1_dtdt = (emg1_dt - emg1_dt_prev) / Ts; // Calculate acceleration of filtered normalized output signal
    emg1_out = emg1_out * emg1_dir; // Set direction of EMG output

    // Idem for emg2
    if ( emg2_norm < emg2_th ) {
        emg2_out = 0.0;
    } else if ( emg2_norm > 1.0f ) {
        emg2_out = 1.0;
    } else {
        emg2_out = rescale(emg2_norm, 0, 1, emg2_th, 1);
    }
    emg2_out = emg2_out * emg2_dir; // Set direction of EMG output

    // Idem for emg3
    if ( emg3_norm < emg3_th ) {
        emg3_out = 0.0;
    } else if ( emg3_norm > 1.0f ) {
        emg3_out = 1.0;
    } else {
        emg3_out = rescale(emg3_norm, 0, 1, emg3_th, 1);
    }

    // Set HIDScope outputs
    scope.set(0, emg1 );
    scope.set(1, emg1_out );
    scope.set(2, emg1_dt );
    scope.set(3, emg1_dtdt );
    //scope.set(2, emg2_out );
    //scope.set(3, emg3_out );
    scope.send();

    led_g = !led_g;


    // State transition guard
    if ( false ) {
        emg_curr_state = emg_wait; // Set next state
        emg_state_changed = true; // Enable entry function
    }
}

/*
------ EMG SUBSTATE MACHINE ------
*/
void emg_state_machine()
{
    switch(emg_curr_state) {
        case emg_wait:
            do_emg_wait();
            break;
        case emg_cal_MVC:
            do_emg_cal();
            break;
        case emg_cal_rest:
            do_emg_cal();
            break;
        case emg_operation:
            do_emg_operation();
            break;
    }
}

// Global loop of program
void tickGlobalFunc()
{
    sampleSignal();
    emg_state_machine();
    // controller();
    // outputToMotors();
}

void main()
{
    pc.baud(115200); // MODSERIAL rate
    pc.printf("Starting\r\n");

    // tickSample.attach(&sample, Ts); // Initialize sample ticker

    // Create BQ chains to reduce computations
    bqc1_notch.add( &bq1_notch );
    bqc1_high.add( &bq1_H1 ).add( &bq1_H2 );
    bqc1_low.add( &bq1_L1 ).add( &bq1_L2 );

    bqc2_notch.add( &bq2_notch );
    bqc2_high.add( &bq2_H1 ).add( &bq2_H2 );
    bqc2_low.add( &bq2_L1 ).add( &bq2_L2 );

    bqc3_notch.add( &bq3_notch );
    bqc3_high.add( &bq3_H1 ).add( &bq3_H2 );
    bqc3_low.add( &bq3_L1 ).add( &bq3_L2 );

    led_b = 1; // Turn blue led off at startup
    led_g = 1; // Turn green led off at startup
    led_r = 1; // Turn red led off at startup

    // If any filter chain is unstable, red led will light up
    if (checkBQChainStable()) {
        led_r = 1; // LED off
    } else {
        led_r = 0; // LED on
    }

    emg_curr_state = emg_wait; // Start off in EMG Wait state
    tickGlobal.attach( &tickGlobalFunc, Ts ); // Start global ticker

    while(true) {
        pc.printf("emg_state: %i   emg1_env: %f   emg1_out:  %f   emg1_th: %f   emg1_factor: %f\r\n", emg_curr_state, emg1_env, emg1_out, emg1_th, emg1_factor);
        pc.printf("               emg1_MVC: %f   emg1_rest: %f \r\n", emg1_MVC, emg1_rest);
        wait(0.5f);
    }
}