Easy Support Vector Machine
MCSVM.cpp@5:792afbb0bcf3, 2015-02-18 (annotated)
- Committer:
- yukari_hinata
- Date:
- Wed Feb 18 15:01:12 2015 +0000
- Revision:
- 5:792afbb0bcf3
- Parent:
- 4:01a20b89db32
Who changed what in which revision?
User | Revision | Line number | New contents of line |
---|---|---|---|
yukari_hinata | 0:3f38e74a4a77 | 1 | #include "MCSVM.hpp" |
yukari_hinata | 0:3f38e74a4a77 | 2 | |
yukari_hinata | 0:3f38e74a4a77 | 3 | // コンストラクタ. 適宜追加予定 |
yukari_hinata | 0:3f38e74a4a77 | 4 | MCSVM::MCSVM(int class_n, |
yukari_hinata | 0:3f38e74a4a77 | 5 | int sample_dim, |
yukari_hinata | 5:792afbb0bcf3 | 6 | int sample_n, |
yukari_hinata | 0:3f38e74a4a77 | 7 | float* sample_data, |
yukari_hinata | 0:3f38e74a4a77 | 8 | int* sample_mclabel) |
yukari_hinata | 5:792afbb0bcf3 | 9 | : SVM(sample_dim, sample_n, sample_data, sample_mclabel) |
yukari_hinata | 0:3f38e74a4a77 | 10 | { |
yukari_hinata | 5:792afbb0bcf3 | 11 | this->n_class = class_n; |
yukari_hinata | 5:792afbb0bcf3 | 12 | this->maxFailcount = 5; |
yukari_hinata | 0:3f38e74a4a77 | 13 | |
yukari_hinata | 5:792afbb0bcf3 | 14 | int n_kC2 = (n_class) * (n_class - 1) / 2; |
yukari_hinata | 5:792afbb0bcf3 | 15 | |
yukari_hinata | 5:792afbb0bcf3 | 16 | mc_alpha = new float[n_sample * n_kC2]; |
yukari_hinata | 5:792afbb0bcf3 | 17 | mc_label = new int[n_sample * n_kC2]; |
yukari_hinata | 0:3f38e74a4a77 | 18 | |
yukari_hinata | 5:792afbb0bcf3 | 19 | // ラベルの生成 |
yukari_hinata | 5:792afbb0bcf3 | 20 | int tmp_lab; |
yukari_hinata | 5:792afbb0bcf3 | 21 | for (int ci = 0; ci < n_class; ci++) { |
yukari_hinata | 5:792afbb0bcf3 | 22 | for (int cj = ci + 1; cj < n_class; cj++) { |
yukari_hinata | 5:792afbb0bcf3 | 23 | // i < jであることから, ラベルiは負例, ラベルjは正例に割り当てる. |
yukari_hinata | 5:792afbb0bcf3 | 24 | // いずれのラベルにも該当しないデータは欠損とし,ラベル0とする. |
yukari_hinata | 5:792afbb0bcf3 | 25 | for (int l=0; l < n_sample; l++) { |
yukari_hinata | 5:792afbb0bcf3 | 26 | if (this->label[l] == ci) { |
yukari_hinata | 5:792afbb0bcf3 | 27 | tmp_lab = -1; |
yukari_hinata | 5:792afbb0bcf3 | 28 | } else if (this->label[l] == cj) { |
yukari_hinata | 5:792afbb0bcf3 | 29 | tmp_lab = 1; |
yukari_hinata | 5:792afbb0bcf3 | 30 | } else { |
yukari_hinata | 5:792afbb0bcf3 | 31 | tmp_lab = 0; |
yukari_hinata | 5:792afbb0bcf3 | 32 | } |
yukari_hinata | 5:792afbb0bcf3 | 33 | MATRIX_AT(mc_label,n_sample,INX_KSVM_IJ(n_class,ci,cj),l) = tmp_lab; |
yukari_hinata | 5:792afbb0bcf3 | 34 | // printf("%d : %d -> %d \r\n", l, label[l], tmp_lab); |
yukari_hinata | 5:792afbb0bcf3 | 35 | } |
yukari_hinata | 5:792afbb0bcf3 | 36 | |
yukari_hinata | 0:3f38e74a4a77 | 37 | } |
yukari_hinata | 0:3f38e74a4a77 | 38 | } |
yukari_hinata | 0:3f38e74a4a77 | 39 | |
yukari_hinata | 0:3f38e74a4a77 | 40 | } |
yukari_hinata | 0:3f38e74a4a77 | 41 | |
yukari_hinata | 0:3f38e74a4a77 | 42 | // 領域開放 |
yukari_hinata | 0:3f38e74a4a77 | 43 | MCSVM::~MCSVM(void) |
yukari_hinata | 0:3f38e74a4a77 | 44 | { |
yukari_hinata | 0:3f38e74a4a77 | 45 | delete [] mc_alpha; |
yukari_hinata | 0:3f38e74a4a77 | 46 | delete [] mc_label; |
yukari_hinata | 0:3f38e74a4a77 | 47 | } |
yukari_hinata | 0:3f38e74a4a77 | 48 | |
yukari_hinata | 0:3f38e74a4a77 | 49 | // 全SVMの学習. |
yukari_hinata | 0:3f38e74a4a77 | 50 | int MCSVM::learning(void) |
yukari_hinata | 0:3f38e74a4a77 | 51 | { |
yukari_hinata | 5:792afbb0bcf3 | 52 | int status, fail_count; |
yukari_hinata | 5:792afbb0bcf3 | 53 | int* tmp_label = new int[n_sample]; |
yukari_hinata | 5:792afbb0bcf3 | 54 | // 元のラベルを退避 |
yukari_hinata | 5:792afbb0bcf3 | 55 | memcpy(tmp_label,label, sizeof(int) * n_sample); |
yukari_hinata | 5:792afbb0bcf3 | 56 | for (int ci = 0; ci < n_class; ci++) { |
yukari_hinata | 5:792afbb0bcf3 | 57 | for (int cj = ci + 1; cj < n_class; cj++) { |
yukari_hinata | 5:792afbb0bcf3 | 58 | // 2値ラベルを取得する. |
yukari_hinata | 5:792afbb0bcf3 | 59 | memcpy(label, &(MATRIX_AT(mc_label, n_sample, INX_KSVM_IJ(n_class,ci,cj), 0)), sizeof(int) * n_sample); |
yukari_hinata | 5:792afbb0bcf3 | 60 | // 学習 - 学習失敗の場合はリトライする. |
yukari_hinata | 5:792afbb0bcf3 | 61 | fail_count = 0; |
yukari_hinata | 5:792afbb0bcf3 | 62 | do { |
yukari_hinata | 5:792afbb0bcf3 | 63 | if (fail_count >= maxFailcount) { |
yukari_hinata | 5:792afbb0bcf3 | 64 | fprintf(stderr, "Learning failed %d times at %d,%d classifier, give up \r\n", fail_count, ci, cj); |
yukari_hinata | 5:792afbb0bcf3 | 65 | return MCSVM_NOT_LEARN; |
yukari_hinata | 5:792afbb0bcf3 | 66 | } |
yukari_hinata | 5:792afbb0bcf3 | 67 | status = SVM::learning(); |
yukari_hinata | 5:792afbb0bcf3 | 68 | |
yukari_hinata | 5:792afbb0bcf3 | 69 | if ( (status == SVM_NOT_CONVERGENCED) |
yukari_hinata | 5:792afbb0bcf3 | 70 | || (status == SVM_DETECT_BAD_VAL) ) { |
yukari_hinata | 5:792afbb0bcf3 | 71 | fail_count++; |
yukari_hinata | 5:792afbb0bcf3 | 72 | } |
yukari_hinata | 5:792afbb0bcf3 | 73 | } while (status != SVM_LEARN_SUCCESS); |
yukari_hinata | 5:792afbb0bcf3 | 74 | // 学習結果の係数とSVラベルの取得 |
yukari_hinata | 5:792afbb0bcf3 | 75 | memcpy(&(MATRIX_AT(mc_alpha, n_sample, INX_KSVM_IJ(n_class,ci,cj), 0)), this->alpha, sizeof(float) * n_sample); |
yukari_hinata | 0:3f38e74a4a77 | 76 | } |
yukari_hinata | 5:792afbb0bcf3 | 77 | } |
yukari_hinata | 0:3f38e74a4a77 | 78 | |
yukari_hinata | 5:792afbb0bcf3 | 79 | // 元のラベルを復帰 |
yukari_hinata | 5:792afbb0bcf3 | 80 | memcpy(this->label, tmp_label, sizeof(int) * n_sample); |
yukari_hinata | 5:792afbb0bcf3 | 81 | delete [] tmp_label; |
yukari_hinata | 5:792afbb0bcf3 | 82 | return MCSVM_LEARN_SUCCESS; |
yukari_hinata | 0:3f38e74a4a77 | 83 | |
yukari_hinata | 0:3f38e74a4a77 | 84 | } |
yukari_hinata | 0:3f38e74a4a77 | 85 | |
yukari_hinata | 0:3f38e74a4a77 | 86 | // 未知データのラベルを推定する. |
yukari_hinata | 5:792afbb0bcf3 | 87 | int MCSVM::predict_label(float* data) |
yukari_hinata | 0:3f38e74a4a77 | 88 | { |
yukari_hinata | 5:792afbb0bcf3 | 89 | // 単位ステップ関数による決定的識別 |
yukari_hinata | 5:792afbb0bcf3 | 90 | float net; |
yukari_hinata | 5:792afbb0bcf3 | 91 | int* result_label_count = new int[n_class]; |
yukari_hinata | 5:792afbb0bcf3 | 92 | |
yukari_hinata | 5:792afbb0bcf3 | 93 | // 元のラベルを退避 |
yukari_hinata | 5:792afbb0bcf3 | 94 | int* tmp_label = new int[n_sample]; |
yukari_hinata | 5:792afbb0bcf3 | 95 | memcpy(tmp_label,label, sizeof(int) * n_sample); |
yukari_hinata | 0:3f38e74a4a77 | 96 | |
yukari_hinata | 5:792afbb0bcf3 | 97 | int tmp_ci, tmp_cj; |
yukari_hinata | 5:792afbb0bcf3 | 98 | memset(result_label_count, 0, sizeof(int) * n_class); |
yukari_hinata | 5:792afbb0bcf3 | 99 | for (int ci = 0; ci < n_class; ci++) { |
yukari_hinata | 5:792afbb0bcf3 | 100 | for (int cj = ci + 1; cj < n_class; cj++) { |
yukari_hinata | 0:3f38e74a4a77 | 101 | |
yukari_hinata | 5:792afbb0bcf3 | 102 | // インデックスをi < jに |
yukari_hinata | 5:792afbb0bcf3 | 103 | tmp_ci = ci; |
yukari_hinata | 5:792afbb0bcf3 | 104 | tmp_cj = cj; |
yukari_hinata | 5:792afbb0bcf3 | 105 | if ( ci > cj ) { |
yukari_hinata | 5:792afbb0bcf3 | 106 | tmp_cj = ci; |
yukari_hinata | 5:792afbb0bcf3 | 107 | tmp_ci = cj; |
yukari_hinata | 5:792afbb0bcf3 | 108 | } |
yukari_hinata | 5:792afbb0bcf3 | 109 | // 係数とラベルを取得し,ci,cjを識別するSVMを構成 |
yukari_hinata | 5:792afbb0bcf3 | 110 | memcpy(this->alpha, &(MATRIX_AT(mc_alpha, n_sample, INX_KSVM_IJ(n_class,tmp_ci,tmp_cj), 0)), sizeof(float) * n_sample); |
yukari_hinata | 5:792afbb0bcf3 | 111 | memcpy(this->label, &(MATRIX_AT(mc_label, n_sample, INX_KSVM_IJ(n_class,tmp_ci,tmp_cj), 0)), sizeof(float) * n_sample); |
yukari_hinata | 0:3f38e74a4a77 | 112 | |
yukari_hinata | 5:792afbb0bcf3 | 113 | // 識別:識別されたクラスに投票. |
yukari_hinata | 5:792afbb0bcf3 | 114 | net = SVM::predict_net(data); |
yukari_hinata | 5:792afbb0bcf3 | 115 | //printf("ci:%d cj:%d >> net : %f \n", ci, cj, net); |
yukari_hinata | 5:792afbb0bcf3 | 116 | if ( net < 0 ) { |
yukari_hinata | 5:792afbb0bcf3 | 117 | result_label_count[ci]++; |
yukari_hinata | 5:792afbb0bcf3 | 118 | } else if ( net >= 0 ) { |
yukari_hinata | 5:792afbb0bcf3 | 119 | result_label_count[cj]++; |
yukari_hinata | 5:792afbb0bcf3 | 120 | } |
yukari_hinata | 0:3f38e74a4a77 | 121 | |
yukari_hinata | 5:792afbb0bcf3 | 122 | } |
yukari_hinata | 5:792afbb0bcf3 | 123 | //printf("sum_net[%d] : %f \n", ci, sum_net[ci]); |
yukari_hinata | 0:3f38e74a4a77 | 124 | } |
yukari_hinata | 0:3f38e74a4a77 | 125 | |
yukari_hinata | 5:792afbb0bcf3 | 126 | // 判定:最大頻度のクラスに判定する. |
yukari_hinata | 5:792afbb0bcf3 | 127 | int max,argmax; |
yukari_hinata | 5:792afbb0bcf3 | 128 | max = 0; |
yukari_hinata | 5:792afbb0bcf3 | 129 | for (int i = 0; i < n_class; i++) { |
yukari_hinata | 5:792afbb0bcf3 | 130 | //printf("%d : %d \n", i, result_label_count[i]); |
yukari_hinata | 5:792afbb0bcf3 | 131 | if ( result_label_count[i] > max ) { |
yukari_hinata | 5:792afbb0bcf3 | 132 | max = result_label_count[i]; |
yukari_hinata | 5:792afbb0bcf3 | 133 | argmax = i; |
yukari_hinata | 5:792afbb0bcf3 | 134 | } |
yukari_hinata | 0:3f38e74a4a77 | 135 | } |
yukari_hinata | 2:c4a5251cee32 | 136 | // 元のラベルを復帰 |
yukari_hinata | 5:792afbb0bcf3 | 137 | memcpy(this->label, tmp_label, sizeof(int) * n_sample); |
yukari_hinata | 5:792afbb0bcf3 | 138 | delete [] tmp_label; |
yukari_hinata | 5:792afbb0bcf3 | 139 | delete [] result_label_count; |
yukari_hinata | 5:792afbb0bcf3 | 140 | |
yukari_hinata | 5:792afbb0bcf3 | 141 | return argmax; |
yukari_hinata | 0:3f38e74a4a77 | 142 | |
yukari_hinata | 0:3f38e74a4a77 | 143 | } |
yukari_hinata | 0:3f38e74a4a77 | 144 | |
yukari_hinata | 0:3f38e74a4a77 | 145 | // 未知データの識別確率を推定する. |
yukari_hinata | 5:792afbb0bcf3 | 146 | float MCSVM::predict_probability(float* data) |
yukari_hinata | 0:3f38e74a4a77 | 147 | { |
yukari_hinata | 5:792afbb0bcf3 | 148 | // シグモイド関数による確率的識別 |
yukari_hinata | 5:792afbb0bcf3 | 149 | float prob; |
yukari_hinata | 5:792afbb0bcf3 | 150 | float* result_label_prob = new float[n_class]; |
yukari_hinata | 5:792afbb0bcf3 | 151 | int tmp_ci, tmp_cj; |
yukari_hinata | 5:792afbb0bcf3 | 152 | |
yukari_hinata | 5:792afbb0bcf3 | 153 | memset(result_label_prob, 0, sizeof(float) * n_class); |
yukari_hinata | 5:792afbb0bcf3 | 154 | |
yukari_hinata | 5:792afbb0bcf3 | 155 | // 元のラベルを退避 |
yukari_hinata | 5:792afbb0bcf3 | 156 | int* tmp_label = new int[n_sample]; |
yukari_hinata | 5:792afbb0bcf3 | 157 | memcpy(tmp_label,label, sizeof(int) * n_sample); |
yukari_hinata | 2:c4a5251cee32 | 158 | |
yukari_hinata | 5:792afbb0bcf3 | 159 | for (int ci = 0; ci < n_class; ci++) { |
yukari_hinata | 5:792afbb0bcf3 | 160 | for (int cj = ci + 1; cj < n_class; cj++) { |
yukari_hinata | 0:3f38e74a4a77 | 161 | |
yukari_hinata | 5:792afbb0bcf3 | 162 | // インデックスをci < cjに : 負例はci, 正例はcj |
yukari_hinata | 5:792afbb0bcf3 | 163 | tmp_ci = ci; |
yukari_hinata | 5:792afbb0bcf3 | 164 | tmp_cj = cj; |
yukari_hinata | 5:792afbb0bcf3 | 165 | if ( ci > cj ) { |
yukari_hinata | 5:792afbb0bcf3 | 166 | tmp_cj = ci; |
yukari_hinata | 5:792afbb0bcf3 | 167 | tmp_ci = cj; |
yukari_hinata | 5:792afbb0bcf3 | 168 | } |
yukari_hinata | 5:792afbb0bcf3 | 169 | // 係数とラベルを取得し,ci,cjを識別するSVMを構成 |
yukari_hinata | 5:792afbb0bcf3 | 170 | memcpy(this->alpha, &(MATRIX_AT(mc_alpha, n_sample, INX_KSVM_IJ(n_class,tmp_ci,tmp_cj), 0)), sizeof(float) * n_sample); |
yukari_hinata | 5:792afbb0bcf3 | 171 | memcpy(this->label, &(MATRIX_AT(mc_label, n_sample, INX_KSVM_IJ(n_class,tmp_ci,tmp_cj), 0)), sizeof(float) * n_sample); |
yukari_hinata | 0:3f38e74a4a77 | 172 | |
yukari_hinata | 5:792afbb0bcf3 | 173 | // 確率識別:確率の足し上げ |
yukari_hinata | 5:792afbb0bcf3 | 174 | prob = SVM::predict_probability(data); |
yukari_hinata | 5:792afbb0bcf3 | 175 | if ( prob > float(0.5) ) { |
yukari_hinata | 5:792afbb0bcf3 | 176 | result_label_prob[cj] += prob; |
yukari_hinata | 5:792afbb0bcf3 | 177 | } else { |
yukari_hinata | 5:792afbb0bcf3 | 178 | result_label_prob[ci] += (1-prob); |
yukari_hinata | 5:792afbb0bcf3 | 179 | } |
yukari_hinata | 0:3f38e74a4a77 | 180 | |
yukari_hinata | 5:792afbb0bcf3 | 181 | } |
yukari_hinata | 5:792afbb0bcf3 | 182 | //printf("sum_net[%d] : %f \n", ci, sum_net[ci]); |
yukari_hinata | 0:3f38e74a4a77 | 183 | } |
yukari_hinata | 0:3f38e74a4a77 | 184 | |
yukari_hinata | 5:792afbb0bcf3 | 185 | // 判定:最大確率和 |
yukari_hinata | 5:792afbb0bcf3 | 186 | // おそらくラベル識別との整合性は取れる...はず |
yukari_hinata | 5:792afbb0bcf3 | 187 | float max = 0; |
yukari_hinata | 5:792afbb0bcf3 | 188 | for (int i = 0; i < n_class; i++) { |
yukari_hinata | 5:792afbb0bcf3 | 189 | //printf("%d : %d \n", i, result_label_count[i]); |
yukari_hinata | 5:792afbb0bcf3 | 190 | if ( result_label_prob[i] > max ) { |
yukari_hinata | 5:792afbb0bcf3 | 191 | max = result_label_prob[i]; |
yukari_hinata | 5:792afbb0bcf3 | 192 | } |
yukari_hinata | 0:3f38e74a4a77 | 193 | } |
yukari_hinata | 5:792afbb0bcf3 | 194 | |
yukari_hinata | 5:792afbb0bcf3 | 195 | // 元のラベルを復帰 |
yukari_hinata | 5:792afbb0bcf3 | 196 | memcpy(this->label, tmp_label, sizeof(int) * n_sample); |
yukari_hinata | 5:792afbb0bcf3 | 197 | delete [] tmp_label; |
yukari_hinata | 5:792afbb0bcf3 | 198 | delete [] result_label_prob; |
yukari_hinata | 5:792afbb0bcf3 | 199 | |
yukari_hinata | 5:792afbb0bcf3 | 200 | // 平均確率を返す. |
yukari_hinata | 5:792afbb0bcf3 | 201 | return (max / (n_class-1)); |
yukari_hinata | 0:3f38e74a4a77 | 202 | |
yukari_hinata | 0:3f38e74a4a77 | 203 | } |
yukari_hinata | 0:3f38e74a4a77 | 204 | |
yukari_hinata | 0:3f38e74a4a77 | 205 | // override |
yukari_hinata | 5:792afbb0bcf3 | 206 | float* MCSVM::get_alpha(void) |
yukari_hinata | 5:792afbb0bcf3 | 207 | { |
yukari_hinata | 0:3f38e74a4a77 | 208 | return (float *)mc_alpha; |
yukari_hinata | 0:3f38e74a4a77 | 209 | } |
yukari_hinata | 0:3f38e74a4a77 | 210 | |
yukari_hinata | 0:3f38e74a4a77 | 211 | // override |
yukari_hinata | 5:792afbb0bcf3 | 212 | void MCSVM::set_alpha(float* mcalpha_data, int nsample, int nclass) |
yukari_hinata | 5:792afbb0bcf3 | 213 | { |
yukari_hinata | 0:3f38e74a4a77 | 214 | if ( nsample != n_sample ) { |
yukari_hinata | 0:3f38e74a4a77 | 215 | fprintf( stderr, " set_alpha : number of sample isn't match : n_samle= %d, arg= %d \r\n", n_sample, nsample); |
yukari_hinata | 0:3f38e74a4a77 | 216 | return; |
yukari_hinata | 0:3f38e74a4a77 | 217 | } else if ( nclass != n_class ) { |
yukari_hinata | 0:3f38e74a4a77 | 218 | fprintf( stderr, " set_alpha : number of class isn't match : n_class= %d, nclass= %d \r\n", n_class, nclass); |
yukari_hinata | 0:3f38e74a4a77 | 219 | return; |
yukari_hinata | 0:3f38e74a4a77 | 220 | } |
yukari_hinata | 5:792afbb0bcf3 | 221 | memcpy(mc_alpha, mcalpha_data, sizeof(float) * n_sample * n_class * (n_class - 1) / 2); |
yukari_hinata | 0:3f38e74a4a77 | 222 | status = SVM_SET_ALPHA; |
yukari_hinata | 0:3f38e74a4a77 | 223 | } |