/*
 *  LELEC 2811 - BadmintonLogger - Group 5
 */

#include "mbed.h"
#include "FreescaleIAP.h"   // Library for Flash Access
#include "MMA8491Q_PG.h"    // Accelerometer
#include <cmath>

#define MMA8491_I2C_ADDRESS (0x55<<1)
#define KL25Z_VDD 2.89      // Value of VDD : To be measured on board KL25Z pin P3V3 (calibration)

#define LED_ON                  0
#define LED_OFF                 1

#define OFFSET                  10000 // OFFSET & RANGE to map inputs between 0 and 1
#define RANGE                   20000 // RANGE = 2*OFFSET

#define CONSOLE                 0 // print all in console
#define FLASH_MVT               1 // save in flash only mvts
#define FLASH_ALL               2 // save in flash mvtSets

#define REG_OUT_X_MSB           0x01
#define REG_OUT_Y_MSB           0x03
#define REG_OUT_Z_MSB           0x05
 
#define SECTOR_SIZE             1024   // Numbers of bits by memory sector
#define RESERVED_SECTOR         32     // 32K reserved for Application Code
 
#define ACQ_TIMER_PERIOD        0.005  // Time between 2 acquisitions (here 5 mSec)
#define N_PTS                   100    // Number of points for each axis used to detect mvt
#define N_MVTS                  5      // Number of mvts detected
#define THRESHOLD_MVT           0.4    // threshold to validate a mvt
#define THRESHOLD_SHOCK         0.45   // threshold to detect shock

MMA8491Q my8491(PTE0, PTE1, MMA8491_I2C_ADDRESS); // Setup I2C for MMA8491

Serial Host_Comm(USBTX, USBRX); // Set Serial Port

AnalogIn myPTE20(PTE20);        // read Vout_IF
AnalogIn myPTE21(PTE21);        // read Vout_FILT
AnalogIn myPTE22(PTE22);        // read Vout_GAIN

Ticker myTick_Acq;              // Periodical timer for Acquisition

DigitalOut Led_Red(LED1);       // Define I/O for LEDs
DigitalOut Led_Green(LED2);
DigitalOut Led_Blue(LED3);

DigitalOut Accel_Enable(PTA13);

DigitalOut Start_Pulse_Out(PTE4);  // Used to enter/exit Acquisition mode 
DigitalIn  Start_Pulse_In(PTE5);   // ShortPin J1_15 and J1_16 to enter in Acq_Mode

// --------------------- Structure and Enumeration ---------------------
struct Data {
    int16_t accX, accY, accZ;
    float Vout_IF, Vout_FILT, Vout_GAIN;
};

//enum Mvt { Undefined = 0, Serve, ClearOverhead, DropOverhead, SmashShot, ClearUnderarm, DropUnderarm };
enum Mvt { Undefined = 0, Serve, ClearOverhead, SmashShot, ClearUnderarm, DropUnderarm };

struct MvtSet {
    int16_t inputs [N_PTS*3];
    Mvt mvt;
};

// -------------------------- Globale variable --------------------------
volatile bool bTimer; // 1 means a Timer tick is done

bool foundError = 0;
bool flashFull = 0;
int mode = CONSOLE;

/* in flash :
    - [ 0x0 ; flash_base_address [ : code
    - [ flash_base_address ; flash_base_address_cmd [ : data
    - flash_base_address_cmd : flash_next_address (int)
    - flash_base_address_cmd+4 : mode (int)
    - flash_base_address_cmd+8 : flashFull (int) */
uint32_t KL25_Flash_Size;
int flash_base_address = RESERVED_SECTOR * SECTOR_SIZE ; // Store Flash Base Address
int flash_next_address; // next address for saving data in flash
int flash_base_address_cmd; // base address where the parameters are saved

