Taiyo Mineo / SVM

Dependents:   WeatherPredictor

Committer:
yukari_hinata
Date:
Thu Jan 15 08:22:02 2015 +0000
Revision:
0:3f38e74a4a77
Child:
2:c4a5251cee32
first commit

Who changed what in which revision?

UserRevisionLine numberNew contents of line
yukari_hinata 0:3f38e74a4a77 1 #ifndef SVM_H_INCLUDED
yukari_hinata 0:3f38e74a4a77 2 #define SVM_H_INCLUDED
yukari_hinata 0:3f38e74a4a77 3
yukari_hinata 0:3f38e74a4a77 4 #include "mbed.h"
yukari_hinata 0:3f38e74a4a77 5
yukari_hinata 0:3f38e74a4a77 6 #include "./util/util.hpp"
yukari_hinata 0:3f38e74a4a77 7
yukari_hinata 0:3f38e74a4a77 8 // SVMの学習状態
yukari_hinata 0:3f38e74a4a77 9 typedef enum {
yukari_hinata 0:3f38e74a4a77 10 SVM_NOT_LEARN, // 学習していない
yukari_hinata 0:3f38e74a4a77 11 SVM_LEARN_SUCCESS, // 正常終了(収束によって終了した)
yukari_hinata 0:3f38e74a4a77 12 SVM_NOT_CONVERGENCED, // 繰り返し上限到達
yukari_hinata 0:3f38e74a4a77 13 SVM_DETECT_BAD_VAL, // 係数で非数/無限を検知した
yukari_hinata 0:3f38e74a4a77 14 SVM_SET_ALPHA, // 学習していないが,最適化した係数がセットされている
yukari_hinata 0:3f38e74a4a77 15 } SVM_STATUS;
yukari_hinata 0:3f38e74a4a77 16
yukari_hinata 0:3f38e74a4a77 17 // Class SVM
yukari_hinata 0:3f38e74a4a77 18
yukari_hinata 0:3f38e74a4a77 19 class SVM
yukari_hinata 0:3f38e74a4a77 20 {
yukari_hinata 0:3f38e74a4a77 21 protected:
yukari_hinata 0:3f38e74a4a77 22
yukari_hinata 0:3f38e74a4a77 23 int dim_signal; // 入力データの次元
yukari_hinata 0:3f38e74a4a77 24 int n_sample; // サンプルの個数
yukari_hinata 0:3f38e74a4a77 25 float* sample_max; // 特徴の正規化用の,各次元の最大値
yukari_hinata 0:3f38e74a4a77 26 float* sample_min; // 特徴の正規化用の,各次元の最小値
yukari_hinata 0:3f38e74a4a77 27 float* alpha; // 双対係数のベクトル
yukari_hinata 0:3f38e74a4a77 28 // float* grammat; // グラム(カーネル)行列 -> 領域をn_sample^2食うので,廃止.時間はかかるけど逐次処理.
yukari_hinata 0:3f38e74a4a77 29 int maxIteration; // 学習の最大繰り返し回数
yukari_hinata 0:3f38e74a4a77 30 float epsilon; // 収束判定用の小さな値
yukari_hinata 0:3f38e74a4a77 31 float eta; // 学習係数
yukari_hinata 0:3f38e74a4a77 32 float learn_alpha; // 慣性項の係数
yukari_hinata 0:3f38e74a4a77 33 float C1; // 1ノルムソフトマージンンのスラック変数とハードマージンのトレード
yukari_hinata 0:3f38e74a4a77 34 // オフを与える係数 (無限でハードマージンに一致,FLT_MAXで近似)
yukari_hinata 0:3f38e74a4a77 35 float C2; // 2ノルムソフトマージンの〜,(無限でハードマージンに一致,FLT_MAXで近似)
yukari_hinata 0:3f38e74a4a77 36 // また,C1かC2どちらか一つのみを設定すること.
yukari_hinata 0:3f38e74a4a77 37
yukari_hinata 0:3f38e74a4a77 38 public:
yukari_hinata 0:3f38e74a4a77 39 int* label; // サンプルの2値ラベルの配列(-1 or 1)
yukari_hinata 0:3f38e74a4a77 40 float* sample; // n次元サンプルデータの配列.
yukari_hinata 0:3f38e74a4a77 41 int status; // SVMの状態
yukari_hinata 0:3f38e74a4a77 42
yukari_hinata 0:3f38e74a4a77 43 protected:
yukari_hinata 0:3f38e74a4a77 44 // カーネル関数. ここでは簡易なRBFカーネルをハードコーディング
yukari_hinata 0:3f38e74a4a77 45 inline float kernel_function(float *x, float *y, int n) {
yukari_hinata 0:3f38e74a4a77 46 register float inprod = 0;
yukari_hinata 0:3f38e74a4a77 47 for (int i=0;i < n;i++) {
yukari_hinata 0:3f38e74a4a77 48 inprod += powf(x[i] - y[i],2);
yukari_hinata 0:3f38e74a4a77 49 //printf("x[%d] : %f y[%d] : %f \n", i, x[i], i, y[i]);
yukari_hinata 0:3f38e74a4a77 50 }
yukari_hinata 0:3f38e74a4a77 51 return expf(-inprod);
yukari_hinata 0:3f38e74a4a77 52 }
yukari_hinata 0:3f38e74a4a77 53
yukari_hinata 0:3f38e74a4a77 54 public:
yukari_hinata 0:3f38e74a4a77 55 // 最小の引数による初期化. 順次拡張予定.
yukari_hinata 0:3f38e74a4a77 56 SVM(int, // データ次元
yukari_hinata 0:3f38e74a4a77 57 int, // サンプル個数
yukari_hinata 0:3f38e74a4a77 58 float*, // サンプル
yukari_hinata 0:3f38e74a4a77 59 int*); // ラベル
yukari_hinata 0:3f38e74a4a77 60
yukari_hinata 0:3f38e74a4a77 61 ~SVM(void);
yukari_hinata 0:3f38e74a4a77 62
yukari_hinata 0:3f38e74a4a77 63 // 学習によりマージンを最大化し,サポートベクトルを確定させる.
yukari_hinata 0:3f38e74a4a77 64 virtual int learning(void);
yukari_hinata 0:3f38e74a4a77 65
yukari_hinata 0:3f38e74a4a77 66 // 未知データの識別.データを受け取り,識別ラベルを-1 or 1で返す.
yukari_hinata 0:3f38e74a4a77 67 virtual int predict_label(float*);
yukari_hinata 0:3f38e74a4a77 68
yukari_hinata 0:3f38e74a4a77 69 // 未知データのネットワーク値(負の値ならば0,即ち識別面の下半空間に,
yukari_hinata 0:3f38e74a4a77 70 // 正の値ならば1,識別面の上半空間に存在すると判定)を計算して返す.
yukari_hinata 0:3f38e74a4a77 71 float predict_net(float*);
yukari_hinata 0:3f38e74a4a77 72
yukari_hinata 0:3f38e74a4a77 73 // 未知データの正例の識別確率[0,1]を返す.
yukari_hinata 0:3f38e74a4a77 74 // 予測はシグモイド関数による.
yukari_hinata 0:3f38e74a4a77 75 // 1ならばマージンを超えて完全に正例領域に入っている.
yukari_hinata 0:3f38e74a4a77 76 // 0ならばマージンを超えて完全に負例領域に入っている.
yukari_hinata 0:3f38e74a4a77 77 virtual float predict_probability(float*);
yukari_hinata 0:3f38e74a4a77 78
yukari_hinata 0:3f38e74a4a77 79 // 双対係数のゲッター
yukari_hinata 0:3f38e74a4a77 80 virtual float* get_alpha(void);
yukari_hinata 0:3f38e74a4a77 81
yukari_hinata 0:3f38e74a4a77 82 // 双対係数のセッター
yukari_hinata 0:3f38e74a4a77 83 virtual void set_alpha(float*, int);
yukari_hinata 0:3f38e74a4a77 84
yukari_hinata 0:3f38e74a4a77 85 };
yukari_hinata 0:3f38e74a4a77 86
yukari_hinata 0:3f38e74a4a77 87 #endif /* SVM_H_INCLUDED */