Easy Support Vector Machine
SVM.cpp@6:e7aa8d270f8b, 2015-02-19 (annotated)
- Committer:
- yukari_hinata
- Date:
- Thu Feb 19 19:14:55 2015 +0000
- Revision:
- 6:e7aa8d270f8b
- Parent:
- 5:792afbb0bcf3
(may be its have memory leak...)
Who changed what in which revision?
User | Revision | Line number | New contents of line |
---|---|---|---|
yukari_hinata | 0:3f38e74a4a77 | 1 | #include "SVM.hpp" |
yukari_hinata | 0:3f38e74a4a77 | 2 | |
yukari_hinata | 0:3f38e74a4a77 | 3 | SVM::SVM(int dim_sample, int n_sample, float* sample_data, int* sample_label) |
yukari_hinata | 0:3f38e74a4a77 | 4 | { |
yukari_hinata | 5:792afbb0bcf3 | 5 | this->dim_signal = dim_sample; |
yukari_hinata | 5:792afbb0bcf3 | 6 | this->n_sample = n_sample; |
yukari_hinata | 0:3f38e74a4a77 | 7 | |
yukari_hinata | 5:792afbb0bcf3 | 8 | // 各配列(ベクトル)のアロケート |
yukari_hinata | 5:792afbb0bcf3 | 9 | alpha = new float[n_sample]; |
yukari_hinata | 5:792afbb0bcf3 | 10 | // grammat = new float[n_sample * n_sample]; |
yukari_hinata | 5:792afbb0bcf3 | 11 | label = new int[n_sample]; |
yukari_hinata | 5:792afbb0bcf3 | 12 | sample = new float[dim_signal * n_sample]; |
yukari_hinata | 5:792afbb0bcf3 | 13 | sample_max = new float[dim_signal]; |
yukari_hinata | 5:792afbb0bcf3 | 14 | sample_min = new float[dim_signal]; |
yukari_hinata | 5:792afbb0bcf3 | 15 | |
yukari_hinata | 5:792afbb0bcf3 | 16 | // サンプルのコピー |
yukari_hinata | 5:792afbb0bcf3 | 17 | memcpy(this->sample, sample_data, |
yukari_hinata | 5:792afbb0bcf3 | 18 | sizeof(float) * dim_signal * n_sample); |
yukari_hinata | 5:792afbb0bcf3 | 19 | memcpy(this->label, sample_label, |
yukari_hinata | 5:792afbb0bcf3 | 20 | sizeof(int) * n_sample); |
yukari_hinata | 0:3f38e74a4a77 | 21 | |
yukari_hinata | 5:792afbb0bcf3 | 22 | // 正規化のための最大最小値 |
yukari_hinata | 5:792afbb0bcf3 | 23 | memset(sample_max, 0, sizeof(float) * dim_sample); |
yukari_hinata | 5:792afbb0bcf3 | 24 | memset(sample_min, 0, sizeof(float) * dim_sample); |
yukari_hinata | 5:792afbb0bcf3 | 25 | for (int i = 0; i < dim_signal; i++) { |
yukari_hinata | 5:792afbb0bcf3 | 26 | float value; |
yukari_hinata | 5:792afbb0bcf3 | 27 | sample_min[i] = FLT_MAX; |
yukari_hinata | 5:792afbb0bcf3 | 28 | for (int j = 0; j < n_sample; j++) { |
yukari_hinata | 5:792afbb0bcf3 | 29 | value = MATRIX_AT(this->sample, dim_signal, j, i); |
yukari_hinata | 5:792afbb0bcf3 | 30 | if ( value > sample_max[i] ) { |
yukari_hinata | 5:792afbb0bcf3 | 31 | sample_max[i] = value; |
yukari_hinata | 5:792afbb0bcf3 | 32 | } else if ( value < sample_min[i] ) { |
yukari_hinata | 5:792afbb0bcf3 | 33 | //printf("min[%d] : %f -> ", i, sample_min[i]); |
yukari_hinata | 5:792afbb0bcf3 | 34 | sample_min[i] = value; |
yukari_hinata | 5:792afbb0bcf3 | 35 | //printf("min[%d] : %f \dim_signal", i, value); |
yukari_hinata | 5:792afbb0bcf3 | 36 | } |
yukari_hinata | 5:792afbb0bcf3 | 37 | } |
yukari_hinata | 5:792afbb0bcf3 | 38 | } |
yukari_hinata | 0:3f38e74a4a77 | 39 | |
yukari_hinata | 5:792afbb0bcf3 | 40 | // 信号の正規化 : 死ぬほど大事 |
yukari_hinata | 5:792afbb0bcf3 | 41 | for (int i = 0; i < dim_signal; i++) { |
yukari_hinata | 5:792afbb0bcf3 | 42 | float max,min; |
yukari_hinata | 5:792afbb0bcf3 | 43 | max = sample_max[i]; |
yukari_hinata | 5:792afbb0bcf3 | 44 | min = sample_min[i]; |
yukari_hinata | 5:792afbb0bcf3 | 45 | for (int j = 0; j < n_sample; j++) { |
yukari_hinata | 5:792afbb0bcf3 | 46 | // printf("[%d,%d] %f -> ", i, j, MATRIX_AT(this->sample, dim_signal, j, i)); |
yukari_hinata | 5:792afbb0bcf3 | 47 | MATRIX_AT(this->sample, dim_signal, j, i) = ( MATRIX_AT(this->sample, dim_signal, j, i) - min ) / (max - min); |
yukari_hinata | 5:792afbb0bcf3 | 48 | // printf("%f \r\n", MATRIX_AT(this->sample, dim_signal, j, i)); |
yukari_hinata | 5:792afbb0bcf3 | 49 | } |
yukari_hinata | 5:792afbb0bcf3 | 50 | } |
yukari_hinata | 5:792afbb0bcf3 | 51 | |
yukari_hinata | 5:792afbb0bcf3 | 52 | /* // グラム行列の計算 : メモリの制約上,廃止 |
yukari_hinata | 5:792afbb0bcf3 | 53 | for (int i = 0; i < n_sample; i++) { |
yukari_hinata | 5:792afbb0bcf3 | 54 | for (int j = i; j < n_sample; j++) { |
yukari_hinata | 5:792afbb0bcf3 | 55 | MATRIX_AT(grammat,n_sample,i,j) = kernel_function(&(MATRIX_AT(this->sample,dim_signal,i,0)), &(MATRIX_AT(this->sample,dim_signal,j,0)), dim_signal); |
yukari_hinata | 5:792afbb0bcf3 | 56 | // グラム行列は対称 |
yukari_hinata | 5:792afbb0bcf3 | 57 | if ( i != j ) { |
yukari_hinata | 5:792afbb0bcf3 | 58 | MATRIX_AT(grammat,n_sample,j,i) = MATRIX_AT(grammat,n_sample,i,j); |
yukari_hinata | 5:792afbb0bcf3 | 59 | } |
yukari_hinata | 0:3f38e74a4a77 | 60 | } |
yukari_hinata | 0:3f38e74a4a77 | 61 | } |
yukari_hinata | 5:792afbb0bcf3 | 62 | */ |
yukari_hinata | 0:3f38e74a4a77 | 63 | |
yukari_hinata | 5:792afbb0bcf3 | 64 | // 学習関連の設定. 例によって経験則 |
yukari_hinata | 5:792afbb0bcf3 | 65 | this->maxIteration = 5000; |
yukari_hinata | 5:792afbb0bcf3 | 66 | this->epsilon = float(0.00001); |
yukari_hinata | 5:792afbb0bcf3 | 67 | this->eta = float(0.05); |
yukari_hinata | 5:792afbb0bcf3 | 68 | this->learn_alpha = float(0.8) * this->eta; |
yukari_hinata | 5:792afbb0bcf3 | 69 | this->status = SVM_NOT_LEARN; |
yukari_hinata | 0:3f38e74a4a77 | 70 | |
yukari_hinata | 5:792afbb0bcf3 | 71 | // ソフトマージンの係数. 両方ともFLT_MAXとすることでハードマージンと一致. |
yukari_hinata | 5:792afbb0bcf3 | 72 | // また, 設定するときはどちらか一方のみにすること. |
yukari_hinata | 5:792afbb0bcf3 | 73 | C1 = FLT_MAX; |
yukari_hinata | 5:792afbb0bcf3 | 74 | C2 = 5; |
yukari_hinata | 0:3f38e74a4a77 | 75 | |
yukari_hinata | 5:792afbb0bcf3 | 76 | srand((unsigned int)time(NULL)); |
yukari_hinata | 0:3f38e74a4a77 | 77 | } |
yukari_hinata | 0:3f38e74a4a77 | 78 | |
yukari_hinata | 0:3f38e74a4a77 | 79 | // 楽園追放 |
yukari_hinata | 5:792afbb0bcf3 | 80 | SVM::~SVM(void) |
yukari_hinata | 0:3f38e74a4a77 | 81 | { |
yukari_hinata | 5:792afbb0bcf3 | 82 | delete [] alpha; |
yukari_hinata | 5:792afbb0bcf3 | 83 | delete [] label; |
yukari_hinata | 0:3f38e74a4a77 | 84 | delete [] sample; |
yukari_hinata | 5:792afbb0bcf3 | 85 | delete [] sample_max; |
yukari_hinata | 5:792afbb0bcf3 | 86 | delete [] sample_min; |
yukari_hinata | 0:3f38e74a4a77 | 87 | } |
yukari_hinata | 0:3f38e74a4a77 | 88 | |
yukari_hinata | 0:3f38e74a4a77 | 89 | // 再急勾配法(サーセンwww)による学習 |
yukari_hinata | 0:3f38e74a4a77 | 90 | int SVM::learning(void) |
yukari_hinata | 0:3f38e74a4a77 | 91 | { |
yukari_hinata | 0:3f38e74a4a77 | 92 | |
yukari_hinata | 5:792afbb0bcf3 | 93 | int iteration; // 学習繰り返しカウント |
yukari_hinata | 0:3f38e74a4a77 | 94 | |
yukari_hinata | 5:792afbb0bcf3 | 95 | float* diff_alpha; // 双対問題の勾配値 |
yukari_hinata | 5:792afbb0bcf3 | 96 | float* pre_diff_alpha; // 双対問題の前回の勾配値(慣性項に用いる) |
yukari_hinata | 5:792afbb0bcf3 | 97 | float* pre_alpha; // 前回の双対係数ベクトル(収束判定に用いる) |
yukari_hinata | 5:792afbb0bcf3 | 98 | register float diff_sum; // 勾配計算用の小計 |
yukari_hinata | 5:792afbb0bcf3 | 99 | register float kernel_val; // カーネル関数とC2を含めた項 |
yukari_hinata | 0:3f38e74a4a77 | 100 | |
yukari_hinata | 5:792afbb0bcf3 | 101 | //float plus_sum, minus_sum; // 正例と負例の係数和 |
yukari_hinata | 0:3f38e74a4a77 | 102 | |
yukari_hinata | 5:792afbb0bcf3 | 103 | // 配列(ベクトル)のアロケート |
yukari_hinata | 5:792afbb0bcf3 | 104 | diff_alpha = new float[n_sample]; |
yukari_hinata | 5:792afbb0bcf3 | 105 | pre_diff_alpha = new float[n_sample]; |
yukari_hinata | 5:792afbb0bcf3 | 106 | pre_alpha = new float[n_sample]; |
yukari_hinata | 0:3f38e74a4a77 | 107 | |
yukari_hinata | 5:792afbb0bcf3 | 108 | status = SVM_NOT_LEARN; // 学習を未完了に |
yukari_hinata | 5:792afbb0bcf3 | 109 | iteration = 0; // 繰り返し回数を初期化 |
yukari_hinata | 0:3f38e74a4a77 | 110 | |
yukari_hinata | 5:792afbb0bcf3 | 111 | // 双対係数の初期化.乱択 |
yukari_hinata | 5:792afbb0bcf3 | 112 | for (int i = 0; i < n_sample; i++ ) { |
yukari_hinata | 5:792afbb0bcf3 | 113 | // 欠損データの係数は0にして使用しない |
yukari_hinata | 5:792afbb0bcf3 | 114 | if ( label[i] == 0 ) { |
yukari_hinata | 5:792afbb0bcf3 | 115 | alpha[i] = 0; |
yukari_hinata | 5:792afbb0bcf3 | 116 | continue; |
yukari_hinata | 0:3f38e74a4a77 | 117 | } |
yukari_hinata | 5:792afbb0bcf3 | 118 | alpha[i] = uniform_rand(1.0) + 1.0; |
yukari_hinata | 0:3f38e74a4a77 | 119 | } |
yukari_hinata | 0:3f38e74a4a77 | 120 | |
yukari_hinata | 5:792afbb0bcf3 | 121 | // 学習ループ |
yukari_hinata | 5:792afbb0bcf3 | 122 | while ( iteration < maxIteration ) { |
yukari_hinata | 5:792afbb0bcf3 | 123 | |
yukari_hinata | 5:792afbb0bcf3 | 124 | printf("ite: %d diff_norm : %f alpha_dist : %f \r\n", iteration, two_norm(diff_alpha, n_sample), vec_dist(alpha, pre_alpha, n_sample)); |
yukari_hinata | 5:792afbb0bcf3 | 125 | // 前回の更新値の記録 |
yukari_hinata | 5:792afbb0bcf3 | 126 | memcpy(pre_alpha, alpha, sizeof(float) * n_sample); |
yukari_hinata | 5:792afbb0bcf3 | 127 | if ( iteration >= 1 ) { |
yukari_hinata | 5:792afbb0bcf3 | 128 | memcpy(pre_diff_alpha, diff_alpha, sizeof(float) * n_sample); |
yukari_hinata | 5:792afbb0bcf3 | 129 | } else { |
yukari_hinata | 5:792afbb0bcf3 | 130 | // 初回は0埋めで初期化 |
yukari_hinata | 5:792afbb0bcf3 | 131 | memset(diff_alpha, 0, sizeof(float) * n_sample); |
yukari_hinata | 5:792afbb0bcf3 | 132 | memset(pre_diff_alpha, 0, sizeof(float) * n_sample); |
yukari_hinata | 5:792afbb0bcf3 | 133 | } |
yukari_hinata | 5:792afbb0bcf3 | 134 | |
yukari_hinata | 5:792afbb0bcf3 | 135 | // 勾配値の計算 |
yukari_hinata | 5:792afbb0bcf3 | 136 | for (int i=0; i < n_sample; i++) { |
yukari_hinata | 5:792afbb0bcf3 | 137 | diff_sum = 0; |
yukari_hinata | 5:792afbb0bcf3 | 138 | for (int j=0; j < n_sample; j++) { |
yukari_hinata | 5:792afbb0bcf3 | 139 | // C2を踏まえたカーネル関数値 |
yukari_hinata | 5:792afbb0bcf3 | 140 | kernel_val = kernel_function(&(MATRIX_AT(sample,dim_signal,i,0)), &(MATRIX_AT(sample,dim_signal,j,0)), dim_signal); |
yukari_hinata | 5:792afbb0bcf3 | 141 | // kernel_val = MATRIX_AT(grammat,n_sample,i,j); // via Gram matrix |
yukari_hinata | 5:792afbb0bcf3 | 142 | if (i == j) { |
yukari_hinata | 5:792afbb0bcf3 | 143 | kernel_val += (1/C2); |
yukari_hinata | 5:792afbb0bcf3 | 144 | } |
yukari_hinata | 5:792afbb0bcf3 | 145 | diff_sum += alpha[j] * label[j] * kernel_val; |
yukari_hinata | 5:792afbb0bcf3 | 146 | } |
yukari_hinata | 5:792afbb0bcf3 | 147 | diff_sum *= label[i]; |
yukari_hinata | 5:792afbb0bcf3 | 148 | diff_alpha[i] = 1 - diff_sum; |
yukari_hinata | 5:792afbb0bcf3 | 149 | } |
yukari_hinata | 5:792afbb0bcf3 | 150 | |
yukari_hinata | 5:792afbb0bcf3 | 151 | // 双対変数の更新 |
yukari_hinata | 5:792afbb0bcf3 | 152 | for (int i=0; i < n_sample; i++) { |
yukari_hinata | 5:792afbb0bcf3 | 153 | if ( label[i] == 0 ) { |
yukari_hinata | 5:792afbb0bcf3 | 154 | continue; |
yukari_hinata | 5:792afbb0bcf3 | 155 | } |
yukari_hinata | 5:792afbb0bcf3 | 156 | //printf("alpha[%d] : %f -> ", i, alpha[i]); |
yukari_hinata | 5:792afbb0bcf3 | 157 | alpha[i] = pre_alpha[i] |
yukari_hinata | 5:792afbb0bcf3 | 158 | + eta * diff_alpha[i] |
yukari_hinata | 5:792afbb0bcf3 | 159 | + learn_alpha * pre_diff_alpha[i]; |
yukari_hinata | 5:792afbb0bcf3 | 160 | //printf("%f \dim_signal", alpha[i]); |
yukari_hinata | 0:3f38e74a4a77 | 161 | |
yukari_hinata | 5:792afbb0bcf3 | 162 | // 非数/無限チェック |
yukari_hinata | 5:792afbb0bcf3 | 163 | if ( isnan(alpha[i]) || isinf(alpha[i]) ) { |
yukari_hinata | 5:792afbb0bcf3 | 164 | fprintf(stderr, "Detected NaN or Inf Dual-Coffience : pre_alhpa[%d]=%f -> alpha[%d]=%f", i, pre_alpha[i], i, alpha[i]); |
yukari_hinata | 5:792afbb0bcf3 | 165 | return SVM_DETECT_BAD_VAL; |
yukari_hinata | 5:792afbb0bcf3 | 166 | } |
yukari_hinata | 5:792afbb0bcf3 | 167 | |
yukari_hinata | 5:792afbb0bcf3 | 168 | } |
yukari_hinata | 0:3f38e74a4a77 | 169 | |
yukari_hinata | 5:792afbb0bcf3 | 170 | // 係数の制約条件1:正例と負例の双対係数和を等しくする. |
yukari_hinata | 5:792afbb0bcf3 | 171 | // 手法:標本平均に寄せる |
yukari_hinata | 5:792afbb0bcf3 | 172 | float norm_sum = 0; |
yukari_hinata | 5:792afbb0bcf3 | 173 | for (int i = 0; i < n_sample; i++ ) { |
yukari_hinata | 5:792afbb0bcf3 | 174 | norm_sum += (label[i] * alpha[i]); |
yukari_hinata | 5:792afbb0bcf3 | 175 | } |
yukari_hinata | 5:792afbb0bcf3 | 176 | norm_sum /= n_sample; |
yukari_hinata | 0:3f38e74a4a77 | 177 | |
yukari_hinata | 5:792afbb0bcf3 | 178 | for (int i = 0; i < n_sample; i++ ) { |
yukari_hinata | 5:792afbb0bcf3 | 179 | if ( label[i] == 0 ) { |
yukari_hinata | 5:792afbb0bcf3 | 180 | continue; |
yukari_hinata | 5:792afbb0bcf3 | 181 | } |
yukari_hinata | 5:792afbb0bcf3 | 182 | alpha[i] -= (norm_sum / label[i]); |
yukari_hinata | 5:792afbb0bcf3 | 183 | } |
yukari_hinata | 0:3f38e74a4a77 | 184 | |
yukari_hinata | 5:792afbb0bcf3 | 185 | // 係数の制約条件2:双対係数は非負 |
yukari_hinata | 5:792afbb0bcf3 | 186 | for (int i = 0; i < n_sample; i++ ) { |
yukari_hinata | 5:792afbb0bcf3 | 187 | if ( alpha[i] < 0 ) { |
yukari_hinata | 5:792afbb0bcf3 | 188 | alpha[i] = 0; |
yukari_hinata | 5:792afbb0bcf3 | 189 | } else if ( alpha[i] > C1 ) { |
yukari_hinata | 5:792afbb0bcf3 | 190 | // C1を踏まえると,係数の上限はC1となる. |
yukari_hinata | 5:792afbb0bcf3 | 191 | alpha[i] = C1; |
yukari_hinata | 5:792afbb0bcf3 | 192 | } |
yukari_hinata | 5:792afbb0bcf3 | 193 | } |
yukari_hinata | 5:792afbb0bcf3 | 194 | |
yukari_hinata | 5:792afbb0bcf3 | 195 | // 収束判定 : 凸計画問題なので,収束時は大域最適が |
yukari_hinata | 5:792afbb0bcf3 | 196 | // 保証されている. |
yukari_hinata | 5:792afbb0bcf3 | 197 | if ( (vec_dist(alpha, pre_alpha, n_sample) < epsilon) |
yukari_hinata | 5:792afbb0bcf3 | 198 | || (two_norm(diff_alpha, n_sample) < epsilon) ) { |
yukari_hinata | 5:792afbb0bcf3 | 199 | // 学習の正常完了 |
yukari_hinata | 5:792afbb0bcf3 | 200 | status = SVM_LEARN_SUCCESS; |
yukari_hinata | 5:792afbb0bcf3 | 201 | break; |
yukari_hinata | 5:792afbb0bcf3 | 202 | } |
yukari_hinata | 5:792afbb0bcf3 | 203 | |
yukari_hinata | 5:792afbb0bcf3 | 204 | // 学習繰り返し回数のインクリメント |
yukari_hinata | 5:792afbb0bcf3 | 205 | iteration++; |
yukari_hinata | 0:3f38e74a4a77 | 206 | } |
yukari_hinata | 0:3f38e74a4a77 | 207 | |
yukari_hinata | 5:792afbb0bcf3 | 208 | if (iteration >= maxIteration) { |
yukari_hinata | 5:792afbb0bcf3 | 209 | fprintf(stderr, "Learning is not convergenced. (iteration count > maxIteration) \r\n"); |
yukari_hinata | 5:792afbb0bcf3 | 210 | status = SVM_NOT_CONVERGENCED; |
yukari_hinata | 5:792afbb0bcf3 | 211 | } else if ( status != SVM_LEARN_SUCCESS ) { |
yukari_hinata | 5:792afbb0bcf3 | 212 | status = SVM_NOT_LEARN; |
yukari_hinata | 0:3f38e74a4a77 | 213 | } |
yukari_hinata | 0:3f38e74a4a77 | 214 | |
yukari_hinata | 5:792afbb0bcf3 | 215 | // 領域開放 |
yukari_hinata | 5:792afbb0bcf3 | 216 | delete [] diff_alpha; |
yukari_hinata | 5:792afbb0bcf3 | 217 | delete [] pre_diff_alpha; |
yukari_hinata | 5:792afbb0bcf3 | 218 | delete [] pre_alpha; |
yukari_hinata | 0:3f38e74a4a77 | 219 | |
yukari_hinata | 5:792afbb0bcf3 | 220 | return status; |
yukari_hinata | 0:3f38e74a4a77 | 221 | |
yukari_hinata | 0:3f38e74a4a77 | 222 | } |
yukari_hinata | 0:3f38e74a4a77 | 223 | |
yukari_hinata | 0:3f38e74a4a77 | 224 | // 未知データのネットワーク値を計算 |
yukari_hinata | 0:3f38e74a4a77 | 225 | float SVM::predict_net(float* data) |
yukari_hinata | 0:3f38e74a4a77 | 226 | { |
yukari_hinata | 5:792afbb0bcf3 | 227 | // 学習の終了を確認 |
yukari_hinata | 5:792afbb0bcf3 | 228 | if (status != SVM_LEARN_SUCCESS && status != SVM_SET_ALPHA) { |
yukari_hinata | 5:792afbb0bcf3 | 229 | fprintf(stderr, "Learning is not completed yet."); |
yukari_hinata | 5:792afbb0bcf3 | 230 | //exit(1); |
yukari_hinata | 5:792afbb0bcf3 | 231 | return SVM_NOT_LEARN; |
yukari_hinata | 5:792afbb0bcf3 | 232 | } |
yukari_hinata | 0:3f38e74a4a77 | 233 | |
yukari_hinata | 5:792afbb0bcf3 | 234 | float* norm_data = new float[dim_signal]; |
yukari_hinata | 5:792afbb0bcf3 | 235 | |
yukari_hinata | 5:792afbb0bcf3 | 236 | // 信号の正規化 |
yukari_hinata | 5:792afbb0bcf3 | 237 | for (int i = 0; i < dim_signal; i++) { |
yukari_hinata | 5:792afbb0bcf3 | 238 | norm_data[i] = ( data[i] - sample_min[i] ) / ( sample_max[i] - sample_min[i] ); |
yukari_hinata | 5:792afbb0bcf3 | 239 | } |
yukari_hinata | 0:3f38e74a4a77 | 240 | |
yukari_hinata | 5:792afbb0bcf3 | 241 | // ネットワーク値の計算 |
yukari_hinata | 5:792afbb0bcf3 | 242 | float net = 0; |
yukari_hinata | 5:792afbb0bcf3 | 243 | for (int l=0; l < n_sample; l++) { |
yukari_hinata | 5:792afbb0bcf3 | 244 | // **係数が正に相当するサンプルはサポートベクトル** |
yukari_hinata | 5:792afbb0bcf3 | 245 | if(alpha[l] > 0) { |
yukari_hinata | 5:792afbb0bcf3 | 246 | net += label[l] * alpha[l] |
yukari_hinata | 5:792afbb0bcf3 | 247 | * kernel_function(&(MATRIX_AT(sample,dim_signal,l,0)), norm_data, dim_signal); |
yukari_hinata | 5:792afbb0bcf3 | 248 | } |
yukari_hinata | 5:792afbb0bcf3 | 249 | } |
yukari_hinata | 5:792afbb0bcf3 | 250 | |
yukari_hinata | 5:792afbb0bcf3 | 251 | delete [] norm_data; |
yukari_hinata | 0:3f38e74a4a77 | 252 | |
yukari_hinata | 5:792afbb0bcf3 | 253 | return net; |
yukari_hinata | 0:3f38e74a4a77 | 254 | |
yukari_hinata | 0:3f38e74a4a77 | 255 | } |
yukari_hinata | 0:3f38e74a4a77 | 256 | |
yukari_hinata | 0:3f38e74a4a77 | 257 | // 未知データの識別確率を計算 |
yukari_hinata | 0:3f38e74a4a77 | 258 | float SVM::predict_probability(float* data) |
yukari_hinata | 0:3f38e74a4a77 | 259 | { |
yukari_hinata | 0:3f38e74a4a77 | 260 | float net, probability; |
yukari_hinata | 0:3f38e74a4a77 | 261 | float* optimal_w = new float[dim_signal]; // 最適時の係数(not 双対係数) |
yukari_hinata | 0:3f38e74a4a77 | 262 | float sigmoid_param; // シグモイド関数の温度パラメタ |
yukari_hinata | 0:3f38e74a4a77 | 263 | float norm_w; // 係数の2乗ノルム |
yukari_hinata | 5:792afbb0bcf3 | 264 | |
yukari_hinata | 0:3f38e74a4a77 | 265 | net = SVM::predict_net(data); |
yukari_hinata | 5:792afbb0bcf3 | 266 | |
yukari_hinata | 0:3f38e74a4a77 | 267 | // 最適時の係数を計算 |
yukari_hinata | 0:3f38e74a4a77 | 268 | for (int n = 0; n < dim_signal; n++ ) { |
yukari_hinata | 0:3f38e74a4a77 | 269 | optimal_w[n] = 0; |
yukari_hinata | 0:3f38e74a4a77 | 270 | for (int l = 0; l < n_sample; l++ ) { |
yukari_hinata | 0:3f38e74a4a77 | 271 | optimal_w[n] += alpha[l] * label[l] * MATRIX_AT(sample, dim_signal, l, n); |
yukari_hinata | 0:3f38e74a4a77 | 272 | } |
yukari_hinata | 0:3f38e74a4a77 | 273 | } |
yukari_hinata | 0:3f38e74a4a77 | 274 | norm_w = two_norm(optimal_w, dim_signal); |
yukari_hinata | 0:3f38e74a4a77 | 275 | sigmoid_param = 1 / ( norm_w * logf( (1 - epsilon) / epsilon ) ); |
yukari_hinata | 5:792afbb0bcf3 | 276 | |
yukari_hinata | 0:3f38e74a4a77 | 277 | probability = sigmoid_func(net/sigmoid_param); |
yukari_hinata | 5:792afbb0bcf3 | 278 | |
yukari_hinata | 0:3f38e74a4a77 | 279 | // 打ち切り:誤差epsilon以内ならば, 1 or 0に打ち切る. |
yukari_hinata | 0:3f38e74a4a77 | 280 | if ( probability > (1 - epsilon) ) { |
yukari_hinata | 0:3f38e74a4a77 | 281 | return float(1); |
yukari_hinata | 0:3f38e74a4a77 | 282 | } else if ( probability < epsilon ) { |
yukari_hinata | 0:3f38e74a4a77 | 283 | return float(0); |
yukari_hinata | 0:3f38e74a4a77 | 284 | } |
yukari_hinata | 0:3f38e74a4a77 | 285 | |
yukari_hinata | 5:792afbb0bcf3 | 286 | delete [] optimal_w; |
yukari_hinata | 5:792afbb0bcf3 | 287 | |
yukari_hinata | 0:3f38e74a4a77 | 288 | return probability; |
yukari_hinata | 5:792afbb0bcf3 | 289 | |
yukari_hinata | 0:3f38e74a4a77 | 290 | } |
yukari_hinata | 0:3f38e74a4a77 | 291 | |
yukari_hinata | 0:3f38e74a4a77 | 292 | // 未知データの識別 |
yukari_hinata | 0:3f38e74a4a77 | 293 | int SVM::predict_label(float* data) |
yukari_hinata | 0:3f38e74a4a77 | 294 | { |
yukari_hinata | 5:792afbb0bcf3 | 295 | return (predict_net(data) >= 0) ? 1 : (-1); |
yukari_hinata | 0:3f38e74a4a77 | 296 | } |
yukari_hinata | 0:3f38e74a4a77 | 297 | |
yukari_hinata | 0:3f38e74a4a77 | 298 | // 双対係数のゲッター |
yukari_hinata | 5:792afbb0bcf3 | 299 | float* SVM::get_alpha(void) |
yukari_hinata | 5:792afbb0bcf3 | 300 | { |
yukari_hinata | 0:3f38e74a4a77 | 301 | return (float *)alpha; |
yukari_hinata | 0:3f38e74a4a77 | 302 | } |
yukari_hinata | 0:3f38e74a4a77 | 303 | |
yukari_hinata | 0:3f38e74a4a77 | 304 | // 双対係数のセッター |
yukari_hinata | 5:792afbb0bcf3 | 305 | void SVM::set_alpha(float* alpha_data, int nsample) |
yukari_hinata | 5:792afbb0bcf3 | 306 | { |
yukari_hinata | 0:3f38e74a4a77 | 307 | if ( nsample != n_sample ) { |
yukari_hinata | 0:3f38e74a4a77 | 308 | fprintf( stderr, " set_alpha : number of sample isn't match with arg. n_samle= %d, arg= %d \r\n", n_sample, nsample); |
yukari_hinata | 0:3f38e74a4a77 | 309 | return; |
yukari_hinata | 0:3f38e74a4a77 | 310 | } |
yukari_hinata | 0:3f38e74a4a77 | 311 | memcpy(alpha, alpha_data, sizeof(float) * nsample); |
yukari_hinata | 0:3f38e74a4a77 | 312 | status = SVM_SET_ALPHA; |
yukari_hinata | 0:3f38e74a4a77 | 313 | } |