Easy Support Vector Machine

Dependents:   WeatherPredictor

Embed: (wiki syntax)

« Back to documentation index

Show/hide line numbers SVM.hpp Source File

SVM.hpp

00001 /* Support Vector Machine by hinata_yukari */
00002 
00003 #ifndef SVM_H_INCLUDED
00004 #define SVM_H_INCLUDED
00005 
00006 #include "mbed.h"
00007 
00008 #include "../ml_util/ml_util.hpp"
00009 
00010 // SVMの学習状態
00011 typedef enum {
00012   SVM_NOT_LEARN,        // 学習していない
00013   SVM_LEARN_SUCCESS,    // 正常終了(収束によって終了した)
00014   SVM_NOT_CONVERGENCED, // 繰り返し上限到達
00015   SVM_DETECT_BAD_VAL,   // 係数で非数/無限を検知した
00016   SVM_SET_ALPHA,        // 学習していないが,最適化した係数がセットされている
00017 } SVM_STATUS;
00018 
00019 // Class SVM
00020 
00021 class SVM
00022 {
00023   protected:
00024 
00025     int dim_signal;     // 入力データの次元
00026     int n_sample;       // サンプルの個数
00027     float* sample_max;  // 特徴の正規化用の,各次元の最大値
00028     float* sample_min;  // 特徴の正規化用の,各次元の最小値
00029     float* alpha;       // 双対係数のベクトル
00030     // float* grammat;     // グラム(カーネル)行列 -> 領域をn_sample^2食うので,廃止.時間はかかるけど逐次処理. 
00031     int maxIteration;   // 学習の最大繰り返し回数
00032     float epsilon;      // 収束判定用の小さな値
00033     float eta;          // 学習係数
00034     float learn_alpha;  // 慣性項の係数
00035     float C1;           // 1ノルムソフトマージンンのスラック変数とハードマージンのトレード
00036                         // オフを与える係数 (無限でハードマージンに一致,FLT_MAXで近似)
00037     float C2;           // 2ノルムソフトマージンの〜,(無限でハードマージンに一致,FLT_MAXで近似)
00038                         // また,C1かC2どちらか一つのみを設定すること.
00039 
00040   public:
00041     int*   label;       // サンプルの2値ラベルの配列(-1 or 1)
00042     float* sample;      // n次元サンプルデータの配列.
00043     int    status;      // SVMの状態
00044 
00045   protected:
00046     // カーネル関数. ここでは簡易なRBFカーネルをハードコーディング
00047     inline float kernel_function(float *x, float *y, int n) {
00048       register float inprod = 0;
00049       for (int i=0;i < n;i++) {
00050         inprod += powf(x[i] - y[i],2);
00051         //printf("x[%d] : %f y[%d] : %f \n", i, x[i], i, y[i]);
00052       }
00053       return expf(-inprod/0.1);
00054     }
00055 
00056   public:
00057     // 最小の引数による初期化. 順次拡張予定.
00058     SVM(int,      // データ次元
00059         int,      // サンプル個数
00060         float*,   // サンプル
00061         int*);    // ラベル
00062         
00063     ~SVM(void);
00064 
00065     // 学習によりマージンを最大化し,サポートベクトルを確定させる.
00066     virtual int learning(void);
00067 
00068     // 未知データの識別.データを受け取り,識別ラベルを-1 or 1で返す.
00069     virtual int predict_label(float*);
00070 
00071     // 未知データのネットワーク値(負の値ならば0,即ち識別面の下半空間に,
00072     // 正の値ならば1,識別面の上半空間に存在すると判定)を計算して返す.
00073     float predict_net(float*);
00074     
00075     // 未知データの正例の識別確率[0,1]を返す.
00076     // 予測はシグモイド関数による.
00077     // 1ならばマージンを超えて完全に正例領域に入っている.
00078     // 0ならばマージンを超えて完全に負例領域に入っている.
00079     virtual float predict_probability(float*);
00080     
00081     // 双対係数のゲッター
00082     virtual float* get_alpha(void);
00083     
00084     // 双対係数のセッター
00085     void set_alpha(float*, int);
00086 
00087 };
00088 
00089 #endif /* SVM_H_INCLUDED */