Easy Support Vector Machine
MCSVM.cpp@1:1a0d5152d50b, 2015-01-28 (annotated)
- Committer:
- yukari_hinata
- Date:
- Wed Jan 28 15:22:10 2015 +0000
- Revision:
- 1:1a0d5152d50b
- Parent:
- 0:3f38e74a4a77
- Child:
- 2:c4a5251cee32
modified.
Who changed what in which revision?
User | Revision | Line number | New contents of line |
---|---|---|---|
yukari_hinata | 0:3f38e74a4a77 | 1 | #include "MCSVM.hpp" |
yukari_hinata | 0:3f38e74a4a77 | 2 | |
yukari_hinata | 0:3f38e74a4a77 | 3 | // コンストラクタ. 適宜追加予定 |
yukari_hinata | 0:3f38e74a4a77 | 4 | MCSVM::MCSVM(int class_n, |
yukari_hinata | 0:3f38e74a4a77 | 5 | int sample_dim, |
yukari_hinata | 0:3f38e74a4a77 | 6 | int sample_n, |
yukari_hinata | 0:3f38e74a4a77 | 7 | float* sample_data, |
yukari_hinata | 0:3f38e74a4a77 | 8 | int* sample_mclabel) |
yukari_hinata | 0:3f38e74a4a77 | 9 | : SVM(sample_dim, sample_n, sample_data, sample_mclabel) |
yukari_hinata | 0:3f38e74a4a77 | 10 | { |
yukari_hinata | 0:3f38e74a4a77 | 11 | this->n_class = class_n; |
yukari_hinata | 0:3f38e74a4a77 | 12 | this->maxFailcount = 5; |
yukari_hinata | 0:3f38e74a4a77 | 13 | |
yukari_hinata | 0:3f38e74a4a77 | 14 | int n_kC2 = (n_class) * (n_class - 1) / 2; |
yukari_hinata | 0:3f38e74a4a77 | 15 | |
yukari_hinata | 0:3f38e74a4a77 | 16 | mc_alpha = new float[n_sample * n_kC2]; |
yukari_hinata | 0:3f38e74a4a77 | 17 | mc_label = new int[n_sample * n_kC2]; |
yukari_hinata | 0:3f38e74a4a77 | 18 | |
yukari_hinata | 0:3f38e74a4a77 | 19 | // ラベルの生成 |
yukari_hinata | 0:3f38e74a4a77 | 20 | int tmp_lab; |
yukari_hinata | 0:3f38e74a4a77 | 21 | for (int ci = 0; ci < n_class; ci++) { |
yukari_hinata | 0:3f38e74a4a77 | 22 | for (int cj = ci + 1; cj < n_class; cj++) { |
yukari_hinata | 0:3f38e74a4a77 | 23 | // i < jであることから, ラベルiは負例, ラベルjは正例に割り当てる. |
yukari_hinata | 0:3f38e74a4a77 | 24 | // いずれのラベルにも該当しないデータは欠損とし,ラベル0とする. |
yukari_hinata | 0:3f38e74a4a77 | 25 | for (int l=0; l < n_sample; l++) { |
yukari_hinata | 0:3f38e74a4a77 | 26 | if (this->label[l] == ci) { |
yukari_hinata | 0:3f38e74a4a77 | 27 | tmp_lab = -1; |
yukari_hinata | 0:3f38e74a4a77 | 28 | } else if (this->label[l] == cj) { |
yukari_hinata | 0:3f38e74a4a77 | 29 | tmp_lab = 1; |
yukari_hinata | 0:3f38e74a4a77 | 30 | } else { |
yukari_hinata | 0:3f38e74a4a77 | 31 | tmp_lab = 0; |
yukari_hinata | 0:3f38e74a4a77 | 32 | } |
yukari_hinata | 0:3f38e74a4a77 | 33 | MATRIX_AT(mc_label,n_sample,INX_KSVM_IJ(n_class,ci,cj),l) = tmp_lab; |
yukari_hinata | 0:3f38e74a4a77 | 34 | //printf("l %d : %d -> %d \n", l, label[l], tmp_lab); |
yukari_hinata | 0:3f38e74a4a77 | 35 | } |
yukari_hinata | 0:3f38e74a4a77 | 36 | |
yukari_hinata | 0:3f38e74a4a77 | 37 | } |
yukari_hinata | 0:3f38e74a4a77 | 38 | } |
yukari_hinata | 0:3f38e74a4a77 | 39 | |
yukari_hinata | 0:3f38e74a4a77 | 40 | } |
yukari_hinata | 0:3f38e74a4a77 | 41 | |
yukari_hinata | 0:3f38e74a4a77 | 42 | // 領域開放 |
yukari_hinata | 0:3f38e74a4a77 | 43 | MCSVM::~MCSVM(void) |
yukari_hinata | 0:3f38e74a4a77 | 44 | { |
yukari_hinata | 0:3f38e74a4a77 | 45 | delete [] mc_alpha; |
yukari_hinata | 0:3f38e74a4a77 | 46 | delete [] mc_label; |
yukari_hinata | 0:3f38e74a4a77 | 47 | } |
yukari_hinata | 0:3f38e74a4a77 | 48 | |
yukari_hinata | 0:3f38e74a4a77 | 49 | // 全SVMの学習. |
yukari_hinata | 0:3f38e74a4a77 | 50 | int MCSVM::learning(void) |
yukari_hinata | 0:3f38e74a4a77 | 51 | { |
yukari_hinata | 0:3f38e74a4a77 | 52 | int status, fail_count; |
yukari_hinata | 0:3f38e74a4a77 | 53 | int* tmp_label = new int[n_sample]; |
yukari_hinata | 0:3f38e74a4a77 | 54 | // 元のラベルを退避 |
yukari_hinata | 0:3f38e74a4a77 | 55 | memcpy(tmp_label,label, sizeof(int) * n_sample); |
yukari_hinata | 0:3f38e74a4a77 | 56 | for (int ci = 0; ci < n_class; ci++) { |
yukari_hinata | 0:3f38e74a4a77 | 57 | for (int cj = ci + 1; cj < n_class; cj++) { |
yukari_hinata | 0:3f38e74a4a77 | 58 | // 2値ラベルを取得する. |
yukari_hinata | 0:3f38e74a4a77 | 59 | memcpy(label, &(MATRIX_AT(mc_label, n_sample, INX_KSVM_IJ(n_class,ci,cj), 0)), sizeof(int) * n_sample); |
yukari_hinata | 0:3f38e74a4a77 | 60 | // 学習 - 学習失敗の場合はリトライする. |
yukari_hinata | 0:3f38e74a4a77 | 61 | fail_count = 0; |
yukari_hinata | 0:3f38e74a4a77 | 62 | do { |
yukari_hinata | 0:3f38e74a4a77 | 63 | if (fail_count >= maxFailcount) { |
yukari_hinata | 0:3f38e74a4a77 | 64 | fprintf(stderr, "Learning failed %d times at %d,%d classifier, give up \r\n", fail_count, ci, cj); |
yukari_hinata | 0:3f38e74a4a77 | 65 | return MCSVM_NOT_LEARN; |
yukari_hinata | 0:3f38e74a4a77 | 66 | } |
yukari_hinata | 0:3f38e74a4a77 | 67 | status = SVM::learning(); |
yukari_hinata | 0:3f38e74a4a77 | 68 | |
yukari_hinata | 0:3f38e74a4a77 | 69 | if ( (status == SVM_NOT_CONVERGENCED) |
yukari_hinata | 0:3f38e74a4a77 | 70 | || (status == SVM_DETECT_BAD_VAL) ) { |
yukari_hinata | 0:3f38e74a4a77 | 71 | fail_count++; |
yukari_hinata | 0:3f38e74a4a77 | 72 | } |
yukari_hinata | 0:3f38e74a4a77 | 73 | } while (status != SVM_LEARN_SUCCESS); |
yukari_hinata | 0:3f38e74a4a77 | 74 | // 学習結果の係数とSVラベルの取得 |
yukari_hinata | 0:3f38e74a4a77 | 75 | memcpy(&(MATRIX_AT(mc_alpha, n_sample, INX_KSVM_IJ(n_class,ci,cj), 0)), this->alpha, sizeof(float) * n_sample); |
yukari_hinata | 0:3f38e74a4a77 | 76 | } |
yukari_hinata | 0:3f38e74a4a77 | 77 | } |
yukari_hinata | 0:3f38e74a4a77 | 78 | |
yukari_hinata | 0:3f38e74a4a77 | 79 | // 元のラベルを復帰 |
yukari_hinata | 0:3f38e74a4a77 | 80 | memcpy(this->label, tmp_label, sizeof(int) * n_sample); |
yukari_hinata | 0:3f38e74a4a77 | 81 | delete [] tmp_label; |
yukari_hinata | 0:3f38e74a4a77 | 82 | return MCSVM_LEARN_SUCCESS; |
yukari_hinata | 0:3f38e74a4a77 | 83 | |
yukari_hinata | 0:3f38e74a4a77 | 84 | } |
yukari_hinata | 0:3f38e74a4a77 | 85 | |
yukari_hinata | 0:3f38e74a4a77 | 86 | // 未知データのラベルを推定する. |
yukari_hinata | 0:3f38e74a4a77 | 87 | int MCSVM::predict_label(float* data) |
yukari_hinata | 0:3f38e74a4a77 | 88 | { |
yukari_hinata | 0:3f38e74a4a77 | 89 | // 単位ステップ関数による決定的識別 |
yukari_hinata | 0:3f38e74a4a77 | 90 | float net; |
yukari_hinata | 0:3f38e74a4a77 | 91 | int* result_label_count = new int[n_class]; |
yukari_hinata | 0:3f38e74a4a77 | 92 | |
yukari_hinata | 0:3f38e74a4a77 | 93 | int tmp_ci, tmp_cj; |
yukari_hinata | 0:3f38e74a4a77 | 94 | memset(result_label_count, 0, sizeof(int) * n_class); |
yukari_hinata | 0:3f38e74a4a77 | 95 | for (int ci = 0; ci < n_class; ci++) { |
yukari_hinata | 0:3f38e74a4a77 | 96 | for (int cj = ci + 1; cj < n_class; cj++) { |
yukari_hinata | 0:3f38e74a4a77 | 97 | |
yukari_hinata | 0:3f38e74a4a77 | 98 | // インデックスをi < jに |
yukari_hinata | 0:3f38e74a4a77 | 99 | tmp_ci = ci; tmp_cj = cj; |
yukari_hinata | 0:3f38e74a4a77 | 100 | if ( ci > cj ) { |
yukari_hinata | 0:3f38e74a4a77 | 101 | tmp_cj = ci; tmp_ci = cj; |
yukari_hinata | 0:3f38e74a4a77 | 102 | } |
yukari_hinata | 0:3f38e74a4a77 | 103 | // 係数とラベルを取得し,ci,cjを識別するSVMを構成 |
yukari_hinata | 0:3f38e74a4a77 | 104 | memcpy(this->alpha, &(MATRIX_AT(mc_alpha, n_sample, INX_KSVM_IJ(n_class,tmp_ci,tmp_cj), 0)), sizeof(float) * n_sample); |
yukari_hinata | 0:3f38e74a4a77 | 105 | memcpy(this->label, &(MATRIX_AT(mc_label, n_sample, INX_KSVM_IJ(n_class,tmp_ci,tmp_cj), 0)), sizeof(float) * n_sample); |
yukari_hinata | 0:3f38e74a4a77 | 106 | |
yukari_hinata | 0:3f38e74a4a77 | 107 | // 識別:識別されたクラスに投票. |
yukari_hinata | 0:3f38e74a4a77 | 108 | net = SVM::predict_net(data); |
yukari_hinata | 0:3f38e74a4a77 | 109 | //printf("ci:%d cj:%d >> net : %f \n", ci, cj, net); |
yukari_hinata | 0:3f38e74a4a77 | 110 | if ( net < 0 ) { |
yukari_hinata | 0:3f38e74a4a77 | 111 | result_label_count[ci]++; |
yukari_hinata | 0:3f38e74a4a77 | 112 | } else if ( net >= 0 ) { |
yukari_hinata | 0:3f38e74a4a77 | 113 | result_label_count[cj]++; |
yukari_hinata | 0:3f38e74a4a77 | 114 | } |
yukari_hinata | 0:3f38e74a4a77 | 115 | |
yukari_hinata | 0:3f38e74a4a77 | 116 | } |
yukari_hinata | 0:3f38e74a4a77 | 117 | //printf("sum_net[%d] : %f \n", ci, sum_net[ci]); |
yukari_hinata | 0:3f38e74a4a77 | 118 | } |
yukari_hinata | 0:3f38e74a4a77 | 119 | |
yukari_hinata | 0:3f38e74a4a77 | 120 | // 判定:最大頻度のクラスに判定する. |
yukari_hinata | 0:3f38e74a4a77 | 121 | int max,argmax; |
yukari_hinata | 0:3f38e74a4a77 | 122 | max = 0; |
yukari_hinata | 0:3f38e74a4a77 | 123 | for (int i = 0; i < n_class; i++) { |
yukari_hinata | 0:3f38e74a4a77 | 124 | //printf("%d : %d \n", i, result_label_count[i]); |
yukari_hinata | 0:3f38e74a4a77 | 125 | if ( result_label_count[i] > max ) { |
yukari_hinata | 0:3f38e74a4a77 | 126 | max = result_label_count[i]; |
yukari_hinata | 0:3f38e74a4a77 | 127 | argmax = i; |
yukari_hinata | 0:3f38e74a4a77 | 128 | } |
yukari_hinata | 0:3f38e74a4a77 | 129 | } |
yukari_hinata | 0:3f38e74a4a77 | 130 | |
yukari_hinata | 0:3f38e74a4a77 | 131 | delete [] result_label_count; |
yukari_hinata | 0:3f38e74a4a77 | 132 | return argmax; |
yukari_hinata | 0:3f38e74a4a77 | 133 | |
yukari_hinata | 0:3f38e74a4a77 | 134 | } |
yukari_hinata | 0:3f38e74a4a77 | 135 | |
yukari_hinata | 0:3f38e74a4a77 | 136 | // 未知データの識別確率を推定する. |
yukari_hinata | 0:3f38e74a4a77 | 137 | float MCSVM::predict_probability(float* data) |
yukari_hinata | 0:3f38e74a4a77 | 138 | { |
yukari_hinata | 0:3f38e74a4a77 | 139 | // シグモイド関数による確率的識別 |
yukari_hinata | 0:3f38e74a4a77 | 140 | float prob; |
yukari_hinata | 0:3f38e74a4a77 | 141 | float* result_label_prob = new float[n_class]; |
yukari_hinata | 0:3f38e74a4a77 | 142 | int tmp_ci, tmp_cj; |
yukari_hinata | 1:1a0d5152d50b | 143 | memset(result_label_prob, 0, sizeof(float) * n_class); |
yukari_hinata | 0:3f38e74a4a77 | 144 | for (int ci = 0; ci < n_class; ci++) { |
yukari_hinata | 0:3f38e74a4a77 | 145 | for (int cj = ci + 1; cj < n_class; cj++) { |
yukari_hinata | 0:3f38e74a4a77 | 146 | |
yukari_hinata | 0:3f38e74a4a77 | 147 | // インデックスをci < cjに : 負例はci, 正例はcj |
yukari_hinata | 0:3f38e74a4a77 | 148 | tmp_ci = ci; tmp_cj = cj; |
yukari_hinata | 0:3f38e74a4a77 | 149 | if ( ci > cj ) { |
yukari_hinata | 0:3f38e74a4a77 | 150 | tmp_cj = ci; tmp_ci = cj; |
yukari_hinata | 0:3f38e74a4a77 | 151 | } |
yukari_hinata | 0:3f38e74a4a77 | 152 | // 係数とラベルを取得し,ci,cjを識別するSVMを構成 |
yukari_hinata | 0:3f38e74a4a77 | 153 | memcpy(this->alpha, &(MATRIX_AT(mc_alpha, n_sample, INX_KSVM_IJ(n_class,tmp_ci,tmp_cj), 0)), sizeof(float) * n_sample); |
yukari_hinata | 0:3f38e74a4a77 | 154 | memcpy(this->label, &(MATRIX_AT(mc_label, n_sample, INX_KSVM_IJ(n_class,tmp_ci,tmp_cj), 0)), sizeof(float) * n_sample); |
yukari_hinata | 0:3f38e74a4a77 | 155 | |
yukari_hinata | 0:3f38e74a4a77 | 156 | // 確率識別:確率の足し上げ |
yukari_hinata | 0:3f38e74a4a77 | 157 | prob = SVM::predict_probability(data); |
yukari_hinata | 0:3f38e74a4a77 | 158 | if ( prob > float(0.5) ) { |
yukari_hinata | 1:1a0d5152d50b | 159 | result_label_prob[cj] += prob; |
yukari_hinata | 0:3f38e74a4a77 | 160 | } else { |
yukari_hinata | 1:1a0d5152d50b | 161 | result_label_prob[ci] += (1-prob); |
yukari_hinata | 0:3f38e74a4a77 | 162 | } |
yukari_hinata | 0:3f38e74a4a77 | 163 | |
yukari_hinata | 0:3f38e74a4a77 | 164 | } |
yukari_hinata | 0:3f38e74a4a77 | 165 | //printf("sum_net[%d] : %f \n", ci, sum_net[ci]); |
yukari_hinata | 0:3f38e74a4a77 | 166 | } |
yukari_hinata | 0:3f38e74a4a77 | 167 | |
yukari_hinata | 0:3f38e74a4a77 | 168 | // 判定:最大確率和 |
yukari_hinata | 0:3f38e74a4a77 | 169 | // おそらくラベル識別との整合性は取れる...はず |
yukari_hinata | 0:3f38e74a4a77 | 170 | float max = 0; |
yukari_hinata | 0:3f38e74a4a77 | 171 | for (int i = 0; i < n_class; i++) { |
yukari_hinata | 0:3f38e74a4a77 | 172 | //printf("%d : %d \n", i, result_label_count[i]); |
yukari_hinata | 0:3f38e74a4a77 | 173 | if ( result_label_prob[i] > max ) { |
yukari_hinata | 0:3f38e74a4a77 | 174 | max = result_label_prob[i]; |
yukari_hinata | 0:3f38e74a4a77 | 175 | } |
yukari_hinata | 0:3f38e74a4a77 | 176 | } |
yukari_hinata | 0:3f38e74a4a77 | 177 | |
yukari_hinata | 0:3f38e74a4a77 | 178 | delete [] result_label_prob; |
yukari_hinata | 0:3f38e74a4a77 | 179 | // 平均確率を返す. |
yukari_hinata | 1:1a0d5152d50b | 180 | return (max / (n_class-1)); |
yukari_hinata | 0:3f38e74a4a77 | 181 | |
yukari_hinata | 0:3f38e74a4a77 | 182 | } |
yukari_hinata | 0:3f38e74a4a77 | 183 | |
yukari_hinata | 0:3f38e74a4a77 | 184 | // override |
yukari_hinata | 0:3f38e74a4a77 | 185 | float* MCSVM::get_alpha(void) { |
yukari_hinata | 0:3f38e74a4a77 | 186 | return (float *)mc_alpha; |
yukari_hinata | 0:3f38e74a4a77 | 187 | } |
yukari_hinata | 0:3f38e74a4a77 | 188 | |
yukari_hinata | 0:3f38e74a4a77 | 189 | // override |
yukari_hinata | 0:3f38e74a4a77 | 190 | void MCSVM::set_alpha(float* mcalpha_data, int nsample, int nclass) { |
yukari_hinata | 0:3f38e74a4a77 | 191 | if ( nsample != n_sample ) { |
yukari_hinata | 0:3f38e74a4a77 | 192 | fprintf( stderr, " set_alpha : number of sample isn't match : n_samle= %d, arg= %d \r\n", n_sample, nsample); |
yukari_hinata | 0:3f38e74a4a77 | 193 | return; |
yukari_hinata | 0:3f38e74a4a77 | 194 | } else if ( nclass != n_class ) { |
yukari_hinata | 0:3f38e74a4a77 | 195 | fprintf( stderr, " set_alpha : number of class isn't match : n_class= %d, nclass= %d \r\n", n_class, nclass); |
yukari_hinata | 0:3f38e74a4a77 | 196 | return; |
yukari_hinata | 0:3f38e74a4a77 | 197 | } |
yukari_hinata | 0:3f38e74a4a77 | 198 | int nC2 = n_class * (n_class - 1)/2; |
yukari_hinata | 1:1a0d5152d50b | 199 | memcpy(mc_alpha, mcalpha_data, sizeof(float) * n_sample * (n_class * (n_class - 1) / 2)); |
yukari_hinata | 0:3f38e74a4a77 | 200 | status = SVM_SET_ALPHA; |
yukari_hinata | 0:3f38e74a4a77 | 201 | } |