/******************************************************************************
 *  NM500 NeuroShield Board SimpleScript
 *  Simple Test Script to understand how the neurons learn and recognize
 *  revision 1.1.5, 2020/02/11
 *  Copyright (c) 2017 nepes inc.
 *
 *  Please use the NeuroShield library v1.1.4 or later
 ******************************************************************************/

#include "mbed.h"
#include <NeuroShield.h>
#include <NeuroShieldSPI.h>

#define VECTOR_LENGTH 4
#define READ_COUNT 3

NeuroShield hnn;

DigitalOut sdcard_ss(D6);       // SDCARD_SSn
DigitalOut arduino_con(D5);     // SPI_SEL

uint8_t vector[NEURON_SIZE];
uint16_t vector16[NEURON_SIZE];   
uint16_t dists[READ_COUNT], cats[READ_COUNT], nids[READ_COUNT];
uint16_t response_nbr, norm_lsup = 0;
uint16_t fpga_version;

void displayNeurons()
{
    uint16_t nm_ncr, nm_aif, nm_cat;
    uint16_t ncount = hnn.getNcount();
    printf("Display the neurons, ncount = %d\n", ncount);
    uint16_t temp_nsr = hnn.getNsr();
    hnn.setNsr(0x0010);
    hnn.resetChain();
    for (int i = 1; i <= ncount; i++) {
        nm_ncr = hnn.getNcr();
        hnn.readCompVector(vector16, VECTOR_LENGTH);
        nm_aif = hnn.getAif();
        nm_cat = hnn.getCat();
        
        printf("neuron#%d \tvector=", i);
        for (int j = 0; j < VECTOR_LENGTH; j++) {
            printf("%d, ", vector16[j]);
        }
        if (nm_cat & 0x8000) {
            printf(" \tncr=%d \taif=%d \tcat=%d (degenerated)\n", nm_ncr, nm_aif, (nm_cat & 0x7FFF));
        }
        else {
            printf(" \tncr=%d \taif=%d \tcat=%d\n", nm_ncr, nm_aif, nm_cat);
        }
    }
    hnn.setNsr(temp_nsr);
}

