Easy Support Vector Machine
Embed:
(wiki syntax)
Show/hide line numbers
MCSVM.cpp
00001 #include "MCSVM.hpp" 00002 00003 // コンストラクタ. 適宜追加予定 00004 MCSVM::MCSVM(int class_n, 00005 int sample_dim, 00006 int sample_n, 00007 float* sample_data, 00008 int* sample_mclabel) 00009 : SVM(sample_dim, sample_n, sample_data, sample_mclabel) 00010 { 00011 this->n_class = class_n; 00012 this->maxFailcount = 5; 00013 00014 int n_kC2 = (n_class) * (n_class - 1) / 2; 00015 00016 mc_alpha = new float[n_sample * n_kC2]; 00017 mc_label = new int[n_sample * n_kC2]; 00018 00019 // ラベルの生成 00020 int tmp_lab; 00021 for (int ci = 0; ci < n_class; ci++) { 00022 for (int cj = ci + 1; cj < n_class; cj++) { 00023 // i < jであることから, ラベルiは負例, ラベルjは正例に割り当てる. 00024 // いずれのラベルにも該当しないデータは欠損とし,ラベル0とする. 00025 for (int l=0; l < n_sample; l++) { 00026 if (this->label[l] == ci) { 00027 tmp_lab = -1; 00028 } else if (this->label[l] == cj) { 00029 tmp_lab = 1; 00030 } else { 00031 tmp_lab = 0; 00032 } 00033 MATRIX_AT(mc_label,n_sample,INX_KSVM_IJ(n_class,ci,cj),l) = tmp_lab; 00034 // printf("%d : %d -> %d \r\n", l, label[l], tmp_lab); 00035 } 00036 00037 } 00038 } 00039 00040 } 00041 00042 // 領域開放 00043 MCSVM::~MCSVM(void) 00044 { 00045 delete [] mc_alpha; 00046 delete [] mc_label; 00047 } 00048 00049 // 全SVMの学習. 00050 int MCSVM::learning(void) 00051 { 00052 int status, fail_count; 00053 int* tmp_label = new int[n_sample]; 00054 // 元のラベルを退避 00055 memcpy(tmp_label,label, sizeof(int) * n_sample); 00056 for (int ci = 0; ci < n_class; ci++) { 00057 for (int cj = ci + 1; cj < n_class; cj++) { 00058 // 2値ラベルを取得する. 00059 memcpy(label, &(MATRIX_AT(mc_label, n_sample, INX_KSVM_IJ(n_class,ci,cj), 0)), sizeof(int) * n_sample); 00060 // 学習 - 学習失敗の場合はリトライする. 00061 fail_count = 0; 00062 do { 00063 if (fail_count >= maxFailcount) { 00064 fprintf(stderr, "Learning failed %d times at %d,%d classifier, give up \r\n", fail_count, ci, cj); 00065 return MCSVM_NOT_LEARN; 00066 } 00067 status = SVM::learning(); 00068 00069 if ( (status == SVM_NOT_CONVERGENCED) 00070 || (status == SVM_DETECT_BAD_VAL) ) { 00071 fail_count++; 00072 } 00073 } while (status != SVM_LEARN_SUCCESS); 00074 // 学習結果の係数とSVラベルの取得 00075 memcpy(&(MATRIX_AT(mc_alpha, n_sample, INX_KSVM_IJ(n_class,ci,cj), 0)), this->alpha, sizeof(float) * n_sample); 00076 } 00077 } 00078 00079 // 元のラベルを復帰 00080 memcpy(this->label, tmp_label, sizeof(int) * n_sample); 00081 delete [] tmp_label; 00082 return MCSVM_LEARN_SUCCESS; 00083 00084 } 00085 00086 // 未知データのラベルを推定する. 00087 int MCSVM::predict_label(float* data) 00088 { 00089 // 単位ステップ関数による決定的識別 00090 float net; 00091 int* result_label_count = new int[n_class]; 00092 00093 // 元のラベルを退避 00094 int* tmp_label = new int[n_sample]; 00095 memcpy(tmp_label,label, sizeof(int) * n_sample); 00096 00097 int tmp_ci, tmp_cj; 00098 memset(result_label_count, 0, sizeof(int) * n_class); 00099 for (int ci = 0; ci < n_class; ci++) { 00100 for (int cj = ci + 1; cj < n_class; cj++) { 00101 00102 // インデックスをi < jに 00103 tmp_ci = ci; 00104 tmp_cj = cj; 00105 if ( ci > cj ) { 00106 tmp_cj = ci; 00107 tmp_ci = cj; 00108 } 00109 // 係数とラベルを取得し,ci,cjを識別するSVMを構成 00110 memcpy(this->alpha, &(MATRIX_AT(mc_alpha, n_sample, INX_KSVM_IJ(n_class,tmp_ci,tmp_cj), 0)), sizeof(float) * n_sample); 00111 memcpy(this->label, &(MATRIX_AT(mc_label, n_sample, INX_KSVM_IJ(n_class,tmp_ci,tmp_cj), 0)), sizeof(float) * n_sample); 00112 00113 // 識別:識別されたクラスに投票. 00114 net = SVM::predict_net(data); 00115 //printf("ci:%d cj:%d >> net : %f \n", ci, cj, net); 00116 if ( net < 0 ) { 00117 result_label_count[ci]++; 00118 } else if ( net >= 0 ) { 00119 result_label_count[cj]++; 00120 } 00121 00122 } 00123 //printf("sum_net[%d] : %f \n", ci, sum_net[ci]); 00124 } 00125 00126 // 判定:最大頻度のクラスに判定する. 00127 int max,argmax; 00128 max = 0; 00129 for (int i = 0; i < n_class; i++) { 00130 //printf("%d : %d \n", i, result_label_count[i]); 00131 if ( result_label_count[i] > max ) { 00132 max = result_label_count[i]; 00133 argmax = i; 00134 } 00135 } 00136 // 元のラベルを復帰 00137 memcpy(this->label, tmp_label, sizeof(int) * n_sample); 00138 delete [] tmp_label; 00139 delete [] result_label_count; 00140 00141 return argmax; 00142 00143 } 00144 00145 // 未知データの識別確率を推定する. 00146 float MCSVM::predict_probability(float* data) 00147 { 00148 // シグモイド関数による確率的識別 00149 float prob; 00150 float* result_label_prob = new float[n_class]; 00151 int tmp_ci, tmp_cj; 00152 00153 memset(result_label_prob, 0, sizeof(float) * n_class); 00154 00155 // 元のラベルを退避 00156 int* tmp_label = new int[n_sample]; 00157 memcpy(tmp_label,label, sizeof(int) * n_sample); 00158 00159 for (int ci = 0; ci < n_class; ci++) { 00160 for (int cj = ci + 1; cj < n_class; cj++) { 00161 00162 // インデックスをci < cjに : 負例はci, 正例はcj 00163 tmp_ci = ci; 00164 tmp_cj = cj; 00165 if ( ci > cj ) { 00166 tmp_cj = ci; 00167 tmp_ci = cj; 00168 } 00169 // 係数とラベルを取得し,ci,cjを識別するSVMを構成 00170 memcpy(this->alpha, &(MATRIX_AT(mc_alpha, n_sample, INX_KSVM_IJ(n_class,tmp_ci,tmp_cj), 0)), sizeof(float) * n_sample); 00171 memcpy(this->label, &(MATRIX_AT(mc_label, n_sample, INX_KSVM_IJ(n_class,tmp_ci,tmp_cj), 0)), sizeof(float) * n_sample); 00172 00173 // 確率識別:確率の足し上げ 00174 prob = SVM::predict_probability(data); 00175 if ( prob > float(0.5) ) { 00176 result_label_prob[cj] += prob; 00177 } else { 00178 result_label_prob[ci] += (1-prob); 00179 } 00180 00181 } 00182 //printf("sum_net[%d] : %f \n", ci, sum_net[ci]); 00183 } 00184 00185 // 判定:最大確率和 00186 // おそらくラベル識別との整合性は取れる...はず 00187 float max = 0; 00188 for (int i = 0; i < n_class; i++) { 00189 //printf("%d : %d \n", i, result_label_count[i]); 00190 if ( result_label_prob[i] > max ) { 00191 max = result_label_prob[i]; 00192 } 00193 } 00194 00195 // 元のラベルを復帰 00196 memcpy(this->label, tmp_label, sizeof(int) * n_sample); 00197 delete [] tmp_label; 00198 delete [] result_label_prob; 00199 00200 // 平均確率を返す. 00201 return (max / (n_class-1)); 00202 00203 } 00204 00205 // override 00206 float* MCSVM::get_alpha(void) 00207 { 00208 return (float *)mc_alpha; 00209 } 00210 00211 // override 00212 void MCSVM::set_alpha(float* mcalpha_data, int nsample, int nclass) 00213 { 00214 if ( nsample != n_sample ) { 00215 fprintf( stderr, " set_alpha : number of sample isn't match : n_samle= %d, arg= %d \r\n", n_sample, nsample); 00216 return; 00217 } else if ( nclass != n_class ) { 00218 fprintf( stderr, " set_alpha : number of class isn't match : n_class= %d, nclass= %d \r\n", n_class, nclass); 00219 return; 00220 } 00221 memcpy(mc_alpha, mcalpha_data, sizeof(float) * n_sample * n_class * (n_class - 1) / 2); 00222 status = SVM_SET_ALPHA; 00223 }
Generated on Thu Jul 21 2022 16:34:03 by 1.7.2