Simple Recurrent Neural Network Predictor

Dependents:   WeatherPredictor

Committer:
yukari_hinata
Date:
Thu Feb 19 19:15:04 2015 +0000
Revision:
7:92ea6cefc6a5
Parent:
5:026d42b4455f
reviced

Who changed what in which revision?

UserRevisionLine numberNew contents of line
yukari_hinata 0:0d42047e140c 1 #ifndef SRNN_H_INCLUDED
yukari_hinata 0:0d42047e140c 2 #define SRNN_H_INCLUDED
yukari_hinata 0:0d42047e140c 3
yukari_hinata 0:0d42047e140c 4 /*
yukari_hinata 0:0d42047e140c 5 #include <math.h>
yukari_hinata 0:0d42047e140c 6 #include <time.h>
yukari_hinata 0:0d42047e140c 7 #include <stdlib.h>
yukari_hinata 0:0d42047e140c 8 #include <stdio.h>
yukari_hinata 0:0d42047e140c 9 #include <string.h>
yukari_hinata 0:0d42047e140c 10 */
yukari_hinata 0:0d42047e140c 11
yukari_hinata 0:0d42047e140c 12 #include "mbed.h"
yukari_hinata 0:0d42047e140c 13
yukari_hinata 4:9d94330f380a 14 #include "../ml_util/ml_util.hpp"
yukari_hinata 5:026d42b4455f 15 #include "../debug/debug.hpp"
yukari_hinata 0:0d42047e140c 16
yukari_hinata 0:0d42047e140c 17 class SRNN
yukari_hinata 0:0d42047e140c 18 {
yukari_hinata 0:0d42047e140c 19 private:
yukari_hinata 0:0d42047e140c 20 int dim_signal; // 入力次元=出力次元数
yukari_hinata 0:0d42047e140c 21 int num_mid_neuron; // 中間層のニューロン数
yukari_hinata 0:0d42047e140c 22 float width_initW; // 係数行列初期値の乱数幅:[-width_initW,+width_initW]
yukari_hinata 0:0d42047e140c 23 float goalError; // 二乗誤差の目標値
yukari_hinata 0:0d42047e140c 24 float epsilon; // 収束判定用の小さい値
yukari_hinata 0:0d42047e140c 25 int maxIteration; // 最大学習繰り返し回数
yukari_hinata 0:0d42047e140c 26 float learnRate; // 学習係数(must be in [0,1])
yukari_hinata 0:0d42047e140c 27 float alpha; // 慣性項の係数(0.8 * learnRateぐらいに設定する予定)
yukari_hinata 0:0d42047e140c 28 float alpha_context; // コンテキスト層の重み付け[0,1]
yukari_hinata 0:0d42047e140c 29 int len_seqence; // サンプルの系列長
yukari_hinata 1:da597cb284a2 30 int len_predict; // 予測系列長
yukari_hinata 0:0d42047e140c 31 float* Win_mid; // 入力<->中間層の係数行列
yukari_hinata 0:0d42047e140c 32 float* Wmid_out; // 中間<->出力層の係数行列
yukari_hinata 0:0d42047e140c 33 float* expand_in_signal; // コンテキスト層も含めた入力層の出力信号.SRNN特有
yukari_hinata 0:0d42047e140c 34 float* expand_mid_signal; // 中間層の出力信号
yukari_hinata 0:0d42047e140c 35 public:
yukari_hinata 0:0d42047e140c 36 float squareError; // 二乗誤差(経験誤差)
yukari_hinata 0:0d42047e140c 37 float* sample; // 系列長len_seqenceに渡る次元dim_signalのサンプル
yukari_hinata 0:0d42047e140c 38 float* sample_maxmin; // サンプルの取りうる最大/最小値信号を並べたベクトル
yukari_hinata 0:0d42047e140c 39 float* predict_signal; // 予測出力
yukari_hinata 0:0d42047e140c 40
yukari_hinata 5:026d42b4455f 41 // for BUG fix....
yukari_hinata 5:026d42b4455f 42 float* norm_sample;
yukari_hinata 5:026d42b4455f 43 float* dWin_mid;
yukari_hinata 5:026d42b4455f 44 float* dWmid_out;
yukari_hinata 5:026d42b4455f 45 float* prevdWin_mid;
yukari_hinata 5:026d42b4455f 46 float* prevdWmid_out;
yukari_hinata 5:026d42b4455f 47 float* out_signal;
yukari_hinata 5:026d42b4455f 48 float* in_mid_net;
yukari_hinata 5:026d42b4455f 49 float* mid_out_net;
yukari_hinata 5:026d42b4455f 50 float* sigma;
yukari_hinata 5:026d42b4455f 51
yukari_hinata 5:026d42b4455f 52
yukari_hinata 0:0d42047e140c 53 private:
yukari_hinata 0:0d42047e140c 54 // サイズn*1のベクトルの要素をそれぞれシグモイド関数に通して,
yukari_hinata 0:0d42047e140c 55 // 結果をoutにセットする.(ニューロンのユニット動作を一括で)
yukari_hinata 0:0d42047e140c 56 void sigmoid_vec(float*, // n * 1 入力ベクトル
yukari_hinata 0:0d42047e140c 57 float*, // n * 1 出力ベクトル
yukari_hinata 0:0d42047e140c 58 int); // n
yukari_hinata 0:0d42047e140c 59
yukari_hinata 0:0d42047e140c 60 public:
yukari_hinata 0:0d42047e140c 61 // 最小限の初期化パラメタによるコンストラクタ.配列(ベクトル)のアロケートを行う.
yukari_hinata 0:0d42047e140c 62 // 適宜追加する予定
yukari_hinata 0:0d42047e140c 63 SRNN(int, // 信号の次元dim_signal
yukari_hinata 0:0d42047e140c 64 int, // 中間層の数num_mid_neuron
yukari_hinata 0:0d42047e140c 65 int, // 系列長len_seqence
yukari_hinata 1:da597cb284a2 66 int, // 予測系列長len_predict
yukari_hinata 0:0d42047e140c 67 float*, // サンプル
yukari_hinata 0:0d42047e140c 68 float*); // サンプルの最大値/最小値ベクトル
yukari_hinata 0:0d42047e140c 69
yukari_hinata 0:0d42047e140c 70 ~SRNN(void);
yukari_hinata 0:0d42047e140c 71
yukari_hinata 0:0d42047e140c 72 // 逆誤差伝搬法による学習を行い,経験誤差が目標値goalErrorに達するか,
yukari_hinata 0:0d42047e140c 73 // 最大繰り返し回数maxIterationに到達したら,その時の二乗誤差を出力する.
yukari_hinata 0:0d42047e140c 74 float learning(void);
yukari_hinata 0:0d42047e140c 75
yukari_hinata 0:0d42047e140c 76 // 予測結果predict_signalにセット
yukari_hinata 0:0d42047e140c 77 void predict(float *input);
yukari_hinata 2:d623e7ef4dca 78
yukari_hinata 2:d623e7ef4dca 79 // サンプルのセット
yukari_hinata 2:d623e7ef4dca 80 void set_sample(float *sample_data);
yukari_hinata 0:0d42047e140c 81 };
yukari_hinata 0:0d42047e140c 82
yukari_hinata 0:0d42047e140c 83 #endif /* SRNN_H_INCLUDED */