Easy Support Vector Machine

Dependents:   WeatherPredictor

Committer:
yukari_hinata
Date:
Thu Jan 15 08:22:02 2015 +0000
Revision:
0:3f38e74a4a77
Child:
1:1a0d5152d50b
first commit

Who changed what in which revision?

UserRevisionLine numberNew contents of line
yukari_hinata 0:3f38e74a4a77 1 #include "SVM.hpp"
yukari_hinata 0:3f38e74a4a77 2
yukari_hinata 0:3f38e74a4a77 3 SVM::SVM(int dim_sample, int n_sample, float* sample_data, int* sample_label)
yukari_hinata 0:3f38e74a4a77 4 {
yukari_hinata 0:3f38e74a4a77 5 this->dim_signal = dim_sample;
yukari_hinata 0:3f38e74a4a77 6 this->n_sample = n_sample;
yukari_hinata 0:3f38e74a4a77 7
yukari_hinata 0:3f38e74a4a77 8 // 各配列(ベクトル)のアロケート
yukari_hinata 0:3f38e74a4a77 9 alpha = new float[n_sample];
yukari_hinata 0:3f38e74a4a77 10 // grammat = new float[n_sample * n_sample];
yukari_hinata 0:3f38e74a4a77 11 label = new int[n_sample];
yukari_hinata 0:3f38e74a4a77 12 sample = new float[dim_signal * n_sample];
yukari_hinata 0:3f38e74a4a77 13 sample_max = new float[dim_signal];
yukari_hinata 0:3f38e74a4a77 14 sample_min = new float[dim_signal];
yukari_hinata 0:3f38e74a4a77 15
yukari_hinata 0:3f38e74a4a77 16 // サンプルのコピー
yukari_hinata 0:3f38e74a4a77 17 memcpy(this->sample, sample_data,
yukari_hinata 0:3f38e74a4a77 18 sizeof(float) * dim_signal * n_sample);
yukari_hinata 0:3f38e74a4a77 19 memcpy(this->label, sample_label,
yukari_hinata 0:3f38e74a4a77 20 sizeof(int) * n_sample);
yukari_hinata 0:3f38e74a4a77 21
yukari_hinata 0:3f38e74a4a77 22 // 正規化のための最大最小値
yukari_hinata 0:3f38e74a4a77 23 memset(sample_max, 0, sizeof(float) * dim_sample);
yukari_hinata 0:3f38e74a4a77 24 memset(sample_min, 0, sizeof(float) * dim_sample);
yukari_hinata 0:3f38e74a4a77 25 for (int i = 0; i < dim_signal; i++) {
yukari_hinata 0:3f38e74a4a77 26 float value;
yukari_hinata 0:3f38e74a4a77 27 sample_min[i] = FLT_MAX;
yukari_hinata 0:3f38e74a4a77 28 for (int j = 0; j < n_sample; j++) {
yukari_hinata 0:3f38e74a4a77 29 value = MATRIX_AT(this->sample, dim_signal, j, i);
yukari_hinata 0:3f38e74a4a77 30 if ( value > sample_max[i] ) {
yukari_hinata 0:3f38e74a4a77 31 sample_max[i] = value;
yukari_hinata 0:3f38e74a4a77 32 } else if ( value < sample_min[i] ) {
yukari_hinata 0:3f38e74a4a77 33 //printf("min[%d] : %f -> ", i, sample_min[i]);
yukari_hinata 0:3f38e74a4a77 34 sample_min[i] = value;
yukari_hinata 0:3f38e74a4a77 35 //printf("min[%d] : %f \dim_signal", i, value);
yukari_hinata 0:3f38e74a4a77 36 }
yukari_hinata 0:3f38e74a4a77 37 }
yukari_hinata 0:3f38e74a4a77 38 }
yukari_hinata 0:3f38e74a4a77 39
yukari_hinata 0:3f38e74a4a77 40 // 信号の正規化 : 死ぬほど大事
yukari_hinata 0:3f38e74a4a77 41 for (int i = 0; i < dim_signal; i++) {
yukari_hinata 0:3f38e74a4a77 42 float max,min;
yukari_hinata 0:3f38e74a4a77 43 max = sample_max[i];
yukari_hinata 0:3f38e74a4a77 44 min = sample_min[i];
yukari_hinata 0:3f38e74a4a77 45 for (int j = 0; j < n_sample; j++) {
yukari_hinata 0:3f38e74a4a77 46 //printf("[%d,%d] %f -> ", i, j, MATRIX_AT(this->sample, dim_signal, j, i));
yukari_hinata 0:3f38e74a4a77 47 MATRIX_AT(this->sample, dim_signal, j, i) = ( MATRIX_AT(this->sample, dim_signal, j, i) - min ) / (max - min);
yukari_hinata 0:3f38e74a4a77 48 //printf("%f\dim_signal", MATRIX_AT(this->sample, dim_signal, j, i));
yukari_hinata 0:3f38e74a4a77 49 }
yukari_hinata 0:3f38e74a4a77 50 }
yukari_hinata 0:3f38e74a4a77 51
yukari_hinata 0:3f38e74a4a77 52 /* // グラム行列の計算 : メモリの制約上,廃止
yukari_hinata 0:3f38e74a4a77 53 for (int i = 0; i < n_sample; i++) {
yukari_hinata 0:3f38e74a4a77 54 for (int j = i; j < n_sample; j++) {
yukari_hinata 0:3f38e74a4a77 55 MATRIX_AT(grammat,n_sample,i,j) = kernel_function(&(MATRIX_AT(this->sample,dim_signal,i,0)), &(MATRIX_AT(this->sample,dim_signal,j,0)), dim_signal);
yukari_hinata 0:3f38e74a4a77 56 // グラム行列は対称
yukari_hinata 0:3f38e74a4a77 57 if ( i != j ) {
yukari_hinata 0:3f38e74a4a77 58 MATRIX_AT(grammat,n_sample,j,i) = MATRIX_AT(grammat,n_sample,i,j);
yukari_hinata 0:3f38e74a4a77 59 }
yukari_hinata 0:3f38e74a4a77 60 }
yukari_hinata 0:3f38e74a4a77 61 }
yukari_hinata 0:3f38e74a4a77 62 */
yukari_hinata 0:3f38e74a4a77 63
yukari_hinata 0:3f38e74a4a77 64 // 学習関連の設定. 例によって経験則
yukari_hinata 0:3f38e74a4a77 65 this->maxIteration = 5000;
yukari_hinata 0:3f38e74a4a77 66 this->epsilon = float(0.00001);
yukari_hinata 0:3f38e74a4a77 67 this->eta = float(0.05);
yukari_hinata 0:3f38e74a4a77 68 this->learn_alpha = float(0.8) * this->eta;
yukari_hinata 0:3f38e74a4a77 69 this->status = SVM_NOT_LEARN;
yukari_hinata 0:3f38e74a4a77 70
yukari_hinata 0:3f38e74a4a77 71 // ソフトマージンの係数. 両方ともFLT_MAXとすることでハードマージンと一致.
yukari_hinata 0:3f38e74a4a77 72 // また, 設定するときはどちらか一方のみにすること.
yukari_hinata 0:3f38e74a4a77 73 C1 = FLT_MAX;
yukari_hinata 0:3f38e74a4a77 74 C2 = 10;
yukari_hinata 0:3f38e74a4a77 75
yukari_hinata 0:3f38e74a4a77 76 srand((unsigned int)time(NULL));
yukari_hinata 0:3f38e74a4a77 77
yukari_hinata 0:3f38e74a4a77 78 }
yukari_hinata 0:3f38e74a4a77 79
yukari_hinata 0:3f38e74a4a77 80 // 楽園追放
yukari_hinata 0:3f38e74a4a77 81 SVM::~SVM(void)
yukari_hinata 0:3f38e74a4a77 82 {
yukari_hinata 0:3f38e74a4a77 83 delete [] alpha; delete [] label;
yukari_hinata 0:3f38e74a4a77 84 delete [] sample;
yukari_hinata 0:3f38e74a4a77 85 delete [] sample_max; delete [] sample_min;
yukari_hinata 0:3f38e74a4a77 86 }
yukari_hinata 0:3f38e74a4a77 87
yukari_hinata 0:3f38e74a4a77 88 // 再急勾配法(サーセンwww)による学習
yukari_hinata 0:3f38e74a4a77 89 int SVM::learning(void)
yukari_hinata 0:3f38e74a4a77 90 {
yukari_hinata 0:3f38e74a4a77 91
yukari_hinata 0:3f38e74a4a77 92 int iteration; // 学習繰り返しカウント
yukari_hinata 0:3f38e74a4a77 93
yukari_hinata 0:3f38e74a4a77 94 float* diff_alpha; // 双対問題の勾配値
yukari_hinata 0:3f38e74a4a77 95 float* pre_diff_alpha; // 双対問題の前回の勾配値(慣性項に用いる)
yukari_hinata 0:3f38e74a4a77 96 float* pre_alpha; // 前回の双対係数ベクトル(収束判定に用いる)
yukari_hinata 0:3f38e74a4a77 97 register float diff_sum; // 勾配計算用の小計
yukari_hinata 0:3f38e74a4a77 98 register float kernel_val; // カーネル関数とC2を含めた項
yukari_hinata 0:3f38e74a4a77 99
yukari_hinata 0:3f38e74a4a77 100 //float plus_sum, minus_sum; // 正例と負例の係数和
yukari_hinata 0:3f38e74a4a77 101
yukari_hinata 0:3f38e74a4a77 102 // 配列(ベクトル)のアロケート
yukari_hinata 0:3f38e74a4a77 103 diff_alpha = new float[n_sample];
yukari_hinata 0:3f38e74a4a77 104 pre_diff_alpha = new float[n_sample];
yukari_hinata 0:3f38e74a4a77 105 pre_alpha = new float[n_sample];
yukari_hinata 0:3f38e74a4a77 106
yukari_hinata 0:3f38e74a4a77 107 status = SVM_NOT_LEARN; // 学習を未完了に
yukari_hinata 0:3f38e74a4a77 108 iteration = 0; // 繰り返し回数を初期化
yukari_hinata 0:3f38e74a4a77 109
yukari_hinata 0:3f38e74a4a77 110 // 双対係数の初期化.乱択
yukari_hinata 0:3f38e74a4a77 111 for (int i = 0; i < n_sample; i++ ) {
yukari_hinata 0:3f38e74a4a77 112 // 欠損データの係数は0にして使用しない
yukari_hinata 0:3f38e74a4a77 113 if ( label[i] == 0 ) {
yukari_hinata 0:3f38e74a4a77 114 alpha[i] = 0;
yukari_hinata 0:3f38e74a4a77 115 continue;
yukari_hinata 0:3f38e74a4a77 116 }
yukari_hinata 0:3f38e74a4a77 117 alpha[i] = uniform_rand(1.0) + 1.0;
yukari_hinata 0:3f38e74a4a77 118 }
yukari_hinata 0:3f38e74a4a77 119
yukari_hinata 0:3f38e74a4a77 120 // 学習ループ
yukari_hinata 0:3f38e74a4a77 121 while ( iteration < maxIteration ) {
yukari_hinata 0:3f38e74a4a77 122
yukari_hinata 0:3f38e74a4a77 123 printf("ite: %d diff_norm : %f alpha_dist : %f \r\n", iteration, two_norm(diff_alpha, n_sample), vec_dist(alpha, pre_alpha, n_sample));
yukari_hinata 0:3f38e74a4a77 124 // 前回の更新値の記録
yukari_hinata 0:3f38e74a4a77 125 memcpy(pre_alpha, alpha, sizeof(float) * n_sample);
yukari_hinata 0:3f38e74a4a77 126 if ( iteration >= 1 ) {
yukari_hinata 0:3f38e74a4a77 127 memcpy(pre_diff_alpha, diff_alpha, sizeof(float) * n_sample);
yukari_hinata 0:3f38e74a4a77 128 } else {
yukari_hinata 0:3f38e74a4a77 129 // 初回は0埋めで初期化
yukari_hinata 0:3f38e74a4a77 130 memset(diff_alpha, 0, sizeof(float) * n_sample);
yukari_hinata 0:3f38e74a4a77 131 memset(pre_diff_alpha, 0, sizeof(float) * n_sample);
yukari_hinata 0:3f38e74a4a77 132 }
yukari_hinata 0:3f38e74a4a77 133
yukari_hinata 0:3f38e74a4a77 134 // 勾配値の計算
yukari_hinata 0:3f38e74a4a77 135 for (int i=0; i < n_sample; i++) {
yukari_hinata 0:3f38e74a4a77 136 diff_sum = 0;
yukari_hinata 0:3f38e74a4a77 137 for (int j=0; j < n_sample;j++) {
yukari_hinata 0:3f38e74a4a77 138 // C2を踏まえたカーネル関数値
yukari_hinata 0:3f38e74a4a77 139 kernel_val = kernel_function(&(MATRIX_AT(sample,dim_signal,i,0)), &(MATRIX_AT(sample,dim_signal,j,0)), dim_signal);
yukari_hinata 0:3f38e74a4a77 140 // kernel_val = MATRIX_AT(grammat,n_sample,i,j); // via Gram matrix
yukari_hinata 0:3f38e74a4a77 141 if (i == j) {
yukari_hinata 0:3f38e74a4a77 142 kernel_val += (1/C2);
yukari_hinata 0:3f38e74a4a77 143 }
yukari_hinata 0:3f38e74a4a77 144 diff_sum += alpha[j] * label[j] * kernel_val;
yukari_hinata 0:3f38e74a4a77 145 }
yukari_hinata 0:3f38e74a4a77 146 diff_sum *= label[i];
yukari_hinata 0:3f38e74a4a77 147 diff_alpha[i] = 1 - diff_sum;
yukari_hinata 0:3f38e74a4a77 148 }
yukari_hinata 0:3f38e74a4a77 149
yukari_hinata 0:3f38e74a4a77 150 // 双対変数の更新
yukari_hinata 0:3f38e74a4a77 151 for (int i=0; i < n_sample; i++) {
yukari_hinata 0:3f38e74a4a77 152 if ( label[i] == 0 ) {
yukari_hinata 0:3f38e74a4a77 153 continue;
yukari_hinata 0:3f38e74a4a77 154 }
yukari_hinata 0:3f38e74a4a77 155 //printf("alpha[%d] : %f -> ", i, alpha[i]);
yukari_hinata 0:3f38e74a4a77 156 alpha[i] = pre_alpha[i]
yukari_hinata 0:3f38e74a4a77 157 + eta * diff_alpha[i]
yukari_hinata 0:3f38e74a4a77 158 + learn_alpha * pre_diff_alpha[i];
yukari_hinata 0:3f38e74a4a77 159 //printf("%f \dim_signal", alpha[i]);
yukari_hinata 0:3f38e74a4a77 160
yukari_hinata 0:3f38e74a4a77 161 // 非数/無限チェック
yukari_hinata 0:3f38e74a4a77 162 if ( isnan(alpha[i]) || isinf(alpha[i]) ) {
yukari_hinata 0:3f38e74a4a77 163 fprintf(stderr, "Detected NaN or Inf Dual-Coffience : pre_alhpa[%d]=%f -> alpha[%d]=%f", i, pre_alpha[i], i, alpha[i]);
yukari_hinata 0:3f38e74a4a77 164 return SVM_DETECT_BAD_VAL;
yukari_hinata 0:3f38e74a4a77 165 }
yukari_hinata 0:3f38e74a4a77 166
yukari_hinata 0:3f38e74a4a77 167 }
yukari_hinata 0:3f38e74a4a77 168
yukari_hinata 0:3f38e74a4a77 169 // 係数の制約条件1:正例と負例の双対係数和を等しくする.
yukari_hinata 0:3f38e74a4a77 170 // 手法:標本平均に寄せる
yukari_hinata 0:3f38e74a4a77 171 float norm_sum = 0;
yukari_hinata 0:3f38e74a4a77 172 for (int i = 0; i < n_sample; i++ ) {
yukari_hinata 0:3f38e74a4a77 173 norm_sum += (label[i] * alpha[i]);
yukari_hinata 0:3f38e74a4a77 174 }
yukari_hinata 0:3f38e74a4a77 175 norm_sum /= n_sample;
yukari_hinata 0:3f38e74a4a77 176
yukari_hinata 0:3f38e74a4a77 177 for (int i = 0; i < n_sample; i++ ) {
yukari_hinata 0:3f38e74a4a77 178 if ( label[i] == 0 ) {
yukari_hinata 0:3f38e74a4a77 179 continue;
yukari_hinata 0:3f38e74a4a77 180 }
yukari_hinata 0:3f38e74a4a77 181 alpha[i] -= (norm_sum / label[i]);
yukari_hinata 0:3f38e74a4a77 182 }
yukari_hinata 0:3f38e74a4a77 183
yukari_hinata 0:3f38e74a4a77 184 // 係数の制約条件2:双対係数は非負
yukari_hinata 0:3f38e74a4a77 185 for (int i = 0; i < n_sample; i++ ) {
yukari_hinata 0:3f38e74a4a77 186 if ( alpha[i] < 0 ) {
yukari_hinata 0:3f38e74a4a77 187 alpha[i] = 0;
yukari_hinata 0:3f38e74a4a77 188 } else if ( alpha[i] > C1 ) {
yukari_hinata 0:3f38e74a4a77 189 // C1を踏まえると,係数の上限はC1となる.
yukari_hinata 0:3f38e74a4a77 190 alpha[i] = C1;
yukari_hinata 0:3f38e74a4a77 191 }
yukari_hinata 0:3f38e74a4a77 192 }
yukari_hinata 0:3f38e74a4a77 193
yukari_hinata 0:3f38e74a4a77 194 // 収束判定 : 凸計画問題なので,収束時は大域最適が
yukari_hinata 0:3f38e74a4a77 195 // 保証されている.
yukari_hinata 0:3f38e74a4a77 196 if ( (vec_dist(alpha, pre_alpha, n_sample) < epsilon)
yukari_hinata 0:3f38e74a4a77 197 || (two_norm(diff_alpha, n_sample) < epsilon) ) {
yukari_hinata 0:3f38e74a4a77 198 // 学習の正常完了
yukari_hinata 0:3f38e74a4a77 199 status = SVM_LEARN_SUCCESS;
yukari_hinata 0:3f38e74a4a77 200 break;
yukari_hinata 0:3f38e74a4a77 201 }
yukari_hinata 0:3f38e74a4a77 202
yukari_hinata 0:3f38e74a4a77 203 // 学習繰り返し回数のインクリメント
yukari_hinata 0:3f38e74a4a77 204 iteration++;
yukari_hinata 0:3f38e74a4a77 205 }
yukari_hinata 0:3f38e74a4a77 206
yukari_hinata 0:3f38e74a4a77 207 if (iteration >= maxIteration) {
yukari_hinata 0:3f38e74a4a77 208 fprintf(stderr, "Learning is not convergenced. (iteration count > maxIteration) \n");
yukari_hinata 0:3f38e74a4a77 209 status = SVM_NOT_CONVERGENCED;
yukari_hinata 0:3f38e74a4a77 210 } else if ( status != SVM_LEARN_SUCCESS ) {
yukari_hinata 0:3f38e74a4a77 211 status = SVM_NOT_LEARN;
yukari_hinata 0:3f38e74a4a77 212 }
yukari_hinata 0:3f38e74a4a77 213
yukari_hinata 0:3f38e74a4a77 214 // 領域開放
yukari_hinata 0:3f38e74a4a77 215 delete [] diff_alpha;
yukari_hinata 0:3f38e74a4a77 216 delete [] pre_diff_alpha;
yukari_hinata 0:3f38e74a4a77 217 delete [] pre_alpha;
yukari_hinata 0:3f38e74a4a77 218
yukari_hinata 0:3f38e74a4a77 219 return status;
yukari_hinata 0:3f38e74a4a77 220
yukari_hinata 0:3f38e74a4a77 221 }
yukari_hinata 0:3f38e74a4a77 222
yukari_hinata 0:3f38e74a4a77 223 // 未知データのネットワーク値を計算
yukari_hinata 0:3f38e74a4a77 224 float SVM::predict_net(float* data)
yukari_hinata 0:3f38e74a4a77 225 {
yukari_hinata 0:3f38e74a4a77 226 // 学習の終了を確認
yukari_hinata 0:3f38e74a4a77 227 if (status != SVM_LEARN_SUCCESS || status != SVM_SET_ALPHA) {
yukari_hinata 0:3f38e74a4a77 228 fprintf(stderr, "Learning is not completed yet.");
yukari_hinata 0:3f38e74a4a77 229 //exit(1);
yukari_hinata 0:3f38e74a4a77 230 return SVM_NOT_LEARN;
yukari_hinata 0:3f38e74a4a77 231 }
yukari_hinata 0:3f38e74a4a77 232
yukari_hinata 0:3f38e74a4a77 233 float* norm_data = new float[dim_signal];
yukari_hinata 0:3f38e74a4a77 234
yukari_hinata 0:3f38e74a4a77 235 // 信号の正規化
yukari_hinata 0:3f38e74a4a77 236 for (int i = 0; i < dim_signal; i++) {
yukari_hinata 0:3f38e74a4a77 237 norm_data[i] = ( data[i] - sample_min[i] ) / ( sample_max[i] - sample_min[i] );
yukari_hinata 0:3f38e74a4a77 238 }
yukari_hinata 0:3f38e74a4a77 239
yukari_hinata 0:3f38e74a4a77 240 // ネットワーク値の計算
yukari_hinata 0:3f38e74a4a77 241 float net = 0;
yukari_hinata 0:3f38e74a4a77 242 for (int l=0; l < n_sample; l++) {
yukari_hinata 0:3f38e74a4a77 243 // **係数が正に相当するサンプルはサポートベクトル**
yukari_hinata 0:3f38e74a4a77 244 if(alpha[l] > 0) {
yukari_hinata 0:3f38e74a4a77 245 net += label[l] * alpha[l]
yukari_hinata 0:3f38e74a4a77 246 * kernel_function(&(MATRIX_AT(sample,dim_signal,l,0)), norm_data, dim_signal);
yukari_hinata 0:3f38e74a4a77 247 }
yukari_hinata 0:3f38e74a4a77 248 }
yukari_hinata 0:3f38e74a4a77 249
yukari_hinata 0:3f38e74a4a77 250 return net;
yukari_hinata 0:3f38e74a4a77 251
yukari_hinata 0:3f38e74a4a77 252 }
yukari_hinata 0:3f38e74a4a77 253
yukari_hinata 0:3f38e74a4a77 254 // 未知データの識別確率を計算
yukari_hinata 0:3f38e74a4a77 255 float SVM::predict_probability(float* data)
yukari_hinata 0:3f38e74a4a77 256 {
yukari_hinata 0:3f38e74a4a77 257 float net, probability;
yukari_hinata 0:3f38e74a4a77 258 float* optimal_w = new float[dim_signal]; // 最適時の係数(not 双対係数)
yukari_hinata 0:3f38e74a4a77 259 float sigmoid_param; // シグモイド関数の温度パラメタ
yukari_hinata 0:3f38e74a4a77 260 float norm_w; // 係数の2乗ノルム
yukari_hinata 0:3f38e74a4a77 261
yukari_hinata 0:3f38e74a4a77 262 net = SVM::predict_net(data);
yukari_hinata 0:3f38e74a4a77 263
yukari_hinata 0:3f38e74a4a77 264 // 最適時の係数を計算
yukari_hinata 0:3f38e74a4a77 265 for (int n = 0; n < dim_signal; n++ ) {
yukari_hinata 0:3f38e74a4a77 266 optimal_w[n] = 0;
yukari_hinata 0:3f38e74a4a77 267 for (int l = 0; l < n_sample; l++ ) {
yukari_hinata 0:3f38e74a4a77 268 optimal_w[n] += alpha[l] * label[l] * MATRIX_AT(sample, dim_signal, l, n);
yukari_hinata 0:3f38e74a4a77 269 }
yukari_hinata 0:3f38e74a4a77 270 }
yukari_hinata 0:3f38e74a4a77 271 norm_w = two_norm(optimal_w, dim_signal);
yukari_hinata 0:3f38e74a4a77 272 sigmoid_param = 1 / ( norm_w * logf( (1 - epsilon) / epsilon ) );
yukari_hinata 0:3f38e74a4a77 273
yukari_hinata 0:3f38e74a4a77 274 probability = sigmoid_func(net/sigmoid_param);
yukari_hinata 0:3f38e74a4a77 275
yukari_hinata 0:3f38e74a4a77 276 delete [] optimal_w;
yukari_hinata 0:3f38e74a4a77 277
yukari_hinata 0:3f38e74a4a77 278 // 打ち切り:誤差epsilon以内ならば, 1 or 0に打ち切る.
yukari_hinata 0:3f38e74a4a77 279 if ( probability > (1 - epsilon) ) {
yukari_hinata 0:3f38e74a4a77 280 return float(1);
yukari_hinata 0:3f38e74a4a77 281 } else if ( probability < epsilon ) {
yukari_hinata 0:3f38e74a4a77 282 return float(0);
yukari_hinata 0:3f38e74a4a77 283 }
yukari_hinata 0:3f38e74a4a77 284
yukari_hinata 0:3f38e74a4a77 285 return probability;
yukari_hinata 0:3f38e74a4a77 286
yukari_hinata 0:3f38e74a4a77 287 }
yukari_hinata 0:3f38e74a4a77 288
yukari_hinata 0:3f38e74a4a77 289 // 未知データの識別
yukari_hinata 0:3f38e74a4a77 290 int SVM::predict_label(float* data)
yukari_hinata 0:3f38e74a4a77 291 {
yukari_hinata 0:3f38e74a4a77 292 return (predict_net(data) >= 0) ? 1 : (-1);
yukari_hinata 0:3f38e74a4a77 293 }
yukari_hinata 0:3f38e74a4a77 294
yukari_hinata 0:3f38e74a4a77 295 // 双対係数のゲッター
yukari_hinata 0:3f38e74a4a77 296 float* SVM::get_alpha(void) {
yukari_hinata 0:3f38e74a4a77 297 return (float *)alpha;
yukari_hinata 0:3f38e74a4a77 298 }
yukari_hinata 0:3f38e74a4a77 299
yukari_hinata 0:3f38e74a4a77 300 // 双対係数のセッター
yukari_hinata 0:3f38e74a4a77 301 void SVM::set_alpha(float* alpha_data, int nsample) {
yukari_hinata 0:3f38e74a4a77 302 if ( nsample != n_sample ) {
yukari_hinata 0:3f38e74a4a77 303 fprintf( stderr, " set_alpha : number of sample isn't match with arg. n_samle= %d, arg= %d \r\n", n_sample, nsample);
yukari_hinata 0:3f38e74a4a77 304 return;
yukari_hinata 0:3f38e74a4a77 305 }
yukari_hinata 0:3f38e74a4a77 306 memcpy(alpha, alpha_data, sizeof(float) * nsample);
yukari_hinata 0:3f38e74a4a77 307 status = SVM_SET_ALPHA;
yukari_hinata 0:3f38e74a4a77 308 }