MLP

main.cpp

Committer:
jcaro
Date:
2018-12-12
Revision:
0:681c3507f129

File content as of revision 0:681c3507f129:

#include <string>
#include <math.h>
#include <vector>

class ANN {
    //Entrenado anteriormente.
    //Perceptron multipcapa con una capa de entrada [configurable],
    //una capa oculta [configurable] y dos neuronas en la capa de salida
    private:
        std::vector<double> x; //std::vector de entrada
        //std::vector <double> *y; //std::vector de salida (dos neuronas, one-hot encoding)
        struct bias{
            std::vector<double> b1;
            std::vector<double> b2;
        };
        bias b;
        struct weigth{
            std::vector<std::vector<double> > in_hidden;
            std::vector<std::vector<double> > hidden_out;
        };
        weigth w;
        double sigmoid (double x){return (2/(1+exp(-2*x))-1);}
        std::vector<double>  softmax (std::vector <double> x); 
    public:
        ANN(const int input, const int hidden); //constructor
        ~ANN();
        void set_values (std::vector <double> data, std::vector<double> b_l1, std::vector<double> b_l2,
                        std::vector<std::vector<double> > w_in, std::vector<std::vector<double> > w_hi);
        std::vector <double> compute();
};

ANN::ANN(const int Num_input, const int Num_hidden){
    x = std::vector<double>(Num_input, 0);
    b.b1 = std::vector<double>(Num_hidden,0);
    b.b2 = std::vector<double>(2,0);
    w.in_hidden = std::vector<std::vector<double> >(Num_hidden, std::vector<double>(Num_input));
    w.hidden_out = std::vector<std::vector<double> >(2,std::vector<double>(Num_hidden));
}

ANN::~ANN(){
    x.clear();
    w.in_hidden.clear();
    w.hidden_out.clear();
    b.b1.clear();
    b.b2.clear();
}

void ANN::set_values(std::vector <double> data, std::vector<double> b_l1, std::vector<double> b_l2,
                     std::vector<std::vector<double> > w_in, std::vector<std::vector<double> > w_hi){
    x = data;
    b.b1 = b_l1;
    b.b2 = b_l2;
    w.in_hidden = w_in;
    w.hidden_out = w_hi;
}

std::vector <double> ANN::softmax(std::vector <double> x){
    double temp = 0.0;
    std::vector<double> soft(x.size(),0);
    
    for (int i=0;i<x.size();i++){
        temp += exp(x[i]);
    }
    
    for(int j=0;j<x.size();j++){
        soft[j] = exp(x[j])/temp;
    }
    
    return soft;
}

std::vector <double> ANN::compute(){
    std::vector <double> temp(w.in_hidden.size(),0);
    std::vector <double> temp2(2,0);
    std::vector <double> sig(w.in_hidden.size(),0);
    std::vector <double> sig2(2,0);
    
    for(int i=0;i<w.in_hidden.size();i++){ 
        for(int j=0;j<x.size();j++){
            temp[i] += (w.in_hidden[i][j] * x[j]); //neurona i, entrada j
        }
        temp[i] += b.b1[i];
        sig[i] = sigmoid(temp[i]);
    }
    
    for(int k=0;k<2;k++){
        for(int i=0;i<w.in_hidden.size();i++){
            temp2[k] += w.hidden_out[k][i] * sig[i];
        }
        temp2[k] += b.b2[k];
    }
    
    return softmax(temp2);
    
}

