Easy Support Vector Machine

Dependents:   WeatherPredictor

Embed: (wiki syntax)

« Back to documentation index

Show/hide line numbers SVM.cpp Source File

SVM.cpp

00001 #include "SVM.hpp"
00002 
00003 SVM::SVM(int dim_sample, int n_sample, float* sample_data, int* sample_label)
00004 {
00005     this->dim_signal = dim_sample;
00006     this->n_sample = n_sample;
00007 
00008     // 各配列(ベクトル)のアロケート
00009     alpha   = new float[n_sample];
00010     // grammat = new float[n_sample * n_sample];
00011     label   = new int[n_sample];
00012     sample  = new float[dim_signal * n_sample];
00013     sample_max = new float[dim_signal];
00014     sample_min = new float[dim_signal];
00015 
00016     // サンプルのコピー
00017     memcpy(this->sample, sample_data,
00018            sizeof(float) * dim_signal * n_sample);
00019     memcpy(this->label, sample_label,
00020            sizeof(int) * n_sample);
00021 
00022     // 正規化のための最大最小値
00023     memset(sample_max, 0, sizeof(float) * dim_sample);
00024     memset(sample_min, 0, sizeof(float) * dim_sample);
00025     for (int i = 0; i < dim_signal; i++) {
00026         float value;
00027         sample_min[i] = FLT_MAX;
00028         for (int j = 0; j < n_sample; j++) {
00029             value = MATRIX_AT(this->sample, dim_signal, j, i);
00030             if ( value > sample_max[i] ) {
00031                 sample_max[i] = value;
00032             } else if ( value < sample_min[i] ) {
00033                 //printf("min[%d] : %f -> ", i, sample_min[i]);
00034                 sample_min[i] = value;
00035                 //printf("min[%d] : %f \dim_signal", i, value);
00036             }
00037         }
00038     }
00039 
00040     // 信号の正規化 : 死ぬほど大事
00041     for (int i = 0; i < dim_signal; i++) {
00042         float max,min;
00043         max = sample_max[i];
00044         min = sample_min[i];
00045         for (int j = 0; j < n_sample; j++) {
00046             // printf("[%d,%d] %f -> ", i, j, MATRIX_AT(this->sample, dim_signal, j, i));
00047             MATRIX_AT(this->sample, dim_signal, j, i) = ( MATRIX_AT(this->sample, dim_signal, j, i) - min ) / (max - min);
00048             // printf("%f \r\n", MATRIX_AT(this->sample, dim_signal, j, i));
00049         }
00050     }
00051 
00052     /* // グラム行列の計算 : メモリの制約上,廃止
00053     for (int i = 0; i < n_sample; i++) {
00054       for (int j = i; j < n_sample; j++) {
00055         MATRIX_AT(grammat,n_sample,i,j) = kernel_function(&(MATRIX_AT(this->sample,dim_signal,i,0)), &(MATRIX_AT(this->sample,dim_signal,j,0)), dim_signal);
00056         // グラム行列は対称
00057         if ( i != j ) {
00058           MATRIX_AT(grammat,n_sample,j,i) = MATRIX_AT(grammat,n_sample,i,j);
00059         }
00060       }
00061     }
00062     */
00063 
00064     // 学習関連の設定. 例によって経験則
00065     this->maxIteration = 5000;
00066     this->epsilon      = float(0.00001);
00067     this->eta          = float(0.05);
00068     this->learn_alpha  = float(0.8) * this->eta;
00069     this->status       = SVM_NOT_LEARN;
00070 
00071     // ソフトマージンの係数. 両方ともFLT_MAXとすることでハードマージンと一致.
00072     // また, 設定するときはどちらか一方のみにすること.
00073     C1 = FLT_MAX;
00074     C2 = 5;
00075 
00076     srand((unsigned int)time(NULL));
00077 }
00078 
00079 // 楽園追放
00080 SVM::~SVM(void)
00081 {
00082     delete [] alpha;
00083     delete [] label;
00084     delete [] sample;
00085     delete [] sample_max;
00086     delete [] sample_min;
00087 }
00088 
00089 // 再急勾配法(サーセンwww)による学習
00090 int SVM::learning(void)
00091 {
00092 
00093     int iteration;              // 学習繰り返しカウント
00094 
00095     float* diff_alpha;          // 双対問題の勾配値
00096     float* pre_diff_alpha;      // 双対問題の前回の勾配値(慣性項に用いる)
00097     float* pre_alpha;           // 前回の双対係数ベクトル(収束判定に用いる)
00098     register float diff_sum;    // 勾配計算用の小計
00099     register float kernel_val;  // カーネル関数とC2を含めた項
00100 
00101     //float plus_sum, minus_sum;  // 正例と負例の係数和
00102 
00103     // 配列(ベクトル)のアロケート
00104     diff_alpha     = new float[n_sample];
00105     pre_diff_alpha = new float[n_sample];
00106     pre_alpha      = new float[n_sample];
00107 
00108     status = SVM_NOT_LEARN;       // 学習を未完了に
00109     iteration  = 0;       // 繰り返し回数を初期化
00110 
00111     // 双対係数の初期化.乱択
00112     for (int i = 0; i < n_sample; i++ ) {
00113         // 欠損データの係数は0にして使用しない
00114         if ( label[i] == 0 ) {
00115             alpha[i] = 0;
00116             continue;
00117         }
00118         alpha[i] = uniform_rand(1.0) + 1.0;
00119     }
00120 
00121     // 学習ループ
00122     while ( iteration < maxIteration ) {
00123 
00124         printf("ite: %d diff_norm : %f alpha_dist : %f \r\n", iteration, two_norm(diff_alpha, n_sample), vec_dist(alpha, pre_alpha, n_sample));
00125         // 前回の更新値の記録
00126         memcpy(pre_alpha, alpha, sizeof(float) * n_sample);
00127         if ( iteration >= 1 ) {
00128             memcpy(pre_diff_alpha, diff_alpha, sizeof(float) * n_sample);
00129         } else {
00130             // 初回は0埋めで初期化
00131             memset(diff_alpha, 0, sizeof(float) * n_sample);
00132             memset(pre_diff_alpha, 0, sizeof(float) * n_sample);
00133         }
00134 
00135         // 勾配値の計算
00136         for (int i=0; i < n_sample; i++) {
00137             diff_sum = 0;
00138             for (int j=0; j < n_sample; j++) {
00139                 // C2を踏まえたカーネル関数値
00140                 kernel_val = kernel_function(&(MATRIX_AT(sample,dim_signal,i,0)), &(MATRIX_AT(sample,dim_signal,j,0)), dim_signal);
00141                 // kernel_val = MATRIX_AT(grammat,n_sample,i,j); // via Gram matrix
00142                 if (i == j) {
00143                     kernel_val += (1/C2);
00144                 }
00145                 diff_sum += alpha[j] * label[j] * kernel_val;
00146             }
00147             diff_sum *= label[i];
00148             diff_alpha[i] = 1 - diff_sum;
00149         }
00150 
00151         // 双対変数の更新
00152         for (int i=0; i < n_sample; i++) {
00153             if ( label[i] == 0 ) {
00154                 continue;
00155             }
00156             //printf("alpha[%d] : %f -> ", i, alpha[i]);
00157             alpha[i] = pre_alpha[i]
00158                        + eta * diff_alpha[i]
00159                        + learn_alpha * pre_diff_alpha[i];
00160             //printf("%f \dim_signal", alpha[i]);
00161 
00162             // 非数/無限チェック
00163             if ( isnan(alpha[i]) || isinf(alpha[i]) ) {
00164                 fprintf(stderr, "Detected NaN or Inf Dual-Coffience : pre_alhpa[%d]=%f -> alpha[%d]=%f", i, pre_alpha[i], i, alpha[i]);
00165                 return SVM_DETECT_BAD_VAL;
00166             }
00167 
00168         }
00169 
00170         // 係数の制約条件1:正例と負例の双対係数和を等しくする.
00171         //                 手法:標本平均に寄せる
00172         float norm_sum = 0;
00173         for (int i = 0; i < n_sample; i++ ) {
00174             norm_sum += (label[i] * alpha[i]);
00175         }
00176         norm_sum /= n_sample;
00177 
00178         for (int i = 0; i < n_sample; i++ ) {
00179             if ( label[i] == 0 ) {
00180                 continue;
00181             }
00182             alpha[i] -= (norm_sum / label[i]);
00183         }
00184 
00185         // 係数の制約条件2:双対係数は非負
00186         for (int i = 0; i < n_sample; i++ ) {
00187             if ( alpha[i] < 0 ) {
00188                 alpha[i] = 0;
00189             } else if ( alpha[i] > C1 ) {
00190                 // C1を踏まえると,係数の上限はC1となる.
00191                 alpha[i] = C1;
00192             }
00193         }
00194 
00195         // 収束判定 : 凸計画問題なので,収束時は大域最適が
00196         //            保証されている.
00197         if ( (vec_dist(alpha, pre_alpha, n_sample) < epsilon)
00198                 || (two_norm(diff_alpha, n_sample) < epsilon) ) {
00199             // 学習の正常完了
00200             status = SVM_LEARN_SUCCESS;
00201             break;
00202         }
00203 
00204         // 学習繰り返し回数のインクリメント
00205         iteration++;
00206     }
00207 
00208     if (iteration >= maxIteration) {
00209         fprintf(stderr, "Learning is not convergenced. (iteration count > maxIteration) \r\n");
00210         status = SVM_NOT_CONVERGENCED;
00211     } else if ( status != SVM_LEARN_SUCCESS ) {
00212         status = SVM_NOT_LEARN;
00213     }
00214 
00215     // 領域開放
00216     delete [] diff_alpha;
00217     delete [] pre_diff_alpha;
00218     delete [] pre_alpha;
00219 
00220     return status;
00221 
00222 }
00223 
00224 // 未知データのネットワーク値を計算
00225 float SVM::predict_net(float* data)
00226 {
00227     // 学習の終了を確認
00228     if (status != SVM_LEARN_SUCCESS && status != SVM_SET_ALPHA) {
00229         fprintf(stderr, "Learning is not completed yet.");
00230         //exit(1);
00231         return SVM_NOT_LEARN;
00232     }
00233 
00234     float* norm_data = new float[dim_signal];
00235 
00236     // 信号の正規化
00237     for (int i = 0; i < dim_signal; i++) {
00238         norm_data[i] = ( data[i] - sample_min[i] ) / ( sample_max[i] - sample_min[i] );
00239     }
00240 
00241     // ネットワーク値の計算
00242     float net = 0;
00243     for (int l=0; l < n_sample; l++) {
00244         // **係数が正に相当するサンプルはサポートベクトル**
00245         if(alpha[l] > 0) {
00246             net += label[l] * alpha[l]
00247                    * kernel_function(&(MATRIX_AT(sample,dim_signal,l,0)), norm_data, dim_signal);
00248         }
00249     }
00250     
00251     delete [] norm_data;
00252 
00253     return net;
00254 
00255 }
00256 
00257 // 未知データの識別確率を計算
00258 float SVM::predict_probability(float* data)
00259 {
00260     float net, probability;
00261     float* optimal_w = new float[dim_signal];   // 最適時の係数(not 双対係数)
00262     float sigmoid_param;                        // シグモイド関数の温度パラメタ
00263     float norm_w;                               // 係数の2乗ノルム
00264 
00265     net = SVM::predict_net(data);
00266 
00267     // 最適時の係数を計算
00268     for (int n = 0; n < dim_signal; n++ ) {
00269         optimal_w[n] = 0;
00270         for (int l = 0; l < n_sample; l++ ) {
00271             optimal_w[n] += alpha[l] * label[l] * MATRIX_AT(sample, dim_signal, l, n);
00272         }
00273     }
00274     norm_w = two_norm(optimal_w, dim_signal);
00275     sigmoid_param = 1 / ( norm_w * logf( (1 - epsilon) / epsilon ) );
00276 
00277     probability = sigmoid_func(net/sigmoid_param);
00278 
00279     // 打ち切り:誤差epsilon以内ならば, 1 or 0に打ち切る.
00280     if ( probability > (1 - epsilon) ) {
00281         return float(1);
00282     } else if ( probability < epsilon ) {
00283         return float(0);
00284     }
00285     
00286     delete [] optimal_w;
00287 
00288     return probability;
00289 
00290 }
00291 
00292 // 未知データの識別
00293 int SVM::predict_label(float* data)
00294 {
00295     return (predict_net(data) >= 0) ? 1 : (-1);
00296 }
00297 
00298 // 双対係数のゲッター
00299 float* SVM::get_alpha(void)
00300 {
00301     return (float *)alpha;
00302 }
00303 
00304 // 双対係数のセッター
00305 void SVM::set_alpha(float* alpha_data, int nsample)
00306 {
00307     if ( nsample != n_sample ) {
00308         fprintf( stderr, " set_alpha : number of sample isn't match with arg. n_samle= %d, arg= %d \r\n", n_sample, nsample);
00309         return;
00310     }
00311     memcpy(alpha, alpha_data, sizeof(float) * nsample);
00312     status = SVM_SET_ALPHA;
00313 }