Easy Support Vector Machine

Dependents:   WeatherPredictor

Committer:
yukari_hinata
Date:
Wed Feb 18 15:01:12 2015 +0000
Revision:
5:792afbb0bcf3
Parent:
4:01a20b89db32

        

Who changed what in which revision?

UserRevisionLine numberNew contents of line
yukari_hinata 0:3f38e74a4a77 1 #include "SVM.hpp"
yukari_hinata 0:3f38e74a4a77 2
yukari_hinata 0:3f38e74a4a77 3 SVM::SVM(int dim_sample, int n_sample, float* sample_data, int* sample_label)
yukari_hinata 0:3f38e74a4a77 4 {
yukari_hinata 5:792afbb0bcf3 5 this->dim_signal = dim_sample;
yukari_hinata 5:792afbb0bcf3 6 this->n_sample = n_sample;
yukari_hinata 0:3f38e74a4a77 7
yukari_hinata 5:792afbb0bcf3 8 // 各配列(ベクトル)のアロケート
yukari_hinata 5:792afbb0bcf3 9 alpha = new float[n_sample];
yukari_hinata 5:792afbb0bcf3 10 // grammat = new float[n_sample * n_sample];
yukari_hinata 5:792afbb0bcf3 11 label = new int[n_sample];
yukari_hinata 5:792afbb0bcf3 12 sample = new float[dim_signal * n_sample];
yukari_hinata 5:792afbb0bcf3 13 sample_max = new float[dim_signal];
yukari_hinata 5:792afbb0bcf3 14 sample_min = new float[dim_signal];
yukari_hinata 5:792afbb0bcf3 15
yukari_hinata 5:792afbb0bcf3 16 // サンプルのコピー
yukari_hinata 5:792afbb0bcf3 17 memcpy(this->sample, sample_data,
yukari_hinata 5:792afbb0bcf3 18 sizeof(float) * dim_signal * n_sample);
yukari_hinata 5:792afbb0bcf3 19 memcpy(this->label, sample_label,
yukari_hinata 5:792afbb0bcf3 20 sizeof(int) * n_sample);
yukari_hinata 0:3f38e74a4a77 21
yukari_hinata 5:792afbb0bcf3 22 // 正規化のための最大最小値
yukari_hinata 5:792afbb0bcf3 23 memset(sample_max, 0, sizeof(float) * dim_sample);
yukari_hinata 5:792afbb0bcf3 24 memset(sample_min, 0, sizeof(float) * dim_sample);
yukari_hinata 5:792afbb0bcf3 25 for (int i = 0; i < dim_signal; i++) {
yukari_hinata 5:792afbb0bcf3 26 float value;
yukari_hinata 5:792afbb0bcf3 27 sample_min[i] = FLT_MAX;
yukari_hinata 5:792afbb0bcf3 28 for (int j = 0; j < n_sample; j++) {
yukari_hinata 5:792afbb0bcf3 29 value = MATRIX_AT(this->sample, dim_signal, j, i);
yukari_hinata 5:792afbb0bcf3 30 if ( value > sample_max[i] ) {
yukari_hinata 5:792afbb0bcf3 31 sample_max[i] = value;
yukari_hinata 5:792afbb0bcf3 32 } else if ( value < sample_min[i] ) {
yukari_hinata 5:792afbb0bcf3 33 //printf("min[%d] : %f -> ", i, sample_min[i]);
yukari_hinata 5:792afbb0bcf3 34 sample_min[i] = value;
yukari_hinata 5:792afbb0bcf3 35 //printf("min[%d] : %f \dim_signal", i, value);
yukari_hinata 5:792afbb0bcf3 36 }
yukari_hinata 5:792afbb0bcf3 37 }
yukari_hinata 5:792afbb0bcf3 38 }
yukari_hinata 0:3f38e74a4a77 39
yukari_hinata 5:792afbb0bcf3 40 // 信号の正規化 : 死ぬほど大事
yukari_hinata 5:792afbb0bcf3 41 for (int i = 0; i < dim_signal; i++) {
yukari_hinata 5:792afbb0bcf3 42 float max,min;
yukari_hinata 5:792afbb0bcf3 43 max = sample_max[i];
yukari_hinata 5:792afbb0bcf3 44 min = sample_min[i];
yukari_hinata 5:792afbb0bcf3 45 for (int j = 0; j < n_sample; j++) {
yukari_hinata 5:792afbb0bcf3 46 // printf("[%d,%d] %f -> ", i, j, MATRIX_AT(this->sample, dim_signal, j, i));
yukari_hinata 5:792afbb0bcf3 47 MATRIX_AT(this->sample, dim_signal, j, i) = ( MATRIX_AT(this->sample, dim_signal, j, i) - min ) / (max - min);
yukari_hinata 5:792afbb0bcf3 48 // printf("%f \r\n", MATRIX_AT(this->sample, dim_signal, j, i));
yukari_hinata 5:792afbb0bcf3 49 }
yukari_hinata 5:792afbb0bcf3 50 }
yukari_hinata 5:792afbb0bcf3 51
yukari_hinata 5:792afbb0bcf3 52 /* // グラム行列の計算 : メモリの制約上,廃止
yukari_hinata 5:792afbb0bcf3 53 for (int i = 0; i < n_sample; i++) {
yukari_hinata 5:792afbb0bcf3 54 for (int j = i; j < n_sample; j++) {
yukari_hinata 5:792afbb0bcf3 55 MATRIX_AT(grammat,n_sample,i,j) = kernel_function(&(MATRIX_AT(this->sample,dim_signal,i,0)), &(MATRIX_AT(this->sample,dim_signal,j,0)), dim_signal);
yukari_hinata 5:792afbb0bcf3 56 // グラム行列は対称
yukari_hinata 5:792afbb0bcf3 57 if ( i != j ) {
yukari_hinata 5:792afbb0bcf3 58 MATRIX_AT(grammat,n_sample,j,i) = MATRIX_AT(grammat,n_sample,i,j);
yukari_hinata 5:792afbb0bcf3 59 }
yukari_hinata 0:3f38e74a4a77 60 }
yukari_hinata 0:3f38e74a4a77 61 }
yukari_hinata 5:792afbb0bcf3 62 */
yukari_hinata 0:3f38e74a4a77 63
yukari_hinata 5:792afbb0bcf3 64 // 学習関連の設定. 例によって経験則
yukari_hinata 5:792afbb0bcf3 65 this->maxIteration = 5000;
yukari_hinata 5:792afbb0bcf3 66 this->epsilon = float(0.00001);
yukari_hinata 5:792afbb0bcf3 67 this->eta = float(0.05);
yukari_hinata 5:792afbb0bcf3 68 this->learn_alpha = float(0.8) * this->eta;
yukari_hinata 5:792afbb0bcf3 69 this->status = SVM_NOT_LEARN;
yukari_hinata 0:3f38e74a4a77 70
yukari_hinata 5:792afbb0bcf3 71 // ソフトマージンの係数. 両方ともFLT_MAXとすることでハードマージンと一致.
yukari_hinata 5:792afbb0bcf3 72 // また, 設定するときはどちらか一方のみにすること.
yukari_hinata 5:792afbb0bcf3 73 C1 = FLT_MAX;
yukari_hinata 5:792afbb0bcf3 74 C2 = 5;
yukari_hinata 0:3f38e74a4a77 75
yukari_hinata 5:792afbb0bcf3 76 srand((unsigned int)time(NULL));
yukari_hinata 0:3f38e74a4a77 77 }
yukari_hinata 0:3f38e74a4a77 78
yukari_hinata 0:3f38e74a4a77 79 // 楽園追放
yukari_hinata 5:792afbb0bcf3 80 SVM::~SVM(void)
yukari_hinata 0:3f38e74a4a77 81 {
yukari_hinata 5:792afbb0bcf3 82 delete [] alpha;
yukari_hinata 5:792afbb0bcf3 83 delete [] label;
yukari_hinata 0:3f38e74a4a77 84 delete [] sample;
yukari_hinata 5:792afbb0bcf3 85 delete [] sample_max;
yukari_hinata 5:792afbb0bcf3 86 delete [] sample_min;
yukari_hinata 0:3f38e74a4a77 87 }
yukari_hinata 0:3f38e74a4a77 88
yukari_hinata 0:3f38e74a4a77 89 // 再急勾配法(サーセンwww)による学習
yukari_hinata 0:3f38e74a4a77 90 int SVM::learning(void)
yukari_hinata 0:3f38e74a4a77 91 {
yukari_hinata 0:3f38e74a4a77 92
yukari_hinata 5:792afbb0bcf3 93 int iteration; // 学習繰り返しカウント
yukari_hinata 0:3f38e74a4a77 94
yukari_hinata 5:792afbb0bcf3 95 float* diff_alpha; // 双対問題の勾配値
yukari_hinata 5:792afbb0bcf3 96 float* pre_diff_alpha; // 双対問題の前回の勾配値(慣性項に用いる)
yukari_hinata 5:792afbb0bcf3 97 float* pre_alpha; // 前回の双対係数ベクトル(収束判定に用いる)
yukari_hinata 5:792afbb0bcf3 98 register float diff_sum; // 勾配計算用の小計
yukari_hinata 5:792afbb0bcf3 99 register float kernel_val; // カーネル関数とC2を含めた項
yukari_hinata 0:3f38e74a4a77 100
yukari_hinata 5:792afbb0bcf3 101 //float plus_sum, minus_sum; // 正例と負例の係数和
yukari_hinata 0:3f38e74a4a77 102
yukari_hinata 5:792afbb0bcf3 103 // 配列(ベクトル)のアロケート
yukari_hinata 5:792afbb0bcf3 104 diff_alpha = new float[n_sample];
yukari_hinata 5:792afbb0bcf3 105 pre_diff_alpha = new float[n_sample];
yukari_hinata 5:792afbb0bcf3 106 pre_alpha = new float[n_sample];
yukari_hinata 0:3f38e74a4a77 107
yukari_hinata 5:792afbb0bcf3 108 status = SVM_NOT_LEARN; // 学習を未完了に
yukari_hinata 5:792afbb0bcf3 109 iteration = 0; // 繰り返し回数を初期化
yukari_hinata 0:3f38e74a4a77 110
yukari_hinata 5:792afbb0bcf3 111 // 双対係数の初期化.乱択
yukari_hinata 5:792afbb0bcf3 112 for (int i = 0; i < n_sample; i++ ) {
yukari_hinata 5:792afbb0bcf3 113 // 欠損データの係数は0にして使用しない
yukari_hinata 5:792afbb0bcf3 114 if ( label[i] == 0 ) {
yukari_hinata 5:792afbb0bcf3 115 alpha[i] = 0;
yukari_hinata 5:792afbb0bcf3 116 continue;
yukari_hinata 0:3f38e74a4a77 117 }
yukari_hinata 5:792afbb0bcf3 118 alpha[i] = uniform_rand(1.0) + 1.0;
yukari_hinata 0:3f38e74a4a77 119 }
yukari_hinata 0:3f38e74a4a77 120
yukari_hinata 5:792afbb0bcf3 121 // 学習ループ
yukari_hinata 5:792afbb0bcf3 122 while ( iteration < maxIteration ) {
yukari_hinata 5:792afbb0bcf3 123
yukari_hinata 5:792afbb0bcf3 124 printf("ite: %d diff_norm : %f alpha_dist : %f \r\n", iteration, two_norm(diff_alpha, n_sample), vec_dist(alpha, pre_alpha, n_sample));
yukari_hinata 5:792afbb0bcf3 125 // 前回の更新値の記録
yukari_hinata 5:792afbb0bcf3 126 memcpy(pre_alpha, alpha, sizeof(float) * n_sample);
yukari_hinata 5:792afbb0bcf3 127 if ( iteration >= 1 ) {
yukari_hinata 5:792afbb0bcf3 128 memcpy(pre_diff_alpha, diff_alpha, sizeof(float) * n_sample);
yukari_hinata 5:792afbb0bcf3 129 } else {
yukari_hinata 5:792afbb0bcf3 130 // 初回は0埋めで初期化
yukari_hinata 5:792afbb0bcf3 131 memset(diff_alpha, 0, sizeof(float) * n_sample);
yukari_hinata 5:792afbb0bcf3 132 memset(pre_diff_alpha, 0, sizeof(float) * n_sample);
yukari_hinata 5:792afbb0bcf3 133 }
yukari_hinata 5:792afbb0bcf3 134
yukari_hinata 5:792afbb0bcf3 135 // 勾配値の計算
yukari_hinata 5:792afbb0bcf3 136 for (int i=0; i < n_sample; i++) {
yukari_hinata 5:792afbb0bcf3 137 diff_sum = 0;
yukari_hinata 5:792afbb0bcf3 138 for (int j=0; j < n_sample; j++) {
yukari_hinata 5:792afbb0bcf3 139 // C2を踏まえたカーネル関数値
yukari_hinata 5:792afbb0bcf3 140 kernel_val = kernel_function(&(MATRIX_AT(sample,dim_signal,i,0)), &(MATRIX_AT(sample,dim_signal,j,0)), dim_signal);
yukari_hinata 5:792afbb0bcf3 141 // kernel_val = MATRIX_AT(grammat,n_sample,i,j); // via Gram matrix
yukari_hinata 5:792afbb0bcf3 142 if (i == j) {
yukari_hinata 5:792afbb0bcf3 143 kernel_val += (1/C2);
yukari_hinata 5:792afbb0bcf3 144 }
yukari_hinata 5:792afbb0bcf3 145 diff_sum += alpha[j] * label[j] * kernel_val;
yukari_hinata 5:792afbb0bcf3 146 }
yukari_hinata 5:792afbb0bcf3 147 diff_sum *= label[i];
yukari_hinata 5:792afbb0bcf3 148 diff_alpha[i] = 1 - diff_sum;
yukari_hinata 5:792afbb0bcf3 149 }
yukari_hinata 5:792afbb0bcf3 150
yukari_hinata 5:792afbb0bcf3 151 // 双対変数の更新
yukari_hinata 5:792afbb0bcf3 152 for (int i=0; i < n_sample; i++) {
yukari_hinata 5:792afbb0bcf3 153 if ( label[i] == 0 ) {
yukari_hinata 5:792afbb0bcf3 154 continue;
yukari_hinata 5:792afbb0bcf3 155 }
yukari_hinata 5:792afbb0bcf3 156 //printf("alpha[%d] : %f -> ", i, alpha[i]);
yukari_hinata 5:792afbb0bcf3 157 alpha[i] = pre_alpha[i]
yukari_hinata 5:792afbb0bcf3 158 + eta * diff_alpha[i]
yukari_hinata 5:792afbb0bcf3 159 + learn_alpha * pre_diff_alpha[i];
yukari_hinata 5:792afbb0bcf3 160 //printf("%f \dim_signal", alpha[i]);
yukari_hinata 0:3f38e74a4a77 161
yukari_hinata 5:792afbb0bcf3 162 // 非数/無限チェック
yukari_hinata 5:792afbb0bcf3 163 if ( isnan(alpha[i]) || isinf(alpha[i]) ) {
yukari_hinata 5:792afbb0bcf3 164 fprintf(stderr, "Detected NaN or Inf Dual-Coffience : pre_alhpa[%d]=%f -> alpha[%d]=%f", i, pre_alpha[i], i, alpha[i]);
yukari_hinata 5:792afbb0bcf3 165 return SVM_DETECT_BAD_VAL;
yukari_hinata 5:792afbb0bcf3 166 }
yukari_hinata 5:792afbb0bcf3 167
yukari_hinata 5:792afbb0bcf3 168 }
yukari_hinata 0:3f38e74a4a77 169
yukari_hinata 5:792afbb0bcf3 170 // 係数の制約条件1:正例と負例の双対係数和を等しくする.
yukari_hinata 5:792afbb0bcf3 171 // 手法:標本平均に寄せる
yukari_hinata 5:792afbb0bcf3 172 float norm_sum = 0;
yukari_hinata 5:792afbb0bcf3 173 for (int i = 0; i < n_sample; i++ ) {
yukari_hinata 5:792afbb0bcf3 174 norm_sum += (label[i] * alpha[i]);
yukari_hinata 5:792afbb0bcf3 175 }
yukari_hinata 5:792afbb0bcf3 176 norm_sum /= n_sample;
yukari_hinata 0:3f38e74a4a77 177
yukari_hinata 5:792afbb0bcf3 178 for (int i = 0; i < n_sample; i++ ) {
yukari_hinata 5:792afbb0bcf3 179 if ( label[i] == 0 ) {
yukari_hinata 5:792afbb0bcf3 180 continue;
yukari_hinata 5:792afbb0bcf3 181 }
yukari_hinata 5:792afbb0bcf3 182 alpha[i] -= (norm_sum / label[i]);
yukari_hinata 5:792afbb0bcf3 183 }
yukari_hinata 0:3f38e74a4a77 184
yukari_hinata 5:792afbb0bcf3 185 // 係数の制約条件2:双対係数は非負
yukari_hinata 5:792afbb0bcf3 186 for (int i = 0; i < n_sample; i++ ) {
yukari_hinata 5:792afbb0bcf3 187 if ( alpha[i] < 0 ) {
yukari_hinata 5:792afbb0bcf3 188 alpha[i] = 0;
yukari_hinata 5:792afbb0bcf3 189 } else if ( alpha[i] > C1 ) {
yukari_hinata 5:792afbb0bcf3 190 // C1を踏まえると,係数の上限はC1となる.
yukari_hinata 5:792afbb0bcf3 191 alpha[i] = C1;
yukari_hinata 5:792afbb0bcf3 192 }
yukari_hinata 5:792afbb0bcf3 193 }
yukari_hinata 5:792afbb0bcf3 194
yukari_hinata 5:792afbb0bcf3 195 // 収束判定 : 凸計画問題なので,収束時は大域最適が
yukari_hinata 5:792afbb0bcf3 196 // 保証されている.
yukari_hinata 5:792afbb0bcf3 197 if ( (vec_dist(alpha, pre_alpha, n_sample) < epsilon)
yukari_hinata 5:792afbb0bcf3 198 || (two_norm(diff_alpha, n_sample) < epsilon) ) {
yukari_hinata 5:792afbb0bcf3 199 // 学習の正常完了
yukari_hinata 5:792afbb0bcf3 200 status = SVM_LEARN_SUCCESS;
yukari_hinata 5:792afbb0bcf3 201 break;
yukari_hinata 5:792afbb0bcf3 202 }
yukari_hinata 5:792afbb0bcf3 203
yukari_hinata 5:792afbb0bcf3 204 // 学習繰り返し回数のインクリメント
yukari_hinata 5:792afbb0bcf3 205 iteration++;
yukari_hinata 0:3f38e74a4a77 206 }
yukari_hinata 0:3f38e74a4a77 207
yukari_hinata 5:792afbb0bcf3 208 if (iteration >= maxIteration) {
yukari_hinata 5:792afbb0bcf3 209 fprintf(stderr, "Learning is not convergenced. (iteration count > maxIteration) \r\n");
yukari_hinata 5:792afbb0bcf3 210 status = SVM_NOT_CONVERGENCED;
yukari_hinata 5:792afbb0bcf3 211 } else if ( status != SVM_LEARN_SUCCESS ) {
yukari_hinata 5:792afbb0bcf3 212 status = SVM_NOT_LEARN;
yukari_hinata 0:3f38e74a4a77 213 }
yukari_hinata 0:3f38e74a4a77 214
yukari_hinata 5:792afbb0bcf3 215 // 領域開放
yukari_hinata 5:792afbb0bcf3 216 delete [] diff_alpha;
yukari_hinata 5:792afbb0bcf3 217 delete [] pre_diff_alpha;
yukari_hinata 5:792afbb0bcf3 218 delete [] pre_alpha;
yukari_hinata 0:3f38e74a4a77 219
yukari_hinata 5:792afbb0bcf3 220 return status;
yukari_hinata 0:3f38e74a4a77 221
yukari_hinata 0:3f38e74a4a77 222 }
yukari_hinata 0:3f38e74a4a77 223
yukari_hinata 0:3f38e74a4a77 224 // 未知データのネットワーク値を計算
yukari_hinata 0:3f38e74a4a77 225 float SVM::predict_net(float* data)
yukari_hinata 0:3f38e74a4a77 226 {
yukari_hinata 5:792afbb0bcf3 227 // 学習の終了を確認
yukari_hinata 5:792afbb0bcf3 228 if (status != SVM_LEARN_SUCCESS && status != SVM_SET_ALPHA) {
yukari_hinata 5:792afbb0bcf3 229 fprintf(stderr, "Learning is not completed yet.");
yukari_hinata 5:792afbb0bcf3 230 //exit(1);
yukari_hinata 5:792afbb0bcf3 231 return SVM_NOT_LEARN;
yukari_hinata 5:792afbb0bcf3 232 }
yukari_hinata 0:3f38e74a4a77 233
yukari_hinata 5:792afbb0bcf3 234 float* norm_data = new float[dim_signal];
yukari_hinata 5:792afbb0bcf3 235
yukari_hinata 5:792afbb0bcf3 236 // 信号の正規化
yukari_hinata 5:792afbb0bcf3 237 for (int i = 0; i < dim_signal; i++) {
yukari_hinata 5:792afbb0bcf3 238 norm_data[i] = ( data[i] - sample_min[i] ) / ( sample_max[i] - sample_min[i] );
yukari_hinata 5:792afbb0bcf3 239 }
yukari_hinata 0:3f38e74a4a77 240
yukari_hinata 5:792afbb0bcf3 241 // ネットワーク値の計算
yukari_hinata 5:792afbb0bcf3 242 float net = 0;
yukari_hinata 5:792afbb0bcf3 243 for (int l=0; l < n_sample; l++) {
yukari_hinata 5:792afbb0bcf3 244 // **係数が正に相当するサンプルはサポートベクトル**
yukari_hinata 5:792afbb0bcf3 245 if(alpha[l] > 0) {
yukari_hinata 5:792afbb0bcf3 246 net += label[l] * alpha[l]
yukari_hinata 5:792afbb0bcf3 247 * kernel_function(&(MATRIX_AT(sample,dim_signal,l,0)), norm_data, dim_signal);
yukari_hinata 5:792afbb0bcf3 248 }
yukari_hinata 5:792afbb0bcf3 249 }
yukari_hinata 5:792afbb0bcf3 250
yukari_hinata 5:792afbb0bcf3 251 delete [] norm_data;
yukari_hinata 0:3f38e74a4a77 252
yukari_hinata 5:792afbb0bcf3 253 return net;
yukari_hinata 0:3f38e74a4a77 254
yukari_hinata 0:3f38e74a4a77 255 }
yukari_hinata 0:3f38e74a4a77 256
yukari_hinata 0:3f38e74a4a77 257 // 未知データの識別確率を計算
yukari_hinata 0:3f38e74a4a77 258 float SVM::predict_probability(float* data)
yukari_hinata 0:3f38e74a4a77 259 {
yukari_hinata 0:3f38e74a4a77 260 float net, probability;
yukari_hinata 0:3f38e74a4a77 261 float* optimal_w = new float[dim_signal]; // 最適時の係数(not 双対係数)
yukari_hinata 0:3f38e74a4a77 262 float sigmoid_param; // シグモイド関数の温度パラメタ
yukari_hinata 0:3f38e74a4a77 263 float norm_w; // 係数の2乗ノルム
yukari_hinata 5:792afbb0bcf3 264
yukari_hinata 0:3f38e74a4a77 265 net = SVM::predict_net(data);
yukari_hinata 5:792afbb0bcf3 266
yukari_hinata 0:3f38e74a4a77 267 // 最適時の係数を計算
yukari_hinata 0:3f38e74a4a77 268 for (int n = 0; n < dim_signal; n++ ) {
yukari_hinata 0:3f38e74a4a77 269 optimal_w[n] = 0;
yukari_hinata 0:3f38e74a4a77 270 for (int l = 0; l < n_sample; l++ ) {
yukari_hinata 0:3f38e74a4a77 271 optimal_w[n] += alpha[l] * label[l] * MATRIX_AT(sample, dim_signal, l, n);
yukari_hinata 0:3f38e74a4a77 272 }
yukari_hinata 0:3f38e74a4a77 273 }
yukari_hinata 0:3f38e74a4a77 274 norm_w = two_norm(optimal_w, dim_signal);
yukari_hinata 0:3f38e74a4a77 275 sigmoid_param = 1 / ( norm_w * logf( (1 - epsilon) / epsilon ) );
yukari_hinata 5:792afbb0bcf3 276
yukari_hinata 0:3f38e74a4a77 277 probability = sigmoid_func(net/sigmoid_param);
yukari_hinata 5:792afbb0bcf3 278
yukari_hinata 0:3f38e74a4a77 279 // 打ち切り:誤差epsilon以内ならば, 1 or 0に打ち切る.
yukari_hinata 0:3f38e74a4a77 280 if ( probability > (1 - epsilon) ) {
yukari_hinata 0:3f38e74a4a77 281 return float(1);
yukari_hinata 0:3f38e74a4a77 282 } else if ( probability < epsilon ) {
yukari_hinata 0:3f38e74a4a77 283 return float(0);
yukari_hinata 0:3f38e74a4a77 284 }
yukari_hinata 0:3f38e74a4a77 285
yukari_hinata 5:792afbb0bcf3 286 delete [] optimal_w;
yukari_hinata 5:792afbb0bcf3 287
yukari_hinata 0:3f38e74a4a77 288 return probability;
yukari_hinata 5:792afbb0bcf3 289
yukari_hinata 0:3f38e74a4a77 290 }
yukari_hinata 0:3f38e74a4a77 291
yukari_hinata 0:3f38e74a4a77 292 // 未知データの識別
yukari_hinata 0:3f38e74a4a77 293 int SVM::predict_label(float* data)
yukari_hinata 0:3f38e74a4a77 294 {
yukari_hinata 5:792afbb0bcf3 295 return (predict_net(data) >= 0) ? 1 : (-1);
yukari_hinata 0:3f38e74a4a77 296 }
yukari_hinata 0:3f38e74a4a77 297
yukari_hinata 0:3f38e74a4a77 298 // 双対係数のゲッター
yukari_hinata 5:792afbb0bcf3 299 float* SVM::get_alpha(void)
yukari_hinata 5:792afbb0bcf3 300 {
yukari_hinata 0:3f38e74a4a77 301 return (float *)alpha;
yukari_hinata 0:3f38e74a4a77 302 }
yukari_hinata 0:3f38e74a4a77 303
yukari_hinata 0:3f38e74a4a77 304 // 双対係数のセッター
yukari_hinata 5:792afbb0bcf3 305 void SVM::set_alpha(float* alpha_data, int nsample)
yukari_hinata 5:792afbb0bcf3 306 {
yukari_hinata 0:3f38e74a4a77 307 if ( nsample != n_sample ) {
yukari_hinata 0:3f38e74a4a77 308 fprintf( stderr, " set_alpha : number of sample isn't match with arg. n_samle= %d, arg= %d \r\n", n_sample, nsample);
yukari_hinata 0:3f38e74a4a77 309 return;
yukari_hinata 0:3f38e74a4a77 310 }
yukari_hinata 0:3f38e74a4a77 311 memcpy(alpha, alpha_data, sizeof(float) * nsample);
yukari_hinata 0:3f38e74a4a77 312 status = SVM_SET_ALPHA;
yukari_hinata 0:3f38e74a4a77 313 }