const float w_s_l [3*N_PTS] = {
    #include "weights_samples_layer.txt"
};
const float b_s_l [N_PTS] = {
    #include "biases_samples_layer.txt"
};
const float w_o_l [N_PTS*N_MVTS] = {
    #include "weights_output_layer.txt"
};
const float b_o_l [N_MVTS] = {
    #include "biases_output_layer.txt"
};

// ------------------------ Function Declaration ------------------------
void Init(void);
void DisplayFlashInfos(void);   // display memory use
void DisplayInstructions(void); // display available cmds
void Clear_Led(void);           // switch off led's
bool Check_Jumper(void);        // if J1_15 & J1_16 connected together -> return 1
void Check_Console(void);       // detect input from user in console
void myTimer_Acq_Task(void);    // called by the timer

void EraseAllSectors(void);     // erase all sectors containing data
void EraseSector(int address);  // erase one sector
void UpdateParamFlash(void);    // update next_address_flash and mode in flash
void WriteFlash(MvtSet mvtSet); // write only mvt or all set depending on 'all'
void ReadFlash(void);           // print memory content in console
void LineHandler(int16_t *line, int16_t next); // handle display of data set (used in ReadFlash)

Data ReadData(void);            // read data from accelerometer and piezo
void Log(void);                 // read data, detect shock and movement
void Rotate(int16_t *AccDataLog, int amount, int16_t *inputs); // inputs = AccDataLog rotated of amount
void PrintSet(MvtSet mvtSet);   // display set of data

float Map (float x);            // inputs mapped between 0 and 1
float Sigmoid (float x);        // sigmoid function
void Softmax (float *inputs, float *result); // softmax function
Mvt SelectMvt(int16_t *inputs); // compute probabilities for each mvt based on the inputs

// -------------------------------------------------------------------------------------------------------
// -------------------------------------------------------------------------------------------------------

// -------------------------------- main --------------------------------
int main()
{
    Init ();
    
    int Count;
    
    while(!foundError)
    {
        if (Check_Jumper() && !flashFull)
        {
            Clear_Led();
            Count = 5;
            while (Count !=0)
            {
                if (Check_Jumper())
                {
                    Led_Green = LED_ON; // Blink to alert user "Enter in Logging mode"
                    wait_ms(750);
                    Led_Green = LED_OFF;
                    wait_ms(250);
                    Count --;
                    if (Count == 0)
                        Log();
                }
                else
                    Count = 0;
            }
        }
        if (flashFull)
            Led_Red = !Led_Red;
        else
            Led_Blue = !Led_Blue;
        Check_Console();
        wait_ms(100);        
    }
    
    Host_Comm.printf("\n\rProgram is exiting due to error...\n\r");
    Clear_Led();
    Led_Red = LED_ON;
}

// -------------------------------- Init --------------------------------
void Init()
{
    Start_Pulse_In.mode(PullNone);  // Input Pin is programmed as floating
    Accel_Enable = 0;               // Turn Accel Enable to disabled state
    Clear_Led();
    
    myTick_Acq.attach(&myTimer_Acq_Task, ACQ_TIMER_PERIOD); // Timer for acquisition
    
    Host_Comm.baud(115200);         // Baud rate setting
    Host_Comm.printf("\n\r*****\n\rLELEC2811 - Badminton Logger - Group 5\n\r*****\n\n\r");
    
    KL25_Flash_Size = flash_size(); // Get Size of KL25 Embedded Flash
    flash_base_address_cmd = KL25_Flash_Size-SECTOR_SIZE;
    
    /* TO CHECK SIZE OF PROGRAM IN FLASH :
    int *ptr;
    int n_sectors = KL25_Flash_Size/SECTOR_SIZE;
    Host_Comm.printf("Number of sectors : %d\n\r",n_sectors);
    for (int i = 0; i < n_sectors; i++){
        ptr = (int*) (i*SECTOR_SIZE);
        Host_Comm.printf("Sector %d : %d\n\r",i, ptr[0]);
    }
    */
    
    int *base_address_ptr = (int*)flash_base_address_cmd;
    flash_next_address = base_address_ptr[0];
    if (flash_next_address >= flash_base_address_cmd || flash_next_address < flash_base_address)
    {
        Host_Comm.printf("First run (or error with previous flash_next_address).\n\r");
        EraseAllSectors();
        flash_next_address = flash_base_address;
        mode = CONSOLE;
        UpdateParamFlash();
    }
    else {
        mode = base_address_ptr[1];
        flashFull = base_address_ptr[2];
        if (mode != CONSOLE && mode != FLASH_MVT && mode != FLASH_ALL) {
            mode = CONSOLE;
            UpdateParamFlash();
        }
    }
    DisplayFlashInfos();
    
    Host_Comm.printf("Initialization done.\n\n\r");
    DisplayInstructions();
}

