Easy Support Vector Machine
SVM.hpp@5:792afbb0bcf3, 2015-02-18 (annotated)
- Committer:
- yukari_hinata
- Date:
- Wed Feb 18 15:01:12 2015 +0000
- Revision:
- 5:792afbb0bcf3
- Parent:
- 2:c4a5251cee32
Who changed what in which revision?
User | Revision | Line number | New contents of line |
---|---|---|---|
yukari_hinata | 2:c4a5251cee32 | 1 | /* Support Vector Machine by hinata_yukari */ |
yukari_hinata | 2:c4a5251cee32 | 2 | |
yukari_hinata | 0:3f38e74a4a77 | 3 | #ifndef SVM_H_INCLUDED |
yukari_hinata | 0:3f38e74a4a77 | 4 | #define SVM_H_INCLUDED |
yukari_hinata | 0:3f38e74a4a77 | 5 | |
yukari_hinata | 0:3f38e74a4a77 | 6 | #include "mbed.h" |
yukari_hinata | 0:3f38e74a4a77 | 7 | |
yukari_hinata | 5:792afbb0bcf3 | 8 | #include "../ml_util/ml_util.hpp" |
yukari_hinata | 0:3f38e74a4a77 | 9 | |
yukari_hinata | 0:3f38e74a4a77 | 10 | // SVMの学習状態 |
yukari_hinata | 0:3f38e74a4a77 | 11 | typedef enum { |
yukari_hinata | 0:3f38e74a4a77 | 12 | SVM_NOT_LEARN, // 学習していない |
yukari_hinata | 0:3f38e74a4a77 | 13 | SVM_LEARN_SUCCESS, // 正常終了(収束によって終了した) |
yukari_hinata | 0:3f38e74a4a77 | 14 | SVM_NOT_CONVERGENCED, // 繰り返し上限到達 |
yukari_hinata | 0:3f38e74a4a77 | 15 | SVM_DETECT_BAD_VAL, // 係数で非数/無限を検知した |
yukari_hinata | 0:3f38e74a4a77 | 16 | SVM_SET_ALPHA, // 学習していないが,最適化した係数がセットされている |
yukari_hinata | 0:3f38e74a4a77 | 17 | } SVM_STATUS; |
yukari_hinata | 0:3f38e74a4a77 | 18 | |
yukari_hinata | 0:3f38e74a4a77 | 19 | // Class SVM |
yukari_hinata | 0:3f38e74a4a77 | 20 | |
yukari_hinata | 0:3f38e74a4a77 | 21 | class SVM |
yukari_hinata | 0:3f38e74a4a77 | 22 | { |
yukari_hinata | 0:3f38e74a4a77 | 23 | protected: |
yukari_hinata | 0:3f38e74a4a77 | 24 | |
yukari_hinata | 0:3f38e74a4a77 | 25 | int dim_signal; // 入力データの次元 |
yukari_hinata | 0:3f38e74a4a77 | 26 | int n_sample; // サンプルの個数 |
yukari_hinata | 0:3f38e74a4a77 | 27 | float* sample_max; // 特徴の正規化用の,各次元の最大値 |
yukari_hinata | 0:3f38e74a4a77 | 28 | float* sample_min; // 特徴の正規化用の,各次元の最小値 |
yukari_hinata | 0:3f38e74a4a77 | 29 | float* alpha; // 双対係数のベクトル |
yukari_hinata | 0:3f38e74a4a77 | 30 | // float* grammat; // グラム(カーネル)行列 -> 領域をn_sample^2食うので,廃止.時間はかかるけど逐次処理. |
yukari_hinata | 0:3f38e74a4a77 | 31 | int maxIteration; // 学習の最大繰り返し回数 |
yukari_hinata | 0:3f38e74a4a77 | 32 | float epsilon; // 収束判定用の小さな値 |
yukari_hinata | 0:3f38e74a4a77 | 33 | float eta; // 学習係数 |
yukari_hinata | 0:3f38e74a4a77 | 34 | float learn_alpha; // 慣性項の係数 |
yukari_hinata | 0:3f38e74a4a77 | 35 | float C1; // 1ノルムソフトマージンンのスラック変数とハードマージンのトレード |
yukari_hinata | 0:3f38e74a4a77 | 36 | // オフを与える係数 (無限でハードマージンに一致,FLT_MAXで近似) |
yukari_hinata | 0:3f38e74a4a77 | 37 | float C2; // 2ノルムソフトマージンの〜,(無限でハードマージンに一致,FLT_MAXで近似) |
yukari_hinata | 0:3f38e74a4a77 | 38 | // また,C1かC2どちらか一つのみを設定すること. |
yukari_hinata | 0:3f38e74a4a77 | 39 | |
yukari_hinata | 0:3f38e74a4a77 | 40 | public: |
yukari_hinata | 0:3f38e74a4a77 | 41 | int* label; // サンプルの2値ラベルの配列(-1 or 1) |
yukari_hinata | 0:3f38e74a4a77 | 42 | float* sample; // n次元サンプルデータの配列. |
yukari_hinata | 0:3f38e74a4a77 | 43 | int status; // SVMの状態 |
yukari_hinata | 0:3f38e74a4a77 | 44 | |
yukari_hinata | 0:3f38e74a4a77 | 45 | protected: |
yukari_hinata | 0:3f38e74a4a77 | 46 | // カーネル関数. ここでは簡易なRBFカーネルをハードコーディング |
yukari_hinata | 0:3f38e74a4a77 | 47 | inline float kernel_function(float *x, float *y, int n) { |
yukari_hinata | 0:3f38e74a4a77 | 48 | register float inprod = 0; |
yukari_hinata | 0:3f38e74a4a77 | 49 | for (int i=0;i < n;i++) { |
yukari_hinata | 0:3f38e74a4a77 | 50 | inprod += powf(x[i] - y[i],2); |
yukari_hinata | 0:3f38e74a4a77 | 51 | //printf("x[%d] : %f y[%d] : %f \n", i, x[i], i, y[i]); |
yukari_hinata | 0:3f38e74a4a77 | 52 | } |
yukari_hinata | 2:c4a5251cee32 | 53 | return expf(-inprod/0.1); |
yukari_hinata | 0:3f38e74a4a77 | 54 | } |
yukari_hinata | 0:3f38e74a4a77 | 55 | |
yukari_hinata | 0:3f38e74a4a77 | 56 | public: |
yukari_hinata | 0:3f38e74a4a77 | 57 | // 最小の引数による初期化. 順次拡張予定. |
yukari_hinata | 0:3f38e74a4a77 | 58 | SVM(int, // データ次元 |
yukari_hinata | 0:3f38e74a4a77 | 59 | int, // サンプル個数 |
yukari_hinata | 0:3f38e74a4a77 | 60 | float*, // サンプル |
yukari_hinata | 0:3f38e74a4a77 | 61 | int*); // ラベル |
yukari_hinata | 0:3f38e74a4a77 | 62 | |
yukari_hinata | 0:3f38e74a4a77 | 63 | ~SVM(void); |
yukari_hinata | 0:3f38e74a4a77 | 64 | |
yukari_hinata | 0:3f38e74a4a77 | 65 | // 学習によりマージンを最大化し,サポートベクトルを確定させる. |
yukari_hinata | 0:3f38e74a4a77 | 66 | virtual int learning(void); |
yukari_hinata | 0:3f38e74a4a77 | 67 | |
yukari_hinata | 0:3f38e74a4a77 | 68 | // 未知データの識別.データを受け取り,識別ラベルを-1 or 1で返す. |
yukari_hinata | 0:3f38e74a4a77 | 69 | virtual int predict_label(float*); |
yukari_hinata | 0:3f38e74a4a77 | 70 | |
yukari_hinata | 0:3f38e74a4a77 | 71 | // 未知データのネットワーク値(負の値ならば0,即ち識別面の下半空間に, |
yukari_hinata | 0:3f38e74a4a77 | 72 | // 正の値ならば1,識別面の上半空間に存在すると判定)を計算して返す. |
yukari_hinata | 0:3f38e74a4a77 | 73 | float predict_net(float*); |
yukari_hinata | 0:3f38e74a4a77 | 74 | |
yukari_hinata | 0:3f38e74a4a77 | 75 | // 未知データの正例の識別確率[0,1]を返す. |
yukari_hinata | 0:3f38e74a4a77 | 76 | // 予測はシグモイド関数による. |
yukari_hinata | 0:3f38e74a4a77 | 77 | // 1ならばマージンを超えて完全に正例領域に入っている. |
yukari_hinata | 0:3f38e74a4a77 | 78 | // 0ならばマージンを超えて完全に負例領域に入っている. |
yukari_hinata | 0:3f38e74a4a77 | 79 | virtual float predict_probability(float*); |
yukari_hinata | 0:3f38e74a4a77 | 80 | |
yukari_hinata | 0:3f38e74a4a77 | 81 | // 双対係数のゲッター |
yukari_hinata | 0:3f38e74a4a77 | 82 | virtual float* get_alpha(void); |
yukari_hinata | 0:3f38e74a4a77 | 83 | |
yukari_hinata | 0:3f38e74a4a77 | 84 | // 双対係数のセッター |
yukari_hinata | 2:c4a5251cee32 | 85 | void set_alpha(float*, int); |
yukari_hinata | 0:3f38e74a4a77 | 86 | |
yukari_hinata | 0:3f38e74a4a77 | 87 | }; |
yukari_hinata | 0:3f38e74a4a77 | 88 | |
yukari_hinata | 0:3f38e74a4a77 | 89 | #endif /* SVM_H_INCLUDED */ |