Easy Support Vector Machine
Embed:
(wiki syntax)
Show/hide line numbers
SVM.hpp
00001 /* Support Vector Machine by hinata_yukari */ 00002 00003 #ifndef SVM_H_INCLUDED 00004 #define SVM_H_INCLUDED 00005 00006 #include "mbed.h" 00007 00008 #include "../ml_util/ml_util.hpp" 00009 00010 // SVMの学習状態 00011 typedef enum { 00012 SVM_NOT_LEARN, // 学習していない 00013 SVM_LEARN_SUCCESS, // 正常終了(収束によって終了した) 00014 SVM_NOT_CONVERGENCED, // 繰り返し上限到達 00015 SVM_DETECT_BAD_VAL, // 係数で非数/無限を検知した 00016 SVM_SET_ALPHA, // 学習していないが,最適化した係数がセットされている 00017 } SVM_STATUS; 00018 00019 // Class SVM 00020 00021 class SVM 00022 { 00023 protected: 00024 00025 int dim_signal; // 入力データの次元 00026 int n_sample; // サンプルの個数 00027 float* sample_max; // 特徴の正規化用の,各次元の最大値 00028 float* sample_min; // 特徴の正規化用の,各次元の最小値 00029 float* alpha; // 双対係数のベクトル 00030 // float* grammat; // グラム(カーネル)行列 -> 領域をn_sample^2食うので,廃止.時間はかかるけど逐次処理. 00031 int maxIteration; // 学習の最大繰り返し回数 00032 float epsilon; // 収束判定用の小さな値 00033 float eta; // 学習係数 00034 float learn_alpha; // 慣性項の係数 00035 float C1; // 1ノルムソフトマージンンのスラック変数とハードマージンのトレード 00036 // オフを与える係数 (無限でハードマージンに一致,FLT_MAXで近似) 00037 float C2; // 2ノルムソフトマージンの〜,(無限でハードマージンに一致,FLT_MAXで近似) 00038 // また,C1かC2どちらか一つのみを設定すること. 00039 00040 public: 00041 int* label; // サンプルの2値ラベルの配列(-1 or 1) 00042 float* sample; // n次元サンプルデータの配列. 00043 int status; // SVMの状態 00044 00045 protected: 00046 // カーネル関数. ここでは簡易なRBFカーネルをハードコーディング 00047 inline float kernel_function(float *x, float *y, int n) { 00048 register float inprod = 0; 00049 for (int i=0;i < n;i++) { 00050 inprod += powf(x[i] - y[i],2); 00051 //printf("x[%d] : %f y[%d] : %f \n", i, x[i], i, y[i]); 00052 } 00053 return expf(-inprod/0.1); 00054 } 00055 00056 public: 00057 // 最小の引数による初期化. 順次拡張予定. 00058 SVM(int, // データ次元 00059 int, // サンプル個数 00060 float*, // サンプル 00061 int*); // ラベル 00062 00063 ~SVM(void); 00064 00065 // 学習によりマージンを最大化し,サポートベクトルを確定させる. 00066 virtual int learning(void); 00067 00068 // 未知データの識別.データを受け取り,識別ラベルを-1 or 1で返す. 00069 virtual int predict_label(float*); 00070 00071 // 未知データのネットワーク値(負の値ならば0,即ち識別面の下半空間に, 00072 // 正の値ならば1,識別面の上半空間に存在すると判定)を計算して返す. 00073 float predict_net(float*); 00074 00075 // 未知データの正例の識別確率[0,1]を返す. 00076 // 予測はシグモイド関数による. 00077 // 1ならばマージンを超えて完全に正例領域に入っている. 00078 // 0ならばマージンを超えて完全に負例領域に入っている. 00079 virtual float predict_probability(float*); 00080 00081 // 双対係数のゲッター 00082 virtual float* get_alpha(void); 00083 00084 // 双対係数のセッター 00085 void set_alpha(float*, int); 00086 00087 }; 00088 00089 #endif /* SVM_H_INCLUDED */
Generated on Thu Jul 21 2022 16:34:03 by 1.7.2