Easy Support Vector Machine

Dependents:   WeatherPredictor

Committer:
yukari_hinata
Date:
Thu Feb 19 19:14:55 2015 +0000
Revision:
6:e7aa8d270f8b
Parent:
5:792afbb0bcf3
(may be its have memory leak...)

Who changed what in which revision?

UserRevisionLine numberNew contents of line
yukari_hinata 0:3f38e74a4a77 1 #include "MCSVM.hpp"
yukari_hinata 0:3f38e74a4a77 2
yukari_hinata 0:3f38e74a4a77 3 // コンストラクタ. 適宜追加予定
yukari_hinata 0:3f38e74a4a77 4 MCSVM::MCSVM(int class_n,
yukari_hinata 0:3f38e74a4a77 5 int sample_dim,
yukari_hinata 5:792afbb0bcf3 6 int sample_n,
yukari_hinata 0:3f38e74a4a77 7 float* sample_data,
yukari_hinata 0:3f38e74a4a77 8 int* sample_mclabel)
yukari_hinata 5:792afbb0bcf3 9 : SVM(sample_dim, sample_n, sample_data, sample_mclabel)
yukari_hinata 0:3f38e74a4a77 10 {
yukari_hinata 5:792afbb0bcf3 11 this->n_class = class_n;
yukari_hinata 5:792afbb0bcf3 12 this->maxFailcount = 5;
yukari_hinata 0:3f38e74a4a77 13
yukari_hinata 5:792afbb0bcf3 14 int n_kC2 = (n_class) * (n_class - 1) / 2;
yukari_hinata 5:792afbb0bcf3 15
yukari_hinata 5:792afbb0bcf3 16 mc_alpha = new float[n_sample * n_kC2];
yukari_hinata 5:792afbb0bcf3 17 mc_label = new int[n_sample * n_kC2];
yukari_hinata 0:3f38e74a4a77 18
yukari_hinata 5:792afbb0bcf3 19 // ラベルの生成
yukari_hinata 5:792afbb0bcf3 20 int tmp_lab;
yukari_hinata 5:792afbb0bcf3 21 for (int ci = 0; ci < n_class; ci++) {
yukari_hinata 5:792afbb0bcf3 22 for (int cj = ci + 1; cj < n_class; cj++) {
yukari_hinata 5:792afbb0bcf3 23 // i < jであることから, ラベルiは負例, ラベルjは正例に割り当てる.
yukari_hinata 5:792afbb0bcf3 24 // いずれのラベルにも該当しないデータは欠損とし,ラベル0とする.
yukari_hinata 5:792afbb0bcf3 25 for (int l=0; l < n_sample; l++) {
yukari_hinata 5:792afbb0bcf3 26 if (this->label[l] == ci) {
yukari_hinata 5:792afbb0bcf3 27 tmp_lab = -1;
yukari_hinata 5:792afbb0bcf3 28 } else if (this->label[l] == cj) {
yukari_hinata 5:792afbb0bcf3 29 tmp_lab = 1;
yukari_hinata 5:792afbb0bcf3 30 } else {
yukari_hinata 5:792afbb0bcf3 31 tmp_lab = 0;
yukari_hinata 5:792afbb0bcf3 32 }
yukari_hinata 5:792afbb0bcf3 33 MATRIX_AT(mc_label,n_sample,INX_KSVM_IJ(n_class,ci,cj),l) = tmp_lab;
yukari_hinata 5:792afbb0bcf3 34 // printf("%d : %d -> %d \r\n", l, label[l], tmp_lab);
yukari_hinata 5:792afbb0bcf3 35 }
yukari_hinata 5:792afbb0bcf3 36
yukari_hinata 0:3f38e74a4a77 37 }
yukari_hinata 0:3f38e74a4a77 38 }
yukari_hinata 0:3f38e74a4a77 39
yukari_hinata 0:3f38e74a4a77 40 }
yukari_hinata 0:3f38e74a4a77 41
yukari_hinata 0:3f38e74a4a77 42 // 領域開放
yukari_hinata 0:3f38e74a4a77 43 MCSVM::~MCSVM(void)
yukari_hinata 0:3f38e74a4a77 44 {
yukari_hinata 0:3f38e74a4a77 45 delete [] mc_alpha;
yukari_hinata 0:3f38e74a4a77 46 delete [] mc_label;
yukari_hinata 0:3f38e74a4a77 47 }
yukari_hinata 0:3f38e74a4a77 48
yukari_hinata 0:3f38e74a4a77 49 // 全SVMの学習.
yukari_hinata 0:3f38e74a4a77 50 int MCSVM::learning(void)
yukari_hinata 0:3f38e74a4a77 51 {
yukari_hinata 5:792afbb0bcf3 52 int status, fail_count;
yukari_hinata 5:792afbb0bcf3 53 int* tmp_label = new int[n_sample];
yukari_hinata 5:792afbb0bcf3 54 // 元のラベルを退避
yukari_hinata 5:792afbb0bcf3 55 memcpy(tmp_label,label, sizeof(int) * n_sample);
yukari_hinata 5:792afbb0bcf3 56 for (int ci = 0; ci < n_class; ci++) {
yukari_hinata 5:792afbb0bcf3 57 for (int cj = ci + 1; cj < n_class; cj++) {
yukari_hinata 5:792afbb0bcf3 58 // 2値ラベルを取得する.
yukari_hinata 5:792afbb0bcf3 59 memcpy(label, &(MATRIX_AT(mc_label, n_sample, INX_KSVM_IJ(n_class,ci,cj), 0)), sizeof(int) * n_sample);
yukari_hinata 5:792afbb0bcf3 60 // 学習 - 学習失敗の場合はリトライする.
yukari_hinata 5:792afbb0bcf3 61 fail_count = 0;
yukari_hinata 5:792afbb0bcf3 62 do {
yukari_hinata 5:792afbb0bcf3 63 if (fail_count >= maxFailcount) {
yukari_hinata 5:792afbb0bcf3 64 fprintf(stderr, "Learning failed %d times at %d,%d classifier, give up \r\n", fail_count, ci, cj);
yukari_hinata 5:792afbb0bcf3 65 return MCSVM_NOT_LEARN;
yukari_hinata 5:792afbb0bcf3 66 }
yukari_hinata 5:792afbb0bcf3 67 status = SVM::learning();
yukari_hinata 5:792afbb0bcf3 68
yukari_hinata 5:792afbb0bcf3 69 if ( (status == SVM_NOT_CONVERGENCED)
yukari_hinata 5:792afbb0bcf3 70 || (status == SVM_DETECT_BAD_VAL) ) {
yukari_hinata 5:792afbb0bcf3 71 fail_count++;
yukari_hinata 5:792afbb0bcf3 72 }
yukari_hinata 5:792afbb0bcf3 73 } while (status != SVM_LEARN_SUCCESS);
yukari_hinata 5:792afbb0bcf3 74 // 学習結果の係数とSVラベルの取得
yukari_hinata 5:792afbb0bcf3 75 memcpy(&(MATRIX_AT(mc_alpha, n_sample, INX_KSVM_IJ(n_class,ci,cj), 0)), this->alpha, sizeof(float) * n_sample);
yukari_hinata 0:3f38e74a4a77 76 }
yukari_hinata 5:792afbb0bcf3 77 }
yukari_hinata 0:3f38e74a4a77 78
yukari_hinata 5:792afbb0bcf3 79 // 元のラベルを復帰
yukari_hinata 5:792afbb0bcf3 80 memcpy(this->label, tmp_label, sizeof(int) * n_sample);
yukari_hinata 5:792afbb0bcf3 81 delete [] tmp_label;
yukari_hinata 5:792afbb0bcf3 82 return MCSVM_LEARN_SUCCESS;
yukari_hinata 0:3f38e74a4a77 83
yukari_hinata 0:3f38e74a4a77 84 }
yukari_hinata 0:3f38e74a4a77 85
yukari_hinata 0:3f38e74a4a77 86 // 未知データのラベルを推定する.
yukari_hinata 5:792afbb0bcf3 87 int MCSVM::predict_label(float* data)
yukari_hinata 0:3f38e74a4a77 88 {
yukari_hinata 5:792afbb0bcf3 89 // 単位ステップ関数による決定的識別
yukari_hinata 5:792afbb0bcf3 90 float net;
yukari_hinata 5:792afbb0bcf3 91 int* result_label_count = new int[n_class];
yukari_hinata 5:792afbb0bcf3 92
yukari_hinata 5:792afbb0bcf3 93 // 元のラベルを退避
yukari_hinata 5:792afbb0bcf3 94 int* tmp_label = new int[n_sample];
yukari_hinata 5:792afbb0bcf3 95 memcpy(tmp_label,label, sizeof(int) * n_sample);
yukari_hinata 0:3f38e74a4a77 96
yukari_hinata 5:792afbb0bcf3 97 int tmp_ci, tmp_cj;
yukari_hinata 5:792afbb0bcf3 98 memset(result_label_count, 0, sizeof(int) * n_class);
yukari_hinata 5:792afbb0bcf3 99 for (int ci = 0; ci < n_class; ci++) {
yukari_hinata 5:792afbb0bcf3 100 for (int cj = ci + 1; cj < n_class; cj++) {
yukari_hinata 0:3f38e74a4a77 101
yukari_hinata 5:792afbb0bcf3 102 // インデックスをi < jに
yukari_hinata 5:792afbb0bcf3 103 tmp_ci = ci;
yukari_hinata 5:792afbb0bcf3 104 tmp_cj = cj;
yukari_hinata 5:792afbb0bcf3 105 if ( ci > cj ) {
yukari_hinata 5:792afbb0bcf3 106 tmp_cj = ci;
yukari_hinata 5:792afbb0bcf3 107 tmp_ci = cj;
yukari_hinata 5:792afbb0bcf3 108 }
yukari_hinata 5:792afbb0bcf3 109 // 係数とラベルを取得し,ci,cjを識別するSVMを構成
yukari_hinata 5:792afbb0bcf3 110 memcpy(this->alpha, &(MATRIX_AT(mc_alpha, n_sample, INX_KSVM_IJ(n_class,tmp_ci,tmp_cj), 0)), sizeof(float) * n_sample);
yukari_hinata 5:792afbb0bcf3 111 memcpy(this->label, &(MATRIX_AT(mc_label, n_sample, INX_KSVM_IJ(n_class,tmp_ci,tmp_cj), 0)), sizeof(float) * n_sample);
yukari_hinata 0:3f38e74a4a77 112
yukari_hinata 5:792afbb0bcf3 113 // 識別:識別されたクラスに投票.
yukari_hinata 5:792afbb0bcf3 114 net = SVM::predict_net(data);
yukari_hinata 5:792afbb0bcf3 115 //printf("ci:%d cj:%d >> net : %f \n", ci, cj, net);
yukari_hinata 5:792afbb0bcf3 116 if ( net < 0 ) {
yukari_hinata 5:792afbb0bcf3 117 result_label_count[ci]++;
yukari_hinata 5:792afbb0bcf3 118 } else if ( net >= 0 ) {
yukari_hinata 5:792afbb0bcf3 119 result_label_count[cj]++;
yukari_hinata 5:792afbb0bcf3 120 }
yukari_hinata 0:3f38e74a4a77 121
yukari_hinata 5:792afbb0bcf3 122 }
yukari_hinata 5:792afbb0bcf3 123 //printf("sum_net[%d] : %f \n", ci, sum_net[ci]);
yukari_hinata 0:3f38e74a4a77 124 }
yukari_hinata 0:3f38e74a4a77 125
yukari_hinata 5:792afbb0bcf3 126 // 判定:最大頻度のクラスに判定する.
yukari_hinata 5:792afbb0bcf3 127 int max,argmax;
yukari_hinata 5:792afbb0bcf3 128 max = 0;
yukari_hinata 5:792afbb0bcf3 129 for (int i = 0; i < n_class; i++) {
yukari_hinata 5:792afbb0bcf3 130 //printf("%d : %d \n", i, result_label_count[i]);
yukari_hinata 5:792afbb0bcf3 131 if ( result_label_count[i] > max ) {
yukari_hinata 5:792afbb0bcf3 132 max = result_label_count[i];
yukari_hinata 5:792afbb0bcf3 133 argmax = i;
yukari_hinata 5:792afbb0bcf3 134 }
yukari_hinata 0:3f38e74a4a77 135 }
yukari_hinata 2:c4a5251cee32 136 // 元のラベルを復帰
yukari_hinata 5:792afbb0bcf3 137 memcpy(this->label, tmp_label, sizeof(int) * n_sample);
yukari_hinata 5:792afbb0bcf3 138 delete [] tmp_label;
yukari_hinata 5:792afbb0bcf3 139 delete [] result_label_count;
yukari_hinata 5:792afbb0bcf3 140
yukari_hinata 5:792afbb0bcf3 141 return argmax;
yukari_hinata 0:3f38e74a4a77 142
yukari_hinata 0:3f38e74a4a77 143 }
yukari_hinata 0:3f38e74a4a77 144
yukari_hinata 0:3f38e74a4a77 145 // 未知データの識別確率を推定する.
yukari_hinata 5:792afbb0bcf3 146 float MCSVM::predict_probability(float* data)
yukari_hinata 0:3f38e74a4a77 147 {
yukari_hinata 5:792afbb0bcf3 148 // シグモイド関数による確率的識別
yukari_hinata 5:792afbb0bcf3 149 float prob;
yukari_hinata 5:792afbb0bcf3 150 float* result_label_prob = new float[n_class];
yukari_hinata 5:792afbb0bcf3 151 int tmp_ci, tmp_cj;
yukari_hinata 5:792afbb0bcf3 152
yukari_hinata 5:792afbb0bcf3 153 memset(result_label_prob, 0, sizeof(float) * n_class);
yukari_hinata 5:792afbb0bcf3 154
yukari_hinata 5:792afbb0bcf3 155 // 元のラベルを退避
yukari_hinata 5:792afbb0bcf3 156 int* tmp_label = new int[n_sample];
yukari_hinata 5:792afbb0bcf3 157 memcpy(tmp_label,label, sizeof(int) * n_sample);
yukari_hinata 2:c4a5251cee32 158
yukari_hinata 5:792afbb0bcf3 159 for (int ci = 0; ci < n_class; ci++) {
yukari_hinata 5:792afbb0bcf3 160 for (int cj = ci + 1; cj < n_class; cj++) {
yukari_hinata 0:3f38e74a4a77 161
yukari_hinata 5:792afbb0bcf3 162 // インデックスをci < cjに : 負例はci, 正例はcj
yukari_hinata 5:792afbb0bcf3 163 tmp_ci = ci;
yukari_hinata 5:792afbb0bcf3 164 tmp_cj = cj;
yukari_hinata 5:792afbb0bcf3 165 if ( ci > cj ) {
yukari_hinata 5:792afbb0bcf3 166 tmp_cj = ci;
yukari_hinata 5:792afbb0bcf3 167 tmp_ci = cj;
yukari_hinata 5:792afbb0bcf3 168 }
yukari_hinata 5:792afbb0bcf3 169 // 係数とラベルを取得し,ci,cjを識別するSVMを構成
yukari_hinata 5:792afbb0bcf3 170 memcpy(this->alpha, &(MATRIX_AT(mc_alpha, n_sample, INX_KSVM_IJ(n_class,tmp_ci,tmp_cj), 0)), sizeof(float) * n_sample);
yukari_hinata 5:792afbb0bcf3 171 memcpy(this->label, &(MATRIX_AT(mc_label, n_sample, INX_KSVM_IJ(n_class,tmp_ci,tmp_cj), 0)), sizeof(float) * n_sample);
yukari_hinata 0:3f38e74a4a77 172
yukari_hinata 5:792afbb0bcf3 173 // 確率識別:確率の足し上げ
yukari_hinata 5:792afbb0bcf3 174 prob = SVM::predict_probability(data);
yukari_hinata 5:792afbb0bcf3 175 if ( prob > float(0.5) ) {
yukari_hinata 5:792afbb0bcf3 176 result_label_prob[cj] += prob;
yukari_hinata 5:792afbb0bcf3 177 } else {
yukari_hinata 5:792afbb0bcf3 178 result_label_prob[ci] += (1-prob);
yukari_hinata 5:792afbb0bcf3 179 }
yukari_hinata 0:3f38e74a4a77 180
yukari_hinata 5:792afbb0bcf3 181 }
yukari_hinata 5:792afbb0bcf3 182 //printf("sum_net[%d] : %f \n", ci, sum_net[ci]);
yukari_hinata 0:3f38e74a4a77 183 }
yukari_hinata 0:3f38e74a4a77 184
yukari_hinata 5:792afbb0bcf3 185 // 判定:最大確率和
yukari_hinata 5:792afbb0bcf3 186 // おそらくラベル識別との整合性は取れる...はず
yukari_hinata 5:792afbb0bcf3 187 float max = 0;
yukari_hinata 5:792afbb0bcf3 188 for (int i = 0; i < n_class; i++) {
yukari_hinata 5:792afbb0bcf3 189 //printf("%d : %d \n", i, result_label_count[i]);
yukari_hinata 5:792afbb0bcf3 190 if ( result_label_prob[i] > max ) {
yukari_hinata 5:792afbb0bcf3 191 max = result_label_prob[i];
yukari_hinata 5:792afbb0bcf3 192 }
yukari_hinata 0:3f38e74a4a77 193 }
yukari_hinata 5:792afbb0bcf3 194
yukari_hinata 5:792afbb0bcf3 195 // 元のラベルを復帰
yukari_hinata 5:792afbb0bcf3 196 memcpy(this->label, tmp_label, sizeof(int) * n_sample);
yukari_hinata 5:792afbb0bcf3 197 delete [] tmp_label;
yukari_hinata 5:792afbb0bcf3 198 delete [] result_label_prob;
yukari_hinata 5:792afbb0bcf3 199
yukari_hinata 5:792afbb0bcf3 200 // 平均確率を返す.
yukari_hinata 5:792afbb0bcf3 201 return (max / (n_class-1));
yukari_hinata 0:3f38e74a4a77 202
yukari_hinata 0:3f38e74a4a77 203 }
yukari_hinata 0:3f38e74a4a77 204
yukari_hinata 0:3f38e74a4a77 205 // override
yukari_hinata 5:792afbb0bcf3 206 float* MCSVM::get_alpha(void)
yukari_hinata 5:792afbb0bcf3 207 {
yukari_hinata 0:3f38e74a4a77 208 return (float *)mc_alpha;
yukari_hinata 0:3f38e74a4a77 209 }
yukari_hinata 0:3f38e74a4a77 210
yukari_hinata 0:3f38e74a4a77 211 // override
yukari_hinata 5:792afbb0bcf3 212 void MCSVM::set_alpha(float* mcalpha_data, int nsample, int nclass)
yukari_hinata 5:792afbb0bcf3 213 {
yukari_hinata 0:3f38e74a4a77 214 if ( nsample != n_sample ) {
yukari_hinata 0:3f38e74a4a77 215 fprintf( stderr, " set_alpha : number of sample isn't match : n_samle= %d, arg= %d \r\n", n_sample, nsample);
yukari_hinata 0:3f38e74a4a77 216 return;
yukari_hinata 0:3f38e74a4a77 217 } else if ( nclass != n_class ) {
yukari_hinata 0:3f38e74a4a77 218 fprintf( stderr, " set_alpha : number of class isn't match : n_class= %d, nclass= %d \r\n", n_class, nclass);
yukari_hinata 0:3f38e74a4a77 219 return;
yukari_hinata 0:3f38e74a4a77 220 }
yukari_hinata 5:792afbb0bcf3 221 memcpy(mc_alpha, mcalpha_data, sizeof(float) * n_sample * n_class * (n_class - 1) / 2);
yukari_hinata 0:3f38e74a4a77 222 status = SVM_SET_ALPHA;
yukari_hinata 0:3f38e74a4a77 223 }