Easy Support Vector Machine

Dependents:   WeatherPredictor

Committer:
yukari_hinata
Date:
Sun Feb 15 09:27:08 2015 +0000
Revision:
2:c4a5251cee32
Parent:
0:3f38e74a4a77
Child:
5:792afbb0bcf3
modified

Who changed what in which revision?

UserRevisionLine numberNew contents of line
yukari_hinata 2:c4a5251cee32 1 /* Support Vector Machine by hinata_yukari */
yukari_hinata 2:c4a5251cee32 2
yukari_hinata 0:3f38e74a4a77 3 #ifndef SVM_H_INCLUDED
yukari_hinata 0:3f38e74a4a77 4 #define SVM_H_INCLUDED
yukari_hinata 0:3f38e74a4a77 5
yukari_hinata 0:3f38e74a4a77 6 #include "mbed.h"
yukari_hinata 0:3f38e74a4a77 7
yukari_hinata 0:3f38e74a4a77 8 #include "./util/util.hpp"
yukari_hinata 0:3f38e74a4a77 9
yukari_hinata 0:3f38e74a4a77 10 // SVMの学習状態
yukari_hinata 0:3f38e74a4a77 11 typedef enum {
yukari_hinata 0:3f38e74a4a77 12 SVM_NOT_LEARN, // 学習していない
yukari_hinata 0:3f38e74a4a77 13 SVM_LEARN_SUCCESS, // 正常終了(収束によって終了した)
yukari_hinata 0:3f38e74a4a77 14 SVM_NOT_CONVERGENCED, // 繰り返し上限到達
yukari_hinata 0:3f38e74a4a77 15 SVM_DETECT_BAD_VAL, // 係数で非数/無限を検知した
yukari_hinata 0:3f38e74a4a77 16 SVM_SET_ALPHA, // 学習していないが,最適化した係数がセットされている
yukari_hinata 0:3f38e74a4a77 17 } SVM_STATUS;
yukari_hinata 0:3f38e74a4a77 18
yukari_hinata 0:3f38e74a4a77 19 // Class SVM
yukari_hinata 0:3f38e74a4a77 20
yukari_hinata 0:3f38e74a4a77 21 class SVM
yukari_hinata 0:3f38e74a4a77 22 {
yukari_hinata 0:3f38e74a4a77 23 protected:
yukari_hinata 0:3f38e74a4a77 24
yukari_hinata 0:3f38e74a4a77 25 int dim_signal; // 入力データの次元
yukari_hinata 0:3f38e74a4a77 26 int n_sample; // サンプルの個数
yukari_hinata 0:3f38e74a4a77 27 float* sample_max; // 特徴の正規化用の,各次元の最大値
yukari_hinata 0:3f38e74a4a77 28 float* sample_min; // 特徴の正規化用の,各次元の最小値
yukari_hinata 0:3f38e74a4a77 29 float* alpha; // 双対係数のベクトル
yukari_hinata 0:3f38e74a4a77 30 // float* grammat; // グラム(カーネル)行列 -> 領域をn_sample^2食うので,廃止.時間はかかるけど逐次処理.
yukari_hinata 0:3f38e74a4a77 31 int maxIteration; // 学習の最大繰り返し回数
yukari_hinata 0:3f38e74a4a77 32 float epsilon; // 収束判定用の小さな値
yukari_hinata 0:3f38e74a4a77 33 float eta; // 学習係数
yukari_hinata 0:3f38e74a4a77 34 float learn_alpha; // 慣性項の係数
yukari_hinata 0:3f38e74a4a77 35 float C1; // 1ノルムソフトマージンンのスラック変数とハードマージンのトレード
yukari_hinata 0:3f38e74a4a77 36 // オフを与える係数 (無限でハードマージンに一致,FLT_MAXで近似)
yukari_hinata 0:3f38e74a4a77 37 float C2; // 2ノルムソフトマージンの〜,(無限でハードマージンに一致,FLT_MAXで近似)
yukari_hinata 0:3f38e74a4a77 38 // また,C1かC2どちらか一つのみを設定すること.
yukari_hinata 0:3f38e74a4a77 39
yukari_hinata 0:3f38e74a4a77 40 public:
yukari_hinata 0:3f38e74a4a77 41 int* label; // サンプルの2値ラベルの配列(-1 or 1)
yukari_hinata 0:3f38e74a4a77 42 float* sample; // n次元サンプルデータの配列.
yukari_hinata 0:3f38e74a4a77 43 int status; // SVMの状態
yukari_hinata 0:3f38e74a4a77 44
yukari_hinata 0:3f38e74a4a77 45 protected:
yukari_hinata 0:3f38e74a4a77 46 // カーネル関数. ここでは簡易なRBFカーネルをハードコーディング
yukari_hinata 0:3f38e74a4a77 47 inline float kernel_function(float *x, float *y, int n) {
yukari_hinata 0:3f38e74a4a77 48 register float inprod = 0;
yukari_hinata 0:3f38e74a4a77 49 for (int i=0;i < n;i++) {
yukari_hinata 0:3f38e74a4a77 50 inprod += powf(x[i] - y[i],2);
yukari_hinata 0:3f38e74a4a77 51 //printf("x[%d] : %f y[%d] : %f \n", i, x[i], i, y[i]);
yukari_hinata 0:3f38e74a4a77 52 }
yukari_hinata 2:c4a5251cee32 53 return expf(-inprod/0.1);
yukari_hinata 0:3f38e74a4a77 54 }
yukari_hinata 0:3f38e74a4a77 55
yukari_hinata 0:3f38e74a4a77 56 public:
yukari_hinata 0:3f38e74a4a77 57 // 最小の引数による初期化. 順次拡張予定.
yukari_hinata 0:3f38e74a4a77 58 SVM(int, // データ次元
yukari_hinata 0:3f38e74a4a77 59 int, // サンプル個数
yukari_hinata 0:3f38e74a4a77 60 float*, // サンプル
yukari_hinata 0:3f38e74a4a77 61 int*); // ラベル
yukari_hinata 0:3f38e74a4a77 62
yukari_hinata 0:3f38e74a4a77 63 ~SVM(void);
yukari_hinata 0:3f38e74a4a77 64
yukari_hinata 0:3f38e74a4a77 65 // 学習によりマージンを最大化し,サポートベクトルを確定させる.
yukari_hinata 0:3f38e74a4a77 66 virtual int learning(void);
yukari_hinata 0:3f38e74a4a77 67
yukari_hinata 0:3f38e74a4a77 68 // 未知データの識別.データを受け取り,識別ラベルを-1 or 1で返す.
yukari_hinata 0:3f38e74a4a77 69 virtual int predict_label(float*);
yukari_hinata 0:3f38e74a4a77 70
yukari_hinata 0:3f38e74a4a77 71 // 未知データのネットワーク値(負の値ならば0,即ち識別面の下半空間に,
yukari_hinata 0:3f38e74a4a77 72 // 正の値ならば1,識別面の上半空間に存在すると判定)を計算して返す.
yukari_hinata 0:3f38e74a4a77 73 float predict_net(float*);
yukari_hinata 0:3f38e74a4a77 74
yukari_hinata 0:3f38e74a4a77 75 // 未知データの正例の識別確率[0,1]を返す.
yukari_hinata 0:3f38e74a4a77 76 // 予測はシグモイド関数による.
yukari_hinata 0:3f38e74a4a77 77 // 1ならばマージンを超えて完全に正例領域に入っている.
yukari_hinata 0:3f38e74a4a77 78 // 0ならばマージンを超えて完全に負例領域に入っている.
yukari_hinata 0:3f38e74a4a77 79 virtual float predict_probability(float*);
yukari_hinata 0:3f38e74a4a77 80
yukari_hinata 0:3f38e74a4a77 81 // 双対係数のゲッター
yukari_hinata 0:3f38e74a4a77 82 virtual float* get_alpha(void);
yukari_hinata 0:3f38e74a4a77 83
yukari_hinata 0:3f38e74a4a77 84 // 双対係数のセッター
yukari_hinata 2:c4a5251cee32 85 void set_alpha(float*, int);
yukari_hinata 0:3f38e74a4a77 86
yukari_hinata 0:3f38e74a4a77 87 };
yukari_hinata 0:3f38e74a4a77 88
yukari_hinata 0:3f38e74a4a77 89 #endif /* SVM_H_INCLUDED */