Easy Support Vector Machine

Dependents:   WeatherPredictor

Revision:
0:3f38e74a4a77
Child:
2:c4a5251cee32
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/SVM.hpp	Thu Jan 15 08:22:02 2015 +0000
@@ -0,0 +1,87 @@
+#ifndef SVM_H_INCLUDED
+#define SVM_H_INCLUDED
+
+#include "mbed.h"
+
+#include "./util/util.hpp"
+
+// SVMの学習状態
+typedef enum {
+  SVM_NOT_LEARN,        // 学習していない
+  SVM_LEARN_SUCCESS,    // 正常終了(収束によって終了した)
+  SVM_NOT_CONVERGENCED, // 繰り返し上限到達
+  SVM_DETECT_BAD_VAL,   // 係数で非数/無限を検知した
+  SVM_SET_ALPHA,        // 学習していないが,最適化した係数がセットされている
+} SVM_STATUS;
+
+// Class SVM
+
+class SVM
+{
+  protected:
+
+    int dim_signal;     // 入力データの次元
+    int n_sample;       // サンプルの個数
+    float* sample_max;  // 特徴の正規化用の,各次元の最大値
+    float* sample_min;  // 特徴の正規化用の,各次元の最小値
+    float* alpha;       // 双対係数のベクトル
+    // float* grammat;     // グラム(カーネル)行列 -> 領域をn_sample^2食うので,廃止.時間はかかるけど逐次処理. 
+    int maxIteration;   // 学習の最大繰り返し回数
+    float epsilon;      // 収束判定用の小さな値
+    float eta;          // 学習係数
+    float learn_alpha;  // 慣性項の係数
+    float C1;           // 1ノルムソフトマージンンのスラック変数とハードマージンのトレード
+                        // オフを与える係数 (無限でハードマージンに一致,FLT_MAXで近似)
+    float C2;           // 2ノルムソフトマージンの〜,(無限でハードマージンに一致,FLT_MAXで近似)
+                        // また,C1かC2どちらか一つのみを設定すること.
+
+  public:
+    int*   label;       // サンプルの2値ラベルの配列(-1 or 1)
+    float* sample;      // n次元サンプルデータの配列.
+    int    status;      // SVMの状態
+
+  protected:
+    // カーネル関数. ここでは簡易なRBFカーネルをハードコーディング
+    inline float kernel_function(float *x, float *y, int n) {
+      register float inprod = 0;
+      for (int i=0;i < n;i++) {
+        inprod += powf(x[i] - y[i],2);
+        //printf("x[%d] : %f y[%d] : %f \n", i, x[i], i, y[i]);
+      }
+      return expf(-inprod);
+    }
+
+  public:
+    // 最小の引数による初期化. 順次拡張予定.
+    SVM(int,      // データ次元
+        int,      // サンプル個数
+        float*,   // サンプル
+        int*);    // ラベル
+        
+    ~SVM(void);
+
+    // 学習によりマージンを最大化し,サポートベクトルを確定させる.
+    virtual int learning(void);
+
+    // 未知データの識別.データを受け取り,識別ラベルを-1 or 1で返す.
+    virtual int predict_label(float*);
+
+    // 未知データのネットワーク値(負の値ならば0,即ち識別面の下半空間に,
+    // 正の値ならば1,識別面の上半空間に存在すると判定)を計算して返す.
+    float predict_net(float*);
+    
+    // 未知データの正例の識別確率[0,1]を返す.
+    // 予測はシグモイド関数による.
+    // 1ならばマージンを超えて完全に正例領域に入っている.
+    // 0ならばマージンを超えて完全に負例領域に入っている.
+    virtual float predict_probability(float*);
+    
+    // 双対係数のゲッター
+    virtual float* get_alpha(void);
+    
+    // 双対係数のセッター
+    virtual void set_alpha(float*, int);
+
+};
+
+#endif /* SVM_H_INCLUDED */
\ No newline at end of file