Easy Support Vector Machine
Embed:
(wiki syntax)
Show/hide line numbers
SVM.cpp
00001 #include "SVM.hpp" 00002 00003 SVM::SVM(int dim_sample, int n_sample, float* sample_data, int* sample_label) 00004 { 00005 this->dim_signal = dim_sample; 00006 this->n_sample = n_sample; 00007 00008 // 各配列(ベクトル)のアロケート 00009 alpha = new float[n_sample]; 00010 // grammat = new float[n_sample * n_sample]; 00011 label = new int[n_sample]; 00012 sample = new float[dim_signal * n_sample]; 00013 sample_max = new float[dim_signal]; 00014 sample_min = new float[dim_signal]; 00015 00016 // サンプルのコピー 00017 memcpy(this->sample, sample_data, 00018 sizeof(float) * dim_signal * n_sample); 00019 memcpy(this->label, sample_label, 00020 sizeof(int) * n_sample); 00021 00022 // 正規化のための最大最小値 00023 memset(sample_max, 0, sizeof(float) * dim_sample); 00024 memset(sample_min, 0, sizeof(float) * dim_sample); 00025 for (int i = 0; i < dim_signal; i++) { 00026 float value; 00027 sample_min[i] = FLT_MAX; 00028 for (int j = 0; j < n_sample; j++) { 00029 value = MATRIX_AT(this->sample, dim_signal, j, i); 00030 if ( value > sample_max[i] ) { 00031 sample_max[i] = value; 00032 } else if ( value < sample_min[i] ) { 00033 //printf("min[%d] : %f -> ", i, sample_min[i]); 00034 sample_min[i] = value; 00035 //printf("min[%d] : %f \dim_signal", i, value); 00036 } 00037 } 00038 } 00039 00040 // 信号の正規化 : 死ぬほど大事 00041 for (int i = 0; i < dim_signal; i++) { 00042 float max,min; 00043 max = sample_max[i]; 00044 min = sample_min[i]; 00045 for (int j = 0; j < n_sample; j++) { 00046 // printf("[%d,%d] %f -> ", i, j, MATRIX_AT(this->sample, dim_signal, j, i)); 00047 MATRIX_AT(this->sample, dim_signal, j, i) = ( MATRIX_AT(this->sample, dim_signal, j, i) - min ) / (max - min); 00048 // printf("%f \r\n", MATRIX_AT(this->sample, dim_signal, j, i)); 00049 } 00050 } 00051 00052 /* // グラム行列の計算 : メモリの制約上,廃止 00053 for (int i = 0; i < n_sample; i++) { 00054 for (int j = i; j < n_sample; j++) { 00055 MATRIX_AT(grammat,n_sample,i,j) = kernel_function(&(MATRIX_AT(this->sample,dim_signal,i,0)), &(MATRIX_AT(this->sample,dim_signal,j,0)), dim_signal); 00056 // グラム行列は対称 00057 if ( i != j ) { 00058 MATRIX_AT(grammat,n_sample,j,i) = MATRIX_AT(grammat,n_sample,i,j); 00059 } 00060 } 00061 } 00062 */ 00063 00064 // 学習関連の設定. 例によって経験則 00065 this->maxIteration = 5000; 00066 this->epsilon = float(0.00001); 00067 this->eta = float(0.05); 00068 this->learn_alpha = float(0.8) * this->eta; 00069 this->status = SVM_NOT_LEARN; 00070 00071 // ソフトマージンの係数. 両方ともFLT_MAXとすることでハードマージンと一致. 00072 // また, 設定するときはどちらか一方のみにすること. 00073 C1 = FLT_MAX; 00074 C2 = 5; 00075 00076 srand((unsigned int)time(NULL)); 00077 } 00078 00079 // 楽園追放 00080 SVM::~SVM(void) 00081 { 00082 delete [] alpha; 00083 delete [] label; 00084 delete [] sample; 00085 delete [] sample_max; 00086 delete [] sample_min; 00087 } 00088 00089 // 再急勾配法(サーセンwww)による学習 00090 int SVM::learning(void) 00091 { 00092 00093 int iteration; // 学習繰り返しカウント 00094 00095 float* diff_alpha; // 双対問題の勾配値 00096 float* pre_diff_alpha; // 双対問題の前回の勾配値(慣性項に用いる) 00097 float* pre_alpha; // 前回の双対係数ベクトル(収束判定に用いる) 00098 register float diff_sum; // 勾配計算用の小計 00099 register float kernel_val; // カーネル関数とC2を含めた項 00100 00101 //float plus_sum, minus_sum; // 正例と負例の係数和 00102 00103 // 配列(ベクトル)のアロケート 00104 diff_alpha = new float[n_sample]; 00105 pre_diff_alpha = new float[n_sample]; 00106 pre_alpha = new float[n_sample]; 00107 00108 status = SVM_NOT_LEARN; // 学習を未完了に 00109 iteration = 0; // 繰り返し回数を初期化 00110 00111 // 双対係数の初期化.乱択 00112 for (int i = 0; i < n_sample; i++ ) { 00113 // 欠損データの係数は0にして使用しない 00114 if ( label[i] == 0 ) { 00115 alpha[i] = 0; 00116 continue; 00117 } 00118 alpha[i] = uniform_rand(1.0) + 1.0; 00119 } 00120 00121 // 学習ループ 00122 while ( iteration < maxIteration ) { 00123 00124 printf("ite: %d diff_norm : %f alpha_dist : %f \r\n", iteration, two_norm(diff_alpha, n_sample), vec_dist(alpha, pre_alpha, n_sample)); 00125 // 前回の更新値の記録 00126 memcpy(pre_alpha, alpha, sizeof(float) * n_sample); 00127 if ( iteration >= 1 ) { 00128 memcpy(pre_diff_alpha, diff_alpha, sizeof(float) * n_sample); 00129 } else { 00130 // 初回は0埋めで初期化 00131 memset(diff_alpha, 0, sizeof(float) * n_sample); 00132 memset(pre_diff_alpha, 0, sizeof(float) * n_sample); 00133 } 00134 00135 // 勾配値の計算 00136 for (int i=0; i < n_sample; i++) { 00137 diff_sum = 0; 00138 for (int j=0; j < n_sample; j++) { 00139 // C2を踏まえたカーネル関数値 00140 kernel_val = kernel_function(&(MATRIX_AT(sample,dim_signal,i,0)), &(MATRIX_AT(sample,dim_signal,j,0)), dim_signal); 00141 // kernel_val = MATRIX_AT(grammat,n_sample,i,j); // via Gram matrix 00142 if (i == j) { 00143 kernel_val += (1/C2); 00144 } 00145 diff_sum += alpha[j] * label[j] * kernel_val; 00146 } 00147 diff_sum *= label[i]; 00148 diff_alpha[i] = 1 - diff_sum; 00149 } 00150 00151 // 双対変数の更新 00152 for (int i=0; i < n_sample; i++) { 00153 if ( label[i] == 0 ) { 00154 continue; 00155 } 00156 //printf("alpha[%d] : %f -> ", i, alpha[i]); 00157 alpha[i] = pre_alpha[i] 00158 + eta * diff_alpha[i] 00159 + learn_alpha * pre_diff_alpha[i]; 00160 //printf("%f \dim_signal", alpha[i]); 00161 00162 // 非数/無限チェック 00163 if ( isnan(alpha[i]) || isinf(alpha[i]) ) { 00164 fprintf(stderr, "Detected NaN or Inf Dual-Coffience : pre_alhpa[%d]=%f -> alpha[%d]=%f", i, pre_alpha[i], i, alpha[i]); 00165 return SVM_DETECT_BAD_VAL; 00166 } 00167 00168 } 00169 00170 // 係数の制約条件1:正例と負例の双対係数和を等しくする. 00171 // 手法:標本平均に寄せる 00172 float norm_sum = 0; 00173 for (int i = 0; i < n_sample; i++ ) { 00174 norm_sum += (label[i] * alpha[i]); 00175 } 00176 norm_sum /= n_sample; 00177 00178 for (int i = 0; i < n_sample; i++ ) { 00179 if ( label[i] == 0 ) { 00180 continue; 00181 } 00182 alpha[i] -= (norm_sum / label[i]); 00183 } 00184 00185 // 係数の制約条件2:双対係数は非負 00186 for (int i = 0; i < n_sample; i++ ) { 00187 if ( alpha[i] < 0 ) { 00188 alpha[i] = 0; 00189 } else if ( alpha[i] > C1 ) { 00190 // C1を踏まえると,係数の上限はC1となる. 00191 alpha[i] = C1; 00192 } 00193 } 00194 00195 // 収束判定 : 凸計画問題なので,収束時は大域最適が 00196 // 保証されている. 00197 if ( (vec_dist(alpha, pre_alpha, n_sample) < epsilon) 00198 || (two_norm(diff_alpha, n_sample) < epsilon) ) { 00199 // 学習の正常完了 00200 status = SVM_LEARN_SUCCESS; 00201 break; 00202 } 00203 00204 // 学習繰り返し回数のインクリメント 00205 iteration++; 00206 } 00207 00208 if (iteration >= maxIteration) { 00209 fprintf(stderr, "Learning is not convergenced. (iteration count > maxIteration) \r\n"); 00210 status = SVM_NOT_CONVERGENCED; 00211 } else if ( status != SVM_LEARN_SUCCESS ) { 00212 status = SVM_NOT_LEARN; 00213 } 00214 00215 // 領域開放 00216 delete [] diff_alpha; 00217 delete [] pre_diff_alpha; 00218 delete [] pre_alpha; 00219 00220 return status; 00221 00222 } 00223 00224 // 未知データのネットワーク値を計算 00225 float SVM::predict_net(float* data) 00226 { 00227 // 学習の終了を確認 00228 if (status != SVM_LEARN_SUCCESS && status != SVM_SET_ALPHA) { 00229 fprintf(stderr, "Learning is not completed yet."); 00230 //exit(1); 00231 return SVM_NOT_LEARN; 00232 } 00233 00234 float* norm_data = new float[dim_signal]; 00235 00236 // 信号の正規化 00237 for (int i = 0; i < dim_signal; i++) { 00238 norm_data[i] = ( data[i] - sample_min[i] ) / ( sample_max[i] - sample_min[i] ); 00239 } 00240 00241 // ネットワーク値の計算 00242 float net = 0; 00243 for (int l=0; l < n_sample; l++) { 00244 // **係数が正に相当するサンプルはサポートベクトル** 00245 if(alpha[l] > 0) { 00246 net += label[l] * alpha[l] 00247 * kernel_function(&(MATRIX_AT(sample,dim_signal,l,0)), norm_data, dim_signal); 00248 } 00249 } 00250 00251 delete [] norm_data; 00252 00253 return net; 00254 00255 } 00256 00257 // 未知データの識別確率を計算 00258 float SVM::predict_probability(float* data) 00259 { 00260 float net, probability; 00261 float* optimal_w = new float[dim_signal]; // 最適時の係数(not 双対係数) 00262 float sigmoid_param; // シグモイド関数の温度パラメタ 00263 float norm_w; // 係数の2乗ノルム 00264 00265 net = SVM::predict_net(data); 00266 00267 // 最適時の係数を計算 00268 for (int n = 0; n < dim_signal; n++ ) { 00269 optimal_w[n] = 0; 00270 for (int l = 0; l < n_sample; l++ ) { 00271 optimal_w[n] += alpha[l] * label[l] * MATRIX_AT(sample, dim_signal, l, n); 00272 } 00273 } 00274 norm_w = two_norm(optimal_w, dim_signal); 00275 sigmoid_param = 1 / ( norm_w * logf( (1 - epsilon) / epsilon ) ); 00276 00277 probability = sigmoid_func(net/sigmoid_param); 00278 00279 // 打ち切り:誤差epsilon以内ならば, 1 or 0に打ち切る. 00280 if ( probability > (1 - epsilon) ) { 00281 return float(1); 00282 } else if ( probability < epsilon ) { 00283 return float(0); 00284 } 00285 00286 delete [] optimal_w; 00287 00288 return probability; 00289 00290 } 00291 00292 // 未知データの識別 00293 int SVM::predict_label(float* data) 00294 { 00295 return (predict_net(data) >= 0) ? 1 : (-1); 00296 } 00297 00298 // 双対係数のゲッター 00299 float* SVM::get_alpha(void) 00300 { 00301 return (float *)alpha; 00302 } 00303 00304 // 双対係数のセッター 00305 void SVM::set_alpha(float* alpha_data, int nsample) 00306 { 00307 if ( nsample != n_sample ) { 00308 fprintf( stderr, " set_alpha : number of sample isn't match with arg. n_samle= %d, arg= %d \r\n", n_sample, nsample); 00309 return; 00310 } 00311 memcpy(alpha, alpha_data, sizeof(float) * nsample); 00312 status = SVM_SET_ALPHA; 00313 }
Generated on Thu Jul 21 2022 16:34:03 by 1.7.2