/******************************************************************************
 *  NM500 NeuroShield Board and mpu6050 imu Test example
 *  revision 1.1.5, 2020/02/11
 *  Copyright (c) 2017 nepes inc.
 *
 *  Please use the NeuroShield library v1.1.4 or later
 ******************************************************************************/

#include "mbed.h"
#include <NeuroShield.h>
#include <NeuroShieldSPI.h>
#include <mpu6050.h>

// for NM500
#define MOTION_REPEAT_COUNT 3  // number of samples to assemble a vector
#define MOTION_SIGNAL_COUNT 8  // d_ax, d_ay, d_az, d_gx, d_gy, d_gz, da, dg
#define MOTION_CAPTURE_COUNT 20

#define DEFAULT_MAXIF 500

NeuroShield hnn;
MPU6050 mpu(0x68, PB_9, PB_8);  // SDA:(D14=PB_9) SCL(D15=PB_8) <= SB143/SB138 must close for I2C on A4/A5 and SB147/SB157 must open!!!
Serial pc(USBTX, USBRX);

DigitalOut sdcard_ss(D6);       // SDCARD_SSn
DigitalOut arduino_con(D5);     // SPI_SEL

int16_t ax, ay, az, gx, gy, gz;

uint8_t learn_cat = 0;     // category to learn
uint8_t prev_cat = 0;  // previously recognized category
uint16_t dist, cat, nid, nsr, ncount;  // response from the neurons
uint16_t prev_ncount = 0;
uint16_t fpga_version;

int16_t min_a = 0xFFFF, max_a = 0, min_g = 0xFFFF, max_g = 0, da = 0, dg = 0;   // reset, or not, at each feature extraction

uint8_t vector[MOTION_REPEAT_COUNT * MOTION_SIGNAL_COUNT];       // vector holding the pattern to learn or recognize

void mpu6050Calibration()
{
    int i, j;
    long sum_ax = 0, sum_ay = 0, sum_az = 0, sum_gx = 0, sum_gy = 0, sum_gz = 0;
    int mean_ax, mean_ay, mean_az, mean_gx, mean_gy, mean_gz;
    int ax_offset, ay_offset, az_offset, gx_offset, gy_offset, gz_offset;
    
    for (i = 0; i < 100; i++) {
        mpu.getMotion6(&ax, &ay, &az, &gx, &gy, &gz);
    }
    
    for (j = 0; j < 5; j++) {
        for (i = 0; i < 100; i++) {
            mpu.getMotion6(&ax, &ay, &az, &gx, &gy, &gz);
            sum_ax += ax;
            sum_ay += ay;
            sum_az += az;
            sum_gx += gx;
            sum_gy += gy;
            sum_gz += gz;
        }
        
        mean_ax = sum_ax / 100;
        mean_ay = sum_ay / 100;
        mean_az = sum_az / 100;
        mean_gx = sum_gx / 100;
        mean_gy = sum_gy / 100;
        mean_gz = sum_gz / 100;
        
        // MPU6050_GYRO_FS_1000 : offset = (-1) * mean_g
        // MPU6050_ACCEL_FS_8   : offset = (-0.5) * mean_a
        ax_offset = (-mean_ax) / 2;
        ay_offset = (-mean_ay) / 2;
        az_offset = (-mean_az) / 2;
        gx_offset = -mean_gx;
        gy_offset = -mean_gy;
        gz_offset = -mean_gz;
        
        // set
        mpu.setXAccelOffset(ax_offset);
        mpu.setYAccelOffset(ay_offset);
        mpu.setZAccelOffset(az_offset);
        mpu.setXGyroOffset(gx_offset);
        mpu.setYGyroOffset(gy_offset);
        mpu.setZGyroOffset(gz_offset);
    }
}

