MLP
main.cpp
- Committer:
- jcaro
- Date:
- 2018-12-12
- Revision:
- 0:681c3507f129
File content as of revision 0:681c3507f129:
#include <string> #include <math.h> #include <vector> class ANN { //Entrenado anteriormente. //Perceptron multipcapa con una capa de entrada [configurable], //una capa oculta [configurable] y dos neuronas en la capa de salida private: std::vector<double> x; //std::vector de entrada //std::vector <double> *y; //std::vector de salida (dos neuronas, one-hot encoding) struct bias{ std::vector<double> b1; std::vector<double> b2; }; bias b; struct weigth{ std::vector<std::vector<double> > in_hidden; std::vector<std::vector<double> > hidden_out; }; weigth w; double sigmoid (double x){return (2/(1+exp(-2*x))-1);} std::vector<double> softmax (std::vector <double> x); public: ANN(const int input, const int hidden); //constructor ~ANN(); void set_values (std::vector <double> data, std::vector<double> b_l1, std::vector<double> b_l2, std::vector<std::vector<double> > w_in, std::vector<std::vector<double> > w_hi); std::vector <double> compute(); }; ANN::ANN(const int Num_input, const int Num_hidden){ x = std::vector<double>(Num_input, 0); b.b1 = std::vector<double>(Num_hidden,0); b.b2 = std::vector<double>(2,0); w.in_hidden = std::vector<std::vector<double> >(Num_hidden, std::vector<double>(Num_input)); w.hidden_out = std::vector<std::vector<double> >(2,std::vector<double>(Num_hidden)); } ANN::~ANN(){ x.clear(); w.in_hidden.clear(); w.hidden_out.clear(); b.b1.clear(); b.b2.clear(); } void ANN::set_values(std::vector <double> data, std::vector<double> b_l1, std::vector<double> b_l2, std::vector<std::vector<double> > w_in, std::vector<std::vector<double> > w_hi){ x = data; b.b1 = b_l1; b.b2 = b_l2; w.in_hidden = w_in; w.hidden_out = w_hi; } std::vector <double> ANN::softmax(std::vector <double> x){ double temp = 0.0; std::vector<double> soft(x.size(),0); for (int i=0;i<x.size();i++){ temp += exp(x[i]); } for(int j=0;j<x.size();j++){ soft[j] = exp(x[j])/temp; } return soft; } std::vector <double> ANN::compute(){ std::vector <double> temp(w.in_hidden.size(),0); std::vector <double> temp2(2,0); std::vector <double> sig(w.in_hidden.size(),0); std::vector <double> sig2(2,0); for(int i=0;i<w.in_hidden.size();i++){ for(int j=0;j<x.size();j++){ temp[i] += (w.in_hidden[i][j] * x[j]); //neurona i, entrada j } temp[i] += b.b1[i]; sig[i] = sigmoid(temp[i]); } for(int k=0;k<2;k++){ for(int i=0;i<w.in_hidden.size();i++){ temp2[k] += w.hidden_out[k][i] * sig[i]; } temp2[k] += b.b2[k]; } return softmax(temp2); } int main(void) { std::vector<std::vector<double> > weight_1 = { {-0.776369181647380, 0.322017193421995, -0.718395481227090, 0.207597170048144, -0.0407073740709331, -0.0246338638368738, -0.0238415052679444, -0.0271093910868941, 0.182134466569427, -0.172975431241752, 0.496230063966937, 0.462437639022830, -0.459419378323299, 0.505263918878570, 0.746072610663429, -3.55052711602702}, {1.17111526798509, 1.20937547369396, 0.114393343058135, -0.0638126643462502, -0.0398797882458655, 0.638081443938190, 1.35286862375869, 0.833921758394330, 0.367161343307387, -2.35130232502543, 1.21377473736041, 1.04738625394362, 0.600303898512392, 0.311808311752524, 0.850025642696570, 1.04463370664129}, {-3.77410735531982, 4.34823465393789, -1.25103385490538, 2.71357696355065, -2.29300936941041, 2.12527386768870, 0.0385831901721302, 0.580727344848123, -7.01162003821432, -7.76192363548358, -5.89732697247955, -3.47241506503105, -8.03760047575429, 0.304854738412657, 7.13622265468984, 11.2933859407335}, {2.11335788332427, -0.252552108025933, -4.00643608044967, 1.15587150564774, -0.123766981519109, -0.204160623423807, 0.344479260358756, 7.36116077063920, -4.62278770281617, -3.55107132529450, 6.01335554896328, 6.68413920197416, 3.82818197204398, 6.34347116504797, -1.71180947889233, -28.1367052333630}, {0.166325126245893, 0.666184064853938, 0.247691432658770, 0.323518040204145, 0.491502114220709, -0.128062160290439, 0.688987813030598, 0.850896886454144, -0.117068128321992, -0.175564449460198, 0.195474754034279, -2.89156037538277, -2.22968778891023, -0.318750615545913, -0.683787967392067, -0.345374714156862}, {-3.70605069078019, -2.44490993779135, -2.29491540620121, -1.64772178648293, -0.500747206634293, 0.491318304930630, 0.376730487409559, 0.406197183424080, -0.150626197058819, -0.307776393410606, 0.789082304094745, -0.753671734521512, 0.264467065071219, 1.30695623988984, 4.80119692349812, 5.97425717899110}, {0.488093749751451, 0.683563753287076, -0.465876528150416, 0.495127959685572, 0.636059685512091, 0.0506475471144673, -0.120422773768957, 0.451914669322343, -1.00825604223428, 3.02763482594512, 3.74258428768249, 0.991894890924456, 0.443005747114071, 0.480042618907100, -0.315170629621497, -0.393498306328231}, {-1.40479087419449, -1.75501176437853, -0.365083903093547, -1.56638081038246, -1.49129820574439, -0.505339827237711, -0.548928928849952, -0.00273742039398381, -0.531209624415765, -0.155190102313166, 0.145013182057714, -0.869678364582962, -0.294451488285369, -0.832694564109877, -0.230306629916176, 0.00685058825907157} }; std::vector<std::vector<double> > weight_2 = { {-4.87609701608860, 2.30685948555474, -12.0236141168916, -14.2381971205111, 0.325740842013551, 5.49212798310212, 2.67931975729815, -1.09861352834029}, {4.99597478225889, -1.62273664177434, 10.9149140065480, 12.8014121987086, 0.572252747991793, -6.10249798294726, -1.72914725888471, 0.827017369857357}, }; std::vector<double> bias_1 = {-6.82318568418225, -1.99136366317936, -9.92350541544209, -10.4543985678487, 1.52559047977300, 3.31243955581835, -0.251312804814257, 1.76005710490262}; std::vector<double> bias_2 = {-0.0625352184830208, -0.981915330635622}; double xoffset = 0.0488758553274682; double gain = 2.64341085271318; std::vector<std::vector<double> > data = { {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746}, {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746}, {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746}, {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746}, {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746}, {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746}, {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746}, {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746}, {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746}, {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746}, {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746}, {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746}, {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746}, {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746}, {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746}, {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746}, {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746}, {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746}, {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746}, {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746}, {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746}, {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746}, {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746}, {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746}, {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.0410557184750733}, {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.0410557184750733, 0.299120234604106}, {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.0410557184750733, 0.299120234604106, 0.829912023460411}, {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.0410557184750733, 0.299120234604106, 0.829912023460411, 0.974584555229717}, {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.0410557184750733, 0.299120234604106, 0.829912023460411, 0.974584555229717, 1}, {0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.00391006842619746, 0.0410557184750733, 0.299120234604106, 0.829912023460411, 0.974584555229717, 1, 1}, }; std::vector<double> y; ANN multi_layer(16, 8); //16 neuronas de entrada, 8 en la capa oculta y 2 de salida for(int k=0;k<data.size();k++){ y = {0,0}; multi_layer.set_values(data[k], bias_1, bias_2, weight_1, weight_2); y = multi_layer.compute(); printf("%f.*e | %f.*e\r\n",y[0],y[1]); } }