Easy Support Vector Machine

Dependents:   WeatherPredictor

Committer:
yukari_hinata
Date:
Sun Feb 15 09:27:08 2015 +0000
Revision:
2:c4a5251cee32
Parent:
1:1a0d5152d50b
Child:
4:01a20b89db32
modified

Who changed what in which revision?

UserRevisionLine numberNew contents of line
yukari_hinata 0:3f38e74a4a77 1 #include "MCSVM.hpp"
yukari_hinata 0:3f38e74a4a77 2
yukari_hinata 0:3f38e74a4a77 3 // コンストラクタ. 適宜追加予定
yukari_hinata 0:3f38e74a4a77 4 MCSVM::MCSVM(int class_n,
yukari_hinata 0:3f38e74a4a77 5 int sample_dim,
yukari_hinata 0:3f38e74a4a77 6 int sample_n,
yukari_hinata 0:3f38e74a4a77 7 float* sample_data,
yukari_hinata 0:3f38e74a4a77 8 int* sample_mclabel)
yukari_hinata 0:3f38e74a4a77 9 : SVM(sample_dim, sample_n, sample_data, sample_mclabel)
yukari_hinata 0:3f38e74a4a77 10 {
yukari_hinata 0:3f38e74a4a77 11 this->n_class = class_n;
yukari_hinata 0:3f38e74a4a77 12 this->maxFailcount = 5;
yukari_hinata 0:3f38e74a4a77 13
yukari_hinata 0:3f38e74a4a77 14 int n_kC2 = (n_class) * (n_class - 1) / 2;
yukari_hinata 0:3f38e74a4a77 15
yukari_hinata 0:3f38e74a4a77 16 mc_alpha = new float[n_sample * n_kC2];
yukari_hinata 0:3f38e74a4a77 17 mc_label = new int[n_sample * n_kC2];
yukari_hinata 0:3f38e74a4a77 18
yukari_hinata 0:3f38e74a4a77 19 // ラベルの生成
yukari_hinata 0:3f38e74a4a77 20 int tmp_lab;
yukari_hinata 0:3f38e74a4a77 21 for (int ci = 0; ci < n_class; ci++) {
yukari_hinata 0:3f38e74a4a77 22 for (int cj = ci + 1; cj < n_class; cj++) {
yukari_hinata 0:3f38e74a4a77 23 // i < jであることから, ラベルiは負例, ラベルjは正例に割り当てる.
yukari_hinata 0:3f38e74a4a77 24 // いずれのラベルにも該当しないデータは欠損とし,ラベル0とする.
yukari_hinata 0:3f38e74a4a77 25 for (int l=0; l < n_sample; l++) {
yukari_hinata 0:3f38e74a4a77 26 if (this->label[l] == ci) {
yukari_hinata 0:3f38e74a4a77 27 tmp_lab = -1;
yukari_hinata 0:3f38e74a4a77 28 } else if (this->label[l] == cj) {
yukari_hinata 0:3f38e74a4a77 29 tmp_lab = 1;
yukari_hinata 0:3f38e74a4a77 30 } else {
yukari_hinata 0:3f38e74a4a77 31 tmp_lab = 0;
yukari_hinata 0:3f38e74a4a77 32 }
yukari_hinata 0:3f38e74a4a77 33 MATRIX_AT(mc_label,n_sample,INX_KSVM_IJ(n_class,ci,cj),l) = tmp_lab;
yukari_hinata 0:3f38e74a4a77 34 //printf("l %d : %d -> %d \n", l, label[l], tmp_lab);
yukari_hinata 0:3f38e74a4a77 35 }
yukari_hinata 0:3f38e74a4a77 36
yukari_hinata 0:3f38e74a4a77 37 }
yukari_hinata 0:3f38e74a4a77 38 }
yukari_hinata 0:3f38e74a4a77 39
yukari_hinata 0:3f38e74a4a77 40 }
yukari_hinata 0:3f38e74a4a77 41
yukari_hinata 0:3f38e74a4a77 42 // 領域開放
yukari_hinata 0:3f38e74a4a77 43 MCSVM::~MCSVM(void)
yukari_hinata 0:3f38e74a4a77 44 {
yukari_hinata 0:3f38e74a4a77 45 delete [] mc_alpha;
yukari_hinata 0:3f38e74a4a77 46 delete [] mc_label;
yukari_hinata 0:3f38e74a4a77 47 }
yukari_hinata 0:3f38e74a4a77 48
yukari_hinata 0:3f38e74a4a77 49 // 全SVMの学習.
yukari_hinata 0:3f38e74a4a77 50 int MCSVM::learning(void)
yukari_hinata 0:3f38e74a4a77 51 {
yukari_hinata 0:3f38e74a4a77 52 int status, fail_count;
yukari_hinata 0:3f38e74a4a77 53 int* tmp_label = new int[n_sample];
yukari_hinata 0:3f38e74a4a77 54 // 元のラベルを退避
yukari_hinata 0:3f38e74a4a77 55 memcpy(tmp_label,label, sizeof(int) * n_sample);
yukari_hinata 0:3f38e74a4a77 56 for (int ci = 0; ci < n_class; ci++) {
yukari_hinata 0:3f38e74a4a77 57 for (int cj = ci + 1; cj < n_class; cj++) {
yukari_hinata 0:3f38e74a4a77 58 // 2値ラベルを取得する.
yukari_hinata 0:3f38e74a4a77 59 memcpy(label, &(MATRIX_AT(mc_label, n_sample, INX_KSVM_IJ(n_class,ci,cj), 0)), sizeof(int) * n_sample);
yukari_hinata 0:3f38e74a4a77 60 // 学習 - 学習失敗の場合はリトライする.
yukari_hinata 0:3f38e74a4a77 61 fail_count = 0;
yukari_hinata 0:3f38e74a4a77 62 do {
yukari_hinata 0:3f38e74a4a77 63 if (fail_count >= maxFailcount) {
yukari_hinata 0:3f38e74a4a77 64 fprintf(stderr, "Learning failed %d times at %d,%d classifier, give up \r\n", fail_count, ci, cj);
yukari_hinata 0:3f38e74a4a77 65 return MCSVM_NOT_LEARN;
yukari_hinata 0:3f38e74a4a77 66 }
yukari_hinata 0:3f38e74a4a77 67 status = SVM::learning();
yukari_hinata 0:3f38e74a4a77 68
yukari_hinata 0:3f38e74a4a77 69 if ( (status == SVM_NOT_CONVERGENCED)
yukari_hinata 0:3f38e74a4a77 70 || (status == SVM_DETECT_BAD_VAL) ) {
yukari_hinata 0:3f38e74a4a77 71 fail_count++;
yukari_hinata 0:3f38e74a4a77 72 }
yukari_hinata 0:3f38e74a4a77 73 } while (status != SVM_LEARN_SUCCESS);
yukari_hinata 0:3f38e74a4a77 74 // 学習結果の係数とSVラベルの取得
yukari_hinata 0:3f38e74a4a77 75 memcpy(&(MATRIX_AT(mc_alpha, n_sample, INX_KSVM_IJ(n_class,ci,cj), 0)), this->alpha, sizeof(float) * n_sample);
yukari_hinata 0:3f38e74a4a77 76 }
yukari_hinata 0:3f38e74a4a77 77 }
yukari_hinata 0:3f38e74a4a77 78
yukari_hinata 0:3f38e74a4a77 79 // 元のラベルを復帰
yukari_hinata 0:3f38e74a4a77 80 memcpy(this->label, tmp_label, sizeof(int) * n_sample);
yukari_hinata 0:3f38e74a4a77 81 delete [] tmp_label;
yukari_hinata 0:3f38e74a4a77 82 return MCSVM_LEARN_SUCCESS;
yukari_hinata 0:3f38e74a4a77 83
yukari_hinata 0:3f38e74a4a77 84 }
yukari_hinata 0:3f38e74a4a77 85
yukari_hinata 0:3f38e74a4a77 86 // 未知データのラベルを推定する.
yukari_hinata 0:3f38e74a4a77 87 int MCSVM::predict_label(float* data)
yukari_hinata 0:3f38e74a4a77 88 {
yukari_hinata 0:3f38e74a4a77 89 // 単位ステップ関数による決定的識別
yukari_hinata 0:3f38e74a4a77 90 float net;
yukari_hinata 0:3f38e74a4a77 91 int* result_label_count = new int[n_class];
yukari_hinata 2:c4a5251cee32 92
yukari_hinata 2:c4a5251cee32 93 // 元のラベルを退避
yukari_hinata 2:c4a5251cee32 94 int* tmp_label = new int[n_sample];
yukari_hinata 2:c4a5251cee32 95 memcpy(tmp_label,label, sizeof(int) * n_sample);
yukari_hinata 0:3f38e74a4a77 96
yukari_hinata 0:3f38e74a4a77 97 int tmp_ci, tmp_cj;
yukari_hinata 0:3f38e74a4a77 98 memset(result_label_count, 0, sizeof(int) * n_class);
yukari_hinata 0:3f38e74a4a77 99 for (int ci = 0; ci < n_class; ci++) {
yukari_hinata 0:3f38e74a4a77 100 for (int cj = ci + 1; cj < n_class; cj++) {
yukari_hinata 0:3f38e74a4a77 101
yukari_hinata 0:3f38e74a4a77 102 // インデックスをi < jに
yukari_hinata 0:3f38e74a4a77 103 tmp_ci = ci; tmp_cj = cj;
yukari_hinata 0:3f38e74a4a77 104 if ( ci > cj ) {
yukari_hinata 0:3f38e74a4a77 105 tmp_cj = ci; tmp_ci = cj;
yukari_hinata 0:3f38e74a4a77 106 }
yukari_hinata 0:3f38e74a4a77 107 // 係数とラベルを取得し,ci,cjを識別するSVMを構成
yukari_hinata 0:3f38e74a4a77 108 memcpy(this->alpha, &(MATRIX_AT(mc_alpha, n_sample, INX_KSVM_IJ(n_class,tmp_ci,tmp_cj), 0)), sizeof(float) * n_sample);
yukari_hinata 0:3f38e74a4a77 109 memcpy(this->label, &(MATRIX_AT(mc_label, n_sample, INX_KSVM_IJ(n_class,tmp_ci,tmp_cj), 0)), sizeof(float) * n_sample);
yukari_hinata 0:3f38e74a4a77 110
yukari_hinata 0:3f38e74a4a77 111 // 識別:識別されたクラスに投票.
yukari_hinata 0:3f38e74a4a77 112 net = SVM::predict_net(data);
yukari_hinata 0:3f38e74a4a77 113 //printf("ci:%d cj:%d >> net : %f \n", ci, cj, net);
yukari_hinata 0:3f38e74a4a77 114 if ( net < 0 ) {
yukari_hinata 0:3f38e74a4a77 115 result_label_count[ci]++;
yukari_hinata 0:3f38e74a4a77 116 } else if ( net >= 0 ) {
yukari_hinata 0:3f38e74a4a77 117 result_label_count[cj]++;
yukari_hinata 0:3f38e74a4a77 118 }
yukari_hinata 0:3f38e74a4a77 119
yukari_hinata 0:3f38e74a4a77 120 }
yukari_hinata 0:3f38e74a4a77 121 //printf("sum_net[%d] : %f \n", ci, sum_net[ci]);
yukari_hinata 0:3f38e74a4a77 122 }
yukari_hinata 0:3f38e74a4a77 123
yukari_hinata 0:3f38e74a4a77 124 // 判定:最大頻度のクラスに判定する.
yukari_hinata 0:3f38e74a4a77 125 int max,argmax;
yukari_hinata 0:3f38e74a4a77 126 max = 0;
yukari_hinata 0:3f38e74a4a77 127 for (int i = 0; i < n_class; i++) {
yukari_hinata 0:3f38e74a4a77 128 //printf("%d : %d \n", i, result_label_count[i]);
yukari_hinata 0:3f38e74a4a77 129 if ( result_label_count[i] > max ) {
yukari_hinata 0:3f38e74a4a77 130 max = result_label_count[i];
yukari_hinata 0:3f38e74a4a77 131 argmax = i;
yukari_hinata 0:3f38e74a4a77 132 }
yukari_hinata 0:3f38e74a4a77 133 }
yukari_hinata 2:c4a5251cee32 134 // 元のラベルを復帰
yukari_hinata 2:c4a5251cee32 135 memcpy(this->label, tmp_label, sizeof(int) * n_sample);
yukari_hinata 2:c4a5251cee32 136 delete [] tmp_label; delete [] result_label_count;
yukari_hinata 2:c4a5251cee32 137
yukari_hinata 0:3f38e74a4a77 138 return argmax;
yukari_hinata 0:3f38e74a4a77 139
yukari_hinata 0:3f38e74a4a77 140 }
yukari_hinata 0:3f38e74a4a77 141
yukari_hinata 0:3f38e74a4a77 142 // 未知データの識別確率を推定する.
yukari_hinata 0:3f38e74a4a77 143 float MCSVM::predict_probability(float* data)
yukari_hinata 0:3f38e74a4a77 144 {
yukari_hinata 0:3f38e74a4a77 145 // シグモイド関数による確率的識別
yukari_hinata 0:3f38e74a4a77 146 float prob;
yukari_hinata 0:3f38e74a4a77 147 float* result_label_prob = new float[n_class];
yukari_hinata 0:3f38e74a4a77 148 int tmp_ci, tmp_cj;
yukari_hinata 2:c4a5251cee32 149
yukari_hinata 1:1a0d5152d50b 150 memset(result_label_prob, 0, sizeof(float) * n_class);
yukari_hinata 2:c4a5251cee32 151
yukari_hinata 2:c4a5251cee32 152 // 元のラベルを退避
yukari_hinata 2:c4a5251cee32 153 int* tmp_label = new int[n_sample];
yukari_hinata 2:c4a5251cee32 154 memcpy(tmp_label,label, sizeof(int) * n_sample);
yukari_hinata 2:c4a5251cee32 155
yukari_hinata 0:3f38e74a4a77 156 for (int ci = 0; ci < n_class; ci++) {
yukari_hinata 0:3f38e74a4a77 157 for (int cj = ci + 1; cj < n_class; cj++) {
yukari_hinata 0:3f38e74a4a77 158
yukari_hinata 0:3f38e74a4a77 159 // インデックスをci < cjに : 負例はci, 正例はcj
yukari_hinata 0:3f38e74a4a77 160 tmp_ci = ci; tmp_cj = cj;
yukari_hinata 0:3f38e74a4a77 161 if ( ci > cj ) {
yukari_hinata 0:3f38e74a4a77 162 tmp_cj = ci; tmp_ci = cj;
yukari_hinata 0:3f38e74a4a77 163 }
yukari_hinata 0:3f38e74a4a77 164 // 係数とラベルを取得し,ci,cjを識別するSVMを構成
yukari_hinata 0:3f38e74a4a77 165 memcpy(this->alpha, &(MATRIX_AT(mc_alpha, n_sample, INX_KSVM_IJ(n_class,tmp_ci,tmp_cj), 0)), sizeof(float) * n_sample);
yukari_hinata 0:3f38e74a4a77 166 memcpy(this->label, &(MATRIX_AT(mc_label, n_sample, INX_KSVM_IJ(n_class,tmp_ci,tmp_cj), 0)), sizeof(float) * n_sample);
yukari_hinata 0:3f38e74a4a77 167
yukari_hinata 0:3f38e74a4a77 168 // 確率識別:確率の足し上げ
yukari_hinata 0:3f38e74a4a77 169 prob = SVM::predict_probability(data);
yukari_hinata 0:3f38e74a4a77 170 if ( prob > float(0.5) ) {
yukari_hinata 1:1a0d5152d50b 171 result_label_prob[cj] += prob;
yukari_hinata 0:3f38e74a4a77 172 } else {
yukari_hinata 1:1a0d5152d50b 173 result_label_prob[ci] += (1-prob);
yukari_hinata 0:3f38e74a4a77 174 }
yukari_hinata 0:3f38e74a4a77 175
yukari_hinata 0:3f38e74a4a77 176 }
yukari_hinata 0:3f38e74a4a77 177 //printf("sum_net[%d] : %f \n", ci, sum_net[ci]);
yukari_hinata 0:3f38e74a4a77 178 }
yukari_hinata 0:3f38e74a4a77 179
yukari_hinata 0:3f38e74a4a77 180 // 判定:最大確率和
yukari_hinata 0:3f38e74a4a77 181 // おそらくラベル識別との整合性は取れる...はず
yukari_hinata 0:3f38e74a4a77 182 float max = 0;
yukari_hinata 0:3f38e74a4a77 183 for (int i = 0; i < n_class; i++) {
yukari_hinata 0:3f38e74a4a77 184 //printf("%d : %d \n", i, result_label_count[i]);
yukari_hinata 0:3f38e74a4a77 185 if ( result_label_prob[i] > max ) {
yukari_hinata 0:3f38e74a4a77 186 max = result_label_prob[i];
yukari_hinata 0:3f38e74a4a77 187 }
yukari_hinata 0:3f38e74a4a77 188 }
yukari_hinata 2:c4a5251cee32 189
yukari_hinata 2:c4a5251cee32 190 // 元のラベルを復帰
yukari_hinata 2:c4a5251cee32 191 memcpy(this->label, tmp_label, sizeof(int) * n_sample);
yukari_hinata 2:c4a5251cee32 192 delete [] tmp_label; delete [] result_label_prob;
yukari_hinata 2:c4a5251cee32 193
yukari_hinata 0:3f38e74a4a77 194 // 平均確率を返す.
yukari_hinata 1:1a0d5152d50b 195 return (max / (n_class-1));
yukari_hinata 0:3f38e74a4a77 196
yukari_hinata 0:3f38e74a4a77 197 }
yukari_hinata 0:3f38e74a4a77 198
yukari_hinata 0:3f38e74a4a77 199 // override
yukari_hinata 0:3f38e74a4a77 200 float* MCSVM::get_alpha(void) {
yukari_hinata 0:3f38e74a4a77 201 return (float *)mc_alpha;
yukari_hinata 0:3f38e74a4a77 202 }
yukari_hinata 0:3f38e74a4a77 203
yukari_hinata 0:3f38e74a4a77 204 // override
yukari_hinata 0:3f38e74a4a77 205 void MCSVM::set_alpha(float* mcalpha_data, int nsample, int nclass) {
yukari_hinata 0:3f38e74a4a77 206 if ( nsample != n_sample ) {
yukari_hinata 0:3f38e74a4a77 207 fprintf( stderr, " set_alpha : number of sample isn't match : n_samle= %d, arg= %d \r\n", n_sample, nsample);
yukari_hinata 0:3f38e74a4a77 208 return;
yukari_hinata 0:3f38e74a4a77 209 } else if ( nclass != n_class ) {
yukari_hinata 0:3f38e74a4a77 210 fprintf( stderr, " set_alpha : number of class isn't match : n_class= %d, nclass= %d \r\n", n_class, nclass);
yukari_hinata 0:3f38e74a4a77 211 return;
yukari_hinata 0:3f38e74a4a77 212 }
yukari_hinata 0:3f38e74a4a77 213 int nC2 = n_class * (n_class - 1)/2;
yukari_hinata 2:c4a5251cee32 214 memcpy(mc_alpha, mcalpha_data, sizeof(float) * n_sample * nC2);
yukari_hinata 0:3f38e74a4a77 215 status = SVM_SET_ALPHA;
yukari_hinata 0:3f38e74a4a77 216 }