void extractFeatureVector()
{
    int i;
    int16_t min_ax, min_ay, min_az, max_ax, max_ay, max_az;
    int16_t min_gx, min_gy, min_gz, max_gx, max_gy, max_gz;
    uint32_t norm_ax, norm_ay, norm_az, norm_gx, norm_gy, norm_gz;
    int32_t d_ax, d_ay, d_az, d_gx, d_gy, d_gz;
    int32_t da_local, dg_local;

    mpu.getMotion6(&ax, &ay, &az, &gx, &gy, &gz);

    max_ax = min_ax = ax;
    max_ay = min_ay = ay;
    max_az = min_az = az;
    max_gx = min_gx = gx;
    max_gy = min_gy = gy;
    max_gz = min_gz = gz;
  
    for (i = 0; i < MOTION_CAPTURE_COUNT; i++) {
        mpu.getMotion6(&ax, &ay, &az, &gx, &gy, &gz);
    
        if (ax < min_ax)
            min_ax = ax;
        else if (ax > max_ax)
            max_ax = ax;

        if (ay < min_ay)
            min_ay = ay;
        else if(ay > max_ay)
            max_ay = ay;

        if (az < min_az)
            min_az = az;
        else if (az > max_az)
            max_az = az;

        if (gx < min_gx)
            min_gx = gx;
        else if (gx > max_gx)
            max_gx = gx;

        if (gy < min_gy)
            min_gy = gy;
        else if (gy > max_gy)
            max_gy = gy;

        if (gz < min_gz)
            min_gz = gz;
        else if (gz > max_gz)
            max_gz = gz;
    }

    d_ax = max_ax - min_ax;
    d_ay = max_ay - min_ay;
    d_az = max_az - min_az;

    d_gx = max_gx - min_gx;
    d_gy = max_gy - min_gy;
    d_gz = max_gz - min_gz;

    da_local = d_ax;
    if (d_ay > da_local)
        da_local = d_ay;
    if (d_az > da_local)
        da_local = d_az;

    dg_local = d_gx;
    if (d_gy > dg_local)
        dg_local = d_gy;
    if (d_gz > dg_local)
        dg_local = d_gz;

    norm_ax = d_ax; norm_ax = norm_ax * 255 / da_local;
    norm_ay = d_ay; norm_ay = norm_ay * 255 / da_local;
    norm_az = d_az; norm_az = norm_az * 255 / da_local;

    norm_gx = d_gx; norm_gx = norm_gx * 255 / dg_local;
    norm_gy = d_gy; norm_gy = norm_gy * 255 / dg_local;
    norm_gz = d_gz; norm_gz = norm_gz * 255 / dg_local;

    for (i = 0; i < MOTION_REPEAT_COUNT; i++) {
        vector[i * MOTION_SIGNAL_COUNT] = norm_ax & 0x00ff;
        vector[(i * MOTION_SIGNAL_COUNT) + 1] = norm_ay & 0x00ff;
        vector[(i * MOTION_SIGNAL_COUNT) + 2] = norm_az & 0x00ff;
        vector[(i * MOTION_SIGNAL_COUNT) + 3] = norm_gx & 0x00ff;
        vector[(i * MOTION_SIGNAL_COUNT) + 4] = norm_gy & 0x00ff;
        vector[(i * MOTION_SIGNAL_COUNT) + 5] = norm_gz & 0x00ff;
        if (da_local >= 4096)
            vector[(i * MOTION_SIGNAL_COUNT) + 6] = 0xff;
        else
            vector[(i * MOTION_SIGNAL_COUNT) + 6] = ((da_local >> 4) & 0x00ff);
        if (dg_local >= 4096)
            vector[(i * MOTION_SIGNAL_COUNT) + 7] = 0xff;
        else
            vector[(i * MOTION_SIGNAL_COUNT) + 7] = ((dg_local >> 4) & 0x00ff);
    }
}

