Simple Recurrent Neural Network Predictor
Diff: SRNN.cpp
- Revision:
- 2:d623e7ef4dca
- Parent:
- 1:da597cb284a2
- Child:
- 4:9d94330f380a
--- a/SRNN.cpp Sun Feb 15 04:05:35 2015 +0000 +++ b/SRNN.cpp Sun Feb 15 09:27:31 2015 +0000 @@ -51,12 +51,9 @@ SRNN::~SRNN(void) { - delete [] sample; - delete [] sample_maxmin; - delete [] predict_signal; - delete [] Win_mid; - delete [] Wmid_out; - delete [] expand_in_signal; + delete [] sample; delete [] sample_maxmin; + delete [] predict_signal; delete [] Win_mid; + delete [] Wmid_out; delete [] expand_in_signal; delete [] expand_mid_signal; } @@ -73,8 +70,7 @@ void SRNN::predict(float* input) { float *norm_input = new float[this->dim_signal]; - - + // output signal float* out_signal = new float[dim_signal]; // value of network in input->hidden layer @@ -184,15 +180,15 @@ // 前回の二乗誤差値:収束判定に用いる. float prevError; - + squareError = FLT_MAX; /* 学習ループ */ while (1) { // 終了条件を満たすか確認 if (!end_flag) { - end_flag = !(iteration < this->maxIteration - && (iteration <= this->len_seqence - || this->squareError > this->goalError) + end_flag = !(iteration < maxIteration + && (iteration <= len_seqence + || squareError > goalError) ); } @@ -260,7 +256,7 @@ MATRIX_AT(sample_maxmin,2,i,0), MATRIX_AT(sample_maxmin,2,i,1)); } - printf("predict : %f %f %f \r\n", predict_signal[0], predict_signal[1], predict_signal[2]); + // printf("predict : %f %f %f \r\n", predict_signal[0], predict_signal[1], predict_signal[2]); // print_mat(Wmid_out, row_mid_out, col_mid_out); @@ -366,3 +362,9 @@ return squareError; } + +// サンプルの(リ)セット +void SRNN::set_sample(float* sample_data) +{ + memcpy(sample, sample_data, sizeof(float) * len_seqence * dim_signal); +}