// -------------------------- DisplayFlashInfos -------------------------
void DisplayFlashInfos()
{
    Host_Comm.printf("flash_next_address = %d\n\r",flash_next_address);
    Host_Comm.printf("mode = %d\n\r",mode);
    Host_Comm.printf("Memory used = %f %%\n\r",((float)flash_next_address-flash_base_address)/((float)flash_base_address_cmd-flash_base_address)*100);
}

// ------------------------- DisplayInstructions ------------------------
void DisplayInstructions()
{
    Host_Comm.printf("When the jumper is removed, use the keyboard :\n\r");
    Host_Comm.printf("- to erase flash : 'E' = erase flash\n\r");
    Host_Comm.printf("- to read flash : 'R' = read flash ; 'S' = stop reading\n\r");
    Host_Comm.printf("- to change mode : 'C' = console mode ; 'M' = write_mvt mode ; 'A' = write_all mode\n\r");
    Host_Comm.printf("!!! If mode change, ReadData() will fail => also press 'E' !!!\n\n\r");
}

// ----------------------------- Clear_Led ------------------------------
void Clear_Led()
{
    Led_Red = LED_OFF;
    Led_Green = LED_OFF;
    Led_Blue = LED_OFF ;
}

// ---------------------------- Check_Jumper ----------------------------
bool Check_Jumper()
{
    int i;
    for (i = 0 ; i < 2 ; i ++)
    {
        Start_Pulse_Out = 1;
        wait_ms(1);
        if (Start_Pulse_In != 1)
            return 0;
    
        Start_Pulse_Out = 0;
        wait_ms(1);
        if (Start_Pulse_In != 0)
            return 0;
    }
    return 1;
}

// --------------------------- Check_Console ----------------------------
void Check_Console()
{
    if(Host_Comm.readable()) 
    {
        char cmd = Host_Comm.getc();
        if ((cmd == 'E') || (cmd == 'e')) {
            Host_Comm.printf("Press 'E' again to confirm.\n\r");
            wait_ms(1000);
            if(Host_Comm.readable()) {
                cmd = Host_Comm.getc();
                if ((cmd == 'E') || (cmd == 'e')) {
                    EraseAllSectors();
                    flash_next_address = flash_base_address;
                    Host_Comm.printf("Erase done.\n\r");
                }
            }
            else
                Host_Comm.printf("Erase aborded.\n\r");
        }
        else if ((cmd == 'C') || (cmd == 'c')) {
            mode = CONSOLE;
            Host_Comm.printf("Mode console (0) actived.\n\r");
        }
        else if ((cmd == 'M') || (cmd == 'm')) {
            mode = FLASH_MVT;
            Host_Comm.printf("Mode flash_mvt (1) actived.\n\r");
        }
        else if ((cmd == 'A') || (cmd == 'a')) {
            mode = FLASH_ALL;
            Host_Comm.printf("Mode flash_all (2) actived.\n\r");
        }
        else if ((cmd == 'R') || (cmd == 'r'))
            ReadFlash();
        
        UpdateParamFlash();
    }
}