int main()
{
    int input_key[2];
    
    arduino_con = LOW;
    sdcard_ss = HIGH;
    wait(0.5);
    
    if (hnn.begin() != 0) {
        fpga_version = hnn.fpgaVersion();
        if ((fpga_version & 0xFF00) == 0x0000) {
            printf("\n\n#### NeuroShield Board (Board v%d.0 / FPGA v%d.0) ####\n", ((fpga_version >> 4) & 0x000F), (fpga_version & 0x000F));
        }
        else if ((fpga_version & 0xFF00) == 0x0100) {
            printf("\n\n#### Prodigy Board (Board v%d.0 / FPGA v%d.0) ####\n", ((fpga_version >> 4) & 0x000F), (fpga_version & 0x000F));
        }
        else {
            printf("\n\n#### Unknown Board (Board v%d.0 / FPGA v%d.0) ####\n", ((fpga_version >> 4) & 0x000F), (fpga_version & 0x000F));
        }
        printf("\nStart NM500 initialzation...\n");
        printf("  NM500 is initialized!\n");
        printf("  There are %d neurons\n", hnn.total_neurons);
    }
    else {
        printf("\n\nStart NM500 initialzation...\n");
        printf("  NM500 is not connected properly!!\n");
        printf("  Please check the connection and reboot!\n");
        while (1);
    }
    
    // initialize mpu6050
    printf("\nStart MPU-6050 initialization...\n");
    mpu.initialize();
    // set gyro & accel range
    mpu.setFullScaleGyroRange(MPU6050_GYRO_FS_1000);
    mpu.setFullScaleAccelRange(MPU6050_ACCEL_FS_8);
    
    // verify connection
    for (int i = 0; i < 10; i++) {
        if (mpu.testConnection()) {
            printf("  MPU-6050 is connected successfully\n");
            break;
        }
        else if (i == 9) {
            printf("  MPU-6050 connection failed\n");
            printf("  Please check the connection and reboot!\n");
            while (1);
        }
        wait(0.1);
    }
    
    // wait for ready
    printf("  Trying to calibrate. Make sure the board is stable and upright\n");
    // reset offsets
    mpu.setXAccelOffset(0);
    mpu.setYAccelOffset(0);
    mpu.setZAccelOffset(0);
    mpu.setXGyroOffset(0);
    mpu.setYGyroOffset(0);
    mpu.setZGyroOffset(0);
    mpu6050Calibration();
    // end message
    printf("  MPU-6050 calibration is complete!!\n\n");
    
    printf("Move the board horizontally or vertically...\n");
    printf("Type '1' and enter, to learn up <-> down motion\n");
    printf("Type '2' and enter, to learn left <-> right motion\n");
    printf("Type '0' and enter, to learn by category 0\n");
    
    // main loop
    while (1) {
        if (pc.readable()) {
            input_key[0] = input_key[1];
            input_key[1] = pc.getc();
            if (input_key[1] == 0x0D) {     // enter key
                learn_cat = input_key[0] - '0';
                if (learn_cat < 3) {
                    printf("Learning motion category %d\n", learn_cat);
                    for (int i = 0; i < 5; i++) {
                        extractFeatureVector();
                        ncount = hnn.learn(vector, MOTION_REPEAT_COUNT * MOTION_SIGNAL_COUNT, learn_cat);
                        if (ncount != prev_ncount) {
                            prev_cat = learn_cat;
                            prev_ncount = ncount;
                        }
                    }
                    printf("Neurons=%d\n", ncount);
                }
            }
        }
        else {                          // recognize
            extractFeatureVector();
            hnn.classify(vector, MOTION_REPEAT_COUNT * MOTION_SIGNAL_COUNT, &dist, &cat, &nid);
            if (cat != 0xFFFF) {
                prev_cat = cat;
                if (cat & 0x8000)
                    printf("Motion #%d (degenerated)\n", (cat & 0x7FFF));
                else
                    printf("Motion #%d \n", cat);
            }
            else if (prev_cat != 0xFFFF) {
                prev_cat = cat;
            }
        }
    }
}