Kentarou Shimatani
/
Theremi
action recognizer with theremin
svm/svm.cpp@0:b9ac53c439ed, 2011-09-14 (annotated)
- Committer:
- peccu
- Date:
- Wed Sep 14 13:42:46 2011 +0000
- Revision:
- 0:b9ac53c439ed
Who changed what in which revision?
User | Revision | Line number | New contents of line |
---|---|---|---|
peccu | 0:b9ac53c439ed | 1 | #include <math.h> |
peccu | 0:b9ac53c439ed | 2 | #include <stdio.h> |
peccu | 0:b9ac53c439ed | 3 | #include <stdlib.h> |
peccu | 0:b9ac53c439ed | 4 | #include <ctype.h> |
peccu | 0:b9ac53c439ed | 5 | #include <float.h> |
peccu | 0:b9ac53c439ed | 6 | #include <string.h> |
peccu | 0:b9ac53c439ed | 7 | #include <stdarg.h> |
peccu | 0:b9ac53c439ed | 8 | #include "svm.h" |
peccu | 0:b9ac53c439ed | 9 | int libsvm_version = LIBSVM_VERSION; |
peccu | 0:b9ac53c439ed | 10 | typedef float Qfloat; |
peccu | 0:b9ac53c439ed | 11 | typedef signed char schar; |
peccu | 0:b9ac53c439ed | 12 | #ifndef min |
peccu | 0:b9ac53c439ed | 13 | template <class T> static inline T min(T x,T y) { return (x<y)?x:y; } |
peccu | 0:b9ac53c439ed | 14 | #endif |
peccu | 0:b9ac53c439ed | 15 | #ifndef max |
peccu | 0:b9ac53c439ed | 16 | template <class T> static inline T max(T x,T y) { return (x>y)?x:y; } |
peccu | 0:b9ac53c439ed | 17 | #endif |
peccu | 0:b9ac53c439ed | 18 | template <class T> static inline void swap(T& x, T& y) { T t=x; x=y; y=t; } |
peccu | 0:b9ac53c439ed | 19 | template <class S, class T> static inline void clone(T*& dst, S* src, int n) |
peccu | 0:b9ac53c439ed | 20 | { |
peccu | 0:b9ac53c439ed | 21 | dst = new T[n]; |
peccu | 0:b9ac53c439ed | 22 | memcpy((void *)dst,(void *)src,sizeof(T)*n); |
peccu | 0:b9ac53c439ed | 23 | } |
peccu | 0:b9ac53c439ed | 24 | static inline double powi(double base, int times) |
peccu | 0:b9ac53c439ed | 25 | { |
peccu | 0:b9ac53c439ed | 26 | double tmp = base, ret = 1.0; |
peccu | 0:b9ac53c439ed | 27 | |
peccu | 0:b9ac53c439ed | 28 | for(int t=times; t>0; t/=2) |
peccu | 0:b9ac53c439ed | 29 | { |
peccu | 0:b9ac53c439ed | 30 | if(t%2==1) ret*=tmp; |
peccu | 0:b9ac53c439ed | 31 | tmp = tmp * tmp; |
peccu | 0:b9ac53c439ed | 32 | } |
peccu | 0:b9ac53c439ed | 33 | return ret; |
peccu | 0:b9ac53c439ed | 34 | } |
peccu | 0:b9ac53c439ed | 35 | #define INF HUGE_VAL |
peccu | 0:b9ac53c439ed | 36 | #define TAU 1e-12 |
peccu | 0:b9ac53c439ed | 37 | #define Malloc(type,n) (type *)malloc((n)*sizeof(type)) |
peccu | 0:b9ac53c439ed | 38 | |
peccu | 0:b9ac53c439ed | 39 | static void print_string_stdout(const char *s) |
peccu | 0:b9ac53c439ed | 40 | { |
peccu | 0:b9ac53c439ed | 41 | fputs(s,stdout); |
peccu | 0:b9ac53c439ed | 42 | fflush(stdout); |
peccu | 0:b9ac53c439ed | 43 | } |
peccu | 0:b9ac53c439ed | 44 | static void (*svm_print_string) (const char *) = &print_string_stdout; |
peccu | 0:b9ac53c439ed | 45 | #if 1 |
peccu | 0:b9ac53c439ed | 46 | static void info(const char *fmt,...) |
peccu | 0:b9ac53c439ed | 47 | { |
peccu | 0:b9ac53c439ed | 48 | char buf[BUFSIZ]; |
peccu | 0:b9ac53c439ed | 49 | va_list ap; |
peccu | 0:b9ac53c439ed | 50 | va_start(ap,fmt); |
peccu | 0:b9ac53c439ed | 51 | vsprintf(buf,fmt,ap); |
peccu | 0:b9ac53c439ed | 52 | va_end(ap); |
peccu | 0:b9ac53c439ed | 53 | (*svm_print_string)(buf); |
peccu | 0:b9ac53c439ed | 54 | } |
peccu | 0:b9ac53c439ed | 55 | #else |
peccu | 0:b9ac53c439ed | 56 | static void info(const char *fmt,...) {} |
peccu | 0:b9ac53c439ed | 57 | #endif |
peccu | 0:b9ac53c439ed | 58 | |
peccu | 0:b9ac53c439ed | 59 | // |
peccu | 0:b9ac53c439ed | 60 | // Kernel Cache |
peccu | 0:b9ac53c439ed | 61 | // |
peccu | 0:b9ac53c439ed | 62 | // l is the number of total data items |
peccu | 0:b9ac53c439ed | 63 | // size is the cache size limit in bytes |
peccu | 0:b9ac53c439ed | 64 | // |
peccu | 0:b9ac53c439ed | 65 | class Cache |
peccu | 0:b9ac53c439ed | 66 | { |
peccu | 0:b9ac53c439ed | 67 | public: |
peccu | 0:b9ac53c439ed | 68 | Cache(int l,long int size); |
peccu | 0:b9ac53c439ed | 69 | ~Cache(); |
peccu | 0:b9ac53c439ed | 70 | |
peccu | 0:b9ac53c439ed | 71 | // request data [0,len) |
peccu | 0:b9ac53c439ed | 72 | // return some position p where [p,len) need to be filled |
peccu | 0:b9ac53c439ed | 73 | // (p >= len if nothing needs to be filled) |
peccu | 0:b9ac53c439ed | 74 | int get_data(const int index, Qfloat **data, int len); |
peccu | 0:b9ac53c439ed | 75 | void swap_index(int i, int j); |
peccu | 0:b9ac53c439ed | 76 | private: |
peccu | 0:b9ac53c439ed | 77 | int l; |
peccu | 0:b9ac53c439ed | 78 | long int size; |
peccu | 0:b9ac53c439ed | 79 | struct head_t |
peccu | 0:b9ac53c439ed | 80 | { |
peccu | 0:b9ac53c439ed | 81 | head_t *prev, *next; // a circular list |
peccu | 0:b9ac53c439ed | 82 | Qfloat *data; |
peccu | 0:b9ac53c439ed | 83 | int len; // data[0,len) is cached in this entry |
peccu | 0:b9ac53c439ed | 84 | }; |
peccu | 0:b9ac53c439ed | 85 | |
peccu | 0:b9ac53c439ed | 86 | head_t *head; |
peccu | 0:b9ac53c439ed | 87 | head_t lru_head; |
peccu | 0:b9ac53c439ed | 88 | void lru_delete(head_t *h); |
peccu | 0:b9ac53c439ed | 89 | void lru_insert(head_t *h); |
peccu | 0:b9ac53c439ed | 90 | }; |
peccu | 0:b9ac53c439ed | 91 | |
peccu | 0:b9ac53c439ed | 92 | Cache::Cache(int l_,long int size_):l(l_),size(size_) |
peccu | 0:b9ac53c439ed | 93 | { |
peccu | 0:b9ac53c439ed | 94 | head = (head_t *)calloc(l,sizeof(head_t)); // initialized to 0 |
peccu | 0:b9ac53c439ed | 95 | size /= sizeof(Qfloat); |
peccu | 0:b9ac53c439ed | 96 | size -= l * sizeof(head_t) / sizeof(Qfloat); |
peccu | 0:b9ac53c439ed | 97 | size = max(size, 2 * (long int) l); // cache must be large enough for two columns |
peccu | 0:b9ac53c439ed | 98 | lru_head.next = lru_head.prev = &lru_head; |
peccu | 0:b9ac53c439ed | 99 | } |
peccu | 0:b9ac53c439ed | 100 | |
peccu | 0:b9ac53c439ed | 101 | Cache::~Cache() |
peccu | 0:b9ac53c439ed | 102 | { |
peccu | 0:b9ac53c439ed | 103 | for(head_t *h = lru_head.next; h != &lru_head; h=h->next) |
peccu | 0:b9ac53c439ed | 104 | free(h->data); |
peccu | 0:b9ac53c439ed | 105 | free(head); |
peccu | 0:b9ac53c439ed | 106 | } |
peccu | 0:b9ac53c439ed | 107 | |
peccu | 0:b9ac53c439ed | 108 | void Cache::lru_delete(head_t *h) |
peccu | 0:b9ac53c439ed | 109 | { |
peccu | 0:b9ac53c439ed | 110 | // delete from current location |
peccu | 0:b9ac53c439ed | 111 | h->prev->next = h->next; |
peccu | 0:b9ac53c439ed | 112 | h->next->prev = h->prev; |
peccu | 0:b9ac53c439ed | 113 | } |
peccu | 0:b9ac53c439ed | 114 | |
peccu | 0:b9ac53c439ed | 115 | void Cache::lru_insert(head_t *h) |
peccu | 0:b9ac53c439ed | 116 | { |
peccu | 0:b9ac53c439ed | 117 | // insert to last position |
peccu | 0:b9ac53c439ed | 118 | h->next = &lru_head; |
peccu | 0:b9ac53c439ed | 119 | h->prev = lru_head.prev; |
peccu | 0:b9ac53c439ed | 120 | h->prev->next = h; |
peccu | 0:b9ac53c439ed | 121 | h->next->prev = h; |
peccu | 0:b9ac53c439ed | 122 | } |
peccu | 0:b9ac53c439ed | 123 | |
peccu | 0:b9ac53c439ed | 124 | int Cache::get_data(const int index, Qfloat **data, int len) |
peccu | 0:b9ac53c439ed | 125 | { |
peccu | 0:b9ac53c439ed | 126 | head_t *h = &head[index]; |
peccu | 0:b9ac53c439ed | 127 | if(h->len) lru_delete(h); |
peccu | 0:b9ac53c439ed | 128 | int more = len - h->len; |
peccu | 0:b9ac53c439ed | 129 | |
peccu | 0:b9ac53c439ed | 130 | if(more > 0) |
peccu | 0:b9ac53c439ed | 131 | { |
peccu | 0:b9ac53c439ed | 132 | // free old space |
peccu | 0:b9ac53c439ed | 133 | while(size < more) |
peccu | 0:b9ac53c439ed | 134 | { |
peccu | 0:b9ac53c439ed | 135 | head_t *old = lru_head.next; |
peccu | 0:b9ac53c439ed | 136 | lru_delete(old); |
peccu | 0:b9ac53c439ed | 137 | free(old->data); |
peccu | 0:b9ac53c439ed | 138 | size += old->len; |
peccu | 0:b9ac53c439ed | 139 | old->data = 0; |
peccu | 0:b9ac53c439ed | 140 | old->len = 0; |
peccu | 0:b9ac53c439ed | 141 | } |
peccu | 0:b9ac53c439ed | 142 | |
peccu | 0:b9ac53c439ed | 143 | // allocate new space |
peccu | 0:b9ac53c439ed | 144 | h->data = (Qfloat *)realloc(h->data,sizeof(Qfloat)*len); |
peccu | 0:b9ac53c439ed | 145 | size -= more; |
peccu | 0:b9ac53c439ed | 146 | swap(h->len,len); |
peccu | 0:b9ac53c439ed | 147 | } |
peccu | 0:b9ac53c439ed | 148 | |
peccu | 0:b9ac53c439ed | 149 | lru_insert(h); |
peccu | 0:b9ac53c439ed | 150 | *data = h->data; |
peccu | 0:b9ac53c439ed | 151 | return len; |
peccu | 0:b9ac53c439ed | 152 | } |
peccu | 0:b9ac53c439ed | 153 | |
peccu | 0:b9ac53c439ed | 154 | void Cache::swap_index(int i, int j) |
peccu | 0:b9ac53c439ed | 155 | { |
peccu | 0:b9ac53c439ed | 156 | if(i==j) return; |
peccu | 0:b9ac53c439ed | 157 | |
peccu | 0:b9ac53c439ed | 158 | if(head[i].len) lru_delete(&head[i]); |
peccu | 0:b9ac53c439ed | 159 | if(head[j].len) lru_delete(&head[j]); |
peccu | 0:b9ac53c439ed | 160 | swap(head[i].data,head[j].data); |
peccu | 0:b9ac53c439ed | 161 | swap(head[i].len,head[j].len); |
peccu | 0:b9ac53c439ed | 162 | if(head[i].len) lru_insert(&head[i]); |
peccu | 0:b9ac53c439ed | 163 | if(head[j].len) lru_insert(&head[j]); |
peccu | 0:b9ac53c439ed | 164 | |
peccu | 0:b9ac53c439ed | 165 | if(i>j) swap(i,j); |
peccu | 0:b9ac53c439ed | 166 | for(head_t *h = lru_head.next; h!=&lru_head; h=h->next) |
peccu | 0:b9ac53c439ed | 167 | { |
peccu | 0:b9ac53c439ed | 168 | if(h->len > i) |
peccu | 0:b9ac53c439ed | 169 | { |
peccu | 0:b9ac53c439ed | 170 | if(h->len > j) |
peccu | 0:b9ac53c439ed | 171 | swap(h->data[i],h->data[j]); |
peccu | 0:b9ac53c439ed | 172 | else |
peccu | 0:b9ac53c439ed | 173 | { |
peccu | 0:b9ac53c439ed | 174 | // give up |
peccu | 0:b9ac53c439ed | 175 | lru_delete(h); |
peccu | 0:b9ac53c439ed | 176 | free(h->data); |
peccu | 0:b9ac53c439ed | 177 | size += h->len; |
peccu | 0:b9ac53c439ed | 178 | h->data = 0; |
peccu | 0:b9ac53c439ed | 179 | h->len = 0; |
peccu | 0:b9ac53c439ed | 180 | } |
peccu | 0:b9ac53c439ed | 181 | } |
peccu | 0:b9ac53c439ed | 182 | } |
peccu | 0:b9ac53c439ed | 183 | } |
peccu | 0:b9ac53c439ed | 184 | |
peccu | 0:b9ac53c439ed | 185 | // |
peccu | 0:b9ac53c439ed | 186 | // Kernel evaluation |
peccu | 0:b9ac53c439ed | 187 | // |
peccu | 0:b9ac53c439ed | 188 | // the static method k_function is for doing single kernel evaluation |
peccu | 0:b9ac53c439ed | 189 | // the constructor of Kernel prepares to calculate the l*l kernel matrix |
peccu | 0:b9ac53c439ed | 190 | // the member function get_Q is for getting one column from the Q Matrix |
peccu | 0:b9ac53c439ed | 191 | // |
peccu | 0:b9ac53c439ed | 192 | class QMatrix { |
peccu | 0:b9ac53c439ed | 193 | public: |
peccu | 0:b9ac53c439ed | 194 | virtual Qfloat *get_Q(int column, int len) const = 0; |
peccu | 0:b9ac53c439ed | 195 | virtual double *get_QD() const = 0; |
peccu | 0:b9ac53c439ed | 196 | virtual void swap_index(int i, int j) const = 0; |
peccu | 0:b9ac53c439ed | 197 | virtual ~QMatrix() {} |
peccu | 0:b9ac53c439ed | 198 | }; |
peccu | 0:b9ac53c439ed | 199 | |
peccu | 0:b9ac53c439ed | 200 | class Kernel: public QMatrix { |
peccu | 0:b9ac53c439ed | 201 | public: |
peccu | 0:b9ac53c439ed | 202 | Kernel(int l, svm_node * const * x, const svm_parameter& param); |
peccu | 0:b9ac53c439ed | 203 | virtual ~Kernel(); |
peccu | 0:b9ac53c439ed | 204 | |
peccu | 0:b9ac53c439ed | 205 | static double k_function(const svm_node *x, const svm_node *y, |
peccu | 0:b9ac53c439ed | 206 | const svm_parameter& param); |
peccu | 0:b9ac53c439ed | 207 | virtual Qfloat *get_Q(int column, int len) const = 0; |
peccu | 0:b9ac53c439ed | 208 | virtual double *get_QD() const = 0; |
peccu | 0:b9ac53c439ed | 209 | virtual void swap_index(int i, int j) const // no so const... |
peccu | 0:b9ac53c439ed | 210 | { |
peccu | 0:b9ac53c439ed | 211 | swap(x[i],x[j]); |
peccu | 0:b9ac53c439ed | 212 | if(x_square) swap(x_square[i],x_square[j]); |
peccu | 0:b9ac53c439ed | 213 | } |
peccu | 0:b9ac53c439ed | 214 | protected: |
peccu | 0:b9ac53c439ed | 215 | |
peccu | 0:b9ac53c439ed | 216 | double (Kernel::*kernel_function)(int i, int j) const; |
peccu | 0:b9ac53c439ed | 217 | |
peccu | 0:b9ac53c439ed | 218 | private: |
peccu | 0:b9ac53c439ed | 219 | const svm_node **x; |
peccu | 0:b9ac53c439ed | 220 | double *x_square; |
peccu | 0:b9ac53c439ed | 221 | |
peccu | 0:b9ac53c439ed | 222 | // svm_parameter |
peccu | 0:b9ac53c439ed | 223 | const int kernel_type; |
peccu | 0:b9ac53c439ed | 224 | const int degree; |
peccu | 0:b9ac53c439ed | 225 | const double gamma; |
peccu | 0:b9ac53c439ed | 226 | const double coef0; |
peccu | 0:b9ac53c439ed | 227 | |
peccu | 0:b9ac53c439ed | 228 | static double dot(const svm_node *px, const svm_node *py); |
peccu | 0:b9ac53c439ed | 229 | double kernel_linear(int i, int j) const |
peccu | 0:b9ac53c439ed | 230 | { |
peccu | 0:b9ac53c439ed | 231 | return dot(x[i],x[j]); |
peccu | 0:b9ac53c439ed | 232 | } |
peccu | 0:b9ac53c439ed | 233 | double kernel_poly(int i, int j) const |
peccu | 0:b9ac53c439ed | 234 | { |
peccu | 0:b9ac53c439ed | 235 | return powi(gamma*dot(x[i],x[j])+coef0,degree); |
peccu | 0:b9ac53c439ed | 236 | } |
peccu | 0:b9ac53c439ed | 237 | double kernel_rbf(int i, int j) const |
peccu | 0:b9ac53c439ed | 238 | { |
peccu | 0:b9ac53c439ed | 239 | return exp(-gamma*(x_square[i]+x_square[j]-2*dot(x[i],x[j]))); |
peccu | 0:b9ac53c439ed | 240 | } |
peccu | 0:b9ac53c439ed | 241 | double kernel_sigmoid(int i, int j) const |
peccu | 0:b9ac53c439ed | 242 | { |
peccu | 0:b9ac53c439ed | 243 | return tanh(gamma*dot(x[i],x[j])+coef0); |
peccu | 0:b9ac53c439ed | 244 | } |
peccu | 0:b9ac53c439ed | 245 | double kernel_precomputed(int i, int j) const |
peccu | 0:b9ac53c439ed | 246 | { |
peccu | 0:b9ac53c439ed | 247 | return x[i][(int)(x[j][0].value)].value; |
peccu | 0:b9ac53c439ed | 248 | } |
peccu | 0:b9ac53c439ed | 249 | }; |
peccu | 0:b9ac53c439ed | 250 | |
peccu | 0:b9ac53c439ed | 251 | Kernel::Kernel(int l, svm_node * const * x_, const svm_parameter& param) |
peccu | 0:b9ac53c439ed | 252 | :kernel_type(param.kernel_type), degree(param.degree), |
peccu | 0:b9ac53c439ed | 253 | gamma(param.gamma), coef0(param.coef0) |
peccu | 0:b9ac53c439ed | 254 | { |
peccu | 0:b9ac53c439ed | 255 | switch(kernel_type) |
peccu | 0:b9ac53c439ed | 256 | { |
peccu | 0:b9ac53c439ed | 257 | case LINEAR: |
peccu | 0:b9ac53c439ed | 258 | kernel_function = &Kernel::kernel_linear; |
peccu | 0:b9ac53c439ed | 259 | break; |
peccu | 0:b9ac53c439ed | 260 | case POLY: |
peccu | 0:b9ac53c439ed | 261 | kernel_function = &Kernel::kernel_poly; |
peccu | 0:b9ac53c439ed | 262 | break; |
peccu | 0:b9ac53c439ed | 263 | case RBF: |
peccu | 0:b9ac53c439ed | 264 | kernel_function = &Kernel::kernel_rbf; |
peccu | 0:b9ac53c439ed | 265 | break; |
peccu | 0:b9ac53c439ed | 266 | case SIGMOID: |
peccu | 0:b9ac53c439ed | 267 | kernel_function = &Kernel::kernel_sigmoid; |
peccu | 0:b9ac53c439ed | 268 | break; |
peccu | 0:b9ac53c439ed | 269 | case PRECOMPUTED: |
peccu | 0:b9ac53c439ed | 270 | kernel_function = &Kernel::kernel_precomputed; |
peccu | 0:b9ac53c439ed | 271 | break; |
peccu | 0:b9ac53c439ed | 272 | } |
peccu | 0:b9ac53c439ed | 273 | |
peccu | 0:b9ac53c439ed | 274 | clone(x,x_,l); |
peccu | 0:b9ac53c439ed | 275 | |
peccu | 0:b9ac53c439ed | 276 | if(kernel_type == RBF) |
peccu | 0:b9ac53c439ed | 277 | { |
peccu | 0:b9ac53c439ed | 278 | x_square = new double[l]; |
peccu | 0:b9ac53c439ed | 279 | for(int i=0;i<l;i++) |
peccu | 0:b9ac53c439ed | 280 | x_square[i] = dot(x[i],x[i]); |
peccu | 0:b9ac53c439ed | 281 | } |
peccu | 0:b9ac53c439ed | 282 | else |
peccu | 0:b9ac53c439ed | 283 | x_square = 0; |
peccu | 0:b9ac53c439ed | 284 | } |
peccu | 0:b9ac53c439ed | 285 | |
peccu | 0:b9ac53c439ed | 286 | Kernel::~Kernel() |
peccu | 0:b9ac53c439ed | 287 | { |
peccu | 0:b9ac53c439ed | 288 | delete[] x; |
peccu | 0:b9ac53c439ed | 289 | delete[] x_square; |
peccu | 0:b9ac53c439ed | 290 | } |
peccu | 0:b9ac53c439ed | 291 | |
peccu | 0:b9ac53c439ed | 292 | double Kernel::dot(const svm_node *px, const svm_node *py) |
peccu | 0:b9ac53c439ed | 293 | { |
peccu | 0:b9ac53c439ed | 294 | double sum = 0; |
peccu | 0:b9ac53c439ed | 295 | while(px->index != -1 && py->index != -1) |
peccu | 0:b9ac53c439ed | 296 | { |
peccu | 0:b9ac53c439ed | 297 | if(px->index == py->index) |
peccu | 0:b9ac53c439ed | 298 | { |
peccu | 0:b9ac53c439ed | 299 | sum += px->value * py->value; |
peccu | 0:b9ac53c439ed | 300 | ++px; |
peccu | 0:b9ac53c439ed | 301 | ++py; |
peccu | 0:b9ac53c439ed | 302 | } |
peccu | 0:b9ac53c439ed | 303 | else |
peccu | 0:b9ac53c439ed | 304 | { |
peccu | 0:b9ac53c439ed | 305 | if(px->index > py->index) |
peccu | 0:b9ac53c439ed | 306 | ++py; |
peccu | 0:b9ac53c439ed | 307 | else |
peccu | 0:b9ac53c439ed | 308 | ++px; |
peccu | 0:b9ac53c439ed | 309 | } |
peccu | 0:b9ac53c439ed | 310 | } |
peccu | 0:b9ac53c439ed | 311 | return sum; |
peccu | 0:b9ac53c439ed | 312 | } |
peccu | 0:b9ac53c439ed | 313 | |
peccu | 0:b9ac53c439ed | 314 | double Kernel::k_function(const svm_node *x, const svm_node *y, |
peccu | 0:b9ac53c439ed | 315 | const svm_parameter& param) |
peccu | 0:b9ac53c439ed | 316 | { |
peccu | 0:b9ac53c439ed | 317 | switch(param.kernel_type) |
peccu | 0:b9ac53c439ed | 318 | { |
peccu | 0:b9ac53c439ed | 319 | case LINEAR: |
peccu | 0:b9ac53c439ed | 320 | return dot(x,y); |
peccu | 0:b9ac53c439ed | 321 | case POLY: |
peccu | 0:b9ac53c439ed | 322 | return powi(param.gamma*dot(x,y)+param.coef0,param.degree); |
peccu | 0:b9ac53c439ed | 323 | case RBF: |
peccu | 0:b9ac53c439ed | 324 | { |
peccu | 0:b9ac53c439ed | 325 | double sum = 0; |
peccu | 0:b9ac53c439ed | 326 | while(x->index != -1 && y->index !=-1) |
peccu | 0:b9ac53c439ed | 327 | { |
peccu | 0:b9ac53c439ed | 328 | if(x->index == y->index) |
peccu | 0:b9ac53c439ed | 329 | { |
peccu | 0:b9ac53c439ed | 330 | double d = x->value - y->value; |
peccu | 0:b9ac53c439ed | 331 | sum += d*d; |
peccu | 0:b9ac53c439ed | 332 | ++x; |
peccu | 0:b9ac53c439ed | 333 | ++y; |
peccu | 0:b9ac53c439ed | 334 | } |
peccu | 0:b9ac53c439ed | 335 | else |
peccu | 0:b9ac53c439ed | 336 | { |
peccu | 0:b9ac53c439ed | 337 | if(x->index > y->index) |
peccu | 0:b9ac53c439ed | 338 | { |
peccu | 0:b9ac53c439ed | 339 | sum += y->value * y->value; |
peccu | 0:b9ac53c439ed | 340 | ++y; |
peccu | 0:b9ac53c439ed | 341 | } |
peccu | 0:b9ac53c439ed | 342 | else |
peccu | 0:b9ac53c439ed | 343 | { |
peccu | 0:b9ac53c439ed | 344 | sum += x->value * x->value; |
peccu | 0:b9ac53c439ed | 345 | ++x; |
peccu | 0:b9ac53c439ed | 346 | } |
peccu | 0:b9ac53c439ed | 347 | } |
peccu | 0:b9ac53c439ed | 348 | } |
peccu | 0:b9ac53c439ed | 349 | |
peccu | 0:b9ac53c439ed | 350 | while(x->index != -1) |
peccu | 0:b9ac53c439ed | 351 | { |
peccu | 0:b9ac53c439ed | 352 | sum += x->value * x->value; |
peccu | 0:b9ac53c439ed | 353 | ++x; |
peccu | 0:b9ac53c439ed | 354 | } |
peccu | 0:b9ac53c439ed | 355 | |
peccu | 0:b9ac53c439ed | 356 | while(y->index != -1) |
peccu | 0:b9ac53c439ed | 357 | { |
peccu | 0:b9ac53c439ed | 358 | sum += y->value * y->value; |
peccu | 0:b9ac53c439ed | 359 | ++y; |
peccu | 0:b9ac53c439ed | 360 | } |
peccu | 0:b9ac53c439ed | 361 | |
peccu | 0:b9ac53c439ed | 362 | return exp(-param.gamma*sum); |
peccu | 0:b9ac53c439ed | 363 | } |
peccu | 0:b9ac53c439ed | 364 | case SIGMOID: |
peccu | 0:b9ac53c439ed | 365 | return tanh(param.gamma*dot(x,y)+param.coef0); |
peccu | 0:b9ac53c439ed | 366 | case PRECOMPUTED: //x: test (validation), y: SV |
peccu | 0:b9ac53c439ed | 367 | return x[(int)(y->value)].value; |
peccu | 0:b9ac53c439ed | 368 | default: |
peccu | 0:b9ac53c439ed | 369 | return 0; // Unreachable |
peccu | 0:b9ac53c439ed | 370 | } |
peccu | 0:b9ac53c439ed | 371 | } |
peccu | 0:b9ac53c439ed | 372 | |
peccu | 0:b9ac53c439ed | 373 | // An SMO algorithm in Fan et al., JMLR 6(2005), p. 1889--1918 |
peccu | 0:b9ac53c439ed | 374 | // Solves: |
peccu | 0:b9ac53c439ed | 375 | // |
peccu | 0:b9ac53c439ed | 376 | // min 0.5(\alpha^T Q \alpha) + p^T \alpha |
peccu | 0:b9ac53c439ed | 377 | // |
peccu | 0:b9ac53c439ed | 378 | // y^T \alpha = \delta |
peccu | 0:b9ac53c439ed | 379 | // y_i = +1 or -1 |
peccu | 0:b9ac53c439ed | 380 | // 0 <= alpha_i <= Cp for y_i = 1 |
peccu | 0:b9ac53c439ed | 381 | // 0 <= alpha_i <= Cn for y_i = -1 |
peccu | 0:b9ac53c439ed | 382 | // |
peccu | 0:b9ac53c439ed | 383 | // Given: |
peccu | 0:b9ac53c439ed | 384 | // |
peccu | 0:b9ac53c439ed | 385 | // Q, p, y, Cp, Cn, and an initial feasible point \alpha |
peccu | 0:b9ac53c439ed | 386 | // l is the size of vectors and matrices |
peccu | 0:b9ac53c439ed | 387 | // eps is the stopping tolerance |
peccu | 0:b9ac53c439ed | 388 | // |
peccu | 0:b9ac53c439ed | 389 | // solution will be put in \alpha, objective value will be put in obj |
peccu | 0:b9ac53c439ed | 390 | // |
peccu | 0:b9ac53c439ed | 391 | class Solver { |
peccu | 0:b9ac53c439ed | 392 | public: |
peccu | 0:b9ac53c439ed | 393 | Solver() {}; |
peccu | 0:b9ac53c439ed | 394 | virtual ~Solver() {}; |
peccu | 0:b9ac53c439ed | 395 | |
peccu | 0:b9ac53c439ed | 396 | struct SolutionInfo { |
peccu | 0:b9ac53c439ed | 397 | double obj; |
peccu | 0:b9ac53c439ed | 398 | double rho; |
peccu | 0:b9ac53c439ed | 399 | double upper_bound_p; |
peccu | 0:b9ac53c439ed | 400 | double upper_bound_n; |
peccu | 0:b9ac53c439ed | 401 | double r; // for Solver_NU |
peccu | 0:b9ac53c439ed | 402 | }; |
peccu | 0:b9ac53c439ed | 403 | |
peccu | 0:b9ac53c439ed | 404 | void Solve(int l, const QMatrix& Q, const double *p_, const schar *y_, |
peccu | 0:b9ac53c439ed | 405 | double *alpha_, double Cp, double Cn, double eps, |
peccu | 0:b9ac53c439ed | 406 | SolutionInfo* si, int shrinking); |
peccu | 0:b9ac53c439ed | 407 | protected: |
peccu | 0:b9ac53c439ed | 408 | int active_size; |
peccu | 0:b9ac53c439ed | 409 | schar *y; |
peccu | 0:b9ac53c439ed | 410 | double *G; // gradient of objective function |
peccu | 0:b9ac53c439ed | 411 | enum { LOWER_BOUND, UPPER_BOUND, FREE }; |
peccu | 0:b9ac53c439ed | 412 | char *alpha_status; // LOWER_BOUND, UPPER_BOUND, FREE |
peccu | 0:b9ac53c439ed | 413 | double *alpha; |
peccu | 0:b9ac53c439ed | 414 | const QMatrix *Q; |
peccu | 0:b9ac53c439ed | 415 | const double *QD; |
peccu | 0:b9ac53c439ed | 416 | double eps; |
peccu | 0:b9ac53c439ed | 417 | double Cp,Cn; |
peccu | 0:b9ac53c439ed | 418 | double *p; |
peccu | 0:b9ac53c439ed | 419 | int *active_set; |
peccu | 0:b9ac53c439ed | 420 | double *G_bar; // gradient, if we treat free variables as 0 |
peccu | 0:b9ac53c439ed | 421 | int l; |
peccu | 0:b9ac53c439ed | 422 | bool unshrink; // XXX |
peccu | 0:b9ac53c439ed | 423 | |
peccu | 0:b9ac53c439ed | 424 | double get_C(int i) |
peccu | 0:b9ac53c439ed | 425 | { |
peccu | 0:b9ac53c439ed | 426 | return (y[i] > 0)? Cp : Cn; |
peccu | 0:b9ac53c439ed | 427 | } |
peccu | 0:b9ac53c439ed | 428 | void update_alpha_status(int i) |
peccu | 0:b9ac53c439ed | 429 | { |
peccu | 0:b9ac53c439ed | 430 | if(alpha[i] >= get_C(i)) |
peccu | 0:b9ac53c439ed | 431 | alpha_status[i] = UPPER_BOUND; |
peccu | 0:b9ac53c439ed | 432 | else if(alpha[i] <= 0) |
peccu | 0:b9ac53c439ed | 433 | alpha_status[i] = LOWER_BOUND; |
peccu | 0:b9ac53c439ed | 434 | else alpha_status[i] = FREE; |
peccu | 0:b9ac53c439ed | 435 | } |
peccu | 0:b9ac53c439ed | 436 | bool is_upper_bound(int i) { return alpha_status[i] == UPPER_BOUND; } |
peccu | 0:b9ac53c439ed | 437 | bool is_lower_bound(int i) { return alpha_status[i] == LOWER_BOUND; } |
peccu | 0:b9ac53c439ed | 438 | bool is_free(int i) { return alpha_status[i] == FREE; } |
peccu | 0:b9ac53c439ed | 439 | void swap_index(int i, int j); |
peccu | 0:b9ac53c439ed | 440 | void reconstruct_gradient(); |
peccu | 0:b9ac53c439ed | 441 | virtual int select_working_set(int &i, int &j); |
peccu | 0:b9ac53c439ed | 442 | virtual double calculate_rho(); |
peccu | 0:b9ac53c439ed | 443 | virtual void do_shrinking(); |
peccu | 0:b9ac53c439ed | 444 | private: |
peccu | 0:b9ac53c439ed | 445 | bool be_shrunk(int i, double Gmax1, double Gmax2); |
peccu | 0:b9ac53c439ed | 446 | }; |
peccu | 0:b9ac53c439ed | 447 | |
peccu | 0:b9ac53c439ed | 448 | void Solver::swap_index(int i, int j) |
peccu | 0:b9ac53c439ed | 449 | { |
peccu | 0:b9ac53c439ed | 450 | Q->swap_index(i,j); |
peccu | 0:b9ac53c439ed | 451 | swap(y[i],y[j]); |
peccu | 0:b9ac53c439ed | 452 | swap(G[i],G[j]); |
peccu | 0:b9ac53c439ed | 453 | swap(alpha_status[i],alpha_status[j]); |
peccu | 0:b9ac53c439ed | 454 | swap(alpha[i],alpha[j]); |
peccu | 0:b9ac53c439ed | 455 | swap(p[i],p[j]); |
peccu | 0:b9ac53c439ed | 456 | swap(active_set[i],active_set[j]); |
peccu | 0:b9ac53c439ed | 457 | swap(G_bar[i],G_bar[j]); |
peccu | 0:b9ac53c439ed | 458 | } |
peccu | 0:b9ac53c439ed | 459 | |
peccu | 0:b9ac53c439ed | 460 | void Solver::reconstruct_gradient() |
peccu | 0:b9ac53c439ed | 461 | { |
peccu | 0:b9ac53c439ed | 462 | // reconstruct inactive elements of G from G_bar and free variables |
peccu | 0:b9ac53c439ed | 463 | |
peccu | 0:b9ac53c439ed | 464 | if(active_size == l) return; |
peccu | 0:b9ac53c439ed | 465 | |
peccu | 0:b9ac53c439ed | 466 | int i,j; |
peccu | 0:b9ac53c439ed | 467 | int nr_free = 0; |
peccu | 0:b9ac53c439ed | 468 | |
peccu | 0:b9ac53c439ed | 469 | for(j=active_size;j<l;j++) |
peccu | 0:b9ac53c439ed | 470 | G[j] = G_bar[j] + p[j]; |
peccu | 0:b9ac53c439ed | 471 | |
peccu | 0:b9ac53c439ed | 472 | for(j=0;j<active_size;j++) |
peccu | 0:b9ac53c439ed | 473 | if(is_free(j)) |
peccu | 0:b9ac53c439ed | 474 | nr_free++; |
peccu | 0:b9ac53c439ed | 475 | |
peccu | 0:b9ac53c439ed | 476 | if(2*nr_free < active_size) |
peccu | 0:b9ac53c439ed | 477 | info("\nWarning: using -h 0 may be faster\n"); |
peccu | 0:b9ac53c439ed | 478 | |
peccu | 0:b9ac53c439ed | 479 | if (nr_free*l > 2*active_size*(l-active_size)) |
peccu | 0:b9ac53c439ed | 480 | { |
peccu | 0:b9ac53c439ed | 481 | for(i=active_size;i<l;i++) |
peccu | 0:b9ac53c439ed | 482 | { |
peccu | 0:b9ac53c439ed | 483 | const Qfloat *Q_i = Q->get_Q(i,active_size); |
peccu | 0:b9ac53c439ed | 484 | for(j=0;j<active_size;j++) |
peccu | 0:b9ac53c439ed | 485 | if(is_free(j)) |
peccu | 0:b9ac53c439ed | 486 | G[i] += alpha[j] * Q_i[j]; |
peccu | 0:b9ac53c439ed | 487 | } |
peccu | 0:b9ac53c439ed | 488 | } |
peccu | 0:b9ac53c439ed | 489 | else |
peccu | 0:b9ac53c439ed | 490 | { |
peccu | 0:b9ac53c439ed | 491 | for(i=0;i<active_size;i++) |
peccu | 0:b9ac53c439ed | 492 | if(is_free(i)) |
peccu | 0:b9ac53c439ed | 493 | { |
peccu | 0:b9ac53c439ed | 494 | const Qfloat *Q_i = Q->get_Q(i,l); |
peccu | 0:b9ac53c439ed | 495 | double alpha_i = alpha[i]; |
peccu | 0:b9ac53c439ed | 496 | for(j=active_size;j<l;j++) |
peccu | 0:b9ac53c439ed | 497 | G[j] += alpha_i * Q_i[j]; |
peccu | 0:b9ac53c439ed | 498 | } |
peccu | 0:b9ac53c439ed | 499 | } |
peccu | 0:b9ac53c439ed | 500 | } |
peccu | 0:b9ac53c439ed | 501 | |
peccu | 0:b9ac53c439ed | 502 | void Solver::Solve(int l, const QMatrix& Q, const double *p_, const schar *y_, |
peccu | 0:b9ac53c439ed | 503 | double *alpha_, double Cp, double Cn, double eps, |
peccu | 0:b9ac53c439ed | 504 | SolutionInfo* si, int shrinking) |
peccu | 0:b9ac53c439ed | 505 | { |
peccu | 0:b9ac53c439ed | 506 | this->l = l; |
peccu | 0:b9ac53c439ed | 507 | this->Q = &Q; |
peccu | 0:b9ac53c439ed | 508 | QD=Q.get_QD(); |
peccu | 0:b9ac53c439ed | 509 | clone(p, p_,l); |
peccu | 0:b9ac53c439ed | 510 | clone(y, y_,l); |
peccu | 0:b9ac53c439ed | 511 | clone(alpha,alpha_,l); |
peccu | 0:b9ac53c439ed | 512 | this->Cp = Cp; |
peccu | 0:b9ac53c439ed | 513 | this->Cn = Cn; |
peccu | 0:b9ac53c439ed | 514 | this->eps = eps; |
peccu | 0:b9ac53c439ed | 515 | unshrink = false; |
peccu | 0:b9ac53c439ed | 516 | |
peccu | 0:b9ac53c439ed | 517 | // initialize alpha_status |
peccu | 0:b9ac53c439ed | 518 | { |
peccu | 0:b9ac53c439ed | 519 | alpha_status = new char[l]; |
peccu | 0:b9ac53c439ed | 520 | for(int i=0;i<l;i++) |
peccu | 0:b9ac53c439ed | 521 | update_alpha_status(i); |
peccu | 0:b9ac53c439ed | 522 | } |
peccu | 0:b9ac53c439ed | 523 | |
peccu | 0:b9ac53c439ed | 524 | // initialize active set (for shrinking) |
peccu | 0:b9ac53c439ed | 525 | { |
peccu | 0:b9ac53c439ed | 526 | active_set = new int[l]; |
peccu | 0:b9ac53c439ed | 527 | for(int i=0;i<l;i++) |
peccu | 0:b9ac53c439ed | 528 | active_set[i] = i; |
peccu | 0:b9ac53c439ed | 529 | active_size = l; |
peccu | 0:b9ac53c439ed | 530 | } |
peccu | 0:b9ac53c439ed | 531 | |
peccu | 0:b9ac53c439ed | 532 | // initialize gradient |
peccu | 0:b9ac53c439ed | 533 | { |
peccu | 0:b9ac53c439ed | 534 | G = new double[l]; |
peccu | 0:b9ac53c439ed | 535 | G_bar = new double[l]; |
peccu | 0:b9ac53c439ed | 536 | int i; |
peccu | 0:b9ac53c439ed | 537 | for(i=0;i<l;i++) |
peccu | 0:b9ac53c439ed | 538 | { |
peccu | 0:b9ac53c439ed | 539 | G[i] = p[i]; |
peccu | 0:b9ac53c439ed | 540 | G_bar[i] = 0; |
peccu | 0:b9ac53c439ed | 541 | } |
peccu | 0:b9ac53c439ed | 542 | for(i=0;i<l;i++) |
peccu | 0:b9ac53c439ed | 543 | if(!is_lower_bound(i)) |
peccu | 0:b9ac53c439ed | 544 | { |
peccu | 0:b9ac53c439ed | 545 | const Qfloat *Q_i = Q.get_Q(i,l); |
peccu | 0:b9ac53c439ed | 546 | double alpha_i = alpha[i]; |
peccu | 0:b9ac53c439ed | 547 | int j; |
peccu | 0:b9ac53c439ed | 548 | for(j=0;j<l;j++) |
peccu | 0:b9ac53c439ed | 549 | G[j] += alpha_i*Q_i[j]; |
peccu | 0:b9ac53c439ed | 550 | if(is_upper_bound(i)) |
peccu | 0:b9ac53c439ed | 551 | for(j=0;j<l;j++) |
peccu | 0:b9ac53c439ed | 552 | G_bar[j] += get_C(i) * Q_i[j]; |
peccu | 0:b9ac53c439ed | 553 | } |
peccu | 0:b9ac53c439ed | 554 | } |
peccu | 0:b9ac53c439ed | 555 | |
peccu | 0:b9ac53c439ed | 556 | // optimization step |
peccu | 0:b9ac53c439ed | 557 | |
peccu | 0:b9ac53c439ed | 558 | int iter = 0; |
peccu | 0:b9ac53c439ed | 559 | int counter = min(l,1000)+1; |
peccu | 0:b9ac53c439ed | 560 | |
peccu | 0:b9ac53c439ed | 561 | while(1) |
peccu | 0:b9ac53c439ed | 562 | { |
peccu | 0:b9ac53c439ed | 563 | // show progress and do shrinking |
peccu | 0:b9ac53c439ed | 564 | |
peccu | 0:b9ac53c439ed | 565 | if(--counter == 0) |
peccu | 0:b9ac53c439ed | 566 | { |
peccu | 0:b9ac53c439ed | 567 | counter = min(l,1000); |
peccu | 0:b9ac53c439ed | 568 | if(shrinking) do_shrinking(); |
peccu | 0:b9ac53c439ed | 569 | info("."); |
peccu | 0:b9ac53c439ed | 570 | } |
peccu | 0:b9ac53c439ed | 571 | |
peccu | 0:b9ac53c439ed | 572 | int i,j; |
peccu | 0:b9ac53c439ed | 573 | if(select_working_set(i,j)!=0) |
peccu | 0:b9ac53c439ed | 574 | { |
peccu | 0:b9ac53c439ed | 575 | // reconstruct the whole gradient |
peccu | 0:b9ac53c439ed | 576 | reconstruct_gradient(); |
peccu | 0:b9ac53c439ed | 577 | // reset active set size and check |
peccu | 0:b9ac53c439ed | 578 | active_size = l; |
peccu | 0:b9ac53c439ed | 579 | info("*"); |
peccu | 0:b9ac53c439ed | 580 | if(select_working_set(i,j)!=0) |
peccu | 0:b9ac53c439ed | 581 | break; |
peccu | 0:b9ac53c439ed | 582 | else |
peccu | 0:b9ac53c439ed | 583 | counter = 1; // do shrinking next iteration |
peccu | 0:b9ac53c439ed | 584 | } |
peccu | 0:b9ac53c439ed | 585 | |
peccu | 0:b9ac53c439ed | 586 | ++iter; |
peccu | 0:b9ac53c439ed | 587 | |
peccu | 0:b9ac53c439ed | 588 | // update alpha[i] and alpha[j], handle bounds carefully |
peccu | 0:b9ac53c439ed | 589 | |
peccu | 0:b9ac53c439ed | 590 | const Qfloat *Q_i = Q.get_Q(i,active_size); |
peccu | 0:b9ac53c439ed | 591 | const Qfloat *Q_j = Q.get_Q(j,active_size); |
peccu | 0:b9ac53c439ed | 592 | |
peccu | 0:b9ac53c439ed | 593 | double C_i = get_C(i); |
peccu | 0:b9ac53c439ed | 594 | double C_j = get_C(j); |
peccu | 0:b9ac53c439ed | 595 | |
peccu | 0:b9ac53c439ed | 596 | double old_alpha_i = alpha[i]; |
peccu | 0:b9ac53c439ed | 597 | double old_alpha_j = alpha[j]; |
peccu | 0:b9ac53c439ed | 598 | |
peccu | 0:b9ac53c439ed | 599 | if(y[i]!=y[j]) |
peccu | 0:b9ac53c439ed | 600 | { |
peccu | 0:b9ac53c439ed | 601 | double quad_coef = QD[i]+QD[j]+2*Q_i[j]; |
peccu | 0:b9ac53c439ed | 602 | if (quad_coef <= 0) |
peccu | 0:b9ac53c439ed | 603 | quad_coef = TAU; |
peccu | 0:b9ac53c439ed | 604 | double delta = (-G[i]-G[j])/quad_coef; |
peccu | 0:b9ac53c439ed | 605 | double diff = alpha[i] - alpha[j]; |
peccu | 0:b9ac53c439ed | 606 | alpha[i] += delta; |
peccu | 0:b9ac53c439ed | 607 | alpha[j] += delta; |
peccu | 0:b9ac53c439ed | 608 | |
peccu | 0:b9ac53c439ed | 609 | if(diff > 0) |
peccu | 0:b9ac53c439ed | 610 | { |
peccu | 0:b9ac53c439ed | 611 | if(alpha[j] < 0) |
peccu | 0:b9ac53c439ed | 612 | { |
peccu | 0:b9ac53c439ed | 613 | alpha[j] = 0; |
peccu | 0:b9ac53c439ed | 614 | alpha[i] = diff; |
peccu | 0:b9ac53c439ed | 615 | } |
peccu | 0:b9ac53c439ed | 616 | } |
peccu | 0:b9ac53c439ed | 617 | else |
peccu | 0:b9ac53c439ed | 618 | { |
peccu | 0:b9ac53c439ed | 619 | if(alpha[i] < 0) |
peccu | 0:b9ac53c439ed | 620 | { |
peccu | 0:b9ac53c439ed | 621 | alpha[i] = 0; |
peccu | 0:b9ac53c439ed | 622 | alpha[j] = -diff; |
peccu | 0:b9ac53c439ed | 623 | } |
peccu | 0:b9ac53c439ed | 624 | } |
peccu | 0:b9ac53c439ed | 625 | if(diff > C_i - C_j) |
peccu | 0:b9ac53c439ed | 626 | { |
peccu | 0:b9ac53c439ed | 627 | if(alpha[i] > C_i) |
peccu | 0:b9ac53c439ed | 628 | { |
peccu | 0:b9ac53c439ed | 629 | alpha[i] = C_i; |
peccu | 0:b9ac53c439ed | 630 | alpha[j] = C_i - diff; |
peccu | 0:b9ac53c439ed | 631 | } |
peccu | 0:b9ac53c439ed | 632 | } |
peccu | 0:b9ac53c439ed | 633 | else |
peccu | 0:b9ac53c439ed | 634 | { |
peccu | 0:b9ac53c439ed | 635 | if(alpha[j] > C_j) |
peccu | 0:b9ac53c439ed | 636 | { |
peccu | 0:b9ac53c439ed | 637 | alpha[j] = C_j; |
peccu | 0:b9ac53c439ed | 638 | alpha[i] = C_j + diff; |
peccu | 0:b9ac53c439ed | 639 | } |
peccu | 0:b9ac53c439ed | 640 | } |
peccu | 0:b9ac53c439ed | 641 | } |
peccu | 0:b9ac53c439ed | 642 | else |
peccu | 0:b9ac53c439ed | 643 | { |
peccu | 0:b9ac53c439ed | 644 | double quad_coef = QD[i]+QD[j]-2*Q_i[j]; |
peccu | 0:b9ac53c439ed | 645 | if (quad_coef <= 0) |
peccu | 0:b9ac53c439ed | 646 | quad_coef = TAU; |
peccu | 0:b9ac53c439ed | 647 | double delta = (G[i]-G[j])/quad_coef; |
peccu | 0:b9ac53c439ed | 648 | double sum = alpha[i] + alpha[j]; |
peccu | 0:b9ac53c439ed | 649 | alpha[i] -= delta; |
peccu | 0:b9ac53c439ed | 650 | alpha[j] += delta; |
peccu | 0:b9ac53c439ed | 651 | |
peccu | 0:b9ac53c439ed | 652 | if(sum > C_i) |
peccu | 0:b9ac53c439ed | 653 | { |
peccu | 0:b9ac53c439ed | 654 | if(alpha[i] > C_i) |
peccu | 0:b9ac53c439ed | 655 | { |
peccu | 0:b9ac53c439ed | 656 | alpha[i] = C_i; |
peccu | 0:b9ac53c439ed | 657 | alpha[j] = sum - C_i; |
peccu | 0:b9ac53c439ed | 658 | } |
peccu | 0:b9ac53c439ed | 659 | } |
peccu | 0:b9ac53c439ed | 660 | else |
peccu | 0:b9ac53c439ed | 661 | { |
peccu | 0:b9ac53c439ed | 662 | if(alpha[j] < 0) |
peccu | 0:b9ac53c439ed | 663 | { |
peccu | 0:b9ac53c439ed | 664 | alpha[j] = 0; |
peccu | 0:b9ac53c439ed | 665 | alpha[i] = sum; |
peccu | 0:b9ac53c439ed | 666 | } |
peccu | 0:b9ac53c439ed | 667 | } |
peccu | 0:b9ac53c439ed | 668 | if(sum > C_j) |
peccu | 0:b9ac53c439ed | 669 | { |
peccu | 0:b9ac53c439ed | 670 | if(alpha[j] > C_j) |
peccu | 0:b9ac53c439ed | 671 | { |
peccu | 0:b9ac53c439ed | 672 | alpha[j] = C_j; |
peccu | 0:b9ac53c439ed | 673 | alpha[i] = sum - C_j; |
peccu | 0:b9ac53c439ed | 674 | } |
peccu | 0:b9ac53c439ed | 675 | } |
peccu | 0:b9ac53c439ed | 676 | else |
peccu | 0:b9ac53c439ed | 677 | { |
peccu | 0:b9ac53c439ed | 678 | if(alpha[i] < 0) |
peccu | 0:b9ac53c439ed | 679 | { |
peccu | 0:b9ac53c439ed | 680 | alpha[i] = 0; |
peccu | 0:b9ac53c439ed | 681 | alpha[j] = sum; |
peccu | 0:b9ac53c439ed | 682 | } |
peccu | 0:b9ac53c439ed | 683 | } |
peccu | 0:b9ac53c439ed | 684 | } |
peccu | 0:b9ac53c439ed | 685 | |
peccu | 0:b9ac53c439ed | 686 | // update G |
peccu | 0:b9ac53c439ed | 687 | |
peccu | 0:b9ac53c439ed | 688 | double delta_alpha_i = alpha[i] - old_alpha_i; |
peccu | 0:b9ac53c439ed | 689 | double delta_alpha_j = alpha[j] - old_alpha_j; |
peccu | 0:b9ac53c439ed | 690 | |
peccu | 0:b9ac53c439ed | 691 | for(int k=0;k<active_size;k++) |
peccu | 0:b9ac53c439ed | 692 | { |
peccu | 0:b9ac53c439ed | 693 | G[k] += Q_i[k]*delta_alpha_i + Q_j[k]*delta_alpha_j; |
peccu | 0:b9ac53c439ed | 694 | } |
peccu | 0:b9ac53c439ed | 695 | |
peccu | 0:b9ac53c439ed | 696 | // update alpha_status and G_bar |
peccu | 0:b9ac53c439ed | 697 | |
peccu | 0:b9ac53c439ed | 698 | { |
peccu | 0:b9ac53c439ed | 699 | bool ui = is_upper_bound(i); |
peccu | 0:b9ac53c439ed | 700 | bool uj = is_upper_bound(j); |
peccu | 0:b9ac53c439ed | 701 | update_alpha_status(i); |
peccu | 0:b9ac53c439ed | 702 | update_alpha_status(j); |
peccu | 0:b9ac53c439ed | 703 | int k; |
peccu | 0:b9ac53c439ed | 704 | if(ui != is_upper_bound(i)) |
peccu | 0:b9ac53c439ed | 705 | { |
peccu | 0:b9ac53c439ed | 706 | Q_i = Q.get_Q(i,l); |
peccu | 0:b9ac53c439ed | 707 | if(ui) |
peccu | 0:b9ac53c439ed | 708 | for(k=0;k<l;k++) |
peccu | 0:b9ac53c439ed | 709 | G_bar[k] -= C_i * Q_i[k]; |
peccu | 0:b9ac53c439ed | 710 | else |
peccu | 0:b9ac53c439ed | 711 | for(k=0;k<l;k++) |
peccu | 0:b9ac53c439ed | 712 | G_bar[k] += C_i * Q_i[k]; |
peccu | 0:b9ac53c439ed | 713 | } |
peccu | 0:b9ac53c439ed | 714 | |
peccu | 0:b9ac53c439ed | 715 | if(uj != is_upper_bound(j)) |
peccu | 0:b9ac53c439ed | 716 | { |
peccu | 0:b9ac53c439ed | 717 | Q_j = Q.get_Q(j,l); |
peccu | 0:b9ac53c439ed | 718 | if(uj) |
peccu | 0:b9ac53c439ed | 719 | for(k=0;k<l;k++) |
peccu | 0:b9ac53c439ed | 720 | G_bar[k] -= C_j * Q_j[k]; |
peccu | 0:b9ac53c439ed | 721 | else |
peccu | 0:b9ac53c439ed | 722 | for(k=0;k<l;k++) |
peccu | 0:b9ac53c439ed | 723 | G_bar[k] += C_j * Q_j[k]; |
peccu | 0:b9ac53c439ed | 724 | } |
peccu | 0:b9ac53c439ed | 725 | } |
peccu | 0:b9ac53c439ed | 726 | } |
peccu | 0:b9ac53c439ed | 727 | |
peccu | 0:b9ac53c439ed | 728 | // calculate rho |
peccu | 0:b9ac53c439ed | 729 | |
peccu | 0:b9ac53c439ed | 730 | si->rho = calculate_rho(); |
peccu | 0:b9ac53c439ed | 731 | |
peccu | 0:b9ac53c439ed | 732 | // calculate objective value |
peccu | 0:b9ac53c439ed | 733 | { |
peccu | 0:b9ac53c439ed | 734 | double v = 0; |
peccu | 0:b9ac53c439ed | 735 | int i; |
peccu | 0:b9ac53c439ed | 736 | for(i=0;i<l;i++) |
peccu | 0:b9ac53c439ed | 737 | v += alpha[i] * (G[i] + p[i]); |
peccu | 0:b9ac53c439ed | 738 | |
peccu | 0:b9ac53c439ed | 739 | si->obj = v/2; |
peccu | 0:b9ac53c439ed | 740 | } |
peccu | 0:b9ac53c439ed | 741 | |
peccu | 0:b9ac53c439ed | 742 | // put back the solution |
peccu | 0:b9ac53c439ed | 743 | { |
peccu | 0:b9ac53c439ed | 744 | for(int i=0;i<l;i++) |
peccu | 0:b9ac53c439ed | 745 | alpha_[active_set[i]] = alpha[i]; |
peccu | 0:b9ac53c439ed | 746 | } |
peccu | 0:b9ac53c439ed | 747 | |
peccu | 0:b9ac53c439ed | 748 | // juggle everything back |
peccu | 0:b9ac53c439ed | 749 | /*{ |
peccu | 0:b9ac53c439ed | 750 | for(int i=0;i<l;i++) |
peccu | 0:b9ac53c439ed | 751 | while(active_set[i] != i) |
peccu | 0:b9ac53c439ed | 752 | swap_index(i,active_set[i]); |
peccu | 0:b9ac53c439ed | 753 | // or Q.swap_index(i,active_set[i]); |
peccu | 0:b9ac53c439ed | 754 | }*/ |
peccu | 0:b9ac53c439ed | 755 | |
peccu | 0:b9ac53c439ed | 756 | si->upper_bound_p = Cp; |
peccu | 0:b9ac53c439ed | 757 | si->upper_bound_n = Cn; |
peccu | 0:b9ac53c439ed | 758 | |
peccu | 0:b9ac53c439ed | 759 | info("\noptimization finished, #iter = %d\n",iter); |
peccu | 0:b9ac53c439ed | 760 | |
peccu | 0:b9ac53c439ed | 761 | delete[] p; |
peccu | 0:b9ac53c439ed | 762 | delete[] y; |
peccu | 0:b9ac53c439ed | 763 | delete[] alpha; |
peccu | 0:b9ac53c439ed | 764 | delete[] alpha_status; |
peccu | 0:b9ac53c439ed | 765 | delete[] active_set; |
peccu | 0:b9ac53c439ed | 766 | delete[] G; |
peccu | 0:b9ac53c439ed | 767 | delete[] G_bar; |
peccu | 0:b9ac53c439ed | 768 | } |
peccu | 0:b9ac53c439ed | 769 | |
peccu | 0:b9ac53c439ed | 770 | // return 1 if already optimal, return 0 otherwise |
peccu | 0:b9ac53c439ed | 771 | int Solver::select_working_set(int &out_i, int &out_j) |
peccu | 0:b9ac53c439ed | 772 | { |
peccu | 0:b9ac53c439ed | 773 | // return i,j such that |
peccu | 0:b9ac53c439ed | 774 | // i: maximizes -y_i * grad(f)_i, i in I_up(\alpha) |
peccu | 0:b9ac53c439ed | 775 | // j: minimizes the decrease of obj value |
peccu | 0:b9ac53c439ed | 776 | // (if quadratic coefficeint <= 0, replace it with tau) |
peccu | 0:b9ac53c439ed | 777 | // -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha) |
peccu | 0:b9ac53c439ed | 778 | |
peccu | 0:b9ac53c439ed | 779 | double Gmax = -INF; |
peccu | 0:b9ac53c439ed | 780 | double Gmax2 = -INF; |
peccu | 0:b9ac53c439ed | 781 | int Gmax_idx = -1; |
peccu | 0:b9ac53c439ed | 782 | int Gmin_idx = -1; |
peccu | 0:b9ac53c439ed | 783 | double obj_diff_min = INF; |
peccu | 0:b9ac53c439ed | 784 | |
peccu | 0:b9ac53c439ed | 785 | for(int t=0;t<active_size;t++) |
peccu | 0:b9ac53c439ed | 786 | if(y[t]==+1) |
peccu | 0:b9ac53c439ed | 787 | { |
peccu | 0:b9ac53c439ed | 788 | if(!is_upper_bound(t)) |
peccu | 0:b9ac53c439ed | 789 | if(-G[t] >= Gmax) |
peccu | 0:b9ac53c439ed | 790 | { |
peccu | 0:b9ac53c439ed | 791 | Gmax = -G[t]; |
peccu | 0:b9ac53c439ed | 792 | Gmax_idx = t; |
peccu | 0:b9ac53c439ed | 793 | } |
peccu | 0:b9ac53c439ed | 794 | } |
peccu | 0:b9ac53c439ed | 795 | else |
peccu | 0:b9ac53c439ed | 796 | { |
peccu | 0:b9ac53c439ed | 797 | if(!is_lower_bound(t)) |
peccu | 0:b9ac53c439ed | 798 | if(G[t] >= Gmax) |
peccu | 0:b9ac53c439ed | 799 | { |
peccu | 0:b9ac53c439ed | 800 | Gmax = G[t]; |
peccu | 0:b9ac53c439ed | 801 | Gmax_idx = t; |
peccu | 0:b9ac53c439ed | 802 | } |
peccu | 0:b9ac53c439ed | 803 | } |
peccu | 0:b9ac53c439ed | 804 | |
peccu | 0:b9ac53c439ed | 805 | int i = Gmax_idx; |
peccu | 0:b9ac53c439ed | 806 | const Qfloat *Q_i = NULL; |
peccu | 0:b9ac53c439ed | 807 | if(i != -1) // NULL Q_i not accessed: Gmax=-INF if i=-1 |
peccu | 0:b9ac53c439ed | 808 | Q_i = Q->get_Q(i,active_size); |
peccu | 0:b9ac53c439ed | 809 | |
peccu | 0:b9ac53c439ed | 810 | for(int j=0;j<active_size;j++) |
peccu | 0:b9ac53c439ed | 811 | { |
peccu | 0:b9ac53c439ed | 812 | if(y[j]==+1) |
peccu | 0:b9ac53c439ed | 813 | { |
peccu | 0:b9ac53c439ed | 814 | if (!is_lower_bound(j)) |
peccu | 0:b9ac53c439ed | 815 | { |
peccu | 0:b9ac53c439ed | 816 | double grad_diff=Gmax+G[j]; |
peccu | 0:b9ac53c439ed | 817 | if (G[j] >= Gmax2) |
peccu | 0:b9ac53c439ed | 818 | Gmax2 = G[j]; |
peccu | 0:b9ac53c439ed | 819 | if (grad_diff > 0) |
peccu | 0:b9ac53c439ed | 820 | { |
peccu | 0:b9ac53c439ed | 821 | double obj_diff; |
peccu | 0:b9ac53c439ed | 822 | double quad_coef = QD[i]+QD[j]-2.0*y[i]*Q_i[j]; |
peccu | 0:b9ac53c439ed | 823 | if (quad_coef > 0) |
peccu | 0:b9ac53c439ed | 824 | obj_diff = -(grad_diff*grad_diff)/quad_coef; |
peccu | 0:b9ac53c439ed | 825 | else |
peccu | 0:b9ac53c439ed | 826 | obj_diff = -(grad_diff*grad_diff)/TAU; |
peccu | 0:b9ac53c439ed | 827 | |
peccu | 0:b9ac53c439ed | 828 | if (obj_diff <= obj_diff_min) |
peccu | 0:b9ac53c439ed | 829 | { |
peccu | 0:b9ac53c439ed | 830 | Gmin_idx=j; |
peccu | 0:b9ac53c439ed | 831 | obj_diff_min = obj_diff; |
peccu | 0:b9ac53c439ed | 832 | } |
peccu | 0:b9ac53c439ed | 833 | } |
peccu | 0:b9ac53c439ed | 834 | } |
peccu | 0:b9ac53c439ed | 835 | } |
peccu | 0:b9ac53c439ed | 836 | else |
peccu | 0:b9ac53c439ed | 837 | { |
peccu | 0:b9ac53c439ed | 838 | if (!is_upper_bound(j)) |
peccu | 0:b9ac53c439ed | 839 | { |
peccu | 0:b9ac53c439ed | 840 | double grad_diff= Gmax-G[j]; |
peccu | 0:b9ac53c439ed | 841 | if (-G[j] >= Gmax2) |
peccu | 0:b9ac53c439ed | 842 | Gmax2 = -G[j]; |
peccu | 0:b9ac53c439ed | 843 | if (grad_diff > 0) |
peccu | 0:b9ac53c439ed | 844 | { |
peccu | 0:b9ac53c439ed | 845 | double obj_diff; |
peccu | 0:b9ac53c439ed | 846 | double quad_coef = QD[i]+QD[j]+2.0*y[i]*Q_i[j]; |
peccu | 0:b9ac53c439ed | 847 | if (quad_coef > 0) |
peccu | 0:b9ac53c439ed | 848 | obj_diff = -(grad_diff*grad_diff)/quad_coef; |
peccu | 0:b9ac53c439ed | 849 | else |
peccu | 0:b9ac53c439ed | 850 | obj_diff = -(grad_diff*grad_diff)/TAU; |
peccu | 0:b9ac53c439ed | 851 | |
peccu | 0:b9ac53c439ed | 852 | if (obj_diff <= obj_diff_min) |
peccu | 0:b9ac53c439ed | 853 | { |
peccu | 0:b9ac53c439ed | 854 | Gmin_idx=j; |
peccu | 0:b9ac53c439ed | 855 | obj_diff_min = obj_diff; |
peccu | 0:b9ac53c439ed | 856 | } |
peccu | 0:b9ac53c439ed | 857 | } |
peccu | 0:b9ac53c439ed | 858 | } |
peccu | 0:b9ac53c439ed | 859 | } |
peccu | 0:b9ac53c439ed | 860 | } |
peccu | 0:b9ac53c439ed | 861 | |
peccu | 0:b9ac53c439ed | 862 | if(Gmax+Gmax2 < eps) |
peccu | 0:b9ac53c439ed | 863 | return 1; |
peccu | 0:b9ac53c439ed | 864 | |
peccu | 0:b9ac53c439ed | 865 | out_i = Gmax_idx; |
peccu | 0:b9ac53c439ed | 866 | out_j = Gmin_idx; |
peccu | 0:b9ac53c439ed | 867 | return 0; |
peccu | 0:b9ac53c439ed | 868 | } |
peccu | 0:b9ac53c439ed | 869 | |
peccu | 0:b9ac53c439ed | 870 | bool Solver::be_shrunk(int i, double Gmax1, double Gmax2) |
peccu | 0:b9ac53c439ed | 871 | { |
peccu | 0:b9ac53c439ed | 872 | if(is_upper_bound(i)) |
peccu | 0:b9ac53c439ed | 873 | { |
peccu | 0:b9ac53c439ed | 874 | if(y[i]==+1) |
peccu | 0:b9ac53c439ed | 875 | return(-G[i] > Gmax1); |
peccu | 0:b9ac53c439ed | 876 | else |
peccu | 0:b9ac53c439ed | 877 | return(-G[i] > Gmax2); |
peccu | 0:b9ac53c439ed | 878 | } |
peccu | 0:b9ac53c439ed | 879 | else if(is_lower_bound(i)) |
peccu | 0:b9ac53c439ed | 880 | { |
peccu | 0:b9ac53c439ed | 881 | if(y[i]==+1) |
peccu | 0:b9ac53c439ed | 882 | return(G[i] > Gmax2); |
peccu | 0:b9ac53c439ed | 883 | else |
peccu | 0:b9ac53c439ed | 884 | return(G[i] > Gmax1); |
peccu | 0:b9ac53c439ed | 885 | } |
peccu | 0:b9ac53c439ed | 886 | else |
peccu | 0:b9ac53c439ed | 887 | return(false); |
peccu | 0:b9ac53c439ed | 888 | } |
peccu | 0:b9ac53c439ed | 889 | |
peccu | 0:b9ac53c439ed | 890 | void Solver::do_shrinking() |
peccu | 0:b9ac53c439ed | 891 | { |
peccu | 0:b9ac53c439ed | 892 | int i; |
peccu | 0:b9ac53c439ed | 893 | double Gmax1 = -INF; // max { -y_i * grad(f)_i | i in I_up(\alpha) } |
peccu | 0:b9ac53c439ed | 894 | double Gmax2 = -INF; // max { y_i * grad(f)_i | i in I_low(\alpha) } |
peccu | 0:b9ac53c439ed | 895 | |
peccu | 0:b9ac53c439ed | 896 | // find maximal violating pair first |
peccu | 0:b9ac53c439ed | 897 | for(i=0;i<active_size;i++) |
peccu | 0:b9ac53c439ed | 898 | { |
peccu | 0:b9ac53c439ed | 899 | if(y[i]==+1) |
peccu | 0:b9ac53c439ed | 900 | { |
peccu | 0:b9ac53c439ed | 901 | if(!is_upper_bound(i)) |
peccu | 0:b9ac53c439ed | 902 | { |
peccu | 0:b9ac53c439ed | 903 | if(-G[i] >= Gmax1) |
peccu | 0:b9ac53c439ed | 904 | Gmax1 = -G[i]; |
peccu | 0:b9ac53c439ed | 905 | } |
peccu | 0:b9ac53c439ed | 906 | if(!is_lower_bound(i)) |
peccu | 0:b9ac53c439ed | 907 | { |
peccu | 0:b9ac53c439ed | 908 | if(G[i] >= Gmax2) |
peccu | 0:b9ac53c439ed | 909 | Gmax2 = G[i]; |
peccu | 0:b9ac53c439ed | 910 | } |
peccu | 0:b9ac53c439ed | 911 | } |
peccu | 0:b9ac53c439ed | 912 | else |
peccu | 0:b9ac53c439ed | 913 | { |
peccu | 0:b9ac53c439ed | 914 | if(!is_upper_bound(i)) |
peccu | 0:b9ac53c439ed | 915 | { |
peccu | 0:b9ac53c439ed | 916 | if(-G[i] >= Gmax2) |
peccu | 0:b9ac53c439ed | 917 | Gmax2 = -G[i]; |
peccu | 0:b9ac53c439ed | 918 | } |
peccu | 0:b9ac53c439ed | 919 | if(!is_lower_bound(i)) |
peccu | 0:b9ac53c439ed | 920 | { |
peccu | 0:b9ac53c439ed | 921 | if(G[i] >= Gmax1) |
peccu | 0:b9ac53c439ed | 922 | Gmax1 = G[i]; |
peccu | 0:b9ac53c439ed | 923 | } |
peccu | 0:b9ac53c439ed | 924 | } |
peccu | 0:b9ac53c439ed | 925 | } |
peccu | 0:b9ac53c439ed | 926 | |
peccu | 0:b9ac53c439ed | 927 | if(unshrink == false && Gmax1 + Gmax2 <= eps*10) |
peccu | 0:b9ac53c439ed | 928 | { |
peccu | 0:b9ac53c439ed | 929 | unshrink = true; |
peccu | 0:b9ac53c439ed | 930 | reconstruct_gradient(); |
peccu | 0:b9ac53c439ed | 931 | active_size = l; |
peccu | 0:b9ac53c439ed | 932 | info("*"); |
peccu | 0:b9ac53c439ed | 933 | } |
peccu | 0:b9ac53c439ed | 934 | |
peccu | 0:b9ac53c439ed | 935 | for(i=0;i<active_size;i++) |
peccu | 0:b9ac53c439ed | 936 | if (be_shrunk(i, Gmax1, Gmax2)) |
peccu | 0:b9ac53c439ed | 937 | { |
peccu | 0:b9ac53c439ed | 938 | active_size--; |
peccu | 0:b9ac53c439ed | 939 | while (active_size > i) |
peccu | 0:b9ac53c439ed | 940 | { |
peccu | 0:b9ac53c439ed | 941 | if (!be_shrunk(active_size, Gmax1, Gmax2)) |
peccu | 0:b9ac53c439ed | 942 | { |
peccu | 0:b9ac53c439ed | 943 | swap_index(i,active_size); |
peccu | 0:b9ac53c439ed | 944 | break; |
peccu | 0:b9ac53c439ed | 945 | } |
peccu | 0:b9ac53c439ed | 946 | active_size--; |
peccu | 0:b9ac53c439ed | 947 | } |
peccu | 0:b9ac53c439ed | 948 | } |
peccu | 0:b9ac53c439ed | 949 | } |
peccu | 0:b9ac53c439ed | 950 | |
peccu | 0:b9ac53c439ed | 951 | double Solver::calculate_rho() |
peccu | 0:b9ac53c439ed | 952 | { |
peccu | 0:b9ac53c439ed | 953 | double r; |
peccu | 0:b9ac53c439ed | 954 | int nr_free = 0; |
peccu | 0:b9ac53c439ed | 955 | double ub = INF, lb = -INF, sum_free = 0; |
peccu | 0:b9ac53c439ed | 956 | for(int i=0;i<active_size;i++) |
peccu | 0:b9ac53c439ed | 957 | { |
peccu | 0:b9ac53c439ed | 958 | double yG = y[i]*G[i]; |
peccu | 0:b9ac53c439ed | 959 | |
peccu | 0:b9ac53c439ed | 960 | if(is_upper_bound(i)) |
peccu | 0:b9ac53c439ed | 961 | { |
peccu | 0:b9ac53c439ed | 962 | if(y[i]==-1) |
peccu | 0:b9ac53c439ed | 963 | ub = min(ub,yG); |
peccu | 0:b9ac53c439ed | 964 | else |
peccu | 0:b9ac53c439ed | 965 | lb = max(lb,yG); |
peccu | 0:b9ac53c439ed | 966 | } |
peccu | 0:b9ac53c439ed | 967 | else if(is_lower_bound(i)) |
peccu | 0:b9ac53c439ed | 968 | { |
peccu | 0:b9ac53c439ed | 969 | if(y[i]==+1) |
peccu | 0:b9ac53c439ed | 970 | ub = min(ub,yG); |
peccu | 0:b9ac53c439ed | 971 | else |
peccu | 0:b9ac53c439ed | 972 | lb = max(lb,yG); |
peccu | 0:b9ac53c439ed | 973 | } |
peccu | 0:b9ac53c439ed | 974 | else |
peccu | 0:b9ac53c439ed | 975 | { |
peccu | 0:b9ac53c439ed | 976 | ++nr_free; |
peccu | 0:b9ac53c439ed | 977 | sum_free += yG; |
peccu | 0:b9ac53c439ed | 978 | } |
peccu | 0:b9ac53c439ed | 979 | } |
peccu | 0:b9ac53c439ed | 980 | |
peccu | 0:b9ac53c439ed | 981 | if(nr_free>0) |
peccu | 0:b9ac53c439ed | 982 | r = sum_free/nr_free; |
peccu | 0:b9ac53c439ed | 983 | else |
peccu | 0:b9ac53c439ed | 984 | r = (ub+lb)/2; |
peccu | 0:b9ac53c439ed | 985 | |
peccu | 0:b9ac53c439ed | 986 | return r; |
peccu | 0:b9ac53c439ed | 987 | } |
peccu | 0:b9ac53c439ed | 988 | |
peccu | 0:b9ac53c439ed | 989 | // |
peccu | 0:b9ac53c439ed | 990 | // Solver for nu-svm classification and regression |
peccu | 0:b9ac53c439ed | 991 | // |
peccu | 0:b9ac53c439ed | 992 | // additional constraint: e^T \alpha = constant |
peccu | 0:b9ac53c439ed | 993 | // |
peccu | 0:b9ac53c439ed | 994 | class Solver_NU : public Solver |
peccu | 0:b9ac53c439ed | 995 | { |
peccu | 0:b9ac53c439ed | 996 | public: |
peccu | 0:b9ac53c439ed | 997 | Solver_NU() {} |
peccu | 0:b9ac53c439ed | 998 | void Solve(int l, const QMatrix& Q, const double *p, const schar *y, |
peccu | 0:b9ac53c439ed | 999 | double *alpha, double Cp, double Cn, double eps, |
peccu | 0:b9ac53c439ed | 1000 | SolutionInfo* si, int shrinking) |
peccu | 0:b9ac53c439ed | 1001 | { |
peccu | 0:b9ac53c439ed | 1002 | this->si = si; |
peccu | 0:b9ac53c439ed | 1003 | Solver::Solve(l,Q,p,y,alpha,Cp,Cn,eps,si,shrinking); |
peccu | 0:b9ac53c439ed | 1004 | } |
peccu | 0:b9ac53c439ed | 1005 | private: |
peccu | 0:b9ac53c439ed | 1006 | SolutionInfo *si; |
peccu | 0:b9ac53c439ed | 1007 | int select_working_set(int &i, int &j); |
peccu | 0:b9ac53c439ed | 1008 | double calculate_rho(); |
peccu | 0:b9ac53c439ed | 1009 | bool be_shrunk(int i, double Gmax1, double Gmax2, double Gmax3, double Gmax4); |
peccu | 0:b9ac53c439ed | 1010 | void do_shrinking(); |
peccu | 0:b9ac53c439ed | 1011 | }; |
peccu | 0:b9ac53c439ed | 1012 | |
peccu | 0:b9ac53c439ed | 1013 | // return 1 if already optimal, return 0 otherwise |
peccu | 0:b9ac53c439ed | 1014 | int Solver_NU::select_working_set(int &out_i, int &out_j) |
peccu | 0:b9ac53c439ed | 1015 | { |
peccu | 0:b9ac53c439ed | 1016 | // return i,j such that y_i = y_j and |
peccu | 0:b9ac53c439ed | 1017 | // i: maximizes -y_i * grad(f)_i, i in I_up(\alpha) |
peccu | 0:b9ac53c439ed | 1018 | // j: minimizes the decrease of obj value |
peccu | 0:b9ac53c439ed | 1019 | // (if quadratic coefficeint <= 0, replace it with tau) |
peccu | 0:b9ac53c439ed | 1020 | // -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha) |
peccu | 0:b9ac53c439ed | 1021 | |
peccu | 0:b9ac53c439ed | 1022 | double Gmaxp = -INF; |
peccu | 0:b9ac53c439ed | 1023 | double Gmaxp2 = -INF; |
peccu | 0:b9ac53c439ed | 1024 | int Gmaxp_idx = -1; |
peccu | 0:b9ac53c439ed | 1025 | |
peccu | 0:b9ac53c439ed | 1026 | double Gmaxn = -INF; |
peccu | 0:b9ac53c439ed | 1027 | double Gmaxn2 = -INF; |
peccu | 0:b9ac53c439ed | 1028 | int Gmaxn_idx = -1; |
peccu | 0:b9ac53c439ed | 1029 | |
peccu | 0:b9ac53c439ed | 1030 | int Gmin_idx = -1; |
peccu | 0:b9ac53c439ed | 1031 | double obj_diff_min = INF; |
peccu | 0:b9ac53c439ed | 1032 | |
peccu | 0:b9ac53c439ed | 1033 | for(int t=0;t<active_size;t++) |
peccu | 0:b9ac53c439ed | 1034 | if(y[t]==+1) |
peccu | 0:b9ac53c439ed | 1035 | { |
peccu | 0:b9ac53c439ed | 1036 | if(!is_upper_bound(t)) |
peccu | 0:b9ac53c439ed | 1037 | if(-G[t] >= Gmaxp) |
peccu | 0:b9ac53c439ed | 1038 | { |
peccu | 0:b9ac53c439ed | 1039 | Gmaxp = -G[t]; |
peccu | 0:b9ac53c439ed | 1040 | Gmaxp_idx = t; |
peccu | 0:b9ac53c439ed | 1041 | } |
peccu | 0:b9ac53c439ed | 1042 | } |
peccu | 0:b9ac53c439ed | 1043 | else |
peccu | 0:b9ac53c439ed | 1044 | { |
peccu | 0:b9ac53c439ed | 1045 | if(!is_lower_bound(t)) |
peccu | 0:b9ac53c439ed | 1046 | if(G[t] >= Gmaxn) |
peccu | 0:b9ac53c439ed | 1047 | { |
peccu | 0:b9ac53c439ed | 1048 | Gmaxn = G[t]; |
peccu | 0:b9ac53c439ed | 1049 | Gmaxn_idx = t; |
peccu | 0:b9ac53c439ed | 1050 | } |
peccu | 0:b9ac53c439ed | 1051 | } |
peccu | 0:b9ac53c439ed | 1052 | |
peccu | 0:b9ac53c439ed | 1053 | int ip = Gmaxp_idx; |
peccu | 0:b9ac53c439ed | 1054 | int in = Gmaxn_idx; |
peccu | 0:b9ac53c439ed | 1055 | const Qfloat *Q_ip = NULL; |
peccu | 0:b9ac53c439ed | 1056 | const Qfloat *Q_in = NULL; |
peccu | 0:b9ac53c439ed | 1057 | if(ip != -1) // NULL Q_ip not accessed: Gmaxp=-INF if ip=-1 |
peccu | 0:b9ac53c439ed | 1058 | Q_ip = Q->get_Q(ip,active_size); |
peccu | 0:b9ac53c439ed | 1059 | if(in != -1) |
peccu | 0:b9ac53c439ed | 1060 | Q_in = Q->get_Q(in,active_size); |
peccu | 0:b9ac53c439ed | 1061 | |
peccu | 0:b9ac53c439ed | 1062 | for(int j=0;j<active_size;j++) |
peccu | 0:b9ac53c439ed | 1063 | { |
peccu | 0:b9ac53c439ed | 1064 | if(y[j]==+1) |
peccu | 0:b9ac53c439ed | 1065 | { |
peccu | 0:b9ac53c439ed | 1066 | if (!is_lower_bound(j)) |
peccu | 0:b9ac53c439ed | 1067 | { |
peccu | 0:b9ac53c439ed | 1068 | double grad_diff=Gmaxp+G[j]; |
peccu | 0:b9ac53c439ed | 1069 | if (G[j] >= Gmaxp2) |
peccu | 0:b9ac53c439ed | 1070 | Gmaxp2 = G[j]; |
peccu | 0:b9ac53c439ed | 1071 | if (grad_diff > 0) |
peccu | 0:b9ac53c439ed | 1072 | { |
peccu | 0:b9ac53c439ed | 1073 | double obj_diff; |
peccu | 0:b9ac53c439ed | 1074 | double quad_coef = QD[ip]+QD[j]-2*Q_ip[j]; |
peccu | 0:b9ac53c439ed | 1075 | if (quad_coef > 0) |
peccu | 0:b9ac53c439ed | 1076 | obj_diff = -(grad_diff*grad_diff)/quad_coef; |
peccu | 0:b9ac53c439ed | 1077 | else |
peccu | 0:b9ac53c439ed | 1078 | obj_diff = -(grad_diff*grad_diff)/TAU; |
peccu | 0:b9ac53c439ed | 1079 | |
peccu | 0:b9ac53c439ed | 1080 | if (obj_diff <= obj_diff_min) |
peccu | 0:b9ac53c439ed | 1081 | { |
peccu | 0:b9ac53c439ed | 1082 | Gmin_idx=j; |
peccu | 0:b9ac53c439ed | 1083 | obj_diff_min = obj_diff; |
peccu | 0:b9ac53c439ed | 1084 | } |
peccu | 0:b9ac53c439ed | 1085 | } |
peccu | 0:b9ac53c439ed | 1086 | } |
peccu | 0:b9ac53c439ed | 1087 | } |
peccu | 0:b9ac53c439ed | 1088 | else |
peccu | 0:b9ac53c439ed | 1089 | { |
peccu | 0:b9ac53c439ed | 1090 | if (!is_upper_bound(j)) |
peccu | 0:b9ac53c439ed | 1091 | { |
peccu | 0:b9ac53c439ed | 1092 | double grad_diff=Gmaxn-G[j]; |
peccu | 0:b9ac53c439ed | 1093 | if (-G[j] >= Gmaxn2) |
peccu | 0:b9ac53c439ed | 1094 | Gmaxn2 = -G[j]; |
peccu | 0:b9ac53c439ed | 1095 | if (grad_diff > 0) |
peccu | 0:b9ac53c439ed | 1096 | { |
peccu | 0:b9ac53c439ed | 1097 | double obj_diff; |
peccu | 0:b9ac53c439ed | 1098 | double quad_coef = QD[in]+QD[j]-2*Q_in[j]; |
peccu | 0:b9ac53c439ed | 1099 | if (quad_coef > 0) |
peccu | 0:b9ac53c439ed | 1100 | obj_diff = -(grad_diff*grad_diff)/quad_coef; |
peccu | 0:b9ac53c439ed | 1101 | else |
peccu | 0:b9ac53c439ed | 1102 | obj_diff = -(grad_diff*grad_diff)/TAU; |
peccu | 0:b9ac53c439ed | 1103 | |
peccu | 0:b9ac53c439ed | 1104 | if (obj_diff <= obj_diff_min) |
peccu | 0:b9ac53c439ed | 1105 | { |
peccu | 0:b9ac53c439ed | 1106 | Gmin_idx=j; |
peccu | 0:b9ac53c439ed | 1107 | obj_diff_min = obj_diff; |
peccu | 0:b9ac53c439ed | 1108 | } |
peccu | 0:b9ac53c439ed | 1109 | } |
peccu | 0:b9ac53c439ed | 1110 | } |
peccu | 0:b9ac53c439ed | 1111 | } |
peccu | 0:b9ac53c439ed | 1112 | } |
peccu | 0:b9ac53c439ed | 1113 | |
peccu | 0:b9ac53c439ed | 1114 | if(max(Gmaxp+Gmaxp2,Gmaxn+Gmaxn2) < eps) |
peccu | 0:b9ac53c439ed | 1115 | return 1; |
peccu | 0:b9ac53c439ed | 1116 | |
peccu | 0:b9ac53c439ed | 1117 | if (y[Gmin_idx] == +1) |
peccu | 0:b9ac53c439ed | 1118 | out_i = Gmaxp_idx; |
peccu | 0:b9ac53c439ed | 1119 | else |
peccu | 0:b9ac53c439ed | 1120 | out_i = Gmaxn_idx; |
peccu | 0:b9ac53c439ed | 1121 | out_j = Gmin_idx; |
peccu | 0:b9ac53c439ed | 1122 | |
peccu | 0:b9ac53c439ed | 1123 | return 0; |
peccu | 0:b9ac53c439ed | 1124 | } |
peccu | 0:b9ac53c439ed | 1125 | |
peccu | 0:b9ac53c439ed | 1126 | bool Solver_NU::be_shrunk(int i, double Gmax1, double Gmax2, double Gmax3, double Gmax4) |
peccu | 0:b9ac53c439ed | 1127 | { |
peccu | 0:b9ac53c439ed | 1128 | if(is_upper_bound(i)) |
peccu | 0:b9ac53c439ed | 1129 | { |
peccu | 0:b9ac53c439ed | 1130 | if(y[i]==+1) |
peccu | 0:b9ac53c439ed | 1131 | return(-G[i] > Gmax1); |
peccu | 0:b9ac53c439ed | 1132 | else |
peccu | 0:b9ac53c439ed | 1133 | return(-G[i] > Gmax4); |
peccu | 0:b9ac53c439ed | 1134 | } |
peccu | 0:b9ac53c439ed | 1135 | else if(is_lower_bound(i)) |
peccu | 0:b9ac53c439ed | 1136 | { |
peccu | 0:b9ac53c439ed | 1137 | if(y[i]==+1) |
peccu | 0:b9ac53c439ed | 1138 | return(G[i] > Gmax2); |
peccu | 0:b9ac53c439ed | 1139 | else |
peccu | 0:b9ac53c439ed | 1140 | return(G[i] > Gmax3); |
peccu | 0:b9ac53c439ed | 1141 | } |
peccu | 0:b9ac53c439ed | 1142 | else |
peccu | 0:b9ac53c439ed | 1143 | return(false); |
peccu | 0:b9ac53c439ed | 1144 | } |
peccu | 0:b9ac53c439ed | 1145 | |
peccu | 0:b9ac53c439ed | 1146 | void Solver_NU::do_shrinking() |
peccu | 0:b9ac53c439ed | 1147 | { |
peccu | 0:b9ac53c439ed | 1148 | double Gmax1 = -INF; // max { -y_i * grad(f)_i | y_i = +1, i in I_up(\alpha) } |
peccu | 0:b9ac53c439ed | 1149 | double Gmax2 = -INF; // max { y_i * grad(f)_i | y_i = +1, i in I_low(\alpha) } |
peccu | 0:b9ac53c439ed | 1150 | double Gmax3 = -INF; // max { -y_i * grad(f)_i | y_i = -1, i in I_up(\alpha) } |
peccu | 0:b9ac53c439ed | 1151 | double Gmax4 = -INF; // max { y_i * grad(f)_i | y_i = -1, i in I_low(\alpha) } |
peccu | 0:b9ac53c439ed | 1152 | |
peccu | 0:b9ac53c439ed | 1153 | // find maximal violating pair first |
peccu | 0:b9ac53c439ed | 1154 | int i; |
peccu | 0:b9ac53c439ed | 1155 | for(i=0;i<active_size;i++) |
peccu | 0:b9ac53c439ed | 1156 | { |
peccu | 0:b9ac53c439ed | 1157 | if(!is_upper_bound(i)) |
peccu | 0:b9ac53c439ed | 1158 | { |
peccu | 0:b9ac53c439ed | 1159 | if(y[i]==+1) |
peccu | 0:b9ac53c439ed | 1160 | { |
peccu | 0:b9ac53c439ed | 1161 | if(-G[i] > Gmax1) Gmax1 = -G[i]; |
peccu | 0:b9ac53c439ed | 1162 | } |
peccu | 0:b9ac53c439ed | 1163 | else if(-G[i] > Gmax4) Gmax4 = -G[i]; |
peccu | 0:b9ac53c439ed | 1164 | } |
peccu | 0:b9ac53c439ed | 1165 | if(!is_lower_bound(i)) |
peccu | 0:b9ac53c439ed | 1166 | { |
peccu | 0:b9ac53c439ed | 1167 | if(y[i]==+1) |
peccu | 0:b9ac53c439ed | 1168 | { |
peccu | 0:b9ac53c439ed | 1169 | if(G[i] > Gmax2) Gmax2 = G[i]; |
peccu | 0:b9ac53c439ed | 1170 | } |
peccu | 0:b9ac53c439ed | 1171 | else if(G[i] > Gmax3) Gmax3 = G[i]; |
peccu | 0:b9ac53c439ed | 1172 | } |
peccu | 0:b9ac53c439ed | 1173 | } |
peccu | 0:b9ac53c439ed | 1174 | |
peccu | 0:b9ac53c439ed | 1175 | if(unshrink == false && max(Gmax1+Gmax2,Gmax3+Gmax4) <= eps*10) |
peccu | 0:b9ac53c439ed | 1176 | { |
peccu | 0:b9ac53c439ed | 1177 | unshrink = true; |
peccu | 0:b9ac53c439ed | 1178 | reconstruct_gradient(); |
peccu | 0:b9ac53c439ed | 1179 | active_size = l; |
peccu | 0:b9ac53c439ed | 1180 | } |
peccu | 0:b9ac53c439ed | 1181 | |
peccu | 0:b9ac53c439ed | 1182 | for(i=0;i<active_size;i++) |
peccu | 0:b9ac53c439ed | 1183 | if (be_shrunk(i, Gmax1, Gmax2, Gmax3, Gmax4)) |
peccu | 0:b9ac53c439ed | 1184 | { |
peccu | 0:b9ac53c439ed | 1185 | active_size--; |
peccu | 0:b9ac53c439ed | 1186 | while (active_size > i) |
peccu | 0:b9ac53c439ed | 1187 | { |
peccu | 0:b9ac53c439ed | 1188 | if (!be_shrunk(active_size, Gmax1, Gmax2, Gmax3, Gmax4)) |
peccu | 0:b9ac53c439ed | 1189 | { |
peccu | 0:b9ac53c439ed | 1190 | swap_index(i,active_size); |
peccu | 0:b9ac53c439ed | 1191 | break; |
peccu | 0:b9ac53c439ed | 1192 | } |
peccu | 0:b9ac53c439ed | 1193 | active_size--; |
peccu | 0:b9ac53c439ed | 1194 | } |
peccu | 0:b9ac53c439ed | 1195 | } |
peccu | 0:b9ac53c439ed | 1196 | } |
peccu | 0:b9ac53c439ed | 1197 | |
peccu | 0:b9ac53c439ed | 1198 | double Solver_NU::calculate_rho() |
peccu | 0:b9ac53c439ed | 1199 | { |
peccu | 0:b9ac53c439ed | 1200 | int nr_free1 = 0,nr_free2 = 0; |
peccu | 0:b9ac53c439ed | 1201 | double ub1 = INF, ub2 = INF; |
peccu | 0:b9ac53c439ed | 1202 | double lb1 = -INF, lb2 = -INF; |
peccu | 0:b9ac53c439ed | 1203 | double sum_free1 = 0, sum_free2 = 0; |
peccu | 0:b9ac53c439ed | 1204 | |
peccu | 0:b9ac53c439ed | 1205 | for(int i=0;i<active_size;i++) |
peccu | 0:b9ac53c439ed | 1206 | { |
peccu | 0:b9ac53c439ed | 1207 | if(y[i]==+1) |
peccu | 0:b9ac53c439ed | 1208 | { |
peccu | 0:b9ac53c439ed | 1209 | if(is_upper_bound(i)) |
peccu | 0:b9ac53c439ed | 1210 | lb1 = max(lb1,G[i]); |
peccu | 0:b9ac53c439ed | 1211 | else if(is_lower_bound(i)) |
peccu | 0:b9ac53c439ed | 1212 | ub1 = min(ub1,G[i]); |
peccu | 0:b9ac53c439ed | 1213 | else |
peccu | 0:b9ac53c439ed | 1214 | { |
peccu | 0:b9ac53c439ed | 1215 | ++nr_free1; |
peccu | 0:b9ac53c439ed | 1216 | sum_free1 += G[i]; |
peccu | 0:b9ac53c439ed | 1217 | } |
peccu | 0:b9ac53c439ed | 1218 | } |
peccu | 0:b9ac53c439ed | 1219 | else |
peccu | 0:b9ac53c439ed | 1220 | { |
peccu | 0:b9ac53c439ed | 1221 | if(is_upper_bound(i)) |
peccu | 0:b9ac53c439ed | 1222 | lb2 = max(lb2,G[i]); |
peccu | 0:b9ac53c439ed | 1223 | else if(is_lower_bound(i)) |
peccu | 0:b9ac53c439ed | 1224 | ub2 = min(ub2,G[i]); |
peccu | 0:b9ac53c439ed | 1225 | else |
peccu | 0:b9ac53c439ed | 1226 | { |
peccu | 0:b9ac53c439ed | 1227 | ++nr_free2; |
peccu | 0:b9ac53c439ed | 1228 | sum_free2 += G[i]; |
peccu | 0:b9ac53c439ed | 1229 | } |
peccu | 0:b9ac53c439ed | 1230 | } |
peccu | 0:b9ac53c439ed | 1231 | } |
peccu | 0:b9ac53c439ed | 1232 | |
peccu | 0:b9ac53c439ed | 1233 | double r1,r2; |
peccu | 0:b9ac53c439ed | 1234 | if(nr_free1 > 0) |
peccu | 0:b9ac53c439ed | 1235 | r1 = sum_free1/nr_free1; |
peccu | 0:b9ac53c439ed | 1236 | else |
peccu | 0:b9ac53c439ed | 1237 | r1 = (ub1+lb1)/2; |
peccu | 0:b9ac53c439ed | 1238 | |
peccu | 0:b9ac53c439ed | 1239 | if(nr_free2 > 0) |
peccu | 0:b9ac53c439ed | 1240 | r2 = sum_free2/nr_free2; |
peccu | 0:b9ac53c439ed | 1241 | else |
peccu | 0:b9ac53c439ed | 1242 | r2 = (ub2+lb2)/2; |
peccu | 0:b9ac53c439ed | 1243 | |
peccu | 0:b9ac53c439ed | 1244 | si->r = (r1+r2)/2; |
peccu | 0:b9ac53c439ed | 1245 | return (r1-r2)/2; |
peccu | 0:b9ac53c439ed | 1246 | } |
peccu | 0:b9ac53c439ed | 1247 | |
peccu | 0:b9ac53c439ed | 1248 | // |
peccu | 0:b9ac53c439ed | 1249 | // Q matrices for various formulations |
peccu | 0:b9ac53c439ed | 1250 | // |
peccu | 0:b9ac53c439ed | 1251 | class SVC_Q: public Kernel |
peccu | 0:b9ac53c439ed | 1252 | { |
peccu | 0:b9ac53c439ed | 1253 | public: |
peccu | 0:b9ac53c439ed | 1254 | SVC_Q(const svm_problem& prob, const svm_parameter& param, const schar *y_) |
peccu | 0:b9ac53c439ed | 1255 | :Kernel(prob.l, prob.x, param) |
peccu | 0:b9ac53c439ed | 1256 | { |
peccu | 0:b9ac53c439ed | 1257 | clone(y,y_,prob.l); |
peccu | 0:b9ac53c439ed | 1258 | cache = new Cache(prob.l,(long int)(param.cache_size*(1<<20))); |
peccu | 0:b9ac53c439ed | 1259 | QD = new double[prob.l]; |
peccu | 0:b9ac53c439ed | 1260 | for(int i=0;i<prob.l;i++) |
peccu | 0:b9ac53c439ed | 1261 | QD[i] = (this->*kernel_function)(i,i); |
peccu | 0:b9ac53c439ed | 1262 | } |
peccu | 0:b9ac53c439ed | 1263 | |
peccu | 0:b9ac53c439ed | 1264 | Qfloat *get_Q(int i, int len) const |
peccu | 0:b9ac53c439ed | 1265 | { |
peccu | 0:b9ac53c439ed | 1266 | Qfloat *data; |
peccu | 0:b9ac53c439ed | 1267 | int start, j; |
peccu | 0:b9ac53c439ed | 1268 | if((start = cache->get_data(i,&data,len)) < len) |
peccu | 0:b9ac53c439ed | 1269 | { |
peccu | 0:b9ac53c439ed | 1270 | for(j=start;j<len;j++) |
peccu | 0:b9ac53c439ed | 1271 | data[j] = (Qfloat)(y[i]*y[j]*(this->*kernel_function)(i,j)); |
peccu | 0:b9ac53c439ed | 1272 | } |
peccu | 0:b9ac53c439ed | 1273 | return data; |
peccu | 0:b9ac53c439ed | 1274 | } |
peccu | 0:b9ac53c439ed | 1275 | |
peccu | 0:b9ac53c439ed | 1276 | double *get_QD() const |
peccu | 0:b9ac53c439ed | 1277 | { |
peccu | 0:b9ac53c439ed | 1278 | return QD; |
peccu | 0:b9ac53c439ed | 1279 | } |
peccu | 0:b9ac53c439ed | 1280 | |
peccu | 0:b9ac53c439ed | 1281 | void swap_index(int i, int j) const |
peccu | 0:b9ac53c439ed | 1282 | { |
peccu | 0:b9ac53c439ed | 1283 | cache->swap_index(i,j); |
peccu | 0:b9ac53c439ed | 1284 | Kernel::swap_index(i,j); |
peccu | 0:b9ac53c439ed | 1285 | swap(y[i],y[j]); |
peccu | 0:b9ac53c439ed | 1286 | swap(QD[i],QD[j]); |
peccu | 0:b9ac53c439ed | 1287 | } |
peccu | 0:b9ac53c439ed | 1288 | |
peccu | 0:b9ac53c439ed | 1289 | ~SVC_Q() |
peccu | 0:b9ac53c439ed | 1290 | { |
peccu | 0:b9ac53c439ed | 1291 | delete[] y; |
peccu | 0:b9ac53c439ed | 1292 | delete cache; |
peccu | 0:b9ac53c439ed | 1293 | delete[] QD; |
peccu | 0:b9ac53c439ed | 1294 | } |
peccu | 0:b9ac53c439ed | 1295 | private: |
peccu | 0:b9ac53c439ed | 1296 | schar *y; |
peccu | 0:b9ac53c439ed | 1297 | Cache *cache; |
peccu | 0:b9ac53c439ed | 1298 | double *QD; |
peccu | 0:b9ac53c439ed | 1299 | }; |
peccu | 0:b9ac53c439ed | 1300 | |
peccu | 0:b9ac53c439ed | 1301 | class ONE_CLASS_Q: public Kernel |
peccu | 0:b9ac53c439ed | 1302 | { |
peccu | 0:b9ac53c439ed | 1303 | public: |
peccu | 0:b9ac53c439ed | 1304 | ONE_CLASS_Q(const svm_problem& prob, const svm_parameter& param) |
peccu | 0:b9ac53c439ed | 1305 | :Kernel(prob.l, prob.x, param) |
peccu | 0:b9ac53c439ed | 1306 | { |
peccu | 0:b9ac53c439ed | 1307 | cache = new Cache(prob.l,(long int)(param.cache_size*(1<<20))); |
peccu | 0:b9ac53c439ed | 1308 | QD = new double[prob.l]; |
peccu | 0:b9ac53c439ed | 1309 | for(int i=0;i<prob.l;i++) |
peccu | 0:b9ac53c439ed | 1310 | QD[i] = (this->*kernel_function)(i,i); |
peccu | 0:b9ac53c439ed | 1311 | } |
peccu | 0:b9ac53c439ed | 1312 | |
peccu | 0:b9ac53c439ed | 1313 | Qfloat *get_Q(int i, int len) const |
peccu | 0:b9ac53c439ed | 1314 | { |
peccu | 0:b9ac53c439ed | 1315 | Qfloat *data; |
peccu | 0:b9ac53c439ed | 1316 | int start, j; |
peccu | 0:b9ac53c439ed | 1317 | if((start = cache->get_data(i,&data,len)) < len) |
peccu | 0:b9ac53c439ed | 1318 | { |
peccu | 0:b9ac53c439ed | 1319 | for(j=start;j<len;j++) |
peccu | 0:b9ac53c439ed | 1320 | data[j] = (Qfloat)(this->*kernel_function)(i,j); |
peccu | 0:b9ac53c439ed | 1321 | } |
peccu | 0:b9ac53c439ed | 1322 | return data; |
peccu | 0:b9ac53c439ed | 1323 | } |
peccu | 0:b9ac53c439ed | 1324 | |
peccu | 0:b9ac53c439ed | 1325 | double *get_QD() const |
peccu | 0:b9ac53c439ed | 1326 | { |
peccu | 0:b9ac53c439ed | 1327 | return QD; |
peccu | 0:b9ac53c439ed | 1328 | } |
peccu | 0:b9ac53c439ed | 1329 | |
peccu | 0:b9ac53c439ed | 1330 | void swap_index(int i, int j) const |
peccu | 0:b9ac53c439ed | 1331 | { |
peccu | 0:b9ac53c439ed | 1332 | cache->swap_index(i,j); |
peccu | 0:b9ac53c439ed | 1333 | Kernel::swap_index(i,j); |
peccu | 0:b9ac53c439ed | 1334 | swap(QD[i],QD[j]); |
peccu | 0:b9ac53c439ed | 1335 | } |
peccu | 0:b9ac53c439ed | 1336 | |
peccu | 0:b9ac53c439ed | 1337 | ~ONE_CLASS_Q() |
peccu | 0:b9ac53c439ed | 1338 | { |
peccu | 0:b9ac53c439ed | 1339 | delete cache; |
peccu | 0:b9ac53c439ed | 1340 | delete[] QD; |
peccu | 0:b9ac53c439ed | 1341 | } |
peccu | 0:b9ac53c439ed | 1342 | private: |
peccu | 0:b9ac53c439ed | 1343 | Cache *cache; |
peccu | 0:b9ac53c439ed | 1344 | double *QD; |
peccu | 0:b9ac53c439ed | 1345 | }; |
peccu | 0:b9ac53c439ed | 1346 | |
peccu | 0:b9ac53c439ed | 1347 | class SVR_Q: public Kernel |
peccu | 0:b9ac53c439ed | 1348 | { |
peccu | 0:b9ac53c439ed | 1349 | public: |
peccu | 0:b9ac53c439ed | 1350 | SVR_Q(const svm_problem& prob, const svm_parameter& param) |
peccu | 0:b9ac53c439ed | 1351 | :Kernel(prob.l, prob.x, param) |
peccu | 0:b9ac53c439ed | 1352 | { |
peccu | 0:b9ac53c439ed | 1353 | l = prob.l; |
peccu | 0:b9ac53c439ed | 1354 | cache = new Cache(l,(long int)(param.cache_size*(1<<20))); |
peccu | 0:b9ac53c439ed | 1355 | QD = new double[2*l]; |
peccu | 0:b9ac53c439ed | 1356 | sign = new schar[2*l]; |
peccu | 0:b9ac53c439ed | 1357 | index = new int[2*l]; |
peccu | 0:b9ac53c439ed | 1358 | for(int k=0;k<l;k++) |
peccu | 0:b9ac53c439ed | 1359 | { |
peccu | 0:b9ac53c439ed | 1360 | sign[k] = 1; |
peccu | 0:b9ac53c439ed | 1361 | sign[k+l] = -1; |
peccu | 0:b9ac53c439ed | 1362 | index[k] = k; |
peccu | 0:b9ac53c439ed | 1363 | index[k+l] = k; |
peccu | 0:b9ac53c439ed | 1364 | QD[k] = (this->*kernel_function)(k,k); |
peccu | 0:b9ac53c439ed | 1365 | QD[k+l] = QD[k]; |
peccu | 0:b9ac53c439ed | 1366 | } |
peccu | 0:b9ac53c439ed | 1367 | buffer[0] = new Qfloat[2*l]; |
peccu | 0:b9ac53c439ed | 1368 | buffer[1] = new Qfloat[2*l]; |
peccu | 0:b9ac53c439ed | 1369 | next_buffer = 0; |
peccu | 0:b9ac53c439ed | 1370 | } |
peccu | 0:b9ac53c439ed | 1371 | |
peccu | 0:b9ac53c439ed | 1372 | void swap_index(int i, int j) const |
peccu | 0:b9ac53c439ed | 1373 | { |
peccu | 0:b9ac53c439ed | 1374 | swap(sign[i],sign[j]); |
peccu | 0:b9ac53c439ed | 1375 | swap(index[i],index[j]); |
peccu | 0:b9ac53c439ed | 1376 | swap(QD[i],QD[j]); |
peccu | 0:b9ac53c439ed | 1377 | } |
peccu | 0:b9ac53c439ed | 1378 | |
peccu | 0:b9ac53c439ed | 1379 | Qfloat *get_Q(int i, int len) const |
peccu | 0:b9ac53c439ed | 1380 | { |
peccu | 0:b9ac53c439ed | 1381 | Qfloat *data; |
peccu | 0:b9ac53c439ed | 1382 | int j, real_i = index[i]; |
peccu | 0:b9ac53c439ed | 1383 | if(cache->get_data(real_i,&data,l) < l) |
peccu | 0:b9ac53c439ed | 1384 | { |
peccu | 0:b9ac53c439ed | 1385 | for(j=0;j<l;j++) |
peccu | 0:b9ac53c439ed | 1386 | data[j] = (Qfloat)(this->*kernel_function)(real_i,j); |
peccu | 0:b9ac53c439ed | 1387 | } |
peccu | 0:b9ac53c439ed | 1388 | |
peccu | 0:b9ac53c439ed | 1389 | // reorder and copy |
peccu | 0:b9ac53c439ed | 1390 | Qfloat *buf = buffer[next_buffer]; |
peccu | 0:b9ac53c439ed | 1391 | next_buffer = 1 - next_buffer; |
peccu | 0:b9ac53c439ed | 1392 | schar si = sign[i]; |
peccu | 0:b9ac53c439ed | 1393 | for(j=0;j<len;j++) |
peccu | 0:b9ac53c439ed | 1394 | buf[j] = (Qfloat) si * (Qfloat) sign[j] * data[index[j]]; |
peccu | 0:b9ac53c439ed | 1395 | return buf; |
peccu | 0:b9ac53c439ed | 1396 | } |
peccu | 0:b9ac53c439ed | 1397 | |
peccu | 0:b9ac53c439ed | 1398 | double *get_QD() const |
peccu | 0:b9ac53c439ed | 1399 | { |
peccu | 0:b9ac53c439ed | 1400 | return QD; |
peccu | 0:b9ac53c439ed | 1401 | } |
peccu | 0:b9ac53c439ed | 1402 | |
peccu | 0:b9ac53c439ed | 1403 | ~SVR_Q() |
peccu | 0:b9ac53c439ed | 1404 | { |
peccu | 0:b9ac53c439ed | 1405 | delete cache; |
peccu | 0:b9ac53c439ed | 1406 | delete[] sign; |
peccu | 0:b9ac53c439ed | 1407 | delete[] index; |
peccu | 0:b9ac53c439ed | 1408 | delete[] buffer[0]; |
peccu | 0:b9ac53c439ed | 1409 | delete[] buffer[1]; |
peccu | 0:b9ac53c439ed | 1410 | delete[] QD; |
peccu | 0:b9ac53c439ed | 1411 | } |
peccu | 0:b9ac53c439ed | 1412 | private: |
peccu | 0:b9ac53c439ed | 1413 | int l; |
peccu | 0:b9ac53c439ed | 1414 | Cache *cache; |
peccu | 0:b9ac53c439ed | 1415 | schar *sign; |
peccu | 0:b9ac53c439ed | 1416 | int *index; |
peccu | 0:b9ac53c439ed | 1417 | mutable int next_buffer; |
peccu | 0:b9ac53c439ed | 1418 | Qfloat *buffer[2]; |
peccu | 0:b9ac53c439ed | 1419 | double *QD; |
peccu | 0:b9ac53c439ed | 1420 | }; |
peccu | 0:b9ac53c439ed | 1421 | |
peccu | 0:b9ac53c439ed | 1422 | // |
peccu | 0:b9ac53c439ed | 1423 | // construct and solve various formulations |
peccu | 0:b9ac53c439ed | 1424 | // |
peccu | 0:b9ac53c439ed | 1425 | static void solve_c_svc( |
peccu | 0:b9ac53c439ed | 1426 | const svm_problem *prob, const svm_parameter* param, |
peccu | 0:b9ac53c439ed | 1427 | double *alpha, Solver::SolutionInfo* si, double Cp, double Cn) |
peccu | 0:b9ac53c439ed | 1428 | { |
peccu | 0:b9ac53c439ed | 1429 | int l = prob->l; |
peccu | 0:b9ac53c439ed | 1430 | double *minus_ones = new double[l]; |
peccu | 0:b9ac53c439ed | 1431 | schar *y = new schar[l]; |
peccu | 0:b9ac53c439ed | 1432 | |
peccu | 0:b9ac53c439ed | 1433 | int i; |
peccu | 0:b9ac53c439ed | 1434 | |
peccu | 0:b9ac53c439ed | 1435 | for(i=0;i<l;i++) |
peccu | 0:b9ac53c439ed | 1436 | { |
peccu | 0:b9ac53c439ed | 1437 | alpha[i] = 0; |
peccu | 0:b9ac53c439ed | 1438 | minus_ones[i] = -1; |
peccu | 0:b9ac53c439ed | 1439 | if(prob->y[i] > 0) y[i] = +1; else y[i] = -1; |
peccu | 0:b9ac53c439ed | 1440 | } |
peccu | 0:b9ac53c439ed | 1441 | |
peccu | 0:b9ac53c439ed | 1442 | Solver s; |
peccu | 0:b9ac53c439ed | 1443 | s.Solve(l, SVC_Q(*prob,*param,y), minus_ones, y, |
peccu | 0:b9ac53c439ed | 1444 | alpha, Cp, Cn, param->eps, si, param->shrinking); |
peccu | 0:b9ac53c439ed | 1445 | |
peccu | 0:b9ac53c439ed | 1446 | double sum_alpha=0; |
peccu | 0:b9ac53c439ed | 1447 | for(i=0;i<l;i++) |
peccu | 0:b9ac53c439ed | 1448 | sum_alpha += alpha[i]; |
peccu | 0:b9ac53c439ed | 1449 | |
peccu | 0:b9ac53c439ed | 1450 | if (Cp==Cn) |
peccu | 0:b9ac53c439ed | 1451 | info("nu = %f\n", sum_alpha/(Cp*prob->l)); |
peccu | 0:b9ac53c439ed | 1452 | |
peccu | 0:b9ac53c439ed | 1453 | for(i=0;i<l;i++) |
peccu | 0:b9ac53c439ed | 1454 | alpha[i] *= y[i]; |
peccu | 0:b9ac53c439ed | 1455 | |
peccu | 0:b9ac53c439ed | 1456 | delete[] minus_ones; |
peccu | 0:b9ac53c439ed | 1457 | delete[] y; |
peccu | 0:b9ac53c439ed | 1458 | } |
peccu | 0:b9ac53c439ed | 1459 | |
peccu | 0:b9ac53c439ed | 1460 | static void solve_nu_svc( |
peccu | 0:b9ac53c439ed | 1461 | const svm_problem *prob, const svm_parameter *param, |
peccu | 0:b9ac53c439ed | 1462 | double *alpha, Solver::SolutionInfo* si) |
peccu | 0:b9ac53c439ed | 1463 | { |
peccu | 0:b9ac53c439ed | 1464 | int i; |
peccu | 0:b9ac53c439ed | 1465 | int l = prob->l; |
peccu | 0:b9ac53c439ed | 1466 | double nu = param->nu; |
peccu | 0:b9ac53c439ed | 1467 | |
peccu | 0:b9ac53c439ed | 1468 | schar *y = new schar[l]; |
peccu | 0:b9ac53c439ed | 1469 | |
peccu | 0:b9ac53c439ed | 1470 | for(i=0;i<l;i++) |
peccu | 0:b9ac53c439ed | 1471 | if(prob->y[i]>0) |
peccu | 0:b9ac53c439ed | 1472 | y[i] = +1; |
peccu | 0:b9ac53c439ed | 1473 | else |
peccu | 0:b9ac53c439ed | 1474 | y[i] = -1; |
peccu | 0:b9ac53c439ed | 1475 | |
peccu | 0:b9ac53c439ed | 1476 | double sum_pos = nu*l/2; |
peccu | 0:b9ac53c439ed | 1477 | double sum_neg = nu*l/2; |
peccu | 0:b9ac53c439ed | 1478 | |
peccu | 0:b9ac53c439ed | 1479 | for(i=0;i<l;i++) |
peccu | 0:b9ac53c439ed | 1480 | if(y[i] == +1) |
peccu | 0:b9ac53c439ed | 1481 | { |
peccu | 0:b9ac53c439ed | 1482 | alpha[i] = min(1.0,sum_pos); |
peccu | 0:b9ac53c439ed | 1483 | sum_pos -= alpha[i]; |
peccu | 0:b9ac53c439ed | 1484 | } |
peccu | 0:b9ac53c439ed | 1485 | else |
peccu | 0:b9ac53c439ed | 1486 | { |
peccu | 0:b9ac53c439ed | 1487 | alpha[i] = min(1.0,sum_neg); |
peccu | 0:b9ac53c439ed | 1488 | sum_neg -= alpha[i]; |
peccu | 0:b9ac53c439ed | 1489 | } |
peccu | 0:b9ac53c439ed | 1490 | |
peccu | 0:b9ac53c439ed | 1491 | double *zeros = new double[l]; |
peccu | 0:b9ac53c439ed | 1492 | |
peccu | 0:b9ac53c439ed | 1493 | for(i=0;i<l;i++) |
peccu | 0:b9ac53c439ed | 1494 | zeros[i] = 0; |
peccu | 0:b9ac53c439ed | 1495 | |
peccu | 0:b9ac53c439ed | 1496 | Solver_NU s; |
peccu | 0:b9ac53c439ed | 1497 | s.Solve(l, SVC_Q(*prob,*param,y), zeros, y, |
peccu | 0:b9ac53c439ed | 1498 | alpha, 1.0, 1.0, param->eps, si, param->shrinking); |
peccu | 0:b9ac53c439ed | 1499 | double r = si->r; |
peccu | 0:b9ac53c439ed | 1500 | |
peccu | 0:b9ac53c439ed | 1501 | info("C = %f\n",1/r); |
peccu | 0:b9ac53c439ed | 1502 | |
peccu | 0:b9ac53c439ed | 1503 | for(i=0;i<l;i++) |
peccu | 0:b9ac53c439ed | 1504 | alpha[i] *= y[i]/r; |
peccu | 0:b9ac53c439ed | 1505 | |
peccu | 0:b9ac53c439ed | 1506 | si->rho /= r; |
peccu | 0:b9ac53c439ed | 1507 | si->obj /= (r*r); |
peccu | 0:b9ac53c439ed | 1508 | si->upper_bound_p = 1/r; |
peccu | 0:b9ac53c439ed | 1509 | si->upper_bound_n = 1/r; |
peccu | 0:b9ac53c439ed | 1510 | |
peccu | 0:b9ac53c439ed | 1511 | delete[] y; |
peccu | 0:b9ac53c439ed | 1512 | delete[] zeros; |
peccu | 0:b9ac53c439ed | 1513 | } |
peccu | 0:b9ac53c439ed | 1514 | |
peccu | 0:b9ac53c439ed | 1515 | static void solve_one_class( |
peccu | 0:b9ac53c439ed | 1516 | const svm_problem *prob, const svm_parameter *param, |
peccu | 0:b9ac53c439ed | 1517 | double *alpha, Solver::SolutionInfo* si) |
peccu | 0:b9ac53c439ed | 1518 | { |
peccu | 0:b9ac53c439ed | 1519 | int l = prob->l; |
peccu | 0:b9ac53c439ed | 1520 | double *zeros = new double[l]; |
peccu | 0:b9ac53c439ed | 1521 | schar *ones = new schar[l]; |
peccu | 0:b9ac53c439ed | 1522 | int i; |
peccu | 0:b9ac53c439ed | 1523 | |
peccu | 0:b9ac53c439ed | 1524 | int n = (int)(param->nu*prob->l); // # of alpha's at upper bound |
peccu | 0:b9ac53c439ed | 1525 | |
peccu | 0:b9ac53c439ed | 1526 | for(i=0;i<n;i++) |
peccu | 0:b9ac53c439ed | 1527 | alpha[i] = 1; |
peccu | 0:b9ac53c439ed | 1528 | if(n<prob->l) |
peccu | 0:b9ac53c439ed | 1529 | alpha[n] = param->nu * prob->l - n; |
peccu | 0:b9ac53c439ed | 1530 | for(i=n+1;i<l;i++) |
peccu | 0:b9ac53c439ed | 1531 | alpha[i] = 0; |
peccu | 0:b9ac53c439ed | 1532 | |
peccu | 0:b9ac53c439ed | 1533 | for(i=0;i<l;i++) |
peccu | 0:b9ac53c439ed | 1534 | { |
peccu | 0:b9ac53c439ed | 1535 | zeros[i] = 0; |
peccu | 0:b9ac53c439ed | 1536 | ones[i] = 1; |
peccu | 0:b9ac53c439ed | 1537 | } |
peccu | 0:b9ac53c439ed | 1538 | |
peccu | 0:b9ac53c439ed | 1539 | Solver s; |
peccu | 0:b9ac53c439ed | 1540 | s.Solve(l, ONE_CLASS_Q(*prob,*param), zeros, ones, |
peccu | 0:b9ac53c439ed | 1541 | alpha, 1.0, 1.0, param->eps, si, param->shrinking); |
peccu | 0:b9ac53c439ed | 1542 | |
peccu | 0:b9ac53c439ed | 1543 | delete[] zeros; |
peccu | 0:b9ac53c439ed | 1544 | delete[] ones; |
peccu | 0:b9ac53c439ed | 1545 | } |
peccu | 0:b9ac53c439ed | 1546 | |
peccu | 0:b9ac53c439ed | 1547 | static void solve_epsilon_svr( |
peccu | 0:b9ac53c439ed | 1548 | const svm_problem *prob, const svm_parameter *param, |
peccu | 0:b9ac53c439ed | 1549 | double *alpha, Solver::SolutionInfo* si) |
peccu | 0:b9ac53c439ed | 1550 | { |
peccu | 0:b9ac53c439ed | 1551 | int l = prob->l; |
peccu | 0:b9ac53c439ed | 1552 | double *alpha2 = new double[2*l]; |
peccu | 0:b9ac53c439ed | 1553 | double *linear_term = new double[2*l]; |
peccu | 0:b9ac53c439ed | 1554 | schar *y = new schar[2*l]; |
peccu | 0:b9ac53c439ed | 1555 | int i; |
peccu | 0:b9ac53c439ed | 1556 | |
peccu | 0:b9ac53c439ed | 1557 | for(i=0;i<l;i++) |
peccu | 0:b9ac53c439ed | 1558 | { |
peccu | 0:b9ac53c439ed | 1559 | alpha2[i] = 0; |
peccu | 0:b9ac53c439ed | 1560 | linear_term[i] = param->p - prob->y[i]; |
peccu | 0:b9ac53c439ed | 1561 | y[i] = 1; |
peccu | 0:b9ac53c439ed | 1562 | |
peccu | 0:b9ac53c439ed | 1563 | alpha2[i+l] = 0; |
peccu | 0:b9ac53c439ed | 1564 | linear_term[i+l] = param->p + prob->y[i]; |
peccu | 0:b9ac53c439ed | 1565 | y[i+l] = -1; |
peccu | 0:b9ac53c439ed | 1566 | } |
peccu | 0:b9ac53c439ed | 1567 | |
peccu | 0:b9ac53c439ed | 1568 | Solver s; |
peccu | 0:b9ac53c439ed | 1569 | s.Solve(2*l, SVR_Q(*prob,*param), linear_term, y, |
peccu | 0:b9ac53c439ed | 1570 | alpha2, param->C, param->C, param->eps, si, param->shrinking); |
peccu | 0:b9ac53c439ed | 1571 | |
peccu | 0:b9ac53c439ed | 1572 | double sum_alpha = 0; |
peccu | 0:b9ac53c439ed | 1573 | for(i=0;i<l;i++) |
peccu | 0:b9ac53c439ed | 1574 | { |
peccu | 0:b9ac53c439ed | 1575 | alpha[i] = alpha2[i] - alpha2[i+l]; |
peccu | 0:b9ac53c439ed | 1576 | sum_alpha += fabs(alpha[i]); |
peccu | 0:b9ac53c439ed | 1577 | } |
peccu | 0:b9ac53c439ed | 1578 | info("nu = %f\n",sum_alpha/(param->C*l)); |
peccu | 0:b9ac53c439ed | 1579 | |
peccu | 0:b9ac53c439ed | 1580 | delete[] alpha2; |
peccu | 0:b9ac53c439ed | 1581 | delete[] linear_term; |
peccu | 0:b9ac53c439ed | 1582 | delete[] y; |
peccu | 0:b9ac53c439ed | 1583 | } |
peccu | 0:b9ac53c439ed | 1584 | |
peccu | 0:b9ac53c439ed | 1585 | static void solve_nu_svr( |
peccu | 0:b9ac53c439ed | 1586 | const svm_problem *prob, const svm_parameter *param, |
peccu | 0:b9ac53c439ed | 1587 | double *alpha, Solver::SolutionInfo* si) |
peccu | 0:b9ac53c439ed | 1588 | { |
peccu | 0:b9ac53c439ed | 1589 | int l = prob->l; |
peccu | 0:b9ac53c439ed | 1590 | double C = param->C; |
peccu | 0:b9ac53c439ed | 1591 | double *alpha2 = new double[2*l]; |
peccu | 0:b9ac53c439ed | 1592 | double *linear_term = new double[2*l]; |
peccu | 0:b9ac53c439ed | 1593 | schar *y = new schar[2*l]; |
peccu | 0:b9ac53c439ed | 1594 | int i; |
peccu | 0:b9ac53c439ed | 1595 | |
peccu | 0:b9ac53c439ed | 1596 | double sum = C * param->nu * l / 2; |
peccu | 0:b9ac53c439ed | 1597 | for(i=0;i<l;i++) |
peccu | 0:b9ac53c439ed | 1598 | { |
peccu | 0:b9ac53c439ed | 1599 | alpha2[i] = alpha2[i+l] = min(sum,C); |
peccu | 0:b9ac53c439ed | 1600 | sum -= alpha2[i]; |
peccu | 0:b9ac53c439ed | 1601 | |
peccu | 0:b9ac53c439ed | 1602 | linear_term[i] = - prob->y[i]; |
peccu | 0:b9ac53c439ed | 1603 | y[i] = 1; |
peccu | 0:b9ac53c439ed | 1604 | |
peccu | 0:b9ac53c439ed | 1605 | linear_term[i+l] = prob->y[i]; |
peccu | 0:b9ac53c439ed | 1606 | y[i+l] = -1; |
peccu | 0:b9ac53c439ed | 1607 | } |
peccu | 0:b9ac53c439ed | 1608 | |
peccu | 0:b9ac53c439ed | 1609 | Solver_NU s; |
peccu | 0:b9ac53c439ed | 1610 | s.Solve(2*l, SVR_Q(*prob,*param), linear_term, y, |
peccu | 0:b9ac53c439ed | 1611 | alpha2, C, C, param->eps, si, param->shrinking); |
peccu | 0:b9ac53c439ed | 1612 | |
peccu | 0:b9ac53c439ed | 1613 | info("epsilon = %f\n",-si->r); |
peccu | 0:b9ac53c439ed | 1614 | |
peccu | 0:b9ac53c439ed | 1615 | for(i=0;i<l;i++) |
peccu | 0:b9ac53c439ed | 1616 | alpha[i] = alpha2[i] - alpha2[i+l]; |
peccu | 0:b9ac53c439ed | 1617 | |
peccu | 0:b9ac53c439ed | 1618 | delete[] alpha2; |
peccu | 0:b9ac53c439ed | 1619 | delete[] linear_term; |
peccu | 0:b9ac53c439ed | 1620 | delete[] y; |
peccu | 0:b9ac53c439ed | 1621 | } |
peccu | 0:b9ac53c439ed | 1622 | |
peccu | 0:b9ac53c439ed | 1623 | // |
peccu | 0:b9ac53c439ed | 1624 | // decision_function |
peccu | 0:b9ac53c439ed | 1625 | // |
peccu | 0:b9ac53c439ed | 1626 | struct decision_function |
peccu | 0:b9ac53c439ed | 1627 | { |
peccu | 0:b9ac53c439ed | 1628 | double *alpha; |
peccu | 0:b9ac53c439ed | 1629 | double rho; |
peccu | 0:b9ac53c439ed | 1630 | }; |
peccu | 0:b9ac53c439ed | 1631 | |
peccu | 0:b9ac53c439ed | 1632 | static decision_function svm_train_one( |
peccu | 0:b9ac53c439ed | 1633 | const svm_problem *prob, const svm_parameter *param, |
peccu | 0:b9ac53c439ed | 1634 | double Cp, double Cn) |
peccu | 0:b9ac53c439ed | 1635 | { |
peccu | 0:b9ac53c439ed | 1636 | double *alpha = Malloc(double,prob->l); |
peccu | 0:b9ac53c439ed | 1637 | Solver::SolutionInfo si; |
peccu | 0:b9ac53c439ed | 1638 | switch(param->svm_type) |
peccu | 0:b9ac53c439ed | 1639 | { |
peccu | 0:b9ac53c439ed | 1640 | case C_SVC: |
peccu | 0:b9ac53c439ed | 1641 | solve_c_svc(prob,param,alpha,&si,Cp,Cn); |
peccu | 0:b9ac53c439ed | 1642 | break; |
peccu | 0:b9ac53c439ed | 1643 | case NU_SVC: |
peccu | 0:b9ac53c439ed | 1644 | solve_nu_svc(prob,param,alpha,&si); |
peccu | 0:b9ac53c439ed | 1645 | break; |
peccu | 0:b9ac53c439ed | 1646 | case ONE_CLASS: |
peccu | 0:b9ac53c439ed | 1647 | solve_one_class(prob,param,alpha,&si); |
peccu | 0:b9ac53c439ed | 1648 | break; |
peccu | 0:b9ac53c439ed | 1649 | case EPSILON_SVR: |
peccu | 0:b9ac53c439ed | 1650 | solve_epsilon_svr(prob,param,alpha,&si); |
peccu | 0:b9ac53c439ed | 1651 | break; |
peccu | 0:b9ac53c439ed | 1652 | case NU_SVR: |
peccu | 0:b9ac53c439ed | 1653 | solve_nu_svr(prob,param,alpha,&si); |
peccu | 0:b9ac53c439ed | 1654 | break; |
peccu | 0:b9ac53c439ed | 1655 | } |
peccu | 0:b9ac53c439ed | 1656 | |
peccu | 0:b9ac53c439ed | 1657 | info("obj = %f, rho = %f\n",si.obj,si.rho); |
peccu | 0:b9ac53c439ed | 1658 | |
peccu | 0:b9ac53c439ed | 1659 | // output SVs |
peccu | 0:b9ac53c439ed | 1660 | |
peccu | 0:b9ac53c439ed | 1661 | int nSV = 0; |
peccu | 0:b9ac53c439ed | 1662 | int nBSV = 0; |
peccu | 0:b9ac53c439ed | 1663 | for(int i=0;i<prob->l;i++) |
peccu | 0:b9ac53c439ed | 1664 | { |
peccu | 0:b9ac53c439ed | 1665 | if(fabs(alpha[i]) > 0) |
peccu | 0:b9ac53c439ed | 1666 | { |
peccu | 0:b9ac53c439ed | 1667 | ++nSV; |
peccu | 0:b9ac53c439ed | 1668 | if(prob->y[i] > 0) |
peccu | 0:b9ac53c439ed | 1669 | { |
peccu | 0:b9ac53c439ed | 1670 | if(fabs(alpha[i]) >= si.upper_bound_p) |
peccu | 0:b9ac53c439ed | 1671 | ++nBSV; |
peccu | 0:b9ac53c439ed | 1672 | } |
peccu | 0:b9ac53c439ed | 1673 | else |
peccu | 0:b9ac53c439ed | 1674 | { |
peccu | 0:b9ac53c439ed | 1675 | if(fabs(alpha[i]) >= si.upper_bound_n) |
peccu | 0:b9ac53c439ed | 1676 | ++nBSV; |
peccu | 0:b9ac53c439ed | 1677 | } |
peccu | 0:b9ac53c439ed | 1678 | } |
peccu | 0:b9ac53c439ed | 1679 | } |
peccu | 0:b9ac53c439ed | 1680 | |
peccu | 0:b9ac53c439ed | 1681 | info("nSV = %d, nBSV = %d\n",nSV,nBSV); |
peccu | 0:b9ac53c439ed | 1682 | |
peccu | 0:b9ac53c439ed | 1683 | decision_function f; |
peccu | 0:b9ac53c439ed | 1684 | f.alpha = alpha; |
peccu | 0:b9ac53c439ed | 1685 | f.rho = si.rho; |
peccu | 0:b9ac53c439ed | 1686 | return f; |
peccu | 0:b9ac53c439ed | 1687 | } |
peccu | 0:b9ac53c439ed | 1688 | |
peccu | 0:b9ac53c439ed | 1689 | // Platt's binary SVM Probablistic Output: an improvement from Lin et al. |
peccu | 0:b9ac53c439ed | 1690 | static void sigmoid_train( |
peccu | 0:b9ac53c439ed | 1691 | int l, const double *dec_values, const double *labels, |
peccu | 0:b9ac53c439ed | 1692 | double& A, double& B) |
peccu | 0:b9ac53c439ed | 1693 | { |
peccu | 0:b9ac53c439ed | 1694 | double prior1=0, prior0 = 0; |
peccu | 0:b9ac53c439ed | 1695 | int i; |
peccu | 0:b9ac53c439ed | 1696 | |
peccu | 0:b9ac53c439ed | 1697 | for (i=0;i<l;i++) |
peccu | 0:b9ac53c439ed | 1698 | if (labels[i] > 0) prior1+=1; |
peccu | 0:b9ac53c439ed | 1699 | else prior0+=1; |
peccu | 0:b9ac53c439ed | 1700 | |
peccu | 0:b9ac53c439ed | 1701 | int max_iter=100; // Maximal number of iterations |
peccu | 0:b9ac53c439ed | 1702 | double min_step=1e-10; // Minimal step taken in line search |
peccu | 0:b9ac53c439ed | 1703 | double sigma=1e-12; // For numerically strict PD of Hessian |
peccu | 0:b9ac53c439ed | 1704 | double eps=1e-5; |
peccu | 0:b9ac53c439ed | 1705 | double hiTarget=(prior1+1.0)/(prior1+2.0); |
peccu | 0:b9ac53c439ed | 1706 | double loTarget=1/(prior0+2.0); |
peccu | 0:b9ac53c439ed | 1707 | double *t=Malloc(double,l); |
peccu | 0:b9ac53c439ed | 1708 | double fApB,p,q,h11,h22,h21,g1,g2,det,dA,dB,gd,stepsize; |
peccu | 0:b9ac53c439ed | 1709 | double newA,newB,newf,d1,d2; |
peccu | 0:b9ac53c439ed | 1710 | int iter; |
peccu | 0:b9ac53c439ed | 1711 | |
peccu | 0:b9ac53c439ed | 1712 | // Initial Point and Initial Fun Value |
peccu | 0:b9ac53c439ed | 1713 | A=0.0; B=log((prior0+1.0)/(prior1+1.0)); |
peccu | 0:b9ac53c439ed | 1714 | double fval = 0.0; |
peccu | 0:b9ac53c439ed | 1715 | |
peccu | 0:b9ac53c439ed | 1716 | for (i=0;i<l;i++) |
peccu | 0:b9ac53c439ed | 1717 | { |
peccu | 0:b9ac53c439ed | 1718 | if (labels[i]>0) t[i]=hiTarget; |
peccu | 0:b9ac53c439ed | 1719 | else t[i]=loTarget; |
peccu | 0:b9ac53c439ed | 1720 | fApB = dec_values[i]*A+B; |
peccu | 0:b9ac53c439ed | 1721 | if (fApB>=0) |
peccu | 0:b9ac53c439ed | 1722 | fval += t[i]*fApB + log(1+exp(-fApB)); |
peccu | 0:b9ac53c439ed | 1723 | else |
peccu | 0:b9ac53c439ed | 1724 | fval += (t[i] - 1)*fApB +log(1+exp(fApB)); |
peccu | 0:b9ac53c439ed | 1725 | } |
peccu | 0:b9ac53c439ed | 1726 | for (iter=0;iter<max_iter;iter++) |
peccu | 0:b9ac53c439ed | 1727 | { |
peccu | 0:b9ac53c439ed | 1728 | // Update Gradient and Hessian (use H' = H + sigma I) |
peccu | 0:b9ac53c439ed | 1729 | h11=sigma; // numerically ensures strict PD |
peccu | 0:b9ac53c439ed | 1730 | h22=sigma; |
peccu | 0:b9ac53c439ed | 1731 | h21=0.0;g1=0.0;g2=0.0; |
peccu | 0:b9ac53c439ed | 1732 | for (i=0;i<l;i++) |
peccu | 0:b9ac53c439ed | 1733 | { |
peccu | 0:b9ac53c439ed | 1734 | fApB = dec_values[i]*A+B; |
peccu | 0:b9ac53c439ed | 1735 | if (fApB >= 0) |
peccu | 0:b9ac53c439ed | 1736 | { |
peccu | 0:b9ac53c439ed | 1737 | p=exp(-fApB)/(1.0+exp(-fApB)); |
peccu | 0:b9ac53c439ed | 1738 | q=1.0/(1.0+exp(-fApB)); |
peccu | 0:b9ac53c439ed | 1739 | } |
peccu | 0:b9ac53c439ed | 1740 | else |
peccu | 0:b9ac53c439ed | 1741 | { |
peccu | 0:b9ac53c439ed | 1742 | p=1.0/(1.0+exp(fApB)); |
peccu | 0:b9ac53c439ed | 1743 | q=exp(fApB)/(1.0+exp(fApB)); |
peccu | 0:b9ac53c439ed | 1744 | } |
peccu | 0:b9ac53c439ed | 1745 | d2=p*q; |
peccu | 0:b9ac53c439ed | 1746 | h11+=dec_values[i]*dec_values[i]*d2; |
peccu | 0:b9ac53c439ed | 1747 | h22+=d2; |
peccu | 0:b9ac53c439ed | 1748 | h21+=dec_values[i]*d2; |
peccu | 0:b9ac53c439ed | 1749 | d1=t[i]-p; |
peccu | 0:b9ac53c439ed | 1750 | g1+=dec_values[i]*d1; |
peccu | 0:b9ac53c439ed | 1751 | g2+=d1; |
peccu | 0:b9ac53c439ed | 1752 | } |
peccu | 0:b9ac53c439ed | 1753 | |
peccu | 0:b9ac53c439ed | 1754 | // Stopping Criteria |
peccu | 0:b9ac53c439ed | 1755 | if (fabs(g1)<eps && fabs(g2)<eps) |
peccu | 0:b9ac53c439ed | 1756 | break; |
peccu | 0:b9ac53c439ed | 1757 | |
peccu | 0:b9ac53c439ed | 1758 | // Finding Newton direction: -inv(H') * g |
peccu | 0:b9ac53c439ed | 1759 | det=h11*h22-h21*h21; |
peccu | 0:b9ac53c439ed | 1760 | dA=-(h22*g1 - h21 * g2) / det; |
peccu | 0:b9ac53c439ed | 1761 | dB=-(-h21*g1+ h11 * g2) / det; |
peccu | 0:b9ac53c439ed | 1762 | gd=g1*dA+g2*dB; |
peccu | 0:b9ac53c439ed | 1763 | |
peccu | 0:b9ac53c439ed | 1764 | |
peccu | 0:b9ac53c439ed | 1765 | stepsize = 1; // Line Search |
peccu | 0:b9ac53c439ed | 1766 | while (stepsize >= min_step) |
peccu | 0:b9ac53c439ed | 1767 | { |
peccu | 0:b9ac53c439ed | 1768 | newA = A + stepsize * dA; |
peccu | 0:b9ac53c439ed | 1769 | newB = B + stepsize * dB; |
peccu | 0:b9ac53c439ed | 1770 | |
peccu | 0:b9ac53c439ed | 1771 | // New function value |
peccu | 0:b9ac53c439ed | 1772 | newf = 0.0; |
peccu | 0:b9ac53c439ed | 1773 | for (i=0;i<l;i++) |
peccu | 0:b9ac53c439ed | 1774 | { |
peccu | 0:b9ac53c439ed | 1775 | fApB = dec_values[i]*newA+newB; |
peccu | 0:b9ac53c439ed | 1776 | if (fApB >= 0) |
peccu | 0:b9ac53c439ed | 1777 | newf += t[i]*fApB + log(1+exp(-fApB)); |
peccu | 0:b9ac53c439ed | 1778 | else |
peccu | 0:b9ac53c439ed | 1779 | newf += (t[i] - 1)*fApB +log(1+exp(fApB)); |
peccu | 0:b9ac53c439ed | 1780 | } |
peccu | 0:b9ac53c439ed | 1781 | // Check sufficient decrease |
peccu | 0:b9ac53c439ed | 1782 | if (newf<fval+0.0001*stepsize*gd) |
peccu | 0:b9ac53c439ed | 1783 | { |
peccu | 0:b9ac53c439ed | 1784 | A=newA;B=newB;fval=newf; |
peccu | 0:b9ac53c439ed | 1785 | break; |
peccu | 0:b9ac53c439ed | 1786 | } |
peccu | 0:b9ac53c439ed | 1787 | else |
peccu | 0:b9ac53c439ed | 1788 | stepsize = stepsize / 2.0; |
peccu | 0:b9ac53c439ed | 1789 | } |
peccu | 0:b9ac53c439ed | 1790 | |
peccu | 0:b9ac53c439ed | 1791 | if (stepsize < min_step) |
peccu | 0:b9ac53c439ed | 1792 | { |
peccu | 0:b9ac53c439ed | 1793 | info("Line search fails in two-class probability estimates\n"); |
peccu | 0:b9ac53c439ed | 1794 | break; |
peccu | 0:b9ac53c439ed | 1795 | } |
peccu | 0:b9ac53c439ed | 1796 | } |
peccu | 0:b9ac53c439ed | 1797 | |
peccu | 0:b9ac53c439ed | 1798 | if (iter>=max_iter) |
peccu | 0:b9ac53c439ed | 1799 | info("Reaching maximal iterations in two-class probability estimates\n"); |
peccu | 0:b9ac53c439ed | 1800 | free(t); |
peccu | 0:b9ac53c439ed | 1801 | } |
peccu | 0:b9ac53c439ed | 1802 | |
peccu | 0:b9ac53c439ed | 1803 | static double sigmoid_predict(double decision_value, double A, double B) |
peccu | 0:b9ac53c439ed | 1804 | { |
peccu | 0:b9ac53c439ed | 1805 | double fApB = decision_value*A+B; |
peccu | 0:b9ac53c439ed | 1806 | if (fApB >= 0) |
peccu | 0:b9ac53c439ed | 1807 | return exp(-fApB)/(1.0+exp(-fApB)); |
peccu | 0:b9ac53c439ed | 1808 | else |
peccu | 0:b9ac53c439ed | 1809 | return 1.0/(1+exp(fApB)) ; |
peccu | 0:b9ac53c439ed | 1810 | } |
peccu | 0:b9ac53c439ed | 1811 | |
peccu | 0:b9ac53c439ed | 1812 | // Method 2 from the multiclass_prob paper by Wu, Lin, and Weng |
peccu | 0:b9ac53c439ed | 1813 | static void multiclass_probability(int k, double **r, double *p) |
peccu | 0:b9ac53c439ed | 1814 | { |
peccu | 0:b9ac53c439ed | 1815 | int t,j; |
peccu | 0:b9ac53c439ed | 1816 | int iter = 0, max_iter=max(100,k); |
peccu | 0:b9ac53c439ed | 1817 | double **Q=Malloc(double *,k); |
peccu | 0:b9ac53c439ed | 1818 | double *Qp=Malloc(double,k); |
peccu | 0:b9ac53c439ed | 1819 | double pQp, eps=0.005/k; |
peccu | 0:b9ac53c439ed | 1820 | |
peccu | 0:b9ac53c439ed | 1821 | for (t=0;t<k;t++) |
peccu | 0:b9ac53c439ed | 1822 | { |
peccu | 0:b9ac53c439ed | 1823 | p[t]=1.0/k; // Valid if k = 1 |
peccu | 0:b9ac53c439ed | 1824 | Q[t]=Malloc(double,k); |
peccu | 0:b9ac53c439ed | 1825 | Q[t][t]=0; |
peccu | 0:b9ac53c439ed | 1826 | for (j=0;j<t;j++) |
peccu | 0:b9ac53c439ed | 1827 | { |
peccu | 0:b9ac53c439ed | 1828 | Q[t][t]+=r[j][t]*r[j][t]; |
peccu | 0:b9ac53c439ed | 1829 | Q[t][j]=Q[j][t]; |
peccu | 0:b9ac53c439ed | 1830 | } |
peccu | 0:b9ac53c439ed | 1831 | for (j=t+1;j<k;j++) |
peccu | 0:b9ac53c439ed | 1832 | { |
peccu | 0:b9ac53c439ed | 1833 | Q[t][t]+=r[j][t]*r[j][t]; |
peccu | 0:b9ac53c439ed | 1834 | Q[t][j]=-r[j][t]*r[t][j]; |
peccu | 0:b9ac53c439ed | 1835 | } |
peccu | 0:b9ac53c439ed | 1836 | } |
peccu | 0:b9ac53c439ed | 1837 | for (iter=0;iter<max_iter;iter++) |
peccu | 0:b9ac53c439ed | 1838 | { |
peccu | 0:b9ac53c439ed | 1839 | // stopping condition, recalculate QP,pQP for numerical accuracy |
peccu | 0:b9ac53c439ed | 1840 | pQp=0; |
peccu | 0:b9ac53c439ed | 1841 | for (t=0;t<k;t++) |
peccu | 0:b9ac53c439ed | 1842 | { |
peccu | 0:b9ac53c439ed | 1843 | Qp[t]=0; |
peccu | 0:b9ac53c439ed | 1844 | for (j=0;j<k;j++) |
peccu | 0:b9ac53c439ed | 1845 | Qp[t]+=Q[t][j]*p[j]; |
peccu | 0:b9ac53c439ed | 1846 | pQp+=p[t]*Qp[t]; |
peccu | 0:b9ac53c439ed | 1847 | } |
peccu | 0:b9ac53c439ed | 1848 | double max_error=0; |
peccu | 0:b9ac53c439ed | 1849 | for (t=0;t<k;t++) |
peccu | 0:b9ac53c439ed | 1850 | { |
peccu | 0:b9ac53c439ed | 1851 | double error=fabs(Qp[t]-pQp); |
peccu | 0:b9ac53c439ed | 1852 | if (error>max_error) |
peccu | 0:b9ac53c439ed | 1853 | max_error=error; |
peccu | 0:b9ac53c439ed | 1854 | } |
peccu | 0:b9ac53c439ed | 1855 | if (max_error<eps) break; |
peccu | 0:b9ac53c439ed | 1856 | |
peccu | 0:b9ac53c439ed | 1857 | for (t=0;t<k;t++) |
peccu | 0:b9ac53c439ed | 1858 | { |
peccu | 0:b9ac53c439ed | 1859 | double diff=(-Qp[t]+pQp)/Q[t][t]; |
peccu | 0:b9ac53c439ed | 1860 | p[t]+=diff; |
peccu | 0:b9ac53c439ed | 1861 | pQp=(pQp+diff*(diff*Q[t][t]+2*Qp[t]))/(1+diff)/(1+diff); |
peccu | 0:b9ac53c439ed | 1862 | for (j=0;j<k;j++) |
peccu | 0:b9ac53c439ed | 1863 | { |
peccu | 0:b9ac53c439ed | 1864 | Qp[j]=(Qp[j]+diff*Q[t][j])/(1+diff); |
peccu | 0:b9ac53c439ed | 1865 | p[j]/=(1+diff); |
peccu | 0:b9ac53c439ed | 1866 | } |
peccu | 0:b9ac53c439ed | 1867 | } |
peccu | 0:b9ac53c439ed | 1868 | } |
peccu | 0:b9ac53c439ed | 1869 | if (iter>=max_iter) |
peccu | 0:b9ac53c439ed | 1870 | info("Exceeds max_iter in multiclass_prob\n"); |
peccu | 0:b9ac53c439ed | 1871 | for(t=0;t<k;t++) free(Q[t]); |
peccu | 0:b9ac53c439ed | 1872 | free(Q); |
peccu | 0:b9ac53c439ed | 1873 | free(Qp); |
peccu | 0:b9ac53c439ed | 1874 | } |
peccu | 0:b9ac53c439ed | 1875 | |
peccu | 0:b9ac53c439ed | 1876 | // Cross-validation decision values for probability estimates |
peccu | 0:b9ac53c439ed | 1877 | static void svm_binary_svc_probability( |
peccu | 0:b9ac53c439ed | 1878 | const svm_problem *prob, const svm_parameter *param, |
peccu | 0:b9ac53c439ed | 1879 | double Cp, double Cn, double& probA, double& probB) |
peccu | 0:b9ac53c439ed | 1880 | { |
peccu | 0:b9ac53c439ed | 1881 | int i; |
peccu | 0:b9ac53c439ed | 1882 | int nr_fold = 5; |
peccu | 0:b9ac53c439ed | 1883 | int *perm = Malloc(int,prob->l); |
peccu | 0:b9ac53c439ed | 1884 | double *dec_values = Malloc(double,prob->l); |
peccu | 0:b9ac53c439ed | 1885 | |
peccu | 0:b9ac53c439ed | 1886 | // random shuffle |
peccu | 0:b9ac53c439ed | 1887 | for(i=0;i<prob->l;i++) perm[i]=i; |
peccu | 0:b9ac53c439ed | 1888 | for(i=0;i<prob->l;i++) |
peccu | 0:b9ac53c439ed | 1889 | { |
peccu | 0:b9ac53c439ed | 1890 | int j = i+rand()%(prob->l-i); |
peccu | 0:b9ac53c439ed | 1891 | swap(perm[i],perm[j]); |
peccu | 0:b9ac53c439ed | 1892 | } |
peccu | 0:b9ac53c439ed | 1893 | for(i=0;i<nr_fold;i++) |
peccu | 0:b9ac53c439ed | 1894 | { |
peccu | 0:b9ac53c439ed | 1895 | int begin = i*prob->l/nr_fold; |
peccu | 0:b9ac53c439ed | 1896 | int end = (i+1)*prob->l/nr_fold; |
peccu | 0:b9ac53c439ed | 1897 | int j,k; |
peccu | 0:b9ac53c439ed | 1898 | struct svm_problem subprob; |
peccu | 0:b9ac53c439ed | 1899 | |
peccu | 0:b9ac53c439ed | 1900 | subprob.l = prob->l-(end-begin); |
peccu | 0:b9ac53c439ed | 1901 | subprob.x = Malloc(struct svm_node*,subprob.l); |
peccu | 0:b9ac53c439ed | 1902 | subprob.y = Malloc(double,subprob.l); |
peccu | 0:b9ac53c439ed | 1903 | |
peccu | 0:b9ac53c439ed | 1904 | k=0; |
peccu | 0:b9ac53c439ed | 1905 | for(j=0;j<begin;j++) |
peccu | 0:b9ac53c439ed | 1906 | { |
peccu | 0:b9ac53c439ed | 1907 | subprob.x[k] = prob->x[perm[j]]; |
peccu | 0:b9ac53c439ed | 1908 | subprob.y[k] = prob->y[perm[j]]; |
peccu | 0:b9ac53c439ed | 1909 | ++k; |
peccu | 0:b9ac53c439ed | 1910 | } |
peccu | 0:b9ac53c439ed | 1911 | for(j=end;j<prob->l;j++) |
peccu | 0:b9ac53c439ed | 1912 | { |
peccu | 0:b9ac53c439ed | 1913 | subprob.x[k] = prob->x[perm[j]]; |
peccu | 0:b9ac53c439ed | 1914 | subprob.y[k] = prob->y[perm[j]]; |
peccu | 0:b9ac53c439ed | 1915 | ++k; |
peccu | 0:b9ac53c439ed | 1916 | } |
peccu | 0:b9ac53c439ed | 1917 | int p_count=0,n_count=0; |
peccu | 0:b9ac53c439ed | 1918 | for(j=0;j<k;j++) |
peccu | 0:b9ac53c439ed | 1919 | if(subprob.y[j]>0) |
peccu | 0:b9ac53c439ed | 1920 | p_count++; |
peccu | 0:b9ac53c439ed | 1921 | else |
peccu | 0:b9ac53c439ed | 1922 | n_count++; |
peccu | 0:b9ac53c439ed | 1923 | |
peccu | 0:b9ac53c439ed | 1924 | if(p_count==0 && n_count==0) |
peccu | 0:b9ac53c439ed | 1925 | for(j=begin;j<end;j++) |
peccu | 0:b9ac53c439ed | 1926 | dec_values[perm[j]] = 0; |
peccu | 0:b9ac53c439ed | 1927 | else if(p_count > 0 && n_count == 0) |
peccu | 0:b9ac53c439ed | 1928 | for(j=begin;j<end;j++) |
peccu | 0:b9ac53c439ed | 1929 | dec_values[perm[j]] = 1; |
peccu | 0:b9ac53c439ed | 1930 | else if(p_count == 0 && n_count > 0) |
peccu | 0:b9ac53c439ed | 1931 | for(j=begin;j<end;j++) |
peccu | 0:b9ac53c439ed | 1932 | dec_values[perm[j]] = -1; |
peccu | 0:b9ac53c439ed | 1933 | else |
peccu | 0:b9ac53c439ed | 1934 | { |
peccu | 0:b9ac53c439ed | 1935 | svm_parameter subparam = *param; |
peccu | 0:b9ac53c439ed | 1936 | subparam.probability=0; |
peccu | 0:b9ac53c439ed | 1937 | subparam.C=1.0; |
peccu | 0:b9ac53c439ed | 1938 | subparam.nr_weight=2; |
peccu | 0:b9ac53c439ed | 1939 | subparam.weight_label = Malloc(int,2); |
peccu | 0:b9ac53c439ed | 1940 | subparam.weight = Malloc(double,2); |
peccu | 0:b9ac53c439ed | 1941 | subparam.weight_label[0]=+1; |
peccu | 0:b9ac53c439ed | 1942 | subparam.weight_label[1]=-1; |
peccu | 0:b9ac53c439ed | 1943 | subparam.weight[0]=Cp; |
peccu | 0:b9ac53c439ed | 1944 | subparam.weight[1]=Cn; |
peccu | 0:b9ac53c439ed | 1945 | struct svm_model *submodel = svm_train(&subprob,&subparam); |
peccu | 0:b9ac53c439ed | 1946 | for(j=begin;j<end;j++) |
peccu | 0:b9ac53c439ed | 1947 | { |
peccu | 0:b9ac53c439ed | 1948 | svm_predict_values(submodel,prob->x[perm[j]],&(dec_values[perm[j]])); |
peccu | 0:b9ac53c439ed | 1949 | // ensure +1 -1 order; reason not using CV subroutine |
peccu | 0:b9ac53c439ed | 1950 | dec_values[perm[j]] *= submodel->label[0]; |
peccu | 0:b9ac53c439ed | 1951 | } |
peccu | 0:b9ac53c439ed | 1952 | svm_free_and_destroy_model(&submodel); |
peccu | 0:b9ac53c439ed | 1953 | svm_destroy_param(&subparam); |
peccu | 0:b9ac53c439ed | 1954 | } |
peccu | 0:b9ac53c439ed | 1955 | free(subprob.x); |
peccu | 0:b9ac53c439ed | 1956 | free(subprob.y); |
peccu | 0:b9ac53c439ed | 1957 | } |
peccu | 0:b9ac53c439ed | 1958 | sigmoid_train(prob->l,dec_values,prob->y,probA,probB); |
peccu | 0:b9ac53c439ed | 1959 | free(dec_values); |
peccu | 0:b9ac53c439ed | 1960 | free(perm); |
peccu | 0:b9ac53c439ed | 1961 | } |
peccu | 0:b9ac53c439ed | 1962 | |
peccu | 0:b9ac53c439ed | 1963 | // Return parameter of a Laplace distribution |
peccu | 0:b9ac53c439ed | 1964 | static double svm_svr_probability( |
peccu | 0:b9ac53c439ed | 1965 | const svm_problem *prob, const svm_parameter *param) |
peccu | 0:b9ac53c439ed | 1966 | { |
peccu | 0:b9ac53c439ed | 1967 | int i; |
peccu | 0:b9ac53c439ed | 1968 | int nr_fold = 5; |
peccu | 0:b9ac53c439ed | 1969 | double *ymv = Malloc(double,prob->l); |
peccu | 0:b9ac53c439ed | 1970 | double mae = 0; |
peccu | 0:b9ac53c439ed | 1971 | |
peccu | 0:b9ac53c439ed | 1972 | svm_parameter newparam = *param; |
peccu | 0:b9ac53c439ed | 1973 | newparam.probability = 0; |
peccu | 0:b9ac53c439ed | 1974 | svm_cross_validation(prob,&newparam,nr_fold,ymv); |
peccu | 0:b9ac53c439ed | 1975 | for(i=0;i<prob->l;i++) |
peccu | 0:b9ac53c439ed | 1976 | { |
peccu | 0:b9ac53c439ed | 1977 | ymv[i]=prob->y[i]-ymv[i]; |
peccu | 0:b9ac53c439ed | 1978 | mae += fabs(ymv[i]); |
peccu | 0:b9ac53c439ed | 1979 | } |
peccu | 0:b9ac53c439ed | 1980 | mae /= prob->l; |
peccu | 0:b9ac53c439ed | 1981 | double std=sqrt(2*mae*mae); |
peccu | 0:b9ac53c439ed | 1982 | int count=0; |
peccu | 0:b9ac53c439ed | 1983 | mae=0; |
peccu | 0:b9ac53c439ed | 1984 | for(i=0;i<prob->l;i++) |
peccu | 0:b9ac53c439ed | 1985 | if (fabs(ymv[i]) > 5*std) |
peccu | 0:b9ac53c439ed | 1986 | count=count+1; |
peccu | 0:b9ac53c439ed | 1987 | else |
peccu | 0:b9ac53c439ed | 1988 | mae+=fabs(ymv[i]); |
peccu | 0:b9ac53c439ed | 1989 | mae /= (prob->l-count); |
peccu | 0:b9ac53c439ed | 1990 | info("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma= %g\n",mae); |
peccu | 0:b9ac53c439ed | 1991 | free(ymv); |
peccu | 0:b9ac53c439ed | 1992 | return mae; |
peccu | 0:b9ac53c439ed | 1993 | } |
peccu | 0:b9ac53c439ed | 1994 | |
peccu | 0:b9ac53c439ed | 1995 | |
peccu | 0:b9ac53c439ed | 1996 | // label: label name, start: begin of each class, count: #data of classes, perm: indices to the original data |
peccu | 0:b9ac53c439ed | 1997 | // perm, length l, must be allocated before calling this subroutine |
peccu | 0:b9ac53c439ed | 1998 | static void svm_group_classes(const svm_problem *prob, int *nr_class_ret, int **label_ret, int **start_ret, int **count_ret, int *perm) |
peccu | 0:b9ac53c439ed | 1999 | { |
peccu | 0:b9ac53c439ed | 2000 | int l = prob->l; |
peccu | 0:b9ac53c439ed | 2001 | int max_nr_class = 16; |
peccu | 0:b9ac53c439ed | 2002 | int nr_class = 0; |
peccu | 0:b9ac53c439ed | 2003 | int *label = Malloc(int,max_nr_class); |
peccu | 0:b9ac53c439ed | 2004 | int *count = Malloc(int,max_nr_class); |
peccu | 0:b9ac53c439ed | 2005 | int *data_label = Malloc(int,l); |
peccu | 0:b9ac53c439ed | 2006 | int i; |
peccu | 0:b9ac53c439ed | 2007 | |
peccu | 0:b9ac53c439ed | 2008 | for(i=0;i<l;i++) |
peccu | 0:b9ac53c439ed | 2009 | { |
peccu | 0:b9ac53c439ed | 2010 | int this_label = (int)prob->y[i]; |
peccu | 0:b9ac53c439ed | 2011 | int j; |
peccu | 0:b9ac53c439ed | 2012 | for(j=0;j<nr_class;j++) |
peccu | 0:b9ac53c439ed | 2013 | { |
peccu | 0:b9ac53c439ed | 2014 | if(this_label == label[j]) |
peccu | 0:b9ac53c439ed | 2015 | { |
peccu | 0:b9ac53c439ed | 2016 | ++count[j]; |
peccu | 0:b9ac53c439ed | 2017 | break; |
peccu | 0:b9ac53c439ed | 2018 | } |
peccu | 0:b9ac53c439ed | 2019 | } |
peccu | 0:b9ac53c439ed | 2020 | data_label[i] = j; |
peccu | 0:b9ac53c439ed | 2021 | if(j == nr_class) |
peccu | 0:b9ac53c439ed | 2022 | { |
peccu | 0:b9ac53c439ed | 2023 | if(nr_class == max_nr_class) |
peccu | 0:b9ac53c439ed | 2024 | { |
peccu | 0:b9ac53c439ed | 2025 | max_nr_class *= 2; |
peccu | 0:b9ac53c439ed | 2026 | label = (int *)realloc(label,max_nr_class*sizeof(int)); |
peccu | 0:b9ac53c439ed | 2027 | count = (int *)realloc(count,max_nr_class*sizeof(int)); |
peccu | 0:b9ac53c439ed | 2028 | } |
peccu | 0:b9ac53c439ed | 2029 | label[nr_class] = this_label; |
peccu | 0:b9ac53c439ed | 2030 | count[nr_class] = 1; |
peccu | 0:b9ac53c439ed | 2031 | ++nr_class; |
peccu | 0:b9ac53c439ed | 2032 | } |
peccu | 0:b9ac53c439ed | 2033 | } |
peccu | 0:b9ac53c439ed | 2034 | |
peccu | 0:b9ac53c439ed | 2035 | int *start = Malloc(int,nr_class); |
peccu | 0:b9ac53c439ed | 2036 | start[0] = 0; |
peccu | 0:b9ac53c439ed | 2037 | for(i=1;i<nr_class;i++) |
peccu | 0:b9ac53c439ed | 2038 | start[i] = start[i-1]+count[i-1]; |
peccu | 0:b9ac53c439ed | 2039 | for(i=0;i<l;i++) |
peccu | 0:b9ac53c439ed | 2040 | { |
peccu | 0:b9ac53c439ed | 2041 | perm[start[data_label[i]]] = i; |
peccu | 0:b9ac53c439ed | 2042 | ++start[data_label[i]]; |
peccu | 0:b9ac53c439ed | 2043 | } |
peccu | 0:b9ac53c439ed | 2044 | start[0] = 0; |
peccu | 0:b9ac53c439ed | 2045 | for(i=1;i<nr_class;i++) |
peccu | 0:b9ac53c439ed | 2046 | start[i] = start[i-1]+count[i-1]; |
peccu | 0:b9ac53c439ed | 2047 | |
peccu | 0:b9ac53c439ed | 2048 | *nr_class_ret = nr_class; |
peccu | 0:b9ac53c439ed | 2049 | *label_ret = label; |
peccu | 0:b9ac53c439ed | 2050 | *start_ret = start; |
peccu | 0:b9ac53c439ed | 2051 | *count_ret = count; |
peccu | 0:b9ac53c439ed | 2052 | free(data_label); |
peccu | 0:b9ac53c439ed | 2053 | } |
peccu | 0:b9ac53c439ed | 2054 | |
peccu | 0:b9ac53c439ed | 2055 | // |
peccu | 0:b9ac53c439ed | 2056 | // Interface functions |
peccu | 0:b9ac53c439ed | 2057 | // |
peccu | 0:b9ac53c439ed | 2058 | svm_model *svm_train(const svm_problem *prob, const svm_parameter *param) |
peccu | 0:b9ac53c439ed | 2059 | { |
peccu | 0:b9ac53c439ed | 2060 | svm_model *model = Malloc(svm_model,1); |
peccu | 0:b9ac53c439ed | 2061 | model->param = *param; |
peccu | 0:b9ac53c439ed | 2062 | model->free_sv = 0; // XXX |
peccu | 0:b9ac53c439ed | 2063 | |
peccu | 0:b9ac53c439ed | 2064 | if(param->svm_type == ONE_CLASS || |
peccu | 0:b9ac53c439ed | 2065 | param->svm_type == EPSILON_SVR || |
peccu | 0:b9ac53c439ed | 2066 | param->svm_type == NU_SVR) |
peccu | 0:b9ac53c439ed | 2067 | { |
peccu | 0:b9ac53c439ed | 2068 | // regression or one-class-svm |
peccu | 0:b9ac53c439ed | 2069 | model->nr_class = 2; |
peccu | 0:b9ac53c439ed | 2070 | model->label = NULL; |
peccu | 0:b9ac53c439ed | 2071 | model->nSV = NULL; |
peccu | 0:b9ac53c439ed | 2072 | model->probA = NULL; model->probB = NULL; |
peccu | 0:b9ac53c439ed | 2073 | model->sv_coef = Malloc(double *,1); |
peccu | 0:b9ac53c439ed | 2074 | |
peccu | 0:b9ac53c439ed | 2075 | if(param->probability && |
peccu | 0:b9ac53c439ed | 2076 | (param->svm_type == EPSILON_SVR || |
peccu | 0:b9ac53c439ed | 2077 | param->svm_type == NU_SVR)) |
peccu | 0:b9ac53c439ed | 2078 | { |
peccu | 0:b9ac53c439ed | 2079 | model->probA = Malloc(double,1); |
peccu | 0:b9ac53c439ed | 2080 | model->probA[0] = svm_svr_probability(prob,param); |
peccu | 0:b9ac53c439ed | 2081 | } |
peccu | 0:b9ac53c439ed | 2082 | |
peccu | 0:b9ac53c439ed | 2083 | decision_function f = svm_train_one(prob,param,0,0); |
peccu | 0:b9ac53c439ed | 2084 | model->rho = Malloc(double,1); |
peccu | 0:b9ac53c439ed | 2085 | model->rho[0] = f.rho; |
peccu | 0:b9ac53c439ed | 2086 | |
peccu | 0:b9ac53c439ed | 2087 | int nSV = 0; |
peccu | 0:b9ac53c439ed | 2088 | int i; |
peccu | 0:b9ac53c439ed | 2089 | for(i=0;i<prob->l;i++) |
peccu | 0:b9ac53c439ed | 2090 | if(fabs(f.alpha[i]) > 0) ++nSV; |
peccu | 0:b9ac53c439ed | 2091 | model->l = nSV; |
peccu | 0:b9ac53c439ed | 2092 | model->SV = Malloc(svm_node *,nSV); |
peccu | 0:b9ac53c439ed | 2093 | model->sv_coef[0] = Malloc(double,nSV); |
peccu | 0:b9ac53c439ed | 2094 | int j = 0; |
peccu | 0:b9ac53c439ed | 2095 | for(i=0;i<prob->l;i++) |
peccu | 0:b9ac53c439ed | 2096 | if(fabs(f.alpha[i]) > 0) |
peccu | 0:b9ac53c439ed | 2097 | { |
peccu | 0:b9ac53c439ed | 2098 | model->SV[j] = prob->x[i]; |
peccu | 0:b9ac53c439ed | 2099 | model->sv_coef[0][j] = f.alpha[i]; |
peccu | 0:b9ac53c439ed | 2100 | ++j; |
peccu | 0:b9ac53c439ed | 2101 | } |
peccu | 0:b9ac53c439ed | 2102 | |
peccu | 0:b9ac53c439ed | 2103 | free(f.alpha); |
peccu | 0:b9ac53c439ed | 2104 | } |
peccu | 0:b9ac53c439ed | 2105 | else |
peccu | 0:b9ac53c439ed | 2106 | { |
peccu | 0:b9ac53c439ed | 2107 | // classification |
peccu | 0:b9ac53c439ed | 2108 | int l = prob->l; |
peccu | 0:b9ac53c439ed | 2109 | int nr_class; |
peccu | 0:b9ac53c439ed | 2110 | int *label = NULL; |
peccu | 0:b9ac53c439ed | 2111 | int *start = NULL; |
peccu | 0:b9ac53c439ed | 2112 | int *count = NULL; |
peccu | 0:b9ac53c439ed | 2113 | int *perm = Malloc(int,l); |
peccu | 0:b9ac53c439ed | 2114 | |
peccu | 0:b9ac53c439ed | 2115 | // group training data of the same class |
peccu | 0:b9ac53c439ed | 2116 | svm_group_classes(prob,&nr_class,&label,&start,&count,perm); |
peccu | 0:b9ac53c439ed | 2117 | svm_node **x = Malloc(svm_node *,l); |
peccu | 0:b9ac53c439ed | 2118 | int i; |
peccu | 0:b9ac53c439ed | 2119 | for(i=0;i<l;i++) |
peccu | 0:b9ac53c439ed | 2120 | x[i] = prob->x[perm[i]]; |
peccu | 0:b9ac53c439ed | 2121 | |
peccu | 0:b9ac53c439ed | 2122 | // calculate weighted C |
peccu | 0:b9ac53c439ed | 2123 | |
peccu | 0:b9ac53c439ed | 2124 | double *weighted_C = Malloc(double, nr_class); |
peccu | 0:b9ac53c439ed | 2125 | for(i=0;i<nr_class;i++) |
peccu | 0:b9ac53c439ed | 2126 | weighted_C[i] = param->C; |
peccu | 0:b9ac53c439ed | 2127 | for(i=0;i<param->nr_weight;i++) |
peccu | 0:b9ac53c439ed | 2128 | { |
peccu | 0:b9ac53c439ed | 2129 | int j; |
peccu | 0:b9ac53c439ed | 2130 | for(j=0;j<nr_class;j++) |
peccu | 0:b9ac53c439ed | 2131 | if(param->weight_label[i] == label[j]) |
peccu | 0:b9ac53c439ed | 2132 | break; |
peccu | 0:b9ac53c439ed | 2133 | if(j == nr_class) |
peccu | 0:b9ac53c439ed | 2134 | fprintf(stderr,"warning: class label %d specified in weight is not found\n", param->weight_label[i]); |
peccu | 0:b9ac53c439ed | 2135 | else |
peccu | 0:b9ac53c439ed | 2136 | weighted_C[j] *= param->weight[i]; |
peccu | 0:b9ac53c439ed | 2137 | } |
peccu | 0:b9ac53c439ed | 2138 | |
peccu | 0:b9ac53c439ed | 2139 | // train k*(k-1)/2 models |
peccu | 0:b9ac53c439ed | 2140 | |
peccu | 0:b9ac53c439ed | 2141 | bool *nonzero = Malloc(bool,l); |
peccu | 0:b9ac53c439ed | 2142 | for(i=0;i<l;i++) |
peccu | 0:b9ac53c439ed | 2143 | nonzero[i] = false; |
peccu | 0:b9ac53c439ed | 2144 | decision_function *f = Malloc(decision_function,nr_class*(nr_class-1)/2); |
peccu | 0:b9ac53c439ed | 2145 | |
peccu | 0:b9ac53c439ed | 2146 | double *probA=NULL,*probB=NULL; |
peccu | 0:b9ac53c439ed | 2147 | if (param->probability) |
peccu | 0:b9ac53c439ed | 2148 | { |
peccu | 0:b9ac53c439ed | 2149 | probA=Malloc(double,nr_class*(nr_class-1)/2); |
peccu | 0:b9ac53c439ed | 2150 | probB=Malloc(double,nr_class*(nr_class-1)/2); |
peccu | 0:b9ac53c439ed | 2151 | } |
peccu | 0:b9ac53c439ed | 2152 | |
peccu | 0:b9ac53c439ed | 2153 | int p = 0; |
peccu | 0:b9ac53c439ed | 2154 | for(i=0;i<nr_class;i++) |
peccu | 0:b9ac53c439ed | 2155 | for(int j=i+1;j<nr_class;j++) |
peccu | 0:b9ac53c439ed | 2156 | { |
peccu | 0:b9ac53c439ed | 2157 | svm_problem sub_prob; |
peccu | 0:b9ac53c439ed | 2158 | int si = start[i], sj = start[j]; |
peccu | 0:b9ac53c439ed | 2159 | int ci = count[i], cj = count[j]; |
peccu | 0:b9ac53c439ed | 2160 | sub_prob.l = ci+cj; |
peccu | 0:b9ac53c439ed | 2161 | sub_prob.x = Malloc(svm_node *,sub_prob.l); |
peccu | 0:b9ac53c439ed | 2162 | sub_prob.y = Malloc(double,sub_prob.l); |
peccu | 0:b9ac53c439ed | 2163 | int k; |
peccu | 0:b9ac53c439ed | 2164 | for(k=0;k<ci;k++) |
peccu | 0:b9ac53c439ed | 2165 | { |
peccu | 0:b9ac53c439ed | 2166 | sub_prob.x[k] = x[si+k]; |
peccu | 0:b9ac53c439ed | 2167 | sub_prob.y[k] = +1; |
peccu | 0:b9ac53c439ed | 2168 | } |
peccu | 0:b9ac53c439ed | 2169 | for(k=0;k<cj;k++) |
peccu | 0:b9ac53c439ed | 2170 | { |
peccu | 0:b9ac53c439ed | 2171 | sub_prob.x[ci+k] = x[sj+k]; |
peccu | 0:b9ac53c439ed | 2172 | sub_prob.y[ci+k] = -1; |
peccu | 0:b9ac53c439ed | 2173 | } |
peccu | 0:b9ac53c439ed | 2174 | |
peccu | 0:b9ac53c439ed | 2175 | if(param->probability) |
peccu | 0:b9ac53c439ed | 2176 | svm_binary_svc_probability(&sub_prob,param,weighted_C[i],weighted_C[j],probA[p],probB[p]); |
peccu | 0:b9ac53c439ed | 2177 | |
peccu | 0:b9ac53c439ed | 2178 | f[p] = svm_train_one(&sub_prob,param,weighted_C[i],weighted_C[j]); |
peccu | 0:b9ac53c439ed | 2179 | for(k=0;k<ci;k++) |
peccu | 0:b9ac53c439ed | 2180 | if(!nonzero[si+k] && fabs(f[p].alpha[k]) > 0) |
peccu | 0:b9ac53c439ed | 2181 | nonzero[si+k] = true; |
peccu | 0:b9ac53c439ed | 2182 | for(k=0;k<cj;k++) |
peccu | 0:b9ac53c439ed | 2183 | if(!nonzero[sj+k] && fabs(f[p].alpha[ci+k]) > 0) |
peccu | 0:b9ac53c439ed | 2184 | nonzero[sj+k] = true; |
peccu | 0:b9ac53c439ed | 2185 | free(sub_prob.x); |
peccu | 0:b9ac53c439ed | 2186 | free(sub_prob.y); |
peccu | 0:b9ac53c439ed | 2187 | ++p; |
peccu | 0:b9ac53c439ed | 2188 | } |
peccu | 0:b9ac53c439ed | 2189 | |
peccu | 0:b9ac53c439ed | 2190 | // build output |
peccu | 0:b9ac53c439ed | 2191 | |
peccu | 0:b9ac53c439ed | 2192 | model->nr_class = nr_class; |
peccu | 0:b9ac53c439ed | 2193 | |
peccu | 0:b9ac53c439ed | 2194 | model->label = Malloc(int,nr_class); |
peccu | 0:b9ac53c439ed | 2195 | for(i=0;i<nr_class;i++) |
peccu | 0:b9ac53c439ed | 2196 | model->label[i] = label[i]; |
peccu | 0:b9ac53c439ed | 2197 | |
peccu | 0:b9ac53c439ed | 2198 | model->rho = Malloc(double,nr_class*(nr_class-1)/2); |
peccu | 0:b9ac53c439ed | 2199 | for(i=0;i<nr_class*(nr_class-1)/2;i++) |
peccu | 0:b9ac53c439ed | 2200 | model->rho[i] = f[i].rho; |
peccu | 0:b9ac53c439ed | 2201 | |
peccu | 0:b9ac53c439ed | 2202 | if(param->probability) |
peccu | 0:b9ac53c439ed | 2203 | { |
peccu | 0:b9ac53c439ed | 2204 | model->probA = Malloc(double,nr_class*(nr_class-1)/2); |
peccu | 0:b9ac53c439ed | 2205 | model->probB = Malloc(double,nr_class*(nr_class-1)/2); |
peccu | 0:b9ac53c439ed | 2206 | for(i=0;i<nr_class*(nr_class-1)/2;i++) |
peccu | 0:b9ac53c439ed | 2207 | { |
peccu | 0:b9ac53c439ed | 2208 | model->probA[i] = probA[i]; |
peccu | 0:b9ac53c439ed | 2209 | model->probB[i] = probB[i]; |
peccu | 0:b9ac53c439ed | 2210 | } |
peccu | 0:b9ac53c439ed | 2211 | } |
peccu | 0:b9ac53c439ed | 2212 | else |
peccu | 0:b9ac53c439ed | 2213 | { |
peccu | 0:b9ac53c439ed | 2214 | model->probA=NULL; |
peccu | 0:b9ac53c439ed | 2215 | model->probB=NULL; |
peccu | 0:b9ac53c439ed | 2216 | } |
peccu | 0:b9ac53c439ed | 2217 | |
peccu | 0:b9ac53c439ed | 2218 | int total_sv = 0; |
peccu | 0:b9ac53c439ed | 2219 | int *nz_count = Malloc(int,nr_class); |
peccu | 0:b9ac53c439ed | 2220 | model->nSV = Malloc(int,nr_class); |
peccu | 0:b9ac53c439ed | 2221 | for(i=0;i<nr_class;i++) |
peccu | 0:b9ac53c439ed | 2222 | { |
peccu | 0:b9ac53c439ed | 2223 | int nSV = 0; |
peccu | 0:b9ac53c439ed | 2224 | for(int j=0;j<count[i];j++) |
peccu | 0:b9ac53c439ed | 2225 | if(nonzero[start[i]+j]) |
peccu | 0:b9ac53c439ed | 2226 | { |
peccu | 0:b9ac53c439ed | 2227 | ++nSV; |
peccu | 0:b9ac53c439ed | 2228 | ++total_sv; |
peccu | 0:b9ac53c439ed | 2229 | } |
peccu | 0:b9ac53c439ed | 2230 | model->nSV[i] = nSV; |
peccu | 0:b9ac53c439ed | 2231 | nz_count[i] = nSV; |
peccu | 0:b9ac53c439ed | 2232 | } |
peccu | 0:b9ac53c439ed | 2233 | |
peccu | 0:b9ac53c439ed | 2234 | info("Total nSV = %d\n",total_sv); |
peccu | 0:b9ac53c439ed | 2235 | |
peccu | 0:b9ac53c439ed | 2236 | model->l = total_sv; |
peccu | 0:b9ac53c439ed | 2237 | model->SV = Malloc(svm_node *,total_sv); |
peccu | 0:b9ac53c439ed | 2238 | p = 0; |
peccu | 0:b9ac53c439ed | 2239 | for(i=0;i<l;i++) |
peccu | 0:b9ac53c439ed | 2240 | if(nonzero[i]) model->SV[p++] = x[i]; |
peccu | 0:b9ac53c439ed | 2241 | |
peccu | 0:b9ac53c439ed | 2242 | int *nz_start = Malloc(int,nr_class); |
peccu | 0:b9ac53c439ed | 2243 | nz_start[0] = 0; |
peccu | 0:b9ac53c439ed | 2244 | for(i=1;i<nr_class;i++) |
peccu | 0:b9ac53c439ed | 2245 | nz_start[i] = nz_start[i-1]+nz_count[i-1]; |
peccu | 0:b9ac53c439ed | 2246 | |
peccu | 0:b9ac53c439ed | 2247 | model->sv_coef = Malloc(double *,nr_class-1); |
peccu | 0:b9ac53c439ed | 2248 | for(i=0;i<nr_class-1;i++) |
peccu | 0:b9ac53c439ed | 2249 | model->sv_coef[i] = Malloc(double,total_sv); |
peccu | 0:b9ac53c439ed | 2250 | |
peccu | 0:b9ac53c439ed | 2251 | p = 0; |
peccu | 0:b9ac53c439ed | 2252 | for(i=0;i<nr_class;i++) |
peccu | 0:b9ac53c439ed | 2253 | for(int j=i+1;j<nr_class;j++) |
peccu | 0:b9ac53c439ed | 2254 | { |
peccu | 0:b9ac53c439ed | 2255 | // classifier (i,j): coefficients with |
peccu | 0:b9ac53c439ed | 2256 | // i are in sv_coef[j-1][nz_start[i]...], |
peccu | 0:b9ac53c439ed | 2257 | // j are in sv_coef[i][nz_start[j]...] |
peccu | 0:b9ac53c439ed | 2258 | |
peccu | 0:b9ac53c439ed | 2259 | int si = start[i]; |
peccu | 0:b9ac53c439ed | 2260 | int sj = start[j]; |
peccu | 0:b9ac53c439ed | 2261 | int ci = count[i]; |
peccu | 0:b9ac53c439ed | 2262 | int cj = count[j]; |
peccu | 0:b9ac53c439ed | 2263 | |
peccu | 0:b9ac53c439ed | 2264 | int q = nz_start[i]; |
peccu | 0:b9ac53c439ed | 2265 | int k; |
peccu | 0:b9ac53c439ed | 2266 | for(k=0;k<ci;k++) |
peccu | 0:b9ac53c439ed | 2267 | if(nonzero[si+k]) |
peccu | 0:b9ac53c439ed | 2268 | model->sv_coef[j-1][q++] = f[p].alpha[k]; |
peccu | 0:b9ac53c439ed | 2269 | q = nz_start[j]; |
peccu | 0:b9ac53c439ed | 2270 | for(k=0;k<cj;k++) |
peccu | 0:b9ac53c439ed | 2271 | if(nonzero[sj+k]) |
peccu | 0:b9ac53c439ed | 2272 | model->sv_coef[i][q++] = f[p].alpha[ci+k]; |
peccu | 0:b9ac53c439ed | 2273 | ++p; |
peccu | 0:b9ac53c439ed | 2274 | } |
peccu | 0:b9ac53c439ed | 2275 | |
peccu | 0:b9ac53c439ed | 2276 | free(label); |
peccu | 0:b9ac53c439ed | 2277 | free(probA); |
peccu | 0:b9ac53c439ed | 2278 | free(probB); |
peccu | 0:b9ac53c439ed | 2279 | free(count); |
peccu | 0:b9ac53c439ed | 2280 | free(perm); |
peccu | 0:b9ac53c439ed | 2281 | free(start); |
peccu | 0:b9ac53c439ed | 2282 | free(x); |
peccu | 0:b9ac53c439ed | 2283 | free(weighted_C); |
peccu | 0:b9ac53c439ed | 2284 | free(nonzero); |
peccu | 0:b9ac53c439ed | 2285 | for(i=0;i<nr_class*(nr_class-1)/2;i++) |
peccu | 0:b9ac53c439ed | 2286 | free(f[i].alpha); |
peccu | 0:b9ac53c439ed | 2287 | free(f); |
peccu | 0:b9ac53c439ed | 2288 | free(nz_count); |
peccu | 0:b9ac53c439ed | 2289 | free(nz_start); |
peccu | 0:b9ac53c439ed | 2290 | } |
peccu | 0:b9ac53c439ed | 2291 | return model; |
peccu | 0:b9ac53c439ed | 2292 | } |
peccu | 0:b9ac53c439ed | 2293 | |
peccu | 0:b9ac53c439ed | 2294 | // Stratified cross validation |
peccu | 0:b9ac53c439ed | 2295 | void svm_cross_validation(const svm_problem *prob, const svm_parameter *param, int nr_fold, double *target) |
peccu | 0:b9ac53c439ed | 2296 | { |
peccu | 0:b9ac53c439ed | 2297 | int i; |
peccu | 0:b9ac53c439ed | 2298 | int *fold_start = Malloc(int,nr_fold+1); |
peccu | 0:b9ac53c439ed | 2299 | int l = prob->l; |
peccu | 0:b9ac53c439ed | 2300 | int *perm = Malloc(int,l); |
peccu | 0:b9ac53c439ed | 2301 | int nr_class; |
peccu | 0:b9ac53c439ed | 2302 | |
peccu | 0:b9ac53c439ed | 2303 | // stratified cv may not give leave-one-out rate |
peccu | 0:b9ac53c439ed | 2304 | // Each class to l folds -> some folds may have zero elements |
peccu | 0:b9ac53c439ed | 2305 | if((param->svm_type == C_SVC || |
peccu | 0:b9ac53c439ed | 2306 | param->svm_type == NU_SVC) && nr_fold < l) |
peccu | 0:b9ac53c439ed | 2307 | { |
peccu | 0:b9ac53c439ed | 2308 | int *start = NULL; |
peccu | 0:b9ac53c439ed | 2309 | int *label = NULL; |
peccu | 0:b9ac53c439ed | 2310 | int *count = NULL; |
peccu | 0:b9ac53c439ed | 2311 | svm_group_classes(prob,&nr_class,&label,&start,&count,perm); |
peccu | 0:b9ac53c439ed | 2312 | |
peccu | 0:b9ac53c439ed | 2313 | // random shuffle and then data grouped by fold using the array perm |
peccu | 0:b9ac53c439ed | 2314 | int *fold_count = Malloc(int,nr_fold); |
peccu | 0:b9ac53c439ed | 2315 | int c; |
peccu | 0:b9ac53c439ed | 2316 | int *index = Malloc(int,l); |
peccu | 0:b9ac53c439ed | 2317 | for(i=0;i<l;i++) |
peccu | 0:b9ac53c439ed | 2318 | index[i]=perm[i]; |
peccu | 0:b9ac53c439ed | 2319 | for (c=0; c<nr_class; c++) |
peccu | 0:b9ac53c439ed | 2320 | for(i=0;i<count[c];i++) |
peccu | 0:b9ac53c439ed | 2321 | { |
peccu | 0:b9ac53c439ed | 2322 | int j = i+rand()%(count[c]-i); |
peccu | 0:b9ac53c439ed | 2323 | swap(index[start[c]+j],index[start[c]+i]); |
peccu | 0:b9ac53c439ed | 2324 | } |
peccu | 0:b9ac53c439ed | 2325 | for(i=0;i<nr_fold;i++) |
peccu | 0:b9ac53c439ed | 2326 | { |
peccu | 0:b9ac53c439ed | 2327 | fold_count[i] = 0; |
peccu | 0:b9ac53c439ed | 2328 | for (c=0; c<nr_class;c++) |
peccu | 0:b9ac53c439ed | 2329 | fold_count[i]+=(i+1)*count[c]/nr_fold-i*count[c]/nr_fold; |
peccu | 0:b9ac53c439ed | 2330 | } |
peccu | 0:b9ac53c439ed | 2331 | fold_start[0]=0; |
peccu | 0:b9ac53c439ed | 2332 | for (i=1;i<=nr_fold;i++) |
peccu | 0:b9ac53c439ed | 2333 | fold_start[i] = fold_start[i-1]+fold_count[i-1]; |
peccu | 0:b9ac53c439ed | 2334 | for (c=0; c<nr_class;c++) |
peccu | 0:b9ac53c439ed | 2335 | for(i=0;i<nr_fold;i++) |
peccu | 0:b9ac53c439ed | 2336 | { |
peccu | 0:b9ac53c439ed | 2337 | int begin = start[c]+i*count[c]/nr_fold; |
peccu | 0:b9ac53c439ed | 2338 | int end = start[c]+(i+1)*count[c]/nr_fold; |
peccu | 0:b9ac53c439ed | 2339 | for(int j=begin;j<end;j++) |
peccu | 0:b9ac53c439ed | 2340 | { |
peccu | 0:b9ac53c439ed | 2341 | perm[fold_start[i]] = index[j]; |
peccu | 0:b9ac53c439ed | 2342 | fold_start[i]++; |
peccu | 0:b9ac53c439ed | 2343 | } |
peccu | 0:b9ac53c439ed | 2344 | } |
peccu | 0:b9ac53c439ed | 2345 | fold_start[0]=0; |
peccu | 0:b9ac53c439ed | 2346 | for (i=1;i<=nr_fold;i++) |
peccu | 0:b9ac53c439ed | 2347 | fold_start[i] = fold_start[i-1]+fold_count[i-1]; |
peccu | 0:b9ac53c439ed | 2348 | free(start); |
peccu | 0:b9ac53c439ed | 2349 | free(label); |
peccu | 0:b9ac53c439ed | 2350 | free(count); |
peccu | 0:b9ac53c439ed | 2351 | free(index); |
peccu | 0:b9ac53c439ed | 2352 | free(fold_count); |
peccu | 0:b9ac53c439ed | 2353 | } |
peccu | 0:b9ac53c439ed | 2354 | else |
peccu | 0:b9ac53c439ed | 2355 | { |
peccu | 0:b9ac53c439ed | 2356 | for(i=0;i<l;i++) perm[i]=i; |
peccu | 0:b9ac53c439ed | 2357 | for(i=0;i<l;i++) |
peccu | 0:b9ac53c439ed | 2358 | { |
peccu | 0:b9ac53c439ed | 2359 | int j = i+rand()%(l-i); |
peccu | 0:b9ac53c439ed | 2360 | swap(perm[i],perm[j]); |
peccu | 0:b9ac53c439ed | 2361 | } |
peccu | 0:b9ac53c439ed | 2362 | for(i=0;i<=nr_fold;i++) |
peccu | 0:b9ac53c439ed | 2363 | fold_start[i]=i*l/nr_fold; |
peccu | 0:b9ac53c439ed | 2364 | } |
peccu | 0:b9ac53c439ed | 2365 | |
peccu | 0:b9ac53c439ed | 2366 | for(i=0;i<nr_fold;i++) |
peccu | 0:b9ac53c439ed | 2367 | { |
peccu | 0:b9ac53c439ed | 2368 | int begin = fold_start[i]; |
peccu | 0:b9ac53c439ed | 2369 | int end = fold_start[i+1]; |
peccu | 0:b9ac53c439ed | 2370 | int j,k; |
peccu | 0:b9ac53c439ed | 2371 | struct svm_problem subprob; |
peccu | 0:b9ac53c439ed | 2372 | |
peccu | 0:b9ac53c439ed | 2373 | subprob.l = l-(end-begin); |
peccu | 0:b9ac53c439ed | 2374 | subprob.x = Malloc(struct svm_node*,subprob.l); |
peccu | 0:b9ac53c439ed | 2375 | subprob.y = Malloc(double,subprob.l); |
peccu | 0:b9ac53c439ed | 2376 | |
peccu | 0:b9ac53c439ed | 2377 | k=0; |
peccu | 0:b9ac53c439ed | 2378 | for(j=0;j<begin;j++) |
peccu | 0:b9ac53c439ed | 2379 | { |
peccu | 0:b9ac53c439ed | 2380 | subprob.x[k] = prob->x[perm[j]]; |
peccu | 0:b9ac53c439ed | 2381 | subprob.y[k] = prob->y[perm[j]]; |
peccu | 0:b9ac53c439ed | 2382 | ++k; |
peccu | 0:b9ac53c439ed | 2383 | } |
peccu | 0:b9ac53c439ed | 2384 | for(j=end;j<l;j++) |
peccu | 0:b9ac53c439ed | 2385 | { |
peccu | 0:b9ac53c439ed | 2386 | subprob.x[k] = prob->x[perm[j]]; |
peccu | 0:b9ac53c439ed | 2387 | subprob.y[k] = prob->y[perm[j]]; |
peccu | 0:b9ac53c439ed | 2388 | ++k; |
peccu | 0:b9ac53c439ed | 2389 | } |
peccu | 0:b9ac53c439ed | 2390 | struct svm_model *submodel = svm_train(&subprob,param); |
peccu | 0:b9ac53c439ed | 2391 | if(param->probability && |
peccu | 0:b9ac53c439ed | 2392 | (param->svm_type == C_SVC || param->svm_type == NU_SVC)) |
peccu | 0:b9ac53c439ed | 2393 | { |
peccu | 0:b9ac53c439ed | 2394 | double *prob_estimates=Malloc(double,svm_get_nr_class(submodel)); |
peccu | 0:b9ac53c439ed | 2395 | for(j=begin;j<end;j++) |
peccu | 0:b9ac53c439ed | 2396 | target[perm[j]] = svm_predict_probability(submodel,prob->x[perm[j]],prob_estimates); |
peccu | 0:b9ac53c439ed | 2397 | free(prob_estimates); |
peccu | 0:b9ac53c439ed | 2398 | } |
peccu | 0:b9ac53c439ed | 2399 | else |
peccu | 0:b9ac53c439ed | 2400 | for(j=begin;j<end;j++) |
peccu | 0:b9ac53c439ed | 2401 | target[perm[j]] = svm_predict(submodel,prob->x[perm[j]]); |
peccu | 0:b9ac53c439ed | 2402 | svm_free_and_destroy_model(&submodel); |
peccu | 0:b9ac53c439ed | 2403 | free(subprob.x); |
peccu | 0:b9ac53c439ed | 2404 | free(subprob.y); |
peccu | 0:b9ac53c439ed | 2405 | } |
peccu | 0:b9ac53c439ed | 2406 | free(fold_start); |
peccu | 0:b9ac53c439ed | 2407 | free(perm); |
peccu | 0:b9ac53c439ed | 2408 | } |
peccu | 0:b9ac53c439ed | 2409 | |
peccu | 0:b9ac53c439ed | 2410 | |
peccu | 0:b9ac53c439ed | 2411 | int svm_get_svm_type(const svm_model *model) |
peccu | 0:b9ac53c439ed | 2412 | { |
peccu | 0:b9ac53c439ed | 2413 | return model->param.svm_type; |
peccu | 0:b9ac53c439ed | 2414 | } |
peccu | 0:b9ac53c439ed | 2415 | |
peccu | 0:b9ac53c439ed | 2416 | int svm_get_nr_class(const svm_model *model) |
peccu | 0:b9ac53c439ed | 2417 | { |
peccu | 0:b9ac53c439ed | 2418 | return model->nr_class; |
peccu | 0:b9ac53c439ed | 2419 | } |
peccu | 0:b9ac53c439ed | 2420 | |
peccu | 0:b9ac53c439ed | 2421 | void svm_get_labels(const svm_model *model, int* label) |
peccu | 0:b9ac53c439ed | 2422 | { |
peccu | 0:b9ac53c439ed | 2423 | if (model->label != NULL) |
peccu | 0:b9ac53c439ed | 2424 | for(int i=0;i<model->nr_class;i++) |
peccu | 0:b9ac53c439ed | 2425 | label[i] = model->label[i]; |
peccu | 0:b9ac53c439ed | 2426 | } |
peccu | 0:b9ac53c439ed | 2427 | |
peccu | 0:b9ac53c439ed | 2428 | double svm_get_svr_probability(const svm_model *model) |
peccu | 0:b9ac53c439ed | 2429 | { |
peccu | 0:b9ac53c439ed | 2430 | if ((model->param.svm_type == EPSILON_SVR || model->param.svm_type == NU_SVR) && |
peccu | 0:b9ac53c439ed | 2431 | model->probA!=NULL) |
peccu | 0:b9ac53c439ed | 2432 | return model->probA[0]; |
peccu | 0:b9ac53c439ed | 2433 | else |
peccu | 0:b9ac53c439ed | 2434 | { |
peccu | 0:b9ac53c439ed | 2435 | fprintf(stderr,"Model doesn't contain information for SVR probability inference\n"); |
peccu | 0:b9ac53c439ed | 2436 | return 0; |
peccu | 0:b9ac53c439ed | 2437 | } |
peccu | 0:b9ac53c439ed | 2438 | } |
peccu | 0:b9ac53c439ed | 2439 | |
peccu | 0:b9ac53c439ed | 2440 | double svm_predict_values(const svm_model *model, const svm_node *x, double* dec_values) |
peccu | 0:b9ac53c439ed | 2441 | { |
peccu | 0:b9ac53c439ed | 2442 | if(model->param.svm_type == ONE_CLASS || |
peccu | 0:b9ac53c439ed | 2443 | model->param.svm_type == EPSILON_SVR || |
peccu | 0:b9ac53c439ed | 2444 | model->param.svm_type == NU_SVR) |
peccu | 0:b9ac53c439ed | 2445 | { |
peccu | 0:b9ac53c439ed | 2446 | double *sv_coef = model->sv_coef[0]; |
peccu | 0:b9ac53c439ed | 2447 | double sum = 0; |
peccu | 0:b9ac53c439ed | 2448 | for(int i=0;i<model->l;i++) |
peccu | 0:b9ac53c439ed | 2449 | sum += sv_coef[i] * Kernel::k_function(x,model->SV[i],model->param); |
peccu | 0:b9ac53c439ed | 2450 | sum -= model->rho[0]; |
peccu | 0:b9ac53c439ed | 2451 | *dec_values = sum; |
peccu | 0:b9ac53c439ed | 2452 | |
peccu | 0:b9ac53c439ed | 2453 | if(model->param.svm_type == ONE_CLASS) |
peccu | 0:b9ac53c439ed | 2454 | return (sum>0)?1:-1; |
peccu | 0:b9ac53c439ed | 2455 | else |
peccu | 0:b9ac53c439ed | 2456 | return sum; |
peccu | 0:b9ac53c439ed | 2457 | } |
peccu | 0:b9ac53c439ed | 2458 | else |
peccu | 0:b9ac53c439ed | 2459 | { |
peccu | 0:b9ac53c439ed | 2460 | int i; |
peccu | 0:b9ac53c439ed | 2461 | int nr_class = model->nr_class; |
peccu | 0:b9ac53c439ed | 2462 | int l = model->l; |
peccu | 0:b9ac53c439ed | 2463 | |
peccu | 0:b9ac53c439ed | 2464 | double *kvalue = Malloc(double,l); |
peccu | 0:b9ac53c439ed | 2465 | for(i=0;i<l;i++) |
peccu | 0:b9ac53c439ed | 2466 | kvalue[i] = Kernel::k_function(x,model->SV[i],model->param); |
peccu | 0:b9ac53c439ed | 2467 | |
peccu | 0:b9ac53c439ed | 2468 | int *start = Malloc(int,nr_class); |
peccu | 0:b9ac53c439ed | 2469 | start[0] = 0; |
peccu | 0:b9ac53c439ed | 2470 | for(i=1;i<nr_class;i++) |
peccu | 0:b9ac53c439ed | 2471 | start[i] = start[i-1]+model->nSV[i-1]; |
peccu | 0:b9ac53c439ed | 2472 | |
peccu | 0:b9ac53c439ed | 2473 | int *vote = Malloc(int,nr_class); |
peccu | 0:b9ac53c439ed | 2474 | for(i=0;i<nr_class;i++) |
peccu | 0:b9ac53c439ed | 2475 | vote[i] = 0; |
peccu | 0:b9ac53c439ed | 2476 | |
peccu | 0:b9ac53c439ed | 2477 | int p=0; |
peccu | 0:b9ac53c439ed | 2478 | for(i=0;i<nr_class;i++) |
peccu | 0:b9ac53c439ed | 2479 | for(int j=i+1;j<nr_class;j++) |
peccu | 0:b9ac53c439ed | 2480 | { |
peccu | 0:b9ac53c439ed | 2481 | double sum = 0; |
peccu | 0:b9ac53c439ed | 2482 | int si = start[i]; |
peccu | 0:b9ac53c439ed | 2483 | int sj = start[j]; |
peccu | 0:b9ac53c439ed | 2484 | int ci = model->nSV[i]; |
peccu | 0:b9ac53c439ed | 2485 | int cj = model->nSV[j]; |
peccu | 0:b9ac53c439ed | 2486 | |
peccu | 0:b9ac53c439ed | 2487 | int k; |
peccu | 0:b9ac53c439ed | 2488 | double *coef1 = model->sv_coef[j-1]; |
peccu | 0:b9ac53c439ed | 2489 | double *coef2 = model->sv_coef[i]; |
peccu | 0:b9ac53c439ed | 2490 | for(k=0;k<ci;k++) |
peccu | 0:b9ac53c439ed | 2491 | sum += coef1[si+k] * kvalue[si+k]; |
peccu | 0:b9ac53c439ed | 2492 | for(k=0;k<cj;k++) |
peccu | 0:b9ac53c439ed | 2493 | sum += coef2[sj+k] * kvalue[sj+k]; |
peccu | 0:b9ac53c439ed | 2494 | sum -= model->rho[p]; |
peccu | 0:b9ac53c439ed | 2495 | dec_values[p] = sum; |
peccu | 0:b9ac53c439ed | 2496 | |
peccu | 0:b9ac53c439ed | 2497 | if(dec_values[p] > 0) |
peccu | 0:b9ac53c439ed | 2498 | ++vote[i]; |
peccu | 0:b9ac53c439ed | 2499 | else |
peccu | 0:b9ac53c439ed | 2500 | ++vote[j]; |
peccu | 0:b9ac53c439ed | 2501 | p++; |
peccu | 0:b9ac53c439ed | 2502 | } |
peccu | 0:b9ac53c439ed | 2503 | |
peccu | 0:b9ac53c439ed | 2504 | int vote_max_idx = 0; |
peccu | 0:b9ac53c439ed | 2505 | for(i=1;i<nr_class;i++) |
peccu | 0:b9ac53c439ed | 2506 | if(vote[i] > vote[vote_max_idx]) |
peccu | 0:b9ac53c439ed | 2507 | vote_max_idx = i; |
peccu | 0:b9ac53c439ed | 2508 | |
peccu | 0:b9ac53c439ed | 2509 | free(kvalue); |
peccu | 0:b9ac53c439ed | 2510 | free(start); |
peccu | 0:b9ac53c439ed | 2511 | free(vote); |
peccu | 0:b9ac53c439ed | 2512 | return model->label[vote_max_idx]; |
peccu | 0:b9ac53c439ed | 2513 | } |
peccu | 0:b9ac53c439ed | 2514 | } |
peccu | 0:b9ac53c439ed | 2515 | |
peccu | 0:b9ac53c439ed | 2516 | double svm_predict(const svm_model *model, const svm_node *x) |
peccu | 0:b9ac53c439ed | 2517 | { |
peccu | 0:b9ac53c439ed | 2518 | int nr_class = model->nr_class; |
peccu | 0:b9ac53c439ed | 2519 | double *dec_values; |
peccu | 0:b9ac53c439ed | 2520 | if(model->param.svm_type == ONE_CLASS || |
peccu | 0:b9ac53c439ed | 2521 | model->param.svm_type == EPSILON_SVR || |
peccu | 0:b9ac53c439ed | 2522 | model->param.svm_type == NU_SVR) |
peccu | 0:b9ac53c439ed | 2523 | dec_values = Malloc(double, 1); |
peccu | 0:b9ac53c439ed | 2524 | else |
peccu | 0:b9ac53c439ed | 2525 | dec_values = Malloc(double, nr_class*(nr_class-1)/2); |
peccu | 0:b9ac53c439ed | 2526 | double pred_result = svm_predict_values(model, x, dec_values); |
peccu | 0:b9ac53c439ed | 2527 | free(dec_values); |
peccu | 0:b9ac53c439ed | 2528 | return pred_result; |
peccu | 0:b9ac53c439ed | 2529 | } |
peccu | 0:b9ac53c439ed | 2530 | |
peccu | 0:b9ac53c439ed | 2531 | double svm_predict_probability( |
peccu | 0:b9ac53c439ed | 2532 | const svm_model *model, const svm_node *x, double *prob_estimates) |
peccu | 0:b9ac53c439ed | 2533 | { |
peccu | 0:b9ac53c439ed | 2534 | if ((model->param.svm_type == C_SVC || model->param.svm_type == NU_SVC) && |
peccu | 0:b9ac53c439ed | 2535 | model->probA!=NULL && model->probB!=NULL) |
peccu | 0:b9ac53c439ed | 2536 | { |
peccu | 0:b9ac53c439ed | 2537 | int i; |
peccu | 0:b9ac53c439ed | 2538 | int nr_class = model->nr_class; |
peccu | 0:b9ac53c439ed | 2539 | double *dec_values = Malloc(double, nr_class*(nr_class-1)/2); |
peccu | 0:b9ac53c439ed | 2540 | svm_predict_values(model, x, dec_values); |
peccu | 0:b9ac53c439ed | 2541 | |
peccu | 0:b9ac53c439ed | 2542 | double min_prob=1e-7; |
peccu | 0:b9ac53c439ed | 2543 | double **pairwise_prob=Malloc(double *,nr_class); |
peccu | 0:b9ac53c439ed | 2544 | for(i=0;i<nr_class;i++) |
peccu | 0:b9ac53c439ed | 2545 | pairwise_prob[i]=Malloc(double,nr_class); |
peccu | 0:b9ac53c439ed | 2546 | int k=0; |
peccu | 0:b9ac53c439ed | 2547 | for(i=0;i<nr_class;i++) |
peccu | 0:b9ac53c439ed | 2548 | for(int j=i+1;j<nr_class;j++) |
peccu | 0:b9ac53c439ed | 2549 | { |
peccu | 0:b9ac53c439ed | 2550 | pairwise_prob[i][j]=min(max(sigmoid_predict(dec_values[k],model->probA[k],model->probB[k]),min_prob),1-min_prob); |
peccu | 0:b9ac53c439ed | 2551 | pairwise_prob[j][i]=1-pairwise_prob[i][j]; |
peccu | 0:b9ac53c439ed | 2552 | k++; |
peccu | 0:b9ac53c439ed | 2553 | } |
peccu | 0:b9ac53c439ed | 2554 | multiclass_probability(nr_class,pairwise_prob,prob_estimates); |
peccu | 0:b9ac53c439ed | 2555 | |
peccu | 0:b9ac53c439ed | 2556 | int prob_max_idx = 0; |
peccu | 0:b9ac53c439ed | 2557 | for(i=1;i<nr_class;i++) |
peccu | 0:b9ac53c439ed | 2558 | if(prob_estimates[i] > prob_estimates[prob_max_idx]) |
peccu | 0:b9ac53c439ed | 2559 | prob_max_idx = i; |
peccu | 0:b9ac53c439ed | 2560 | for(i=0;i<nr_class;i++) |
peccu | 0:b9ac53c439ed | 2561 | free(pairwise_prob[i]); |
peccu | 0:b9ac53c439ed | 2562 | free(dec_values); |
peccu | 0:b9ac53c439ed | 2563 | free(pairwise_prob); |
peccu | 0:b9ac53c439ed | 2564 | return model->label[prob_max_idx]; |
peccu | 0:b9ac53c439ed | 2565 | } |
peccu | 0:b9ac53c439ed | 2566 | else |
peccu | 0:b9ac53c439ed | 2567 | return svm_predict(model, x); |
peccu | 0:b9ac53c439ed | 2568 | } |
peccu | 0:b9ac53c439ed | 2569 | |
peccu | 0:b9ac53c439ed | 2570 | static const char *svm_type_table[] = |
peccu | 0:b9ac53c439ed | 2571 | { |
peccu | 0:b9ac53c439ed | 2572 | "c_svc","nu_svc","one_class","epsilon_svr","nu_svr",NULL |
peccu | 0:b9ac53c439ed | 2573 | }; |
peccu | 0:b9ac53c439ed | 2574 | |
peccu | 0:b9ac53c439ed | 2575 | static const char *kernel_type_table[]= |
peccu | 0:b9ac53c439ed | 2576 | { |
peccu | 0:b9ac53c439ed | 2577 | "linear","polynomial","rbf","sigmoid","precomputed",NULL |
peccu | 0:b9ac53c439ed | 2578 | }; |
peccu | 0:b9ac53c439ed | 2579 | |
peccu | 0:b9ac53c439ed | 2580 | int svm_save_model(const char *model_file_name, const svm_model *model) |
peccu | 0:b9ac53c439ed | 2581 | { |
peccu | 0:b9ac53c439ed | 2582 | FILE *fp = fopen(model_file_name,"w"); |
peccu | 0:b9ac53c439ed | 2583 | if(fp==NULL) return -1; |
peccu | 0:b9ac53c439ed | 2584 | |
peccu | 0:b9ac53c439ed | 2585 | const svm_parameter& param = model->param; |
peccu | 0:b9ac53c439ed | 2586 | |
peccu | 0:b9ac53c439ed | 2587 | fprintf(fp,"svm_type %s\n", svm_type_table[param.svm_type]); |
peccu | 0:b9ac53c439ed | 2588 | fprintf(fp,"kernel_type %s\n", kernel_type_table[param.kernel_type]); |
peccu | 0:b9ac53c439ed | 2589 | |
peccu | 0:b9ac53c439ed | 2590 | if(param.kernel_type == POLY) |
peccu | 0:b9ac53c439ed | 2591 | fprintf(fp,"degree %d\n", param.degree); |
peccu | 0:b9ac53c439ed | 2592 | |
peccu | 0:b9ac53c439ed | 2593 | if(param.kernel_type == POLY || param.kernel_type == RBF || param.kernel_type == SIGMOID) |
peccu | 0:b9ac53c439ed | 2594 | fprintf(fp,"gamma %g\n", param.gamma); |
peccu | 0:b9ac53c439ed | 2595 | |
peccu | 0:b9ac53c439ed | 2596 | if(param.kernel_type == POLY || param.kernel_type == SIGMOID) |
peccu | 0:b9ac53c439ed | 2597 | fprintf(fp,"coef0 %g\n", param.coef0); |
peccu | 0:b9ac53c439ed | 2598 | |
peccu | 0:b9ac53c439ed | 2599 | int nr_class = model->nr_class; |
peccu | 0:b9ac53c439ed | 2600 | int l = model->l; |
peccu | 0:b9ac53c439ed | 2601 | fprintf(fp, "nr_class %d\n", nr_class); |
peccu | 0:b9ac53c439ed | 2602 | fprintf(fp, "total_sv %d\n",l); |
peccu | 0:b9ac53c439ed | 2603 | |
peccu | 0:b9ac53c439ed | 2604 | { |
peccu | 0:b9ac53c439ed | 2605 | fprintf(fp, "rho"); |
peccu | 0:b9ac53c439ed | 2606 | for(int i=0;i<nr_class*(nr_class-1)/2;i++) |
peccu | 0:b9ac53c439ed | 2607 | fprintf(fp," %g",model->rho[i]); |
peccu | 0:b9ac53c439ed | 2608 | fprintf(fp, "\n"); |
peccu | 0:b9ac53c439ed | 2609 | } |
peccu | 0:b9ac53c439ed | 2610 | |
peccu | 0:b9ac53c439ed | 2611 | if(model->label) |
peccu | 0:b9ac53c439ed | 2612 | { |
peccu | 0:b9ac53c439ed | 2613 | fprintf(fp, "label"); |
peccu | 0:b9ac53c439ed | 2614 | for(int i=0;i<nr_class;i++) |
peccu | 0:b9ac53c439ed | 2615 | fprintf(fp," %d",model->label[i]); |
peccu | 0:b9ac53c439ed | 2616 | fprintf(fp, "\n"); |
peccu | 0:b9ac53c439ed | 2617 | } |
peccu | 0:b9ac53c439ed | 2618 | |
peccu | 0:b9ac53c439ed | 2619 | if(model->probA) // regression has probA only |
peccu | 0:b9ac53c439ed | 2620 | { |
peccu | 0:b9ac53c439ed | 2621 | fprintf(fp, "probA"); |
peccu | 0:b9ac53c439ed | 2622 | for(int i=0;i<nr_class*(nr_class-1)/2;i++) |
peccu | 0:b9ac53c439ed | 2623 | fprintf(fp," %g",model->probA[i]); |
peccu | 0:b9ac53c439ed | 2624 | fprintf(fp, "\n"); |
peccu | 0:b9ac53c439ed | 2625 | } |
peccu | 0:b9ac53c439ed | 2626 | if(model->probB) |
peccu | 0:b9ac53c439ed | 2627 | { |
peccu | 0:b9ac53c439ed | 2628 | fprintf(fp, "probB"); |
peccu | 0:b9ac53c439ed | 2629 | for(int i=0;i<nr_class*(nr_class-1)/2;i++) |
peccu | 0:b9ac53c439ed | 2630 | fprintf(fp," %g",model->probB[i]); |
peccu | 0:b9ac53c439ed | 2631 | fprintf(fp, "\n"); |
peccu | 0:b9ac53c439ed | 2632 | } |
peccu | 0:b9ac53c439ed | 2633 | |
peccu | 0:b9ac53c439ed | 2634 | if(model->nSV) |
peccu | 0:b9ac53c439ed | 2635 | { |
peccu | 0:b9ac53c439ed | 2636 | fprintf(fp, "nr_sv"); |
peccu | 0:b9ac53c439ed | 2637 | for(int i=0;i<nr_class;i++) |
peccu | 0:b9ac53c439ed | 2638 | fprintf(fp," %d",model->nSV[i]); |
peccu | 0:b9ac53c439ed | 2639 | fprintf(fp, "\n"); |
peccu | 0:b9ac53c439ed | 2640 | } |
peccu | 0:b9ac53c439ed | 2641 | |
peccu | 0:b9ac53c439ed | 2642 | fprintf(fp, "SV\n"); |
peccu | 0:b9ac53c439ed | 2643 | const double * const *sv_coef = model->sv_coef; |
peccu | 0:b9ac53c439ed | 2644 | const svm_node * const *SV = model->SV; |
peccu | 0:b9ac53c439ed | 2645 | |
peccu | 0:b9ac53c439ed | 2646 | for(int i=0;i<l;i++) |
peccu | 0:b9ac53c439ed | 2647 | { |
peccu | 0:b9ac53c439ed | 2648 | for(int j=0;j<nr_class-1;j++) |
peccu | 0:b9ac53c439ed | 2649 | fprintf(fp, "%.16g ",sv_coef[j][i]); |
peccu | 0:b9ac53c439ed | 2650 | |
peccu | 0:b9ac53c439ed | 2651 | const svm_node *p = SV[i]; |
peccu | 0:b9ac53c439ed | 2652 | |
peccu | 0:b9ac53c439ed | 2653 | if(param.kernel_type == PRECOMPUTED) |
peccu | 0:b9ac53c439ed | 2654 | fprintf(fp,"0:%d ",(int)(p->value)); |
peccu | 0:b9ac53c439ed | 2655 | else |
peccu | 0:b9ac53c439ed | 2656 | while(p->index != -1) |
peccu | 0:b9ac53c439ed | 2657 | { |
peccu | 0:b9ac53c439ed | 2658 | fprintf(fp,"%d:%.8g ",p->index,p->value); |
peccu | 0:b9ac53c439ed | 2659 | p++; |
peccu | 0:b9ac53c439ed | 2660 | } |
peccu | 0:b9ac53c439ed | 2661 | fprintf(fp, "\n"); |
peccu | 0:b9ac53c439ed | 2662 | } |
peccu | 0:b9ac53c439ed | 2663 | if (ferror(fp) != 0 || fclose(fp) != 0) return -1; |
peccu | 0:b9ac53c439ed | 2664 | else return 0; |
peccu | 0:b9ac53c439ed | 2665 | } |
peccu | 0:b9ac53c439ed | 2666 | |
peccu | 0:b9ac53c439ed | 2667 | static char *line = NULL; |
peccu | 0:b9ac53c439ed | 2668 | static int max_line_len; |
peccu | 0:b9ac53c439ed | 2669 | |
peccu | 0:b9ac53c439ed | 2670 | static char* readline(FILE *input) |
peccu | 0:b9ac53c439ed | 2671 | { |
peccu | 0:b9ac53c439ed | 2672 | int len; |
peccu | 0:b9ac53c439ed | 2673 | |
peccu | 0:b9ac53c439ed | 2674 | if(fgets(line,max_line_len,input) == NULL) |
peccu | 0:b9ac53c439ed | 2675 | return NULL; |
peccu | 0:b9ac53c439ed | 2676 | |
peccu | 0:b9ac53c439ed | 2677 | while(strrchr(line,'\n') == NULL) |
peccu | 0:b9ac53c439ed | 2678 | { |
peccu | 0:b9ac53c439ed | 2679 | max_line_len *= 2; |
peccu | 0:b9ac53c439ed | 2680 | line = (char *) realloc(line,max_line_len); |
peccu | 0:b9ac53c439ed | 2681 | len = (int) strlen(line); |
peccu | 0:b9ac53c439ed | 2682 | if(fgets(line+len,max_line_len-len,input) == NULL) |
peccu | 0:b9ac53c439ed | 2683 | break; |
peccu | 0:b9ac53c439ed | 2684 | } |
peccu | 0:b9ac53c439ed | 2685 | return line; |
peccu | 0:b9ac53c439ed | 2686 | } |
peccu | 0:b9ac53c439ed | 2687 | |
peccu | 0:b9ac53c439ed | 2688 | svm_model *svm_load_model(const char *model_file_name) |
peccu | 0:b9ac53c439ed | 2689 | { |
peccu | 0:b9ac53c439ed | 2690 | FILE *fp = fopen(model_file_name,"rb"); |
peccu | 0:b9ac53c439ed | 2691 | printf("load\r\n"); |
peccu | 0:b9ac53c439ed | 2692 | if(fp==NULL) return NULL; |
peccu | 0:b9ac53c439ed | 2693 | printf("loaded\r\n"); |
peccu | 0:b9ac53c439ed | 2694 | // read parameters |
peccu | 0:b9ac53c439ed | 2695 | |
peccu | 0:b9ac53c439ed | 2696 | svm_model *model = Malloc(svm_model,1); |
peccu | 0:b9ac53c439ed | 2697 | svm_parameter& param = model->param; |
peccu | 0:b9ac53c439ed | 2698 | model->rho = NULL; |
peccu | 0:b9ac53c439ed | 2699 | model->probA = NULL; |
peccu | 0:b9ac53c439ed | 2700 | model->probB = NULL; |
peccu | 0:b9ac53c439ed | 2701 | model->label = NULL; |
peccu | 0:b9ac53c439ed | 2702 | model->nSV = NULL; |
peccu | 0:b9ac53c439ed | 2703 | |
peccu | 0:b9ac53c439ed | 2704 | char cmd[81]; |
peccu | 0:b9ac53c439ed | 2705 | while(1) |
peccu | 0:b9ac53c439ed | 2706 | { |
peccu | 0:b9ac53c439ed | 2707 | fscanf(fp,"%80s",cmd); |
peccu | 0:b9ac53c439ed | 2708 | |
peccu | 0:b9ac53c439ed | 2709 | if(strcmp(cmd,"svm_type")==0) |
peccu | 0:b9ac53c439ed | 2710 | { |
peccu | 0:b9ac53c439ed | 2711 | fscanf(fp,"%80s",cmd); |
peccu | 0:b9ac53c439ed | 2712 | int i; |
peccu | 0:b9ac53c439ed | 2713 | for(i=0;svm_type_table[i];i++) |
peccu | 0:b9ac53c439ed | 2714 | { |
peccu | 0:b9ac53c439ed | 2715 | if(strcmp(svm_type_table[i],cmd)==0) |
peccu | 0:b9ac53c439ed | 2716 | { |
peccu | 0:b9ac53c439ed | 2717 | param.svm_type=i; |
peccu | 0:b9ac53c439ed | 2718 | break; |
peccu | 0:b9ac53c439ed | 2719 | } |
peccu | 0:b9ac53c439ed | 2720 | } |
peccu | 0:b9ac53c439ed | 2721 | if(svm_type_table[i] == NULL) |
peccu | 0:b9ac53c439ed | 2722 | { |
peccu | 0:b9ac53c439ed | 2723 | fprintf(stderr,"unknown svm type.\n"); |
peccu | 0:b9ac53c439ed | 2724 | free(model->rho); |
peccu | 0:b9ac53c439ed | 2725 | free(model->label); |
peccu | 0:b9ac53c439ed | 2726 | free(model->nSV); |
peccu | 0:b9ac53c439ed | 2727 | free(model); |
peccu | 0:b9ac53c439ed | 2728 | return NULL; |
peccu | 0:b9ac53c439ed | 2729 | } |
peccu | 0:b9ac53c439ed | 2730 | } |
peccu | 0:b9ac53c439ed | 2731 | else if(strcmp(cmd,"kernel_type")==0) |
peccu | 0:b9ac53c439ed | 2732 | { |
peccu | 0:b9ac53c439ed | 2733 | fscanf(fp,"%80s",cmd); |
peccu | 0:b9ac53c439ed | 2734 | int i; |
peccu | 0:b9ac53c439ed | 2735 | for(i=0;kernel_type_table[i];i++) |
peccu | 0:b9ac53c439ed | 2736 | { |
peccu | 0:b9ac53c439ed | 2737 | if(strcmp(kernel_type_table[i],cmd)==0) |
peccu | 0:b9ac53c439ed | 2738 | { |
peccu | 0:b9ac53c439ed | 2739 | param.kernel_type=i; |
peccu | 0:b9ac53c439ed | 2740 | break; |
peccu | 0:b9ac53c439ed | 2741 | } |
peccu | 0:b9ac53c439ed | 2742 | } |
peccu | 0:b9ac53c439ed | 2743 | if(kernel_type_table[i] == NULL) |
peccu | 0:b9ac53c439ed | 2744 | { |
peccu | 0:b9ac53c439ed | 2745 | fprintf(stderr,"unknown kernel function.\n"); |
peccu | 0:b9ac53c439ed | 2746 | free(model->rho); |
peccu | 0:b9ac53c439ed | 2747 | free(model->label); |
peccu | 0:b9ac53c439ed | 2748 | free(model->nSV); |
peccu | 0:b9ac53c439ed | 2749 | free(model); |
peccu | 0:b9ac53c439ed | 2750 | return NULL; |
peccu | 0:b9ac53c439ed | 2751 | } |
peccu | 0:b9ac53c439ed | 2752 | } |
peccu | 0:b9ac53c439ed | 2753 | else if(strcmp(cmd,"degree")==0) |
peccu | 0:b9ac53c439ed | 2754 | fscanf(fp,"%d",¶m.degree); |
peccu | 0:b9ac53c439ed | 2755 | else if(strcmp(cmd,"gamma")==0) |
peccu | 0:b9ac53c439ed | 2756 | fscanf(fp,"%lf",¶m.gamma); |
peccu | 0:b9ac53c439ed | 2757 | else if(strcmp(cmd,"coef0")==0) |
peccu | 0:b9ac53c439ed | 2758 | fscanf(fp,"%lf",¶m.coef0); |
peccu | 0:b9ac53c439ed | 2759 | else if(strcmp(cmd,"nr_class")==0) |
peccu | 0:b9ac53c439ed | 2760 | fscanf(fp,"%d",&model->nr_class); |
peccu | 0:b9ac53c439ed | 2761 | else if(strcmp(cmd,"total_sv")==0) |
peccu | 0:b9ac53c439ed | 2762 | fscanf(fp,"%d",&model->l); |
peccu | 0:b9ac53c439ed | 2763 | else if(strcmp(cmd,"rho")==0) |
peccu | 0:b9ac53c439ed | 2764 | { |
peccu | 0:b9ac53c439ed | 2765 | int n = model->nr_class * (model->nr_class-1)/2; |
peccu | 0:b9ac53c439ed | 2766 | model->rho = Malloc(double,n); |
peccu | 0:b9ac53c439ed | 2767 | for(int i=0;i<n;i++) |
peccu | 0:b9ac53c439ed | 2768 | fscanf(fp,"%lf",&model->rho[i]); |
peccu | 0:b9ac53c439ed | 2769 | } |
peccu | 0:b9ac53c439ed | 2770 | else if(strcmp(cmd,"label")==0) |
peccu | 0:b9ac53c439ed | 2771 | { |
peccu | 0:b9ac53c439ed | 2772 | int n = model->nr_class; |
peccu | 0:b9ac53c439ed | 2773 | model->label = Malloc(int,n); |
peccu | 0:b9ac53c439ed | 2774 | for(int i=0;i<n;i++) |
peccu | 0:b9ac53c439ed | 2775 | fscanf(fp,"%d",&model->label[i]); |
peccu | 0:b9ac53c439ed | 2776 | } |
peccu | 0:b9ac53c439ed | 2777 | else if(strcmp(cmd,"probA")==0) |
peccu | 0:b9ac53c439ed | 2778 | { |
peccu | 0:b9ac53c439ed | 2779 | int n = model->nr_class * (model->nr_class-1)/2; |
peccu | 0:b9ac53c439ed | 2780 | model->probA = Malloc(double,n); |
peccu | 0:b9ac53c439ed | 2781 | for(int i=0;i<n;i++) |
peccu | 0:b9ac53c439ed | 2782 | fscanf(fp,"%lf",&model->probA[i]); |
peccu | 0:b9ac53c439ed | 2783 | } |
peccu | 0:b9ac53c439ed | 2784 | else if(strcmp(cmd,"probB")==0) |
peccu | 0:b9ac53c439ed | 2785 | { |
peccu | 0:b9ac53c439ed | 2786 | int n = model->nr_class * (model->nr_class-1)/2; |
peccu | 0:b9ac53c439ed | 2787 | model->probB = Malloc(double,n); |
peccu | 0:b9ac53c439ed | 2788 | for(int i=0;i<n;i++) |
peccu | 0:b9ac53c439ed | 2789 | fscanf(fp,"%lf",&model->probB[i]); |
peccu | 0:b9ac53c439ed | 2790 | } |
peccu | 0:b9ac53c439ed | 2791 | else if(strcmp(cmd,"nr_sv")==0) |
peccu | 0:b9ac53c439ed | 2792 | { |
peccu | 0:b9ac53c439ed | 2793 | int n = model->nr_class; |
peccu | 0:b9ac53c439ed | 2794 | model->nSV = Malloc(int,n); |
peccu | 0:b9ac53c439ed | 2795 | for(int i=0;i<n;i++) |
peccu | 0:b9ac53c439ed | 2796 | fscanf(fp,"%d",&model->nSV[i]); |
peccu | 0:b9ac53c439ed | 2797 | } |
peccu | 0:b9ac53c439ed | 2798 | else if(strcmp(cmd,"SV")==0) |
peccu | 0:b9ac53c439ed | 2799 | { |
peccu | 0:b9ac53c439ed | 2800 | while(1) |
peccu | 0:b9ac53c439ed | 2801 | { |
peccu | 0:b9ac53c439ed | 2802 | int c = getc(fp); |
peccu | 0:b9ac53c439ed | 2803 | if(c==EOF || c=='\n') break; |
peccu | 0:b9ac53c439ed | 2804 | } |
peccu | 0:b9ac53c439ed | 2805 | break; |
peccu | 0:b9ac53c439ed | 2806 | } |
peccu | 0:b9ac53c439ed | 2807 | else |
peccu | 0:b9ac53c439ed | 2808 | { |
peccu | 0:b9ac53c439ed | 2809 | fprintf(stderr,"unknown text in model file: [%s]\n",cmd); |
peccu | 0:b9ac53c439ed | 2810 | free(model->rho); |
peccu | 0:b9ac53c439ed | 2811 | free(model->label); |
peccu | 0:b9ac53c439ed | 2812 | free(model->nSV); |
peccu | 0:b9ac53c439ed | 2813 | free(model); |
peccu | 0:b9ac53c439ed | 2814 | return NULL; |
peccu | 0:b9ac53c439ed | 2815 | } |
peccu | 0:b9ac53c439ed | 2816 | } |
peccu | 0:b9ac53c439ed | 2817 | |
peccu | 0:b9ac53c439ed | 2818 | // read sv_coef and SV |
peccu | 0:b9ac53c439ed | 2819 | |
peccu | 0:b9ac53c439ed | 2820 | int elements = 0; |
peccu | 0:b9ac53c439ed | 2821 | long pos = ftell(fp); |
peccu | 0:b9ac53c439ed | 2822 | |
peccu | 0:b9ac53c439ed | 2823 | max_line_len = 1024; |
peccu | 0:b9ac53c439ed | 2824 | line = Malloc(char,max_line_len); |
peccu | 0:b9ac53c439ed | 2825 | char *p,*endptr,*idx,*val; |
peccu | 0:b9ac53c439ed | 2826 | |
peccu | 0:b9ac53c439ed | 2827 | while(readline(fp)!=NULL) |
peccu | 0:b9ac53c439ed | 2828 | { |
peccu | 0:b9ac53c439ed | 2829 | p = strtok(line,":"); |
peccu | 0:b9ac53c439ed | 2830 | while(1) |
peccu | 0:b9ac53c439ed | 2831 | { |
peccu | 0:b9ac53c439ed | 2832 | p = strtok(NULL,":"); |
peccu | 0:b9ac53c439ed | 2833 | if(p == NULL) |
peccu | 0:b9ac53c439ed | 2834 | break; |
peccu | 0:b9ac53c439ed | 2835 | ++elements; |
peccu | 0:b9ac53c439ed | 2836 | } |
peccu | 0:b9ac53c439ed | 2837 | } |
peccu | 0:b9ac53c439ed | 2838 | elements += model->l; |
peccu | 0:b9ac53c439ed | 2839 | |
peccu | 0:b9ac53c439ed | 2840 | fseek(fp,pos,SEEK_SET); |
peccu | 0:b9ac53c439ed | 2841 | |
peccu | 0:b9ac53c439ed | 2842 | int m = model->nr_class - 1; |
peccu | 0:b9ac53c439ed | 2843 | int l = model->l; |
peccu | 0:b9ac53c439ed | 2844 | model->sv_coef = Malloc(double *,m); |
peccu | 0:b9ac53c439ed | 2845 | int i; |
peccu | 0:b9ac53c439ed | 2846 | for(i=0;i<m;i++) |
peccu | 0:b9ac53c439ed | 2847 | model->sv_coef[i] = Malloc(double,l); |
peccu | 0:b9ac53c439ed | 2848 | model->SV = Malloc(svm_node*,l); |
peccu | 0:b9ac53c439ed | 2849 | svm_node *x_space = NULL; |
peccu | 0:b9ac53c439ed | 2850 | if(l>0) x_space = Malloc(svm_node,elements); |
peccu | 0:b9ac53c439ed | 2851 | |
peccu | 0:b9ac53c439ed | 2852 | int j=0; |
peccu | 0:b9ac53c439ed | 2853 | for(i=0;i<l;i++) |
peccu | 0:b9ac53c439ed | 2854 | { |
peccu | 0:b9ac53c439ed | 2855 | readline(fp); |
peccu | 0:b9ac53c439ed | 2856 | model->SV[i] = &x_space[j]; |
peccu | 0:b9ac53c439ed | 2857 | |
peccu | 0:b9ac53c439ed | 2858 | p = strtok(line, " \t"); |
peccu | 0:b9ac53c439ed | 2859 | model->sv_coef[0][i] = strtod(p,&endptr); |
peccu | 0:b9ac53c439ed | 2860 | for(int k=1;k<m;k++) |
peccu | 0:b9ac53c439ed | 2861 | { |
peccu | 0:b9ac53c439ed | 2862 | p = strtok(NULL, " \t"); |
peccu | 0:b9ac53c439ed | 2863 | model->sv_coef[k][i] = strtod(p,&endptr); |
peccu | 0:b9ac53c439ed | 2864 | } |
peccu | 0:b9ac53c439ed | 2865 | |
peccu | 0:b9ac53c439ed | 2866 | while(1) |
peccu | 0:b9ac53c439ed | 2867 | { |
peccu | 0:b9ac53c439ed | 2868 | idx = strtok(NULL, ":"); |
peccu | 0:b9ac53c439ed | 2869 | val = strtok(NULL, " \t"); |
peccu | 0:b9ac53c439ed | 2870 | |
peccu | 0:b9ac53c439ed | 2871 | if(val == NULL) |
peccu | 0:b9ac53c439ed | 2872 | break; |
peccu | 0:b9ac53c439ed | 2873 | x_space[j].index = (int) strtol(idx,&endptr,10); |
peccu | 0:b9ac53c439ed | 2874 | x_space[j].value = strtod(val,&endptr); |
peccu | 0:b9ac53c439ed | 2875 | |
peccu | 0:b9ac53c439ed | 2876 | ++j; |
peccu | 0:b9ac53c439ed | 2877 | } |
peccu | 0:b9ac53c439ed | 2878 | x_space[j++].index = -1; |
peccu | 0:b9ac53c439ed | 2879 | } |
peccu | 0:b9ac53c439ed | 2880 | free(line); |
peccu | 0:b9ac53c439ed | 2881 | |
peccu | 0:b9ac53c439ed | 2882 | if (ferror(fp) != 0 || fclose(fp) != 0) |
peccu | 0:b9ac53c439ed | 2883 | return NULL; |
peccu | 0:b9ac53c439ed | 2884 | |
peccu | 0:b9ac53c439ed | 2885 | model->free_sv = 1; // XXX |
peccu | 0:b9ac53c439ed | 2886 | return model; |
peccu | 0:b9ac53c439ed | 2887 | } |
peccu | 0:b9ac53c439ed | 2888 | |
peccu | 0:b9ac53c439ed | 2889 | void svm_free_model_content(svm_model* model_ptr) |
peccu | 0:b9ac53c439ed | 2890 | { |
peccu | 0:b9ac53c439ed | 2891 | if(model_ptr->free_sv && model_ptr->l > 0) |
peccu | 0:b9ac53c439ed | 2892 | free((void *)(model_ptr->SV[0])); |
peccu | 0:b9ac53c439ed | 2893 | for(int i=0;i<model_ptr->nr_class-1;i++) |
peccu | 0:b9ac53c439ed | 2894 | free(model_ptr->sv_coef[i]); |
peccu | 0:b9ac53c439ed | 2895 | free(model_ptr->SV); |
peccu | 0:b9ac53c439ed | 2896 | free(model_ptr->sv_coef); |
peccu | 0:b9ac53c439ed | 2897 | free(model_ptr->rho); |
peccu | 0:b9ac53c439ed | 2898 | free(model_ptr->label); |
peccu | 0:b9ac53c439ed | 2899 | free(model_ptr->probA); |
peccu | 0:b9ac53c439ed | 2900 | free(model_ptr->probB); |
peccu | 0:b9ac53c439ed | 2901 | free(model_ptr->nSV); |
peccu | 0:b9ac53c439ed | 2902 | } |
peccu | 0:b9ac53c439ed | 2903 | |
peccu | 0:b9ac53c439ed | 2904 | void svm_free_and_destroy_model(svm_model** model_ptr_ptr) |
peccu | 0:b9ac53c439ed | 2905 | { |
peccu | 0:b9ac53c439ed | 2906 | svm_model* model_ptr = *model_ptr_ptr; |
peccu | 0:b9ac53c439ed | 2907 | if(model_ptr != NULL) |
peccu | 0:b9ac53c439ed | 2908 | { |
peccu | 0:b9ac53c439ed | 2909 | svm_free_model_content(model_ptr); |
peccu | 0:b9ac53c439ed | 2910 | free(model_ptr); |
peccu | 0:b9ac53c439ed | 2911 | } |
peccu | 0:b9ac53c439ed | 2912 | } |
peccu | 0:b9ac53c439ed | 2913 | |
peccu | 0:b9ac53c439ed | 2914 | void svm_destroy_model(svm_model* model_ptr) |
peccu | 0:b9ac53c439ed | 2915 | { |
peccu | 0:b9ac53c439ed | 2916 | fprintf(stderr,"warning: svm_destroy_model is deprecated and should not be used. Please use svm_free_and_destroy_model(svm_model **model_ptr_ptr)\n"); |
peccu | 0:b9ac53c439ed | 2917 | svm_free_and_destroy_model(&model_ptr); |
peccu | 0:b9ac53c439ed | 2918 | } |
peccu | 0:b9ac53c439ed | 2919 | |
peccu | 0:b9ac53c439ed | 2920 | void svm_destroy_param(svm_parameter* param) |
peccu | 0:b9ac53c439ed | 2921 | { |
peccu | 0:b9ac53c439ed | 2922 | free(param->weight_label); |
peccu | 0:b9ac53c439ed | 2923 | free(param->weight); |
peccu | 0:b9ac53c439ed | 2924 | } |
peccu | 0:b9ac53c439ed | 2925 | |
peccu | 0:b9ac53c439ed | 2926 | const char *svm_check_parameter(const svm_problem *prob, const svm_parameter *param) |
peccu | 0:b9ac53c439ed | 2927 | { |
peccu | 0:b9ac53c439ed | 2928 | // svm_type |
peccu | 0:b9ac53c439ed | 2929 | |
peccu | 0:b9ac53c439ed | 2930 | int svm_type = param->svm_type; |
peccu | 0:b9ac53c439ed | 2931 | if(svm_type != C_SVC && |
peccu | 0:b9ac53c439ed | 2932 | svm_type != NU_SVC && |
peccu | 0:b9ac53c439ed | 2933 | svm_type != ONE_CLASS && |
peccu | 0:b9ac53c439ed | 2934 | svm_type != EPSILON_SVR && |
peccu | 0:b9ac53c439ed | 2935 | svm_type != NU_SVR) |
peccu | 0:b9ac53c439ed | 2936 | return "unknown svm type"; |
peccu | 0:b9ac53c439ed | 2937 | |
peccu | 0:b9ac53c439ed | 2938 | // kernel_type, degree |
peccu | 0:b9ac53c439ed | 2939 | |
peccu | 0:b9ac53c439ed | 2940 | int kernel_type = param->kernel_type; |
peccu | 0:b9ac53c439ed | 2941 | if(kernel_type != LINEAR && |
peccu | 0:b9ac53c439ed | 2942 | kernel_type != POLY && |
peccu | 0:b9ac53c439ed | 2943 | kernel_type != RBF && |
peccu | 0:b9ac53c439ed | 2944 | kernel_type != SIGMOID && |
peccu | 0:b9ac53c439ed | 2945 | kernel_type != PRECOMPUTED) |
peccu | 0:b9ac53c439ed | 2946 | return "unknown kernel type"; |
peccu | 0:b9ac53c439ed | 2947 | |
peccu | 0:b9ac53c439ed | 2948 | if(param->gamma < 0) |
peccu | 0:b9ac53c439ed | 2949 | return "gamma < 0"; |
peccu | 0:b9ac53c439ed | 2950 | |
peccu | 0:b9ac53c439ed | 2951 | if(param->degree < 0) |
peccu | 0:b9ac53c439ed | 2952 | return "degree of polynomial kernel < 0"; |
peccu | 0:b9ac53c439ed | 2953 | |
peccu | 0:b9ac53c439ed | 2954 | // cache_size,eps,C,nu,p,shrinking |
peccu | 0:b9ac53c439ed | 2955 | |
peccu | 0:b9ac53c439ed | 2956 | if(param->cache_size <= 0) |
peccu | 0:b9ac53c439ed | 2957 | return "cache_size <= 0"; |
peccu | 0:b9ac53c439ed | 2958 | |
peccu | 0:b9ac53c439ed | 2959 | if(param->eps <= 0) |
peccu | 0:b9ac53c439ed | 2960 | return "eps <= 0"; |
peccu | 0:b9ac53c439ed | 2961 | |
peccu | 0:b9ac53c439ed | 2962 | if(svm_type == C_SVC || |
peccu | 0:b9ac53c439ed | 2963 | svm_type == EPSILON_SVR || |
peccu | 0:b9ac53c439ed | 2964 | svm_type == NU_SVR) |
peccu | 0:b9ac53c439ed | 2965 | if(param->C <= 0) |
peccu | 0:b9ac53c439ed | 2966 | return "C <= 0"; |
peccu | 0:b9ac53c439ed | 2967 | |
peccu | 0:b9ac53c439ed | 2968 | if(svm_type == NU_SVC || |
peccu | 0:b9ac53c439ed | 2969 | svm_type == ONE_CLASS || |
peccu | 0:b9ac53c439ed | 2970 | svm_type == NU_SVR) |
peccu | 0:b9ac53c439ed | 2971 | if(param->nu <= 0 || param->nu > 1) |
peccu | 0:b9ac53c439ed | 2972 | return "nu <= 0 or nu > 1"; |
peccu | 0:b9ac53c439ed | 2973 | |
peccu | 0:b9ac53c439ed | 2974 | if(svm_type == EPSILON_SVR) |
peccu | 0:b9ac53c439ed | 2975 | if(param->p < 0) |
peccu | 0:b9ac53c439ed | 2976 | return "p < 0"; |
peccu | 0:b9ac53c439ed | 2977 | |
peccu | 0:b9ac53c439ed | 2978 | if(param->shrinking != 0 && |
peccu | 0:b9ac53c439ed | 2979 | param->shrinking != 1) |
peccu | 0:b9ac53c439ed | 2980 | return "shrinking != 0 and shrinking != 1"; |
peccu | 0:b9ac53c439ed | 2981 | |
peccu | 0:b9ac53c439ed | 2982 | if(param->probability != 0 && |
peccu | 0:b9ac53c439ed | 2983 | param->probability != 1) |
peccu | 0:b9ac53c439ed | 2984 | return "probability != 0 and probability != 1"; |
peccu | 0:b9ac53c439ed | 2985 | |
peccu | 0:b9ac53c439ed | 2986 | if(param->probability == 1 && |
peccu | 0:b9ac53c439ed | 2987 | svm_type == ONE_CLASS) |
peccu | 0:b9ac53c439ed | 2988 | return "one-class SVM probability output not supported yet"; |
peccu | 0:b9ac53c439ed | 2989 | |
peccu | 0:b9ac53c439ed | 2990 | |
peccu | 0:b9ac53c439ed | 2991 | // check whether nu-svc is feasible |
peccu | 0:b9ac53c439ed | 2992 | |
peccu | 0:b9ac53c439ed | 2993 | if(svm_type == NU_SVC) |
peccu | 0:b9ac53c439ed | 2994 | { |
peccu | 0:b9ac53c439ed | 2995 | int l = prob->l; |
peccu | 0:b9ac53c439ed | 2996 | int max_nr_class = 16; |
peccu | 0:b9ac53c439ed | 2997 | int nr_class = 0; |
peccu | 0:b9ac53c439ed | 2998 | int *label = Malloc(int,max_nr_class); |
peccu | 0:b9ac53c439ed | 2999 | int *count = Malloc(int,max_nr_class); |
peccu | 0:b9ac53c439ed | 3000 | |
peccu | 0:b9ac53c439ed | 3001 | int i; |
peccu | 0:b9ac53c439ed | 3002 | for(i=0;i<l;i++) |
peccu | 0:b9ac53c439ed | 3003 | { |
peccu | 0:b9ac53c439ed | 3004 | int this_label = (int)prob->y[i]; |
peccu | 0:b9ac53c439ed | 3005 | int j; |
peccu | 0:b9ac53c439ed | 3006 | for(j=0;j<nr_class;j++) |
peccu | 0:b9ac53c439ed | 3007 | if(this_label == label[j]) |
peccu | 0:b9ac53c439ed | 3008 | { |
peccu | 0:b9ac53c439ed | 3009 | ++count[j]; |
peccu | 0:b9ac53c439ed | 3010 | break; |
peccu | 0:b9ac53c439ed | 3011 | } |
peccu | 0:b9ac53c439ed | 3012 | if(j == nr_class) |
peccu | 0:b9ac53c439ed | 3013 | { |
peccu | 0:b9ac53c439ed | 3014 | if(nr_class == max_nr_class) |
peccu | 0:b9ac53c439ed | 3015 | { |
peccu | 0:b9ac53c439ed | 3016 | max_nr_class *= 2; |
peccu | 0:b9ac53c439ed | 3017 | label = (int *)realloc(label,max_nr_class*sizeof(int)); |
peccu | 0:b9ac53c439ed | 3018 | count = (int *)realloc(count,max_nr_class*sizeof(int)); |
peccu | 0:b9ac53c439ed | 3019 | } |
peccu | 0:b9ac53c439ed | 3020 | label[nr_class] = this_label; |
peccu | 0:b9ac53c439ed | 3021 | count[nr_class] = 1; |
peccu | 0:b9ac53c439ed | 3022 | ++nr_class; |
peccu | 0:b9ac53c439ed | 3023 | } |
peccu | 0:b9ac53c439ed | 3024 | } |
peccu | 0:b9ac53c439ed | 3025 | |
peccu | 0:b9ac53c439ed | 3026 | for(i=0;i<nr_class;i++) |
peccu | 0:b9ac53c439ed | 3027 | { |
peccu | 0:b9ac53c439ed | 3028 | int n1 = count[i]; |
peccu | 0:b9ac53c439ed | 3029 | for(int j=i+1;j<nr_class;j++) |
peccu | 0:b9ac53c439ed | 3030 | { |
peccu | 0:b9ac53c439ed | 3031 | int n2 = count[j]; |
peccu | 0:b9ac53c439ed | 3032 | if(param->nu*(n1+n2)/2 > min(n1,n2)) |
peccu | 0:b9ac53c439ed | 3033 | { |
peccu | 0:b9ac53c439ed | 3034 | free(label); |
peccu | 0:b9ac53c439ed | 3035 | free(count); |
peccu | 0:b9ac53c439ed | 3036 | return "specified nu is infeasible"; |
peccu | 0:b9ac53c439ed | 3037 | } |
peccu | 0:b9ac53c439ed | 3038 | } |
peccu | 0:b9ac53c439ed | 3039 | } |
peccu | 0:b9ac53c439ed | 3040 | free(label); |
peccu | 0:b9ac53c439ed | 3041 | free(count); |
peccu | 0:b9ac53c439ed | 3042 | } |
peccu | 0:b9ac53c439ed | 3043 | |
peccu | 0:b9ac53c439ed | 3044 | return NULL; |
peccu | 0:b9ac53c439ed | 3045 | } |
peccu | 0:b9ac53c439ed | 3046 | |
peccu | 0:b9ac53c439ed | 3047 | int svm_check_probability_model(const svm_model *model) |
peccu | 0:b9ac53c439ed | 3048 | { |
peccu | 0:b9ac53c439ed | 3049 | return ((model->param.svm_type == C_SVC || model->param.svm_type == NU_SVC) && |
peccu | 0:b9ac53c439ed | 3050 | model->probA!=NULL && model->probB!=NULL) || |
peccu | 0:b9ac53c439ed | 3051 | ((model->param.svm_type == EPSILON_SVR || model->param.svm_type == NU_SVR) && |
peccu | 0:b9ac53c439ed | 3052 | model->probA!=NULL); |
peccu | 0:b9ac53c439ed | 3053 | } |
peccu | 0:b9ac53c439ed | 3054 | |
peccu | 0:b9ac53c439ed | 3055 | void svm_set_print_string_function(void (*print_func)(const char *)) |
peccu | 0:b9ac53c439ed | 3056 | { |
peccu | 0:b9ac53c439ed | 3057 | if(print_func == NULL) |
peccu | 0:b9ac53c439ed | 3058 | svm_print_string = &print_string_stdout; |
peccu | 0:b9ac53c439ed | 3059 | else |
peccu | 0:b9ac53c439ed | 3060 | svm_print_string = print_func; |
peccu | 0:b9ac53c439ed | 3061 | } |
peccu | 0:b9ac53c439ed | 3062 | |
peccu | 0:b9ac53c439ed | 3063 | // this function is copied by shimatani |
peccu | 0:b9ac53c439ed | 3064 | // for fopen on iPhone |
peccu | 0:b9ac53c439ed | 3065 | svm_model *svm_load_model_fp(FILE *fp) |
peccu | 0:b9ac53c439ed | 3066 | { |
peccu | 0:b9ac53c439ed | 3067 | if(fp==NULL) return NULL; |
peccu | 0:b9ac53c439ed | 3068 | |
peccu | 0:b9ac53c439ed | 3069 | // read parameters |
peccu | 0:b9ac53c439ed | 3070 | |
peccu | 0:b9ac53c439ed | 3071 | svm_model *model = Malloc(svm_model,1); |
peccu | 0:b9ac53c439ed | 3072 | svm_parameter& param = model->param; |
peccu | 0:b9ac53c439ed | 3073 | model->rho = NULL; |
peccu | 0:b9ac53c439ed | 3074 | model->probA = NULL; |
peccu | 0:b9ac53c439ed | 3075 | model->probB = NULL; |
peccu | 0:b9ac53c439ed | 3076 | model->label = NULL; |
peccu | 0:b9ac53c439ed | 3077 | model->nSV = NULL; |
peccu | 0:b9ac53c439ed | 3078 | |
peccu | 0:b9ac53c439ed | 3079 | char cmd[81]; |
peccu | 0:b9ac53c439ed | 3080 | while(1) |
peccu | 0:b9ac53c439ed | 3081 | { |
peccu | 0:b9ac53c439ed | 3082 | fscanf(fp,"%80s",cmd); |
peccu | 0:b9ac53c439ed | 3083 | |
peccu | 0:b9ac53c439ed | 3084 | if(strcmp(cmd,"svm_type")==0) |
peccu | 0:b9ac53c439ed | 3085 | { |
peccu | 0:b9ac53c439ed | 3086 | fscanf(fp,"%80s",cmd); |
peccu | 0:b9ac53c439ed | 3087 | int i; |
peccu | 0:b9ac53c439ed | 3088 | for(i=0;svm_type_table[i];i++) |
peccu | 0:b9ac53c439ed | 3089 | { |
peccu | 0:b9ac53c439ed | 3090 | if(strcmp(svm_type_table[i],cmd)==0) |
peccu | 0:b9ac53c439ed | 3091 | { |
peccu | 0:b9ac53c439ed | 3092 | param.svm_type=i; |
peccu | 0:b9ac53c439ed | 3093 | break; |
peccu | 0:b9ac53c439ed | 3094 | } |
peccu | 0:b9ac53c439ed | 3095 | } |
peccu | 0:b9ac53c439ed | 3096 | if(svm_type_table[i] == NULL) |
peccu | 0:b9ac53c439ed | 3097 | { |
peccu | 0:b9ac53c439ed | 3098 | fprintf(stderr,"unknown svm type.\n"); |
peccu | 0:b9ac53c439ed | 3099 | free(model->rho); |
peccu | 0:b9ac53c439ed | 3100 | free(model->label); |
peccu | 0:b9ac53c439ed | 3101 | free(model->nSV); |
peccu | 0:b9ac53c439ed | 3102 | free(model); |
peccu | 0:b9ac53c439ed | 3103 | return NULL; |
peccu | 0:b9ac53c439ed | 3104 | } |
peccu | 0:b9ac53c439ed | 3105 | } |
peccu | 0:b9ac53c439ed | 3106 | else if(strcmp(cmd,"kernel_type")==0) |
peccu | 0:b9ac53c439ed | 3107 | { |
peccu | 0:b9ac53c439ed | 3108 | fscanf(fp,"%80s",cmd); |
peccu | 0:b9ac53c439ed | 3109 | int i; |
peccu | 0:b9ac53c439ed | 3110 | for(i=0;kernel_type_table[i];i++) |
peccu | 0:b9ac53c439ed | 3111 | { |
peccu | 0:b9ac53c439ed | 3112 | if(strcmp(kernel_type_table[i],cmd)==0) |
peccu | 0:b9ac53c439ed | 3113 | { |
peccu | 0:b9ac53c439ed | 3114 | param.kernel_type=i; |
peccu | 0:b9ac53c439ed | 3115 | break; |
peccu | 0:b9ac53c439ed | 3116 | } |
peccu | 0:b9ac53c439ed | 3117 | } |
peccu | 0:b9ac53c439ed | 3118 | if(kernel_type_table[i] == NULL) |
peccu | 0:b9ac53c439ed | 3119 | { |
peccu | 0:b9ac53c439ed | 3120 | fprintf(stderr,"unknown kernel function.\n"); |
peccu | 0:b9ac53c439ed | 3121 | free(model->rho); |
peccu | 0:b9ac53c439ed | 3122 | free(model->label); |
peccu | 0:b9ac53c439ed | 3123 | free(model->nSV); |
peccu | 0:b9ac53c439ed | 3124 | free(model); |
peccu | 0:b9ac53c439ed | 3125 | return NULL; |
peccu | 0:b9ac53c439ed | 3126 | } |
peccu | 0:b9ac53c439ed | 3127 | } |
peccu | 0:b9ac53c439ed | 3128 | else if(strcmp(cmd,"degree")==0) |
peccu | 0:b9ac53c439ed | 3129 | fscanf(fp,"%d",¶m.degree); |
peccu | 0:b9ac53c439ed | 3130 | else if(strcmp(cmd,"gamma")==0) |
peccu | 0:b9ac53c439ed | 3131 | fscanf(fp,"%lf",¶m.gamma); |
peccu | 0:b9ac53c439ed | 3132 | else if(strcmp(cmd,"coef0")==0) |
peccu | 0:b9ac53c439ed | 3133 | fscanf(fp,"%lf",¶m.coef0); |
peccu | 0:b9ac53c439ed | 3134 | else if(strcmp(cmd,"nr_class")==0) |
peccu | 0:b9ac53c439ed | 3135 | fscanf(fp,"%d",&model->nr_class); |
peccu | 0:b9ac53c439ed | 3136 | else if(strcmp(cmd,"total_sv")==0) |
peccu | 0:b9ac53c439ed | 3137 | fscanf(fp,"%d",&model->l); |
peccu | 0:b9ac53c439ed | 3138 | else if(strcmp(cmd,"rho")==0) |
peccu | 0:b9ac53c439ed | 3139 | { |
peccu | 0:b9ac53c439ed | 3140 | int n = model->nr_class * (model->nr_class-1)/2; |
peccu | 0:b9ac53c439ed | 3141 | model->rho = Malloc(double,n); |
peccu | 0:b9ac53c439ed | 3142 | for(int i=0;i<n;i++) |
peccu | 0:b9ac53c439ed | 3143 | fscanf(fp,"%lf",&model->rho[i]); |
peccu | 0:b9ac53c439ed | 3144 | } |
peccu | 0:b9ac53c439ed | 3145 | else if(strcmp(cmd,"label")==0) |
peccu | 0:b9ac53c439ed | 3146 | { |
peccu | 0:b9ac53c439ed | 3147 | int n = model->nr_class; |
peccu | 0:b9ac53c439ed | 3148 | model->label = Malloc(int,n); |
peccu | 0:b9ac53c439ed | 3149 | for(int i=0;i<n;i++) |
peccu | 0:b9ac53c439ed | 3150 | fscanf(fp,"%d",&model->label[i]); |
peccu | 0:b9ac53c439ed | 3151 | } |
peccu | 0:b9ac53c439ed | 3152 | else if(strcmp(cmd,"probA")==0) |
peccu | 0:b9ac53c439ed | 3153 | { |
peccu | 0:b9ac53c439ed | 3154 | int n = model->nr_class * (model->nr_class-1)/2; |
peccu | 0:b9ac53c439ed | 3155 | model->probA = Malloc(double,n); |
peccu | 0:b9ac53c439ed | 3156 | for(int i=0;i<n;i++) |
peccu | 0:b9ac53c439ed | 3157 | fscanf(fp,"%lf",&model->probA[i]); |
peccu | 0:b9ac53c439ed | 3158 | } |
peccu | 0:b9ac53c439ed | 3159 | else if(strcmp(cmd,"probB")==0) |
peccu | 0:b9ac53c439ed | 3160 | { |
peccu | 0:b9ac53c439ed | 3161 | int n = model->nr_class * (model->nr_class-1)/2; |
peccu | 0:b9ac53c439ed | 3162 | model->probB = Malloc(double,n); |
peccu | 0:b9ac53c439ed | 3163 | for(int i=0;i<n;i++) |
peccu | 0:b9ac53c439ed | 3164 | fscanf(fp,"%lf",&model->probB[i]); |
peccu | 0:b9ac53c439ed | 3165 | } |
peccu | 0:b9ac53c439ed | 3166 | else if(strcmp(cmd,"nr_sv")==0) |
peccu | 0:b9ac53c439ed | 3167 | { |
peccu | 0:b9ac53c439ed | 3168 | int n = model->nr_class; |
peccu | 0:b9ac53c439ed | 3169 | model->nSV = Malloc(int,n); |
peccu | 0:b9ac53c439ed | 3170 | for(int i=0;i<n;i++) |
peccu | 0:b9ac53c439ed | 3171 | fscanf(fp,"%d",&model->nSV[i]); |
peccu | 0:b9ac53c439ed | 3172 | } |
peccu | 0:b9ac53c439ed | 3173 | else if(strcmp(cmd,"SV")==0) |
peccu | 0:b9ac53c439ed | 3174 | { |
peccu | 0:b9ac53c439ed | 3175 | while(1) |
peccu | 0:b9ac53c439ed | 3176 | { |
peccu | 0:b9ac53c439ed | 3177 | int c = getc(fp); |
peccu | 0:b9ac53c439ed | 3178 | if(c==EOF || c=='\n') break; |
peccu | 0:b9ac53c439ed | 3179 | } |
peccu | 0:b9ac53c439ed | 3180 | break; |
peccu | 0:b9ac53c439ed | 3181 | } |
peccu | 0:b9ac53c439ed | 3182 | else |
peccu | 0:b9ac53c439ed | 3183 | { |
peccu | 0:b9ac53c439ed | 3184 | fprintf(stderr,"unknown text in model file: [%s]\n",cmd); |
peccu | 0:b9ac53c439ed | 3185 | free(model->rho); |
peccu | 0:b9ac53c439ed | 3186 | free(model->label); |
peccu | 0:b9ac53c439ed | 3187 | free(model->nSV); |
peccu | 0:b9ac53c439ed | 3188 | free(model); |
peccu | 0:b9ac53c439ed | 3189 | return NULL; |
peccu | 0:b9ac53c439ed | 3190 | } |
peccu | 0:b9ac53c439ed | 3191 | } |
peccu | 0:b9ac53c439ed | 3192 | |
peccu | 0:b9ac53c439ed | 3193 | // read sv_coef and SV |
peccu | 0:b9ac53c439ed | 3194 | |
peccu | 0:b9ac53c439ed | 3195 | int elements = 0; |
peccu | 0:b9ac53c439ed | 3196 | long pos = ftell(fp); |
peccu | 0:b9ac53c439ed | 3197 | |
peccu | 0:b9ac53c439ed | 3198 | max_line_len = 1024; |
peccu | 0:b9ac53c439ed | 3199 | line = Malloc(char,max_line_len); |
peccu | 0:b9ac53c439ed | 3200 | char *p,*endptr,*idx,*val; |
peccu | 0:b9ac53c439ed | 3201 | |
peccu | 0:b9ac53c439ed | 3202 | while(readline(fp)!=NULL) |
peccu | 0:b9ac53c439ed | 3203 | { |
peccu | 0:b9ac53c439ed | 3204 | p = strtok(line,":"); |
peccu | 0:b9ac53c439ed | 3205 | while(1) |
peccu | 0:b9ac53c439ed | 3206 | { |
peccu | 0:b9ac53c439ed | 3207 | p = strtok(NULL,":"); |
peccu | 0:b9ac53c439ed | 3208 | if(p == NULL) |
peccu | 0:b9ac53c439ed | 3209 | break; |
peccu | 0:b9ac53c439ed | 3210 | ++elements; |
peccu | 0:b9ac53c439ed | 3211 | } |
peccu | 0:b9ac53c439ed | 3212 | } |
peccu | 0:b9ac53c439ed | 3213 | elements += model->l; |
peccu | 0:b9ac53c439ed | 3214 | |
peccu | 0:b9ac53c439ed | 3215 | fseek(fp,pos,SEEK_SET); |
peccu | 0:b9ac53c439ed | 3216 | |
peccu | 0:b9ac53c439ed | 3217 | int m = model->nr_class - 1; |
peccu | 0:b9ac53c439ed | 3218 | int l = model->l; |
peccu | 0:b9ac53c439ed | 3219 | model->sv_coef = Malloc(double *,m); |
peccu | 0:b9ac53c439ed | 3220 | int i; |
peccu | 0:b9ac53c439ed | 3221 | for(i=0;i<m;i++) |
peccu | 0:b9ac53c439ed | 3222 | model->sv_coef[i] = Malloc(double,l); |
peccu | 0:b9ac53c439ed | 3223 | model->SV = Malloc(svm_node*,l); |
peccu | 0:b9ac53c439ed | 3224 | svm_node *x_space = NULL; |
peccu | 0:b9ac53c439ed | 3225 | if(l>0) x_space = Malloc(svm_node,elements); |
peccu | 0:b9ac53c439ed | 3226 | |
peccu | 0:b9ac53c439ed | 3227 | int j=0; |
peccu | 0:b9ac53c439ed | 3228 | for(i=0;i<l;i++) |
peccu | 0:b9ac53c439ed | 3229 | { |
peccu | 0:b9ac53c439ed | 3230 | readline(fp); |
peccu | 0:b9ac53c439ed | 3231 | model->SV[i] = &x_space[j]; |
peccu | 0:b9ac53c439ed | 3232 | |
peccu | 0:b9ac53c439ed | 3233 | p = strtok(line, " \t"); |
peccu | 0:b9ac53c439ed | 3234 | model->sv_coef[0][i] = strtod(p,&endptr); |
peccu | 0:b9ac53c439ed | 3235 | for(int k=1;k<m;k++) |
peccu | 0:b9ac53c439ed | 3236 | { |
peccu | 0:b9ac53c439ed | 3237 | p = strtok(NULL, " \t"); |
peccu | 0:b9ac53c439ed | 3238 | model->sv_coef[k][i] = strtod(p,&endptr); |
peccu | 0:b9ac53c439ed | 3239 | } |
peccu | 0:b9ac53c439ed | 3240 | |
peccu | 0:b9ac53c439ed | 3241 | while(1) |
peccu | 0:b9ac53c439ed | 3242 | { |
peccu | 0:b9ac53c439ed | 3243 | idx = strtok(NULL, ":"); |
peccu | 0:b9ac53c439ed | 3244 | val = strtok(NULL, " \t"); |
peccu | 0:b9ac53c439ed | 3245 | |
peccu | 0:b9ac53c439ed | 3246 | if(val == NULL) |
peccu | 0:b9ac53c439ed | 3247 | break; |
peccu | 0:b9ac53c439ed | 3248 | x_space[j].index = (int) strtol(idx,&endptr,10); |
peccu | 0:b9ac53c439ed | 3249 | x_space[j].value = strtod(val,&endptr); |
peccu | 0:b9ac53c439ed | 3250 | |
peccu | 0:b9ac53c439ed | 3251 | ++j; |
peccu | 0:b9ac53c439ed | 3252 | } |
peccu | 0:b9ac53c439ed | 3253 | x_space[j++].index = -1; |
peccu | 0:b9ac53c439ed | 3254 | } |
peccu | 0:b9ac53c439ed | 3255 | free(line); |
peccu | 0:b9ac53c439ed | 3256 | |
peccu | 0:b9ac53c439ed | 3257 | if (ferror(fp) != 0 || fclose(fp) != 0) |
peccu | 0:b9ac53c439ed | 3258 | return NULL; |
peccu | 0:b9ac53c439ed | 3259 | |
peccu | 0:b9ac53c439ed | 3260 | model->free_sv = 1; // XXX |
peccu | 0:b9ac53c439ed | 3261 | return model; |
peccu | 0:b9ac53c439ed | 3262 | } |