// -------------------------- myTimer_Acq_Task --------------------------
void myTimer_Acq_Task() { bTimer = 1; }

// -------------------------------------------------------------------------------------------------------
// -------------------------------------------------------------------------------------------------------

// -------------------------- EraseAllSectors ---------------------------
void EraseAllSectors(void)
{
    for (int address = flash_base_address ; address < KL25_Flash_Size ; address += SECTOR_SIZE)
    {
        EraseSector(address);
        if(foundError)
            return;
    }
    flashFull = 0;
}

// ---------------------------- EraseSector -----------------------------
void EraseSector(int address)
{
    IAPCode status = erase_sector(address); 
    if (status != Success) {
        Host_Comm.printf("\n\rError in EraseSector() : status = %d\n\r", status);
        foundError = 1;
    }
}

// -------------------------- UpdateParamFlash --------------------------
void UpdateParamFlash()
{
    EraseSector(flash_base_address_cmd);
    if(foundError)
        return;
    
    int toWrite[3] = {flash_next_address, mode, flashFull};
    IAPCode status = program_flash(flash_base_address_cmd, (char *) &toWrite, 12);
    if (status != Success) {
        Host_Comm.printf("\n\rError in UpdateParamFlash() : status = %d\n\r", status);
        foundError = 1;
    }
}

// ----------------------------- WriteFlash -----------------------------
void WriteFlash(MvtSet mvtSet)
{
    IAPCode status;
    int toWrite;
    
    if (mode == FLASH_ALL) // inputs (2*3*N_PTS bytes) + mvt (1 byte)
    {
        // check if enough place
        if (flash_next_address+(2*(3*N_PTS)+4) > flash_base_address_cmd) {
            Host_Comm.printf("\n\rFlash is full.\n\r");
            flashFull = 1; return;
        }
        
        // add all bytes one behind the other : 2bytes*3*N_PTS (inputs) + 1byte (mvt)
        int remainder = (3*N_PTS) % 2; // modulo 2 because compacting 2bytes into 4bytes words
        int even = 3*N_PTS-remainder;
        
        for (int i = 0; i < even; i+=2)
        {
            toWrite = (mvtSet.inputs[i] << 16) | (mvtSet.inputs[i+1] & 0x0000FFFF);
            status = program_flash(flash_next_address, (char*) &toWrite, 4);
            if (status != Success) {
                Host_Comm.printf("\n\rError in WriteFlash() (0) : status = %d\n\r", status);
                foundError = 1; return;
            }
            flash_next_address += 4;
        }
        
        toWrite = mvtSet.mvt;
        if(remainder == 1)
            toWrite = toWrite | (mvtSet.inputs[3*N_PTS-1] << 16);
        status = program_flash(flash_next_address, (char*) &toWrite, 4);
    }
    else
    {
        // check if enough place
        if (flash_next_address+4 > flash_base_address_cmd) {
            Host_Comm.printf("\n\rFlash is full.\n\r");
            flashFull = 1; return;
        }
        
        toWrite = mvtSet.mvt;
        status = program_flash(flash_next_address, (char*) &toWrite, 4);
    }
    
    if (status != Success) {
        Host_Comm.printf("\n\rError in WriteFlash() (1) : status = %d\n\r", status);
        foundError = 1; return;
    }
    
    flash_next_address += 4;
    UpdateParamFlash();
}

