Easy Support Vector Machine

Dependents:   WeatherPredictor

Revision:
0:3f38e74a4a77
Child:
2:c4a5251cee32
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/MCSVM.hpp	Thu Jan 15 08:22:02 2015 +0000
@@ -0,0 +1,57 @@
+#ifndef MCSVM_H_INCLUDED
+#define MCSVM_H_INCLUDED
+
+#include "SVM.hpp"
+
+// MCSVMの学習状態
+typedef enum {
+  MCSVM_NOT_LEARN,        // 学習していない
+  MCSVM_LEARN_SUCCESS,    // 正常に全SVMを学習した.
+} MCSVM_STATUS;
+
+// クラスi,jを識別するSVMのインデックスを返す
+#define INX_KSVM_IJ(n_class,i,j) ((((i) * ((2 * (n_class)) - (3) - (i)))/2) + ((j) - (1)))
+
+// Class Multi-Class SVM
+// One-vs-One法(うーんこの)によるSVM.
+
+class MCSVM : public SVM
+{
+  private:
+    int    n_class;      // 識別クラス数
+                         // 異なる2クラスi,j(i < j)を識別するSVMをインデックスi*(2k-3-i)/2 + j-1で参照する
+    int    maxFailcount; // 学習失敗を許容する最大回数
+    
+    // マルチクラス用の拡張:サンプルは共有し, 各SVMのパラメタを個別に保持
+    float* mc_alpha;     // 各識別用の双対係数
+    int*   mc_label;     // 各識別用の2値(-1,1)ラベル, 識別に関係しないデータにはラベル0が付与される.
+                         // マルチクラス識別の場合,SVM::labelには0,...,n_class-1までのラベルが付いている
+
+  public:
+    MCSVM(int,      // クラス個数
+          int,      // データ次元
+          int,      // サンプル個数
+          float*,   // サンプルデータ
+          int*);    // マルチクラスラベル
+          
+    ~MCSVM(void);
+          
+    // 未知データのラベルを推定する.返り値はマルチクラスラベル0,...,n_class-1
+    int predict_label(float*);
+    
+    // 未知データの識別確率を推定する.
+    // ラベル識別predict_label結果の整合性を考えない.
+    float predict_probability(float*);
+
+    // 全てのSVMの学習する.
+    int learning(void);
+    
+    // 双対係数のゲッター
+    float* get_alpha(void);
+    
+    // 双対係数のセッター
+    void set_alpha(float*, int, int);
+                        
+};
+
+#endif /* MCSVM_H_INCLUDED */