int main()
{
    arduino_con = LOW;
    sdcard_ss = HIGH;
    wait(0.5);
    
    if (hnn.begin() != 0) {
        fpga_version = hnn.fpgaVersion();
        if ((fpga_version & 0xFF00) == 0x0000) {
            printf("\n\n#### NeuroShield Board (Board v%d.0 / FPGA v%d.0) ####\n", ((fpga_version >> 4) & 0x000F), (fpga_version & 0x000F));
        }
        else if ((fpga_version & 0xFF00) == 0x0100) {
            printf("\n\n#### Prodigy Board (Board v%d.0 / FPGA v%d.0) ####\n", ((fpga_version >> 4) & 0x000F), (fpga_version & 0x000F));
        }
        else {
            printf("\n\n#### Unknown Board (Board v%d.0 / FPGA v%d.0) ####\n", ((fpga_version >> 4) & 0x000F), (fpga_version & 0x000F));
        }
        printf("\nStart NM500 initialization...\n");
        printf("  NM500 is initialized!\n");
        printf("  There are %d neurons\n", hnn.total_neurons);
    }
    else {
        printf("\n\nStart NM500 initialization...\n");
        printf("  NM500 is not connected properly!!\n");
        printf("  Please check the connection and reboot!\n");
        while (1);
    }
    
    // if you want to run in lsup mode, uncomment below
    //norm_lsup = 0x80;
    hnn.setGcr(1 + norm_lsup);
    
    // build knowledge by learning 3 patterns with each constant values (respectively 11, 15 and 20)
    printf("\nLearning three patterns...\n");
    for (int i = 0; i < VECTOR_LENGTH; i++)
        vector[i] = 11;
    hnn.learn(vector, VECTOR_LENGTH, 55);
    for (int i = 0; i < VECTOR_LENGTH; i++)
        vector[i] = 15;
    hnn.learn(vector, VECTOR_LENGTH, 33);
    for (int i = 0; i < VECTOR_LENGTH; i++)
        vector[i] = 20;
    hnn.learn(vector, VECTOR_LENGTH, 100);
    displayNeurons();
    
    for (uint8_t value = 12; value < 16; value++) {
        for (int i = 0; i < VECTOR_LENGTH; i++)
            vector[i] = value;
        printf("\nRecognizing a new pattern: ");
        for (int i = 0; i < VECTOR_LENGTH; i++)
            printf("%d, ", vector[i]);
        printf("\n");
        response_nbr = hnn.classify(vector, VECTOR_LENGTH, READ_COUNT, dists, cats, nids);
        for (int i = 0; i < response_nbr; i++) {
            if (cats[i] & 0x8000) {
                printf("Firing neuron#%d, category=%d (degenerated), distance=%d\n", nids[i], (cats[i] & 0x7FFF), dists[i]);
            }
            else {
                printf("Firing neuron#%d, category=%d, distance=%d\n", nids[i], (cats[i] & 0x7FFF), dists[i]);
            }
        }
    }
    
    for (int i = 0; i < VECTOR_LENGTH; i++)
        vector[i] = 20;
    printf("\nRecognizing a new pattern using KNN classifier: ");
    for (int i = 0; i < VECTOR_LENGTH; i++)
        printf("%d, ", vector[i]);
    printf("\n");
    hnn.setKnnClassifier();
    response_nbr = hnn.classify(vector, VECTOR_LENGTH, READ_COUNT, dists, cats, nids);
    hnn.setRbfClassifier();
    for (int i = 0; i < READ_COUNT; i++) {
        if (cats[i] & 0x8000) {
            printf("Firing neuron#%d, category=%d (degenerated), distance=%d\n", nids[i], (cats[i] & 0x7FFF), dists[i]);
        }
        else {
            printf("Firing neuron#%d, category=%d, distance=%d\n", nids[i], (cats[i] & 0x7FFF), dists[i]);
        }
    }
    
    printf("\nLearning a new example (13) falling between neuron1 and neuron2\n");
    for (int i = 0; i < VECTOR_LENGTH; i++)
        vector[i] = 13;
    hnn.learn(vector, VECTOR_LENGTH, 100);
    displayNeurons();
    printf("=> Notice the addition of neuron 4 and the shrinking of the influence fields of neuron1 and 2\n");
    
    printf("\nLearning a same example (13) using a different category 77\n");
    for (int i = 0; i < VECTOR_LENGTH; i++)
        vector[i] = 13;
    hnn.learn(vector, VECTOR_LENGTH, 77);
    displayNeurons();
    printf("=> Notice if the AIF of a neuron reaches the MINIF, the neuron will be degenerated\n");
    
    printf("\nLearning a new example (12) using context 5, category 200\n");
    for (int i = 0; i < VECTOR_LENGTH; i++)
        vector[i] = 12;
    hnn.setContext(5);
    hnn.learn(vector, VECTOR_LENGTH, 200);
    hnn.setContext(1);
    displayNeurons();
    
    for (int i = 0; i < VECTOR_LENGTH; i++)
        vector[i] = 20;
    printf("\nRecognizing a new pattern using context 5: ");
    for (int i = 0; i < VECTOR_LENGTH; i++)
        printf("%d, ", vector[i]);
    printf("\n");
    hnn.setContext(5);
    response_nbr = hnn.classify(vector, VECTOR_LENGTH, READ_COUNT, dists, cats, nids);
    hnn.setContext(1);
    for (int i = 0; i < response_nbr; i++) {
        if (cats[i] & 0x8000) {
            printf("Firing neuron#%d, category=%d (degenerated), distance=%d\n", nids[i], (cats[i] & 0x7FFF), dists[i]);
        }
        else {
            printf("Firing neuron#%d, category=%d, distance=%d\n", nids[i], (cats[i] & 0x7FFF), dists[i]);
        }
    }
    printf("=> Notice the neurons will not be recognize and shrink if the value of context is not equal\n");
}