// ----------------------------- ReadFlash ------------------------------
void ReadFlash()
{
    Host_Comm.printf("\n\r------ Begin Read Flash ------\n\r");
    DisplayFlashInfos();
    
    char cmd;
    int *currAddress_ptr = (int*)flash_base_address;
    int *stopAddress_ptr = (int*)flash_next_address;
    
    Host_Comm.printf("ReadFlash : size = %d\n\r",stopAddress_ptr-currAddress_ptr);
    
    if (mode == FLASH_ALL)
    {
        int remainder = (3*N_PTS) % 2;
        int n_words_by_set = (3*N_PTS - remainder)/2+1;
        int word;
        int16_t line [4] = {}; // 3 components of acc (line[3] = position in line)
        
        while (currAddress_ptr < stopAddress_ptr)
        {
            // check for user input
            if(Host_Comm.readable()) {
                cmd = Host_Comm.getc();
                if ((cmd == 'S') || (cmd == 's'))
                    return;
            }
            
            // print all set
            for (word = 0; word < n_words_by_set-1; word++) {
                LineHandler(line,(int16_t)(currAddress_ptr[word]>>16));
                LineHandler(line,(int16_t)currAddress_ptr[word]);
            }
            if (remainder == 1)
                LineHandler(line,(int16_t)(currAddress_ptr[word]>>16));
            Host_Comm.printf("Mvt = %d\n\r",(int16_t)currAddress_ptr[word]);
            currAddress_ptr += n_words_by_set;
        }
    }
    else
    {
        while (currAddress_ptr < stopAddress_ptr)
        {
            // check for user input
            if(Host_Comm.readable()) {
                cmd = Host_Comm.getc();
                if ((cmd == 'S') || (cmd == 's'))
                    return;
            }
            
            // read mvt
            Host_Comm.printf("Mvt = %d\n\r",currAddress_ptr[0]);
            currAddress_ptr ++;
        }
    }
    
    Host_Comm.printf("\n\r------- End Read Flash -------\n\n\r");
}

// ----------------------------- LineHandler ----------------------------
void LineHandler(int16_t *line, int16_t next)
{
    line[line[3]] = next;
    line[3] ++;
    if (line[3] == 3) {
        Host_Comm.printf("%d %d %d\n\r",line[0],line[1],line[2]);
        line[3] = 0;
    }
}

// -------------------------------------------------------------------------------------------------------
// -------------------------------------------------------------------------------------------------------

// ------------------------------ ReadData ------------------------------
Data ReadData()
{
    Data data;
    
    // Get Accelerometer data's
    Accel_Enable = 1; // Rising Edge -> Start measure
    
    int ready = 0;
    while((ready && 0x10) == 0) // Wait for accelerometer to have new data's
        ready = my8491.Read_Status();

    data.accX = my8491.getAccAxis(REG_OUT_X_MSB);
    data.accY = my8491.getAccAxis(REG_OUT_Y_MSB);
    data.accZ = my8491.getAccAxis(REG_OUT_Z_MSB);
        
    Accel_Enable = 0;
    
    // Get Piezo Data's
    data.Vout_IF = ((float) myPTE20.read_u16() / 0XFFFF) * KL25Z_VDD; // convert in volt
    data.Vout_FILT = ((float) myPTE21.read_u16() / 0XFFFF) * KL25Z_VDD;
    data.Vout_GAIN = ((float) myPTE22.read_u16() / 0XFFFF) * KL25Z_VDD;
    
    return data;
}

