Easy Support Vector Machine

Dependents:   WeatherPredictor

SVM.hpp

Committer:
yukari_hinata
Date:
2015-02-18
Revision:
5:792afbb0bcf3
Parent:
2:c4a5251cee32

File content as of revision 5:792afbb0bcf3:

/* Support Vector Machine by hinata_yukari */

#ifndef SVM_H_INCLUDED
#define SVM_H_INCLUDED

#include "mbed.h"

#include "../ml_util/ml_util.hpp"

// SVMの学習状態
typedef enum {
  SVM_NOT_LEARN,        // 学習していない
  SVM_LEARN_SUCCESS,    // 正常終了(収束によって終了した)
  SVM_NOT_CONVERGENCED, // 繰り返し上限到達
  SVM_DETECT_BAD_VAL,   // 係数で非数/無限を検知した
  SVM_SET_ALPHA,        // 学習していないが,最適化した係数がセットされている
} SVM_STATUS;

// Class SVM

class SVM
{
  protected:

    int dim_signal;     // 入力データの次元
    int n_sample;       // サンプルの個数
    float* sample_max;  // 特徴の正規化用の,各次元の最大値
    float* sample_min;  // 特徴の正規化用の,各次元の最小値
    float* alpha;       // 双対係数のベクトル
    // float* grammat;     // グラム(カーネル)行列 -> 領域をn_sample^2食うので,廃止.時間はかかるけど逐次処理. 
    int maxIteration;   // 学習の最大繰り返し回数
    float epsilon;      // 収束判定用の小さな値
    float eta;          // 学習係数
    float learn_alpha;  // 慣性項の係数
    float C1;           // 1ノルムソフトマージンンのスラック変数とハードマージンのトレード
                        // オフを与える係数 (無限でハードマージンに一致,FLT_MAXで近似)
    float C2;           // 2ノルムソフトマージンの〜,(無限でハードマージンに一致,FLT_MAXで近似)
                        // また,C1かC2どちらか一つのみを設定すること.

  public:
    int*   label;       // サンプルの2値ラベルの配列(-1 or 1)
    float* sample;      // n次元サンプルデータの配列.
    int    status;      // SVMの状態

  protected:
    // カーネル関数. ここでは簡易なRBFカーネルをハードコーディング
    inline float kernel_function(float *x, float *y, int n) {
      register float inprod = 0;
      for (int i=0;i < n;i++) {
        inprod += powf(x[i] - y[i],2);
        //printf("x[%d] : %f y[%d] : %f \n", i, x[i], i, y[i]);
      }
      return expf(-inprod/0.1);
    }

  public:
    // 最小の引数による初期化. 順次拡張予定.
    SVM(int,      // データ次元
        int,      // サンプル個数
        float*,   // サンプル
        int*);    // ラベル
        
    ~SVM(void);

    // 学習によりマージンを最大化し,サポートベクトルを確定させる.
    virtual int learning(void);

    // 未知データの識別.データを受け取り,識別ラベルを-1 or 1で返す.
    virtual int predict_label(float*);

    // 未知データのネットワーク値(負の値ならば0,即ち識別面の下半空間に,
    // 正の値ならば1,識別面の上半空間に存在すると判定)を計算して返す.
    float predict_net(float*);
    
    // 未知データの正例の識別確率[0,1]を返す.
    // 予測はシグモイド関数による.
    // 1ならばマージンを超えて完全に正例領域に入っている.
    // 0ならばマージンを超えて完全に負例領域に入っている.
    virtual float predict_probability(float*);
    
    // 双対係数のゲッター
    virtual float* get_alpha(void);
    
    // 双対係数のセッター
    void set_alpha(float*, int);

};

#endif /* SVM_H_INCLUDED */