Easy Support Vector Machine

Dependents:   WeatherPredictor

Embed: (wiki syntax)

« Back to documentation index

Show/hide line numbers MCSVM.cpp Source File

MCSVM.cpp

00001 #include "MCSVM.hpp"
00002 
00003 // コンストラクタ. 適宜追加予定
00004 MCSVM::MCSVM(int    class_n,
00005              int    sample_dim,
00006              int    sample_n,
00007              float* sample_data,
00008              int*   sample_mclabel)
00009     : SVM(sample_dim, sample_n, sample_data, sample_mclabel)
00010 {
00011     this->n_class = class_n;
00012     this->maxFailcount = 5;
00013 
00014     int n_kC2 = (n_class) * (n_class - 1) / 2;
00015 
00016     mc_alpha   = new float[n_sample * n_kC2];
00017     mc_label   = new int[n_sample * n_kC2];
00018 
00019     // ラベルの生成
00020     int tmp_lab;
00021     for (int ci = 0; ci < n_class; ci++) {
00022         for (int cj = ci + 1; cj < n_class; cj++) {
00023             // i < jであることから, ラベルiは負例, ラベルjは正例に割り当てる.
00024             // いずれのラベルにも該当しないデータは欠損とし,ラベル0とする.
00025             for (int l=0; l < n_sample; l++) {
00026                 if (this->label[l] == ci) {
00027                     tmp_lab = -1;
00028                 } else if (this->label[l] == cj) {
00029                     tmp_lab = 1;
00030                 } else {
00031                     tmp_lab = 0;
00032                 }
00033                 MATRIX_AT(mc_label,n_sample,INX_KSVM_IJ(n_class,ci,cj),l) = tmp_lab;
00034                 // printf("%d : %d -> %d \r\n", l, label[l], tmp_lab);
00035             }
00036 
00037         }
00038     }
00039 
00040 }
00041 
00042 // 領域開放
00043 MCSVM::~MCSVM(void)
00044 {
00045     delete [] mc_alpha;
00046     delete [] mc_label;
00047 }
00048 
00049 // 全SVMの学習.
00050 int MCSVM::learning(void)
00051 {
00052     int status, fail_count;
00053     int* tmp_label = new int[n_sample];
00054     // 元のラベルを退避
00055     memcpy(tmp_label,label, sizeof(int) * n_sample);
00056     for (int ci = 0; ci < n_class; ci++) {
00057         for (int cj = ci + 1; cj < n_class; cj++) {
00058             // 2値ラベルを取得する.
00059             memcpy(label, &(MATRIX_AT(mc_label, n_sample, INX_KSVM_IJ(n_class,ci,cj), 0)), sizeof(int) * n_sample);
00060             // 学習 - 学習失敗の場合はリトライする.
00061             fail_count = 0;
00062             do {
00063                 if (fail_count >= maxFailcount) {
00064                     fprintf(stderr, "Learning failed %d times at %d,%d classifier, give up \r\n", fail_count, ci, cj);
00065                     return MCSVM_NOT_LEARN;
00066                 }
00067                 status = SVM::learning();
00068 
00069                 if ( (status == SVM_NOT_CONVERGENCED)
00070                         || (status == SVM_DETECT_BAD_VAL) ) {
00071                     fail_count++;
00072                 }
00073             } while (status != SVM_LEARN_SUCCESS);
00074             // 学習結果の係数とSVラベルの取得
00075             memcpy(&(MATRIX_AT(mc_alpha, n_sample, INX_KSVM_IJ(n_class,ci,cj), 0)), this->alpha, sizeof(float) * n_sample);
00076         }
00077     }
00078 
00079     // 元のラベルを復帰
00080     memcpy(this->label, tmp_label, sizeof(int) * n_sample);
00081     delete [] tmp_label;
00082     return MCSVM_LEARN_SUCCESS;
00083 
00084 }
00085 
00086 // 未知データのラベルを推定する.
00087 int MCSVM::predict_label(float* data)
00088 {
00089     // 単位ステップ関数による決定的識別
00090     float net;
00091     int* result_label_count = new int[n_class];
00092 
00093     // 元のラベルを退避
00094     int* tmp_label = new int[n_sample];
00095     memcpy(tmp_label,label, sizeof(int) * n_sample);
00096 
00097     int tmp_ci, tmp_cj;
00098     memset(result_label_count, 0, sizeof(int) * n_class);
00099     for (int ci = 0; ci < n_class; ci++) {
00100         for (int cj = ci + 1; cj < n_class; cj++) {
00101 
00102             // インデックスをi < jに
00103             tmp_ci = ci;
00104             tmp_cj = cj;
00105             if ( ci > cj ) {
00106                 tmp_cj = ci;
00107                 tmp_ci = cj;
00108             }
00109             // 係数とラベルを取得し,ci,cjを識別するSVMを構成
00110             memcpy(this->alpha, &(MATRIX_AT(mc_alpha, n_sample, INX_KSVM_IJ(n_class,tmp_ci,tmp_cj), 0)), sizeof(float) * n_sample);
00111             memcpy(this->label, &(MATRIX_AT(mc_label, n_sample, INX_KSVM_IJ(n_class,tmp_ci,tmp_cj), 0)), sizeof(float) * n_sample);
00112 
00113             // 識別:識別されたクラスに投票.
00114             net = SVM::predict_net(data);
00115             //printf("ci:%d cj:%d >> net : %f \n", ci, cj, net);
00116             if ( net < 0 ) {
00117                 result_label_count[ci]++;
00118             } else if ( net >= 0 ) {
00119                 result_label_count[cj]++;
00120             }
00121 
00122         }
00123         //printf("sum_net[%d] : %f \n", ci, sum_net[ci]);
00124     }
00125 
00126     // 判定:最大頻度のクラスに判定する.
00127     int max,argmax;
00128     max = 0;
00129     for (int i = 0; i < n_class; i++) {
00130         //printf("%d : %d \n", i, result_label_count[i]);
00131         if ( result_label_count[i] > max ) {
00132             max = result_label_count[i];
00133             argmax = i;
00134         }
00135     }
00136     // 元のラベルを復帰
00137     memcpy(this->label, tmp_label, sizeof(int) * n_sample);
00138     delete [] tmp_label;
00139     delete [] result_label_count;
00140 
00141     return argmax;
00142 
00143 }
00144 
00145 // 未知データの識別確率を推定する.
00146 float MCSVM::predict_probability(float* data)
00147 {
00148     // シグモイド関数による確率的識別
00149     float prob;
00150     float* result_label_prob = new float[n_class];
00151     int tmp_ci, tmp_cj;
00152 
00153     memset(result_label_prob, 0, sizeof(float) * n_class);
00154 
00155     // 元のラベルを退避
00156     int* tmp_label = new int[n_sample];
00157     memcpy(tmp_label,label, sizeof(int) * n_sample);
00158 
00159     for (int ci = 0; ci < n_class; ci++) {
00160         for (int cj = ci + 1; cj < n_class; cj++) {
00161 
00162             // インデックスをci < cjに : 負例はci, 正例はcj
00163             tmp_ci = ci;
00164             tmp_cj = cj;
00165             if ( ci > cj ) {
00166                 tmp_cj = ci;
00167                 tmp_ci = cj;
00168             }
00169             // 係数とラベルを取得し,ci,cjを識別するSVMを構成
00170             memcpy(this->alpha, &(MATRIX_AT(mc_alpha, n_sample, INX_KSVM_IJ(n_class,tmp_ci,tmp_cj), 0)), sizeof(float) * n_sample);
00171             memcpy(this->label, &(MATRIX_AT(mc_label, n_sample, INX_KSVM_IJ(n_class,tmp_ci,tmp_cj), 0)), sizeof(float) * n_sample);
00172 
00173             // 確率識別:確率の足し上げ
00174             prob = SVM::predict_probability(data);
00175             if ( prob > float(0.5) ) {
00176                 result_label_prob[cj] += prob;
00177             } else {
00178                 result_label_prob[ci] += (1-prob);
00179             }
00180 
00181         }
00182         //printf("sum_net[%d] : %f \n", ci, sum_net[ci]);
00183     }
00184 
00185     // 判定:最大確率和
00186     // おそらくラベル識別との整合性は取れる...はず
00187     float max = 0;
00188     for (int i = 0; i < n_class; i++) {
00189         //printf("%d : %d \n", i, result_label_count[i]);
00190         if ( result_label_prob[i] > max ) {
00191             max = result_label_prob[i];
00192         }
00193     }
00194 
00195     // 元のラベルを復帰
00196     memcpy(this->label, tmp_label, sizeof(int) * n_sample);
00197     delete [] tmp_label;
00198     delete [] result_label_prob;
00199 
00200     // 平均確率を返す.
00201     return (max / (n_class-1));
00202 
00203 }
00204 
00205 // override
00206 float* MCSVM::get_alpha(void)
00207 {
00208     return (float *)mc_alpha;
00209 }
00210 
00211 // override
00212 void MCSVM::set_alpha(float* mcalpha_data, int nsample, int nclass)
00213 {
00214     if ( nsample != n_sample ) {
00215         fprintf( stderr, " set_alpha : number of sample isn't match : n_samle= %d, arg= %d \r\n", n_sample, nsample);
00216         return;
00217     } else if ( nclass != n_class ) {
00218         fprintf( stderr, " set_alpha : number of class isn't match : n_class= %d, nclass= %d \r\n", n_class, nclass);
00219         return;
00220     }
00221     memcpy(mc_alpha, mcalpha_data, sizeof(float) * n_sample * n_class * (n_class - 1) / 2);
00222     status = SVM_SET_ALPHA;
00223 }