// -------------------------------- Log ---------------------------------
void Log()
{    
    Data currData;
    int16_t AccDataLog [N_PTS*3] = {}; // array to save latest data read
    int index_write = 0;               // current position to write data in AccDataLog
    bool enoughData = 0;
    bool shockDetected = 0;            // if shock detected
    int n_sinceShock = 0;              // number of ReadData() since last chock
    
    while(Check_Jumper() && !foundError && !flashFull)
    {        
        while (bTimer == 0) {} // Wait Acq Tick Timer
        bTimer = 0;
        
        currData = ReadData();
        //Host_Comm.printf("%d ; %d ; %d ; %f\n\r", currData.accX, currData.accY, currData.accZ, currData.Vout_FILT);
        AccDataLog[index_write*3] = currData.accX;
        AccDataLog[index_write*3+1] = currData.accY;
        AccDataLog[index_write*3+2] = currData.accZ;
        
        float amplitude = abs(currData.Vout_FILT - KL25Z_VDD/2.0);
        //Host_Comm.printf("amplitude = %f\n\r",amplitude);
        if (amplitude >= THRESHOLD_SHOCK && enoughData)
        {
            shockDetected = 1;
            n_sinceShock = 0;
        }
        if (n_sinceShock == N_PTS/2 && shockDetected == 1)
        {
            Led_Green = LED_ON;
            
            MvtSet mvtSet;
            Rotate(AccDataLog, N_PTS-1-index_write, mvtSet.inputs);
            mvtSet.mvt = SelectMvt(mvtSet.inputs);
            
            if (mode == CONSOLE)
                PrintSet(mvtSet);
            else
                WriteFlash(mvtSet);
            
            shockDetected = 0;
            wait_ms(100);
            Led_Green = LED_OFF;
        }
        
        index_write ++;
        n_sinceShock ++;
        if (index_write == N_PTS)
        {
            enoughData = 1;
            index_write = 0;
        }
    }
    Clear_Led();
}

// ------------------------------- Rotate -------------------------------
void Rotate(int16_t *AccDataLog, int amount, int16_t *inputs)
{
    for(int i = 0; i < N_PTS; i++)
        for(int j = 0; j < 3; j++)
            inputs[((i+amount)%N_PTS)*3+j] = AccDataLog[i*3+j];
}

// ------------------------------ PrintSet ------------------------------
void PrintSet(MvtSet mvtSet)
{
    Host_Comm.printf("------ Begin Set ------\n\r");
    for(int i = 0; i < N_PTS; i++)
        Host_Comm.printf("%d %d %d\n\r",mvtSet.inputs[i*3],mvtSet.inputs[i*3+1],mvtSet.inputs[i*3+2]);
    Host_Comm.printf("Mvt = %d\n\r",mvtSet.mvt);
    Host_Comm.printf("------- End Set -------\n\n\r");
}

// -------------------------------------------------------------------------------------------------------
// -------------------------------------------------------------------------------------------------------

// -------------------------------- Map ---------------------------------
float Map (float x) { return (x+OFFSET)/RANGE; }

// ------------------------------ Sigmoid -------------------------------
float Sigmoid (float x) { return 1/(1+exp(-x)); }

// ------------------------------ Softmax -------------------------------
void Softmax (float *inputs, float *result)
{
    float exps [N_MVTS];
    float sum = 0;
    for (int i = 0; i < N_MVTS; i++)
    {
        exps[i] = exp(inputs[i]);
        sum += exps[i];
    }
    
    for (int i = 0; i < N_MVTS; i++)
        result[i] = exps[i] / sum;
}

// ----------------------------- SelectMvt ------------------------------
Mvt SelectMvt(int16_t *inputs)
{
    int i, j;
    
    float samples [N_PTS];
    for (i = 0; i < N_PTS; i++)
        samples[i] = Sigmoid( Map(inputs[i*3])*w_s_l[i*3] + Map(inputs[i*3+1])*w_s_l[i*3+1] + Map(inputs[i*3+2])*w_s_l[i*3+2] + b_s_l[i] );
    
    float probabilities [N_MVTS] = {};
    for (j = 0; j < N_MVTS; j++) {
        for (i = 0; i < N_PTS; i++)
            probabilities[j] += samples[i] * w_o_l[i*N_MVTS+j];
        probabilities[j] = probabilities[j] + b_o_l[j];
    }
    Softmax(probabilities,probabilities);
    
    Mvt mvt = Undefined;
    Host_Comm.printf("Proba mvt : ");
    for (i = 0; i < N_MVTS; i++) {
        if (probabilities[i] > THRESHOLD_MVT)
            mvt = static_cast<Mvt>(i+1);
        Host_Comm.printf("%f ",probabilities[i]);
    }
    Host_Comm.printf("\n\r");
    
    return mvt;
}