int main(void)
{
    std::vector<std::vector<double> > weight_1 = {            
        {-0.776369181647380, 0.322017193421995, -0.718395481227090, 0.207597170048144, -0.0407073740709331, -0.0246338638368738, -0.0238415052679444, -0.0271093910868941, 0.182134466569427, -0.172975431241752, 0.496230063966937, 0.462437639022830, -0.459419378323299, 0.505263918878570, 0.746072610663429, -3.55052711602702},
        {1.17111526798509, 1.20937547369396, 0.114393343058135, -0.0638126643462502, -0.0398797882458655, 0.638081443938190, 1.35286862375869, 0.833921758394330, 0.367161343307387, -2.35130232502543, 1.21377473736041, 1.04738625394362, 0.600303898512392, 0.311808311752524, 0.850025642696570, 1.04463370664129},
        {-3.77410735531982, 4.34823465393789, -1.25103385490538, 2.71357696355065, -2.29300936941041, 2.12527386768870, 0.0385831901721302, 0.580727344848123, -7.01162003821432, -7.76192363548358, -5.89732697247955, -3.47241506503105, -8.03760047575429, 0.304854738412657, 7.13622265468984, 11.2933859407335},
        {2.11335788332427, -0.252552108025933, -4.00643608044967, 1.15587150564774, -0.123766981519109, -0.204160623423807, 0.344479260358756, 7.36116077063920, -4.62278770281617, -3.55107132529450, 6.01335554896328, 6.68413920197416, 3.82818197204398, 6.34347116504797, -1.71180947889233, -28.1367052333630},
        {0.166325126245893, 0.666184064853938, 0.247691432658770, 0.323518040204145, 0.491502114220709, -0.128062160290439, 0.688987813030598, 0.850896886454144, -0.117068128321992, -0.175564449460198, 0.195474754034279, -2.89156037538277, -2.22968778891023, -0.318750615545913, -0.683787967392067, -0.345374714156862},
        {-3.70605069078019, -2.44490993779135, -2.29491540620121, -1.64772178648293, -0.500747206634293, 0.491318304930630, 0.376730487409559, 0.406197183424080, -0.150626197058819, -0.307776393410606, 0.789082304094745, -0.753671734521512, 0.264467065071219, 1.30695623988984, 4.80119692349812, 5.97425717899110},
        {0.488093749751451, 0.683563753287076, -0.465876528150416, 0.495127959685572, 0.636059685512091, 0.0506475471144673, -0.120422773768957, 0.451914669322343, -1.00825604223428, 3.02763482594512, 3.74258428768249, 0.991894890924456, 0.443005747114071, 0.480042618907100, -0.315170629621497, -0.393498306328231},
        {-1.40479087419449, -1.75501176437853, -0.365083903093547, -1.56638081038246, -1.49129820574439, -0.505339827237711, -0.548928928849952, -0.00273742039398381, -0.531209624415765, -0.155190102313166, 0.145013182057714, -0.869678364582962, -0.294451488285369, -0.832694564109877, -0.230306629916176, 0.00685058825907157}
    };
    std::vector<std::vector<double> > weight_2 = {
        {-4.87609701608860, 2.30685948555474, -12.0236141168916, -14.2381971205111, 0.325740842013551, 5.49212798310212, 2.67931975729815, -1.09861352834029},
        {4.99597478225889, -1.62273664177434, 10.9149140065480, 12.8014121987086, 0.572252747991793, -6.10249798294726, -1.72914725888471, 0.827017369857357},
    };
    std::vector<double> bias_1 = {-6.82318568418225, -1.99136366317936, -9.92350541544209, -10.4543985678487, 1.52559047977300, 3.31243955581835, -0.251312804814257, 1.76005710490262};
    std::vector<double> bias_2 = {-0.0625352184830208, -0.981915330635622};
    double xoffset = 0.0488758553274682;
    double gain = 2.64341085271318;
    
    std::vector<std::vector<double> > data = {
        {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746},
        {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746},
        {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746},
        {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746},
        {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746},
        {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746},
        {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746},
        {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746},
        {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746},
        {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746},
        {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746},
        {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746},
        {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746},
        {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746},
        {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746},
        {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746},
        {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746},
        {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746},
        {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746},
        {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746},
        {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746},
        {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746},
        {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746},
        {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746},
        {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.0410557184750733},
        {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.0410557184750733, 0.299120234604106},
        {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.0410557184750733, 0.299120234604106, 0.829912023460411},
        {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.0410557184750733, 0.299120234604106, 0.829912023460411, 0.974584555229717},
        {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.0410557184750733, 0.299120234604106, 0.829912023460411, 0.974584555229717, 1},
        {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.0410557184750733, 0.299120234604106, 0.829912023460411, 0.974584555229717, 1, 1},
        };
    
    std::vector<double> y;
    
    ANN multi_layer(16, 8); //16 neuronas de entrada, 8 en la capa oculta y 2 de salida    
    
    for(int k=0;k<data.size();k++){
        y = {0,0};
        multi_layer.set_values(data[k], bias_1, bias_2, weight_1, weight_2);    
        y = multi_layer.compute();
        printf("%f.*e  |  %f.*e\r\n",y[0],y[1]);
    }
    
}