Easy Support Vector Machine

Dependents:   WeatherPredictor

Committer:
yukari_hinata
Date:
Thu Jan 15 08:22:02 2015 +0000
Revision:
0:3f38e74a4a77
Child:
1:1a0d5152d50b
first commit

Who changed what in which revision?

UserRevisionLine numberNew contents of line
yukari_hinata 0:3f38e74a4a77 1 #include "MCSVM.hpp"
yukari_hinata 0:3f38e74a4a77 2
yukari_hinata 0:3f38e74a4a77 3 // コンストラクタ. 適宜追加予定
yukari_hinata 0:3f38e74a4a77 4 MCSVM::MCSVM(int class_n,
yukari_hinata 0:3f38e74a4a77 5 int sample_dim,
yukari_hinata 0:3f38e74a4a77 6 int sample_n,
yukari_hinata 0:3f38e74a4a77 7 float* sample_data,
yukari_hinata 0:3f38e74a4a77 8 int* sample_mclabel)
yukari_hinata 0:3f38e74a4a77 9 : SVM(sample_dim, sample_n, sample_data, sample_mclabel)
yukari_hinata 0:3f38e74a4a77 10 {
yukari_hinata 0:3f38e74a4a77 11 this->n_class = class_n;
yukari_hinata 0:3f38e74a4a77 12 this->maxFailcount = 5;
yukari_hinata 0:3f38e74a4a77 13
yukari_hinata 0:3f38e74a4a77 14 int n_kC2 = (n_class) * (n_class - 1) / 2;
yukari_hinata 0:3f38e74a4a77 15
yukari_hinata 0:3f38e74a4a77 16 mc_alpha = new float[n_sample * n_kC2];
yukari_hinata 0:3f38e74a4a77 17 mc_label = new int[n_sample * n_kC2];
yukari_hinata 0:3f38e74a4a77 18
yukari_hinata 0:3f38e74a4a77 19 // ラベルの生成
yukari_hinata 0:3f38e74a4a77 20 int tmp_lab;
yukari_hinata 0:3f38e74a4a77 21 for (int ci = 0; ci < n_class; ci++) {
yukari_hinata 0:3f38e74a4a77 22 for (int cj = ci + 1; cj < n_class; cj++) {
yukari_hinata 0:3f38e74a4a77 23 // i < jであることから, ラベルiは負例, ラベルjは正例に割り当てる.
yukari_hinata 0:3f38e74a4a77 24 // いずれのラベルにも該当しないデータは欠損とし,ラベル0とする.
yukari_hinata 0:3f38e74a4a77 25 for (int l=0; l < n_sample; l++) {
yukari_hinata 0:3f38e74a4a77 26 if (this->label[l] == ci) {
yukari_hinata 0:3f38e74a4a77 27 tmp_lab = -1;
yukari_hinata 0:3f38e74a4a77 28 } else if (this->label[l] == cj) {
yukari_hinata 0:3f38e74a4a77 29 tmp_lab = 1;
yukari_hinata 0:3f38e74a4a77 30 } else {
yukari_hinata 0:3f38e74a4a77 31 tmp_lab = 0;
yukari_hinata 0:3f38e74a4a77 32 }
yukari_hinata 0:3f38e74a4a77 33 MATRIX_AT(mc_label,n_sample,INX_KSVM_IJ(n_class,ci,cj),l) = tmp_lab;
yukari_hinata 0:3f38e74a4a77 34 //printf("l %d : %d -> %d \n", l, label[l], tmp_lab);
yukari_hinata 0:3f38e74a4a77 35 }
yukari_hinata 0:3f38e74a4a77 36
yukari_hinata 0:3f38e74a4a77 37 }
yukari_hinata 0:3f38e74a4a77 38 }
yukari_hinata 0:3f38e74a4a77 39
yukari_hinata 0:3f38e74a4a77 40 }
yukari_hinata 0:3f38e74a4a77 41
yukari_hinata 0:3f38e74a4a77 42 // 領域開放
yukari_hinata 0:3f38e74a4a77 43 MCSVM::~MCSVM(void)
yukari_hinata 0:3f38e74a4a77 44 {
yukari_hinata 0:3f38e74a4a77 45 delete [] mc_alpha;
yukari_hinata 0:3f38e74a4a77 46 delete [] mc_label;
yukari_hinata 0:3f38e74a4a77 47 }
yukari_hinata 0:3f38e74a4a77 48
yukari_hinata 0:3f38e74a4a77 49 // 全SVMの学習.
yukari_hinata 0:3f38e74a4a77 50 int MCSVM::learning(void)
yukari_hinata 0:3f38e74a4a77 51 {
yukari_hinata 0:3f38e74a4a77 52 int status, fail_count;
yukari_hinata 0:3f38e74a4a77 53 int* tmp_label = new int[n_sample];
yukari_hinata 0:3f38e74a4a77 54 // 元のラベルを退避
yukari_hinata 0:3f38e74a4a77 55 memcpy(tmp_label,label, sizeof(int) * n_sample);
yukari_hinata 0:3f38e74a4a77 56 for (int ci = 0; ci < n_class; ci++) {
yukari_hinata 0:3f38e74a4a77 57 for (int cj = ci + 1; cj < n_class; cj++) {
yukari_hinata 0:3f38e74a4a77 58 // 2値ラベルを取得する.
yukari_hinata 0:3f38e74a4a77 59 memcpy(label, &(MATRIX_AT(mc_label, n_sample, INX_KSVM_IJ(n_class,ci,cj), 0)), sizeof(int) * n_sample);
yukari_hinata 0:3f38e74a4a77 60 // 学習 - 学習失敗の場合はリトライする.
yukari_hinata 0:3f38e74a4a77 61 fail_count = 0;
yukari_hinata 0:3f38e74a4a77 62 do {
yukari_hinata 0:3f38e74a4a77 63 if (fail_count >= maxFailcount) {
yukari_hinata 0:3f38e74a4a77 64 fprintf(stderr, "Learning failed %d times at %d,%d classifier, give up \r\n", fail_count, ci, cj);
yukari_hinata 0:3f38e74a4a77 65 return MCSVM_NOT_LEARN;
yukari_hinata 0:3f38e74a4a77 66 }
yukari_hinata 0:3f38e74a4a77 67 status = SVM::learning();
yukari_hinata 0:3f38e74a4a77 68
yukari_hinata 0:3f38e74a4a77 69 if ( (status == SVM_NOT_CONVERGENCED)
yukari_hinata 0:3f38e74a4a77 70 || (status == SVM_DETECT_BAD_VAL) ) {
yukari_hinata 0:3f38e74a4a77 71 fail_count++;
yukari_hinata 0:3f38e74a4a77 72 }
yukari_hinata 0:3f38e74a4a77 73 } while (status != SVM_LEARN_SUCCESS);
yukari_hinata 0:3f38e74a4a77 74 // 学習結果の係数とSVラベルの取得
yukari_hinata 0:3f38e74a4a77 75 memcpy(&(MATRIX_AT(mc_alpha, n_sample, INX_KSVM_IJ(n_class,ci,cj), 0)), this->alpha, sizeof(float) * n_sample);
yukari_hinata 0:3f38e74a4a77 76 }
yukari_hinata 0:3f38e74a4a77 77 }
yukari_hinata 0:3f38e74a4a77 78
yukari_hinata 0:3f38e74a4a77 79 // 元のラベルを復帰
yukari_hinata 0:3f38e74a4a77 80 memcpy(this->label, tmp_label, sizeof(int) * n_sample);
yukari_hinata 0:3f38e74a4a77 81 delete [] tmp_label;
yukari_hinata 0:3f38e74a4a77 82 return MCSVM_LEARN_SUCCESS;
yukari_hinata 0:3f38e74a4a77 83
yukari_hinata 0:3f38e74a4a77 84 }
yukari_hinata 0:3f38e74a4a77 85
yukari_hinata 0:3f38e74a4a77 86 // 未知データのラベルを推定する.
yukari_hinata 0:3f38e74a4a77 87 int MCSVM::predict_label(float* data)
yukari_hinata 0:3f38e74a4a77 88 {
yukari_hinata 0:3f38e74a4a77 89 // 単位ステップ関数による決定的識別
yukari_hinata 0:3f38e74a4a77 90 float net;
yukari_hinata 0:3f38e74a4a77 91 int* result_label_count = new int[n_class];
yukari_hinata 0:3f38e74a4a77 92
yukari_hinata 0:3f38e74a4a77 93 int tmp_ci, tmp_cj;
yukari_hinata 0:3f38e74a4a77 94 memset(result_label_count, 0, sizeof(int) * n_class);
yukari_hinata 0:3f38e74a4a77 95 for (int ci = 0; ci < n_class; ci++) {
yukari_hinata 0:3f38e74a4a77 96 for (int cj = ci + 1; cj < n_class; cj++) {
yukari_hinata 0:3f38e74a4a77 97
yukari_hinata 0:3f38e74a4a77 98 // インデックスをi < jに
yukari_hinata 0:3f38e74a4a77 99 tmp_ci = ci; tmp_cj = cj;
yukari_hinata 0:3f38e74a4a77 100 if ( ci > cj ) {
yukari_hinata 0:3f38e74a4a77 101 tmp_cj = ci; tmp_ci = cj;
yukari_hinata 0:3f38e74a4a77 102 }
yukari_hinata 0:3f38e74a4a77 103 // 係数とラベルを取得し,ci,cjを識別するSVMを構成
yukari_hinata 0:3f38e74a4a77 104 memcpy(this->alpha, &(MATRIX_AT(mc_alpha, n_sample, INX_KSVM_IJ(n_class,tmp_ci,tmp_cj), 0)), sizeof(float) * n_sample);
yukari_hinata 0:3f38e74a4a77 105 memcpy(this->label, &(MATRIX_AT(mc_label, n_sample, INX_KSVM_IJ(n_class,tmp_ci,tmp_cj), 0)), sizeof(float) * n_sample);
yukari_hinata 0:3f38e74a4a77 106
yukari_hinata 0:3f38e74a4a77 107 // 識別:識別されたクラスに投票.
yukari_hinata 0:3f38e74a4a77 108 net = SVM::predict_net(data);
yukari_hinata 0:3f38e74a4a77 109 //printf("ci:%d cj:%d >> net : %f \n", ci, cj, net);
yukari_hinata 0:3f38e74a4a77 110 if ( net < 0 ) {
yukari_hinata 0:3f38e74a4a77 111 result_label_count[ci]++;
yukari_hinata 0:3f38e74a4a77 112 } else if ( net >= 0 ) {
yukari_hinata 0:3f38e74a4a77 113 result_label_count[cj]++;
yukari_hinata 0:3f38e74a4a77 114 }
yukari_hinata 0:3f38e74a4a77 115
yukari_hinata 0:3f38e74a4a77 116 }
yukari_hinata 0:3f38e74a4a77 117 //printf("sum_net[%d] : %f \n", ci, sum_net[ci]);
yukari_hinata 0:3f38e74a4a77 118 }
yukari_hinata 0:3f38e74a4a77 119
yukari_hinata 0:3f38e74a4a77 120 // 判定:最大頻度のクラスに判定する.
yukari_hinata 0:3f38e74a4a77 121 int max,argmax;
yukari_hinata 0:3f38e74a4a77 122 max = 0;
yukari_hinata 0:3f38e74a4a77 123 for (int i = 0; i < n_class; i++) {
yukari_hinata 0:3f38e74a4a77 124 //printf("%d : %d \n", i, result_label_count[i]);
yukari_hinata 0:3f38e74a4a77 125 if ( result_label_count[i] > max ) {
yukari_hinata 0:3f38e74a4a77 126 max = result_label_count[i];
yukari_hinata 0:3f38e74a4a77 127 argmax = i;
yukari_hinata 0:3f38e74a4a77 128 }
yukari_hinata 0:3f38e74a4a77 129 }
yukari_hinata 0:3f38e74a4a77 130
yukari_hinata 0:3f38e74a4a77 131 delete [] result_label_count;
yukari_hinata 0:3f38e74a4a77 132 return argmax;
yukari_hinata 0:3f38e74a4a77 133
yukari_hinata 0:3f38e74a4a77 134 }
yukari_hinata 0:3f38e74a4a77 135
yukari_hinata 0:3f38e74a4a77 136 // 未知データの識別確率を推定する.
yukari_hinata 0:3f38e74a4a77 137 float MCSVM::predict_probability(float* data)
yukari_hinata 0:3f38e74a4a77 138 {
yukari_hinata 0:3f38e74a4a77 139 // シグモイド関数による確率的識別
yukari_hinata 0:3f38e74a4a77 140 float prob;
yukari_hinata 0:3f38e74a4a77 141 float* result_label_prob = new float[n_class];
yukari_hinata 0:3f38e74a4a77 142 int tmp_ci, tmp_cj;
yukari_hinata 0:3f38e74a4a77 143 memset(result_label_count, 0, sizeof(int) * n_class);
yukari_hinata 0:3f38e74a4a77 144 for (int ci = 0; ci < n_class; ci++) {
yukari_hinata 0:3f38e74a4a77 145 for (int cj = ci + 1; cj < n_class; cj++) {
yukari_hinata 0:3f38e74a4a77 146
yukari_hinata 0:3f38e74a4a77 147 // インデックスをci < cjに : 負例はci, 正例はcj
yukari_hinata 0:3f38e74a4a77 148 tmp_ci = ci; tmp_cj = cj;
yukari_hinata 0:3f38e74a4a77 149 if ( ci > cj ) {
yukari_hinata 0:3f38e74a4a77 150 tmp_cj = ci; tmp_ci = cj;
yukari_hinata 0:3f38e74a4a77 151 }
yukari_hinata 0:3f38e74a4a77 152 // 係数とラベルを取得し,ci,cjを識別するSVMを構成
yukari_hinata 0:3f38e74a4a77 153 memcpy(this->alpha, &(MATRIX_AT(mc_alpha, n_sample, INX_KSVM_IJ(n_class,tmp_ci,tmp_cj), 0)), sizeof(float) * n_sample);
yukari_hinata 0:3f38e74a4a77 154 memcpy(this->label, &(MATRIX_AT(mc_label, n_sample, INX_KSVM_IJ(n_class,tmp_ci,tmp_cj), 0)), sizeof(float) * n_sample);
yukari_hinata 0:3f38e74a4a77 155
yukari_hinata 0:3f38e74a4a77 156 // 確率識別:確率の足し上げ
yukari_hinata 0:3f38e74a4a77 157 prob = SVM::predict_probability(data);
yukari_hinata 0:3f38e74a4a77 158 if ( prob > float(0.5) ) {
yukari_hinata 0:3f38e74a4a77 159 result_label_count[cj] += prob;
yukari_hinata 0:3f38e74a4a77 160 } else {
yukari_hinata 0:3f38e74a4a77 161 result_label_count[ci] += (1-prob);
yukari_hinata 0:3f38e74a4a77 162 }
yukari_hinata 0:3f38e74a4a77 163
yukari_hinata 0:3f38e74a4a77 164 }
yukari_hinata 0:3f38e74a4a77 165 //printf("sum_net[%d] : %f \n", ci, sum_net[ci]);
yukari_hinata 0:3f38e74a4a77 166 }
yukari_hinata 0:3f38e74a4a77 167
yukari_hinata 0:3f38e74a4a77 168 // 判定:最大確率和
yukari_hinata 0:3f38e74a4a77 169 // おそらくラベル識別との整合性は取れる...はず
yukari_hinata 0:3f38e74a4a77 170 float max = 0;
yukari_hinata 0:3f38e74a4a77 171 for (int i = 0; i < n_class; i++) {
yukari_hinata 0:3f38e74a4a77 172 //printf("%d : %d \n", i, result_label_count[i]);
yukari_hinata 0:3f38e74a4a77 173 if ( result_label_prob[i] > max ) {
yukari_hinata 0:3f38e74a4a77 174 max = result_label_prob[i];
yukari_hinata 0:3f38e74a4a77 175 }
yukari_hinata 0:3f38e74a4a77 176 }
yukari_hinata 0:3f38e74a4a77 177
yukari_hinata 0:3f38e74a4a77 178 delete [] result_label_prob;
yukari_hinata 0:3f38e74a4a77 179 // 平均確率を返す.
yukari_hinata 0:3f38e74a4a77 180 return (max / n_class);
yukari_hinata 0:3f38e74a4a77 181
yukari_hinata 0:3f38e74a4a77 182 }
yukari_hinata 0:3f38e74a4a77 183
yukari_hinata 0:3f38e74a4a77 184 // override
yukari_hinata 0:3f38e74a4a77 185 float* MCSVM::get_alpha(void) {
yukari_hinata 0:3f38e74a4a77 186 return (float *)mc_alpha;
yukari_hinata 0:3f38e74a4a77 187 }
yukari_hinata 0:3f38e74a4a77 188
yukari_hinata 0:3f38e74a4a77 189 // override
yukari_hinata 0:3f38e74a4a77 190 void MCSVM::set_alpha(float* mcalpha_data, int nsample, int nclass) {
yukari_hinata 0:3f38e74a4a77 191 if ( nsample != n_sample ) {
yukari_hinata 0:3f38e74a4a77 192 fprintf( stderr, " set_alpha : number of sample isn't match : n_samle= %d, arg= %d \r\n", n_sample, nsample);
yukari_hinata 0:3f38e74a4a77 193 return;
yukari_hinata 0:3f38e74a4a77 194 } else if ( nclass != n_class ) {
yukari_hinata 0:3f38e74a4a77 195 fprintf( stderr, " set_alpha : number of class isn't match : n_class= %d, nclass= %d \r\n", n_class, nclass);
yukari_hinata 0:3f38e74a4a77 196 return;
yukari_hinata 0:3f38e74a4a77 197 }
yukari_hinata 0:3f38e74a4a77 198 int nC2 = n_class * (n_class - 1)/2;
yukari_hinata 0:3f38e74a4a77 199 memcpy(mc_alpha, mcalpha_data, sizeof(float) * nsample * (n_class * (n_class - 1) / 2));
yukari_hinata 0:3f38e74a4a77 200 status = SVM_SET_ALPHA;
yukari_hinata 0:3f38e74a4a77 201 }