action recognizer with theremin

Dependencies:   mbed

Committer:
peccu
Date:
Wed Sep 14 13:42:46 2011 +0000
Revision:
0:b9ac53c439ed

        

Who changed what in which revision?

UserRevisionLine numberNew contents of line
peccu 0:b9ac53c439ed 1 #include <math.h>
peccu 0:b9ac53c439ed 2 #include <stdio.h>
peccu 0:b9ac53c439ed 3 #include <stdlib.h>
peccu 0:b9ac53c439ed 4 #include <ctype.h>
peccu 0:b9ac53c439ed 5 #include <float.h>
peccu 0:b9ac53c439ed 6 #include <string.h>
peccu 0:b9ac53c439ed 7 #include <stdarg.h>
peccu 0:b9ac53c439ed 8 #include "svm.h"
peccu 0:b9ac53c439ed 9 int libsvm_version = LIBSVM_VERSION;
peccu 0:b9ac53c439ed 10 typedef float Qfloat;
peccu 0:b9ac53c439ed 11 typedef signed char schar;
peccu 0:b9ac53c439ed 12 #ifndef min
peccu 0:b9ac53c439ed 13 template <class T> static inline T min(T x,T y) { return (x<y)?x:y; }
peccu 0:b9ac53c439ed 14 #endif
peccu 0:b9ac53c439ed 15 #ifndef max
peccu 0:b9ac53c439ed 16 template <class T> static inline T max(T x,T y) { return (x>y)?x:y; }
peccu 0:b9ac53c439ed 17 #endif
peccu 0:b9ac53c439ed 18 template <class T> static inline void swap(T& x, T& y) { T t=x; x=y; y=t; }
peccu 0:b9ac53c439ed 19 template <class S, class T> static inline void clone(T*& dst, S* src, int n)
peccu 0:b9ac53c439ed 20 {
peccu 0:b9ac53c439ed 21 dst = new T[n];
peccu 0:b9ac53c439ed 22 memcpy((void *)dst,(void *)src,sizeof(T)*n);
peccu 0:b9ac53c439ed 23 }
peccu 0:b9ac53c439ed 24 static inline double powi(double base, int times)
peccu 0:b9ac53c439ed 25 {
peccu 0:b9ac53c439ed 26 double tmp = base, ret = 1.0;
peccu 0:b9ac53c439ed 27
peccu 0:b9ac53c439ed 28 for(int t=times; t>0; t/=2)
peccu 0:b9ac53c439ed 29 {
peccu 0:b9ac53c439ed 30 if(t%2==1) ret*=tmp;
peccu 0:b9ac53c439ed 31 tmp = tmp * tmp;
peccu 0:b9ac53c439ed 32 }
peccu 0:b9ac53c439ed 33 return ret;
peccu 0:b9ac53c439ed 34 }
peccu 0:b9ac53c439ed 35 #define INF HUGE_VAL
peccu 0:b9ac53c439ed 36 #define TAU 1e-12
peccu 0:b9ac53c439ed 37 #define Malloc(type,n) (type *)malloc((n)*sizeof(type))
peccu 0:b9ac53c439ed 38
peccu 0:b9ac53c439ed 39 static void print_string_stdout(const char *s)
peccu 0:b9ac53c439ed 40 {
peccu 0:b9ac53c439ed 41 fputs(s,stdout);
peccu 0:b9ac53c439ed 42 fflush(stdout);
peccu 0:b9ac53c439ed 43 }
peccu 0:b9ac53c439ed 44 static void (*svm_print_string) (const char *) = &print_string_stdout;
peccu 0:b9ac53c439ed 45 #if 1
peccu 0:b9ac53c439ed 46 static void info(const char *fmt,...)
peccu 0:b9ac53c439ed 47 {
peccu 0:b9ac53c439ed 48 char buf[BUFSIZ];
peccu 0:b9ac53c439ed 49 va_list ap;
peccu 0:b9ac53c439ed 50 va_start(ap,fmt);
peccu 0:b9ac53c439ed 51 vsprintf(buf,fmt,ap);
peccu 0:b9ac53c439ed 52 va_end(ap);
peccu 0:b9ac53c439ed 53 (*svm_print_string)(buf);
peccu 0:b9ac53c439ed 54 }
peccu 0:b9ac53c439ed 55 #else
peccu 0:b9ac53c439ed 56 static void info(const char *fmt,...) {}
peccu 0:b9ac53c439ed 57 #endif
peccu 0:b9ac53c439ed 58
peccu 0:b9ac53c439ed 59 //
peccu 0:b9ac53c439ed 60 // Kernel Cache
peccu 0:b9ac53c439ed 61 //
peccu 0:b9ac53c439ed 62 // l is the number of total data items
peccu 0:b9ac53c439ed 63 // size is the cache size limit in bytes
peccu 0:b9ac53c439ed 64 //
peccu 0:b9ac53c439ed 65 class Cache
peccu 0:b9ac53c439ed 66 {
peccu 0:b9ac53c439ed 67 public:
peccu 0:b9ac53c439ed 68 Cache(int l,long int size);
peccu 0:b9ac53c439ed 69 ~Cache();
peccu 0:b9ac53c439ed 70
peccu 0:b9ac53c439ed 71 // request data [0,len)
peccu 0:b9ac53c439ed 72 // return some position p where [p,len) need to be filled
peccu 0:b9ac53c439ed 73 // (p >= len if nothing needs to be filled)
peccu 0:b9ac53c439ed 74 int get_data(const int index, Qfloat **data, int len);
peccu 0:b9ac53c439ed 75 void swap_index(int i, int j);
peccu 0:b9ac53c439ed 76 private:
peccu 0:b9ac53c439ed 77 int l;
peccu 0:b9ac53c439ed 78 long int size;
peccu 0:b9ac53c439ed 79 struct head_t
peccu 0:b9ac53c439ed 80 {
peccu 0:b9ac53c439ed 81 head_t *prev, *next; // a circular list
peccu 0:b9ac53c439ed 82 Qfloat *data;
peccu 0:b9ac53c439ed 83 int len; // data[0,len) is cached in this entry
peccu 0:b9ac53c439ed 84 };
peccu 0:b9ac53c439ed 85
peccu 0:b9ac53c439ed 86 head_t *head;
peccu 0:b9ac53c439ed 87 head_t lru_head;
peccu 0:b9ac53c439ed 88 void lru_delete(head_t *h);
peccu 0:b9ac53c439ed 89 void lru_insert(head_t *h);
peccu 0:b9ac53c439ed 90 };
peccu 0:b9ac53c439ed 91
peccu 0:b9ac53c439ed 92 Cache::Cache(int l_,long int size_):l(l_),size(size_)
peccu 0:b9ac53c439ed 93 {
peccu 0:b9ac53c439ed 94 head = (head_t *)calloc(l,sizeof(head_t)); // initialized to 0
peccu 0:b9ac53c439ed 95 size /= sizeof(Qfloat);
peccu 0:b9ac53c439ed 96 size -= l * sizeof(head_t) / sizeof(Qfloat);
peccu 0:b9ac53c439ed 97 size = max(size, 2 * (long int) l); // cache must be large enough for two columns
peccu 0:b9ac53c439ed 98 lru_head.next = lru_head.prev = &lru_head;
peccu 0:b9ac53c439ed 99 }
peccu 0:b9ac53c439ed 100
peccu 0:b9ac53c439ed 101 Cache::~Cache()
peccu 0:b9ac53c439ed 102 {
peccu 0:b9ac53c439ed 103 for(head_t *h = lru_head.next; h != &lru_head; h=h->next)
peccu 0:b9ac53c439ed 104 free(h->data);
peccu 0:b9ac53c439ed 105 free(head);
peccu 0:b9ac53c439ed 106 }
peccu 0:b9ac53c439ed 107
peccu 0:b9ac53c439ed 108 void Cache::lru_delete(head_t *h)
peccu 0:b9ac53c439ed 109 {
peccu 0:b9ac53c439ed 110 // delete from current location
peccu 0:b9ac53c439ed 111 h->prev->next = h->next;
peccu 0:b9ac53c439ed 112 h->next->prev = h->prev;
peccu 0:b9ac53c439ed 113 }
peccu 0:b9ac53c439ed 114
peccu 0:b9ac53c439ed 115 void Cache::lru_insert(head_t *h)
peccu 0:b9ac53c439ed 116 {
peccu 0:b9ac53c439ed 117 // insert to last position
peccu 0:b9ac53c439ed 118 h->next = &lru_head;
peccu 0:b9ac53c439ed 119 h->prev = lru_head.prev;
peccu 0:b9ac53c439ed 120 h->prev->next = h;
peccu 0:b9ac53c439ed 121 h->next->prev = h;
peccu 0:b9ac53c439ed 122 }
peccu 0:b9ac53c439ed 123
peccu 0:b9ac53c439ed 124 int Cache::get_data(const int index, Qfloat **data, int len)
peccu 0:b9ac53c439ed 125 {
peccu 0:b9ac53c439ed 126 head_t *h = &head[index];
peccu 0:b9ac53c439ed 127 if(h->len) lru_delete(h);
peccu 0:b9ac53c439ed 128 int more = len - h->len;
peccu 0:b9ac53c439ed 129
peccu 0:b9ac53c439ed 130 if(more > 0)
peccu 0:b9ac53c439ed 131 {
peccu 0:b9ac53c439ed 132 // free old space
peccu 0:b9ac53c439ed 133 while(size < more)
peccu 0:b9ac53c439ed 134 {
peccu 0:b9ac53c439ed 135 head_t *old = lru_head.next;
peccu 0:b9ac53c439ed 136 lru_delete(old);
peccu 0:b9ac53c439ed 137 free(old->data);
peccu 0:b9ac53c439ed 138 size += old->len;
peccu 0:b9ac53c439ed 139 old->data = 0;
peccu 0:b9ac53c439ed 140 old->len = 0;
peccu 0:b9ac53c439ed 141 }
peccu 0:b9ac53c439ed 142
peccu 0:b9ac53c439ed 143 // allocate new space
peccu 0:b9ac53c439ed 144 h->data = (Qfloat *)realloc(h->data,sizeof(Qfloat)*len);
peccu 0:b9ac53c439ed 145 size -= more;
peccu 0:b9ac53c439ed 146 swap(h->len,len);
peccu 0:b9ac53c439ed 147 }
peccu 0:b9ac53c439ed 148
peccu 0:b9ac53c439ed 149 lru_insert(h);
peccu 0:b9ac53c439ed 150 *data = h->data;
peccu 0:b9ac53c439ed 151 return len;
peccu 0:b9ac53c439ed 152 }
peccu 0:b9ac53c439ed 153
peccu 0:b9ac53c439ed 154 void Cache::swap_index(int i, int j)
peccu 0:b9ac53c439ed 155 {
peccu 0:b9ac53c439ed 156 if(i==j) return;
peccu 0:b9ac53c439ed 157
peccu 0:b9ac53c439ed 158 if(head[i].len) lru_delete(&head[i]);
peccu 0:b9ac53c439ed 159 if(head[j].len) lru_delete(&head[j]);
peccu 0:b9ac53c439ed 160 swap(head[i].data,head[j].data);
peccu 0:b9ac53c439ed 161 swap(head[i].len,head[j].len);
peccu 0:b9ac53c439ed 162 if(head[i].len) lru_insert(&head[i]);
peccu 0:b9ac53c439ed 163 if(head[j].len) lru_insert(&head[j]);
peccu 0:b9ac53c439ed 164
peccu 0:b9ac53c439ed 165 if(i>j) swap(i,j);
peccu 0:b9ac53c439ed 166 for(head_t *h = lru_head.next; h!=&lru_head; h=h->next)
peccu 0:b9ac53c439ed 167 {
peccu 0:b9ac53c439ed 168 if(h->len > i)
peccu 0:b9ac53c439ed 169 {
peccu 0:b9ac53c439ed 170 if(h->len > j)
peccu 0:b9ac53c439ed 171 swap(h->data[i],h->data[j]);
peccu 0:b9ac53c439ed 172 else
peccu 0:b9ac53c439ed 173 {
peccu 0:b9ac53c439ed 174 // give up
peccu 0:b9ac53c439ed 175 lru_delete(h);
peccu 0:b9ac53c439ed 176 free(h->data);
peccu 0:b9ac53c439ed 177 size += h->len;
peccu 0:b9ac53c439ed 178 h->data = 0;
peccu 0:b9ac53c439ed 179 h->len = 0;
peccu 0:b9ac53c439ed 180 }
peccu 0:b9ac53c439ed 181 }
peccu 0:b9ac53c439ed 182 }
peccu 0:b9ac53c439ed 183 }
peccu 0:b9ac53c439ed 184
peccu 0:b9ac53c439ed 185 //
peccu 0:b9ac53c439ed 186 // Kernel evaluation
peccu 0:b9ac53c439ed 187 //
peccu 0:b9ac53c439ed 188 // the static method k_function is for doing single kernel evaluation
peccu 0:b9ac53c439ed 189 // the constructor of Kernel prepares to calculate the l*l kernel matrix
peccu 0:b9ac53c439ed 190 // the member function get_Q is for getting one column from the Q Matrix
peccu 0:b9ac53c439ed 191 //
peccu 0:b9ac53c439ed 192 class QMatrix {
peccu 0:b9ac53c439ed 193 public:
peccu 0:b9ac53c439ed 194 virtual Qfloat *get_Q(int column, int len) const = 0;
peccu 0:b9ac53c439ed 195 virtual double *get_QD() const = 0;
peccu 0:b9ac53c439ed 196 virtual void swap_index(int i, int j) const = 0;
peccu 0:b9ac53c439ed 197 virtual ~QMatrix() {}
peccu 0:b9ac53c439ed 198 };
peccu 0:b9ac53c439ed 199
peccu 0:b9ac53c439ed 200 class Kernel: public QMatrix {
peccu 0:b9ac53c439ed 201 public:
peccu 0:b9ac53c439ed 202 Kernel(int l, svm_node * const * x, const svm_parameter& param);
peccu 0:b9ac53c439ed 203 virtual ~Kernel();
peccu 0:b9ac53c439ed 204
peccu 0:b9ac53c439ed 205 static double k_function(const svm_node *x, const svm_node *y,
peccu 0:b9ac53c439ed 206 const svm_parameter& param);
peccu 0:b9ac53c439ed 207 virtual Qfloat *get_Q(int column, int len) const = 0;
peccu 0:b9ac53c439ed 208 virtual double *get_QD() const = 0;
peccu 0:b9ac53c439ed 209 virtual void swap_index(int i, int j) const // no so const...
peccu 0:b9ac53c439ed 210 {
peccu 0:b9ac53c439ed 211 swap(x[i],x[j]);
peccu 0:b9ac53c439ed 212 if(x_square) swap(x_square[i],x_square[j]);
peccu 0:b9ac53c439ed 213 }
peccu 0:b9ac53c439ed 214 protected:
peccu 0:b9ac53c439ed 215
peccu 0:b9ac53c439ed 216 double (Kernel::*kernel_function)(int i, int j) const;
peccu 0:b9ac53c439ed 217
peccu 0:b9ac53c439ed 218 private:
peccu 0:b9ac53c439ed 219 const svm_node **x;
peccu 0:b9ac53c439ed 220 double *x_square;
peccu 0:b9ac53c439ed 221
peccu 0:b9ac53c439ed 222 // svm_parameter
peccu 0:b9ac53c439ed 223 const int kernel_type;
peccu 0:b9ac53c439ed 224 const int degree;
peccu 0:b9ac53c439ed 225 const double gamma;
peccu 0:b9ac53c439ed 226 const double coef0;
peccu 0:b9ac53c439ed 227
peccu 0:b9ac53c439ed 228 static double dot(const svm_node *px, const svm_node *py);
peccu 0:b9ac53c439ed 229 double kernel_linear(int i, int j) const
peccu 0:b9ac53c439ed 230 {
peccu 0:b9ac53c439ed 231 return dot(x[i],x[j]);
peccu 0:b9ac53c439ed 232 }
peccu 0:b9ac53c439ed 233 double kernel_poly(int i, int j) const
peccu 0:b9ac53c439ed 234 {
peccu 0:b9ac53c439ed 235 return powi(gamma*dot(x[i],x[j])+coef0,degree);
peccu 0:b9ac53c439ed 236 }
peccu 0:b9ac53c439ed 237 double kernel_rbf(int i, int j) const
peccu 0:b9ac53c439ed 238 {
peccu 0:b9ac53c439ed 239 return exp(-gamma*(x_square[i]+x_square[j]-2*dot(x[i],x[j])));
peccu 0:b9ac53c439ed 240 }
peccu 0:b9ac53c439ed 241 double kernel_sigmoid(int i, int j) const
peccu 0:b9ac53c439ed 242 {
peccu 0:b9ac53c439ed 243 return tanh(gamma*dot(x[i],x[j])+coef0);
peccu 0:b9ac53c439ed 244 }
peccu 0:b9ac53c439ed 245 double kernel_precomputed(int i, int j) const
peccu 0:b9ac53c439ed 246 {
peccu 0:b9ac53c439ed 247 return x[i][(int)(x[j][0].value)].value;
peccu 0:b9ac53c439ed 248 }
peccu 0:b9ac53c439ed 249 };
peccu 0:b9ac53c439ed 250
peccu 0:b9ac53c439ed 251 Kernel::Kernel(int l, svm_node * const * x_, const svm_parameter& param)
peccu 0:b9ac53c439ed 252 :kernel_type(param.kernel_type), degree(param.degree),
peccu 0:b9ac53c439ed 253 gamma(param.gamma), coef0(param.coef0)
peccu 0:b9ac53c439ed 254 {
peccu 0:b9ac53c439ed 255 switch(kernel_type)
peccu 0:b9ac53c439ed 256 {
peccu 0:b9ac53c439ed 257 case LINEAR:
peccu 0:b9ac53c439ed 258 kernel_function = &Kernel::kernel_linear;
peccu 0:b9ac53c439ed 259 break;
peccu 0:b9ac53c439ed 260 case POLY:
peccu 0:b9ac53c439ed 261 kernel_function = &Kernel::kernel_poly;
peccu 0:b9ac53c439ed 262 break;
peccu 0:b9ac53c439ed 263 case RBF:
peccu 0:b9ac53c439ed 264 kernel_function = &Kernel::kernel_rbf;
peccu 0:b9ac53c439ed 265 break;
peccu 0:b9ac53c439ed 266 case SIGMOID:
peccu 0:b9ac53c439ed 267 kernel_function = &Kernel::kernel_sigmoid;
peccu 0:b9ac53c439ed 268 break;
peccu 0:b9ac53c439ed 269 case PRECOMPUTED:
peccu 0:b9ac53c439ed 270 kernel_function = &Kernel::kernel_precomputed;
peccu 0:b9ac53c439ed 271 break;
peccu 0:b9ac53c439ed 272 }
peccu 0:b9ac53c439ed 273
peccu 0:b9ac53c439ed 274 clone(x,x_,l);
peccu 0:b9ac53c439ed 275
peccu 0:b9ac53c439ed 276 if(kernel_type == RBF)
peccu 0:b9ac53c439ed 277 {
peccu 0:b9ac53c439ed 278 x_square = new double[l];
peccu 0:b9ac53c439ed 279 for(int i=0;i<l;i++)
peccu 0:b9ac53c439ed 280 x_square[i] = dot(x[i],x[i]);
peccu 0:b9ac53c439ed 281 }
peccu 0:b9ac53c439ed 282 else
peccu 0:b9ac53c439ed 283 x_square = 0;
peccu 0:b9ac53c439ed 284 }
peccu 0:b9ac53c439ed 285
peccu 0:b9ac53c439ed 286 Kernel::~Kernel()
peccu 0:b9ac53c439ed 287 {
peccu 0:b9ac53c439ed 288 delete[] x;
peccu 0:b9ac53c439ed 289 delete[] x_square;
peccu 0:b9ac53c439ed 290 }
peccu 0:b9ac53c439ed 291
peccu 0:b9ac53c439ed 292 double Kernel::dot(const svm_node *px, const svm_node *py)
peccu 0:b9ac53c439ed 293 {
peccu 0:b9ac53c439ed 294 double sum = 0;
peccu 0:b9ac53c439ed 295 while(px->index != -1 && py->index != -1)
peccu 0:b9ac53c439ed 296 {
peccu 0:b9ac53c439ed 297 if(px->index == py->index)
peccu 0:b9ac53c439ed 298 {
peccu 0:b9ac53c439ed 299 sum += px->value * py->value;
peccu 0:b9ac53c439ed 300 ++px;
peccu 0:b9ac53c439ed 301 ++py;
peccu 0:b9ac53c439ed 302 }
peccu 0:b9ac53c439ed 303 else
peccu 0:b9ac53c439ed 304 {
peccu 0:b9ac53c439ed 305 if(px->index > py->index)
peccu 0:b9ac53c439ed 306 ++py;
peccu 0:b9ac53c439ed 307 else
peccu 0:b9ac53c439ed 308 ++px;
peccu 0:b9ac53c439ed 309 }
peccu 0:b9ac53c439ed 310 }
peccu 0:b9ac53c439ed 311 return sum;
peccu 0:b9ac53c439ed 312 }
peccu 0:b9ac53c439ed 313
peccu 0:b9ac53c439ed 314 double Kernel::k_function(const svm_node *x, const svm_node *y,
peccu 0:b9ac53c439ed 315 const svm_parameter& param)
peccu 0:b9ac53c439ed 316 {
peccu 0:b9ac53c439ed 317 switch(param.kernel_type)
peccu 0:b9ac53c439ed 318 {
peccu 0:b9ac53c439ed 319 case LINEAR:
peccu 0:b9ac53c439ed 320 return dot(x,y);
peccu 0:b9ac53c439ed 321 case POLY:
peccu 0:b9ac53c439ed 322 return powi(param.gamma*dot(x,y)+param.coef0,param.degree);
peccu 0:b9ac53c439ed 323 case RBF:
peccu 0:b9ac53c439ed 324 {
peccu 0:b9ac53c439ed 325 double sum = 0;
peccu 0:b9ac53c439ed 326 while(x->index != -1 && y->index !=-1)
peccu 0:b9ac53c439ed 327 {
peccu 0:b9ac53c439ed 328 if(x->index == y->index)
peccu 0:b9ac53c439ed 329 {
peccu 0:b9ac53c439ed 330 double d = x->value - y->value;
peccu 0:b9ac53c439ed 331 sum += d*d;
peccu 0:b9ac53c439ed 332 ++x;
peccu 0:b9ac53c439ed 333 ++y;
peccu 0:b9ac53c439ed 334 }
peccu 0:b9ac53c439ed 335 else
peccu 0:b9ac53c439ed 336 {
peccu 0:b9ac53c439ed 337 if(x->index > y->index)
peccu 0:b9ac53c439ed 338 {
peccu 0:b9ac53c439ed 339 sum += y->value * y->value;
peccu 0:b9ac53c439ed 340 ++y;
peccu 0:b9ac53c439ed 341 }
peccu 0:b9ac53c439ed 342 else
peccu 0:b9ac53c439ed 343 {
peccu 0:b9ac53c439ed 344 sum += x->value * x->value;
peccu 0:b9ac53c439ed 345 ++x;
peccu 0:b9ac53c439ed 346 }
peccu 0:b9ac53c439ed 347 }
peccu 0:b9ac53c439ed 348 }
peccu 0:b9ac53c439ed 349
peccu 0:b9ac53c439ed 350 while(x->index != -1)
peccu 0:b9ac53c439ed 351 {
peccu 0:b9ac53c439ed 352 sum += x->value * x->value;
peccu 0:b9ac53c439ed 353 ++x;
peccu 0:b9ac53c439ed 354 }
peccu 0:b9ac53c439ed 355
peccu 0:b9ac53c439ed 356 while(y->index != -1)
peccu 0:b9ac53c439ed 357 {
peccu 0:b9ac53c439ed 358 sum += y->value * y->value;
peccu 0:b9ac53c439ed 359 ++y;
peccu 0:b9ac53c439ed 360 }
peccu 0:b9ac53c439ed 361
peccu 0:b9ac53c439ed 362 return exp(-param.gamma*sum);
peccu 0:b9ac53c439ed 363 }
peccu 0:b9ac53c439ed 364 case SIGMOID:
peccu 0:b9ac53c439ed 365 return tanh(param.gamma*dot(x,y)+param.coef0);
peccu 0:b9ac53c439ed 366 case PRECOMPUTED: //x: test (validation), y: SV
peccu 0:b9ac53c439ed 367 return x[(int)(y->value)].value;
peccu 0:b9ac53c439ed 368 default:
peccu 0:b9ac53c439ed 369 return 0; // Unreachable
peccu 0:b9ac53c439ed 370 }
peccu 0:b9ac53c439ed 371 }
peccu 0:b9ac53c439ed 372
peccu 0:b9ac53c439ed 373 // An SMO algorithm in Fan et al., JMLR 6(2005), p. 1889--1918
peccu 0:b9ac53c439ed 374 // Solves:
peccu 0:b9ac53c439ed 375 //
peccu 0:b9ac53c439ed 376 // min 0.5(\alpha^T Q \alpha) + p^T \alpha
peccu 0:b9ac53c439ed 377 //
peccu 0:b9ac53c439ed 378 // y^T \alpha = \delta
peccu 0:b9ac53c439ed 379 // y_i = +1 or -1
peccu 0:b9ac53c439ed 380 // 0 <= alpha_i <= Cp for y_i = 1
peccu 0:b9ac53c439ed 381 // 0 <= alpha_i <= Cn for y_i = -1
peccu 0:b9ac53c439ed 382 //
peccu 0:b9ac53c439ed 383 // Given:
peccu 0:b9ac53c439ed 384 //
peccu 0:b9ac53c439ed 385 // Q, p, y, Cp, Cn, and an initial feasible point \alpha
peccu 0:b9ac53c439ed 386 // l is the size of vectors and matrices
peccu 0:b9ac53c439ed 387 // eps is the stopping tolerance
peccu 0:b9ac53c439ed 388 //
peccu 0:b9ac53c439ed 389 // solution will be put in \alpha, objective value will be put in obj
peccu 0:b9ac53c439ed 390 //
peccu 0:b9ac53c439ed 391 class Solver {
peccu 0:b9ac53c439ed 392 public:
peccu 0:b9ac53c439ed 393 Solver() {};
peccu 0:b9ac53c439ed 394 virtual ~Solver() {};
peccu 0:b9ac53c439ed 395
peccu 0:b9ac53c439ed 396 struct SolutionInfo {
peccu 0:b9ac53c439ed 397 double obj;
peccu 0:b9ac53c439ed 398 double rho;
peccu 0:b9ac53c439ed 399 double upper_bound_p;
peccu 0:b9ac53c439ed 400 double upper_bound_n;
peccu 0:b9ac53c439ed 401 double r; // for Solver_NU
peccu 0:b9ac53c439ed 402 };
peccu 0:b9ac53c439ed 403
peccu 0:b9ac53c439ed 404 void Solve(int l, const QMatrix& Q, const double *p_, const schar *y_,
peccu 0:b9ac53c439ed 405 double *alpha_, double Cp, double Cn, double eps,
peccu 0:b9ac53c439ed 406 SolutionInfo* si, int shrinking);
peccu 0:b9ac53c439ed 407 protected:
peccu 0:b9ac53c439ed 408 int active_size;
peccu 0:b9ac53c439ed 409 schar *y;
peccu 0:b9ac53c439ed 410 double *G; // gradient of objective function
peccu 0:b9ac53c439ed 411 enum { LOWER_BOUND, UPPER_BOUND, FREE };
peccu 0:b9ac53c439ed 412 char *alpha_status; // LOWER_BOUND, UPPER_BOUND, FREE
peccu 0:b9ac53c439ed 413 double *alpha;
peccu 0:b9ac53c439ed 414 const QMatrix *Q;
peccu 0:b9ac53c439ed 415 const double *QD;
peccu 0:b9ac53c439ed 416 double eps;
peccu 0:b9ac53c439ed 417 double Cp,Cn;
peccu 0:b9ac53c439ed 418 double *p;
peccu 0:b9ac53c439ed 419 int *active_set;
peccu 0:b9ac53c439ed 420 double *G_bar; // gradient, if we treat free variables as 0
peccu 0:b9ac53c439ed 421 int l;
peccu 0:b9ac53c439ed 422 bool unshrink; // XXX
peccu 0:b9ac53c439ed 423
peccu 0:b9ac53c439ed 424 double get_C(int i)
peccu 0:b9ac53c439ed 425 {
peccu 0:b9ac53c439ed 426 return (y[i] > 0)? Cp : Cn;
peccu 0:b9ac53c439ed 427 }
peccu 0:b9ac53c439ed 428 void update_alpha_status(int i)
peccu 0:b9ac53c439ed 429 {
peccu 0:b9ac53c439ed 430 if(alpha[i] >= get_C(i))
peccu 0:b9ac53c439ed 431 alpha_status[i] = UPPER_BOUND;
peccu 0:b9ac53c439ed 432 else if(alpha[i] <= 0)
peccu 0:b9ac53c439ed 433 alpha_status[i] = LOWER_BOUND;
peccu 0:b9ac53c439ed 434 else alpha_status[i] = FREE;
peccu 0:b9ac53c439ed 435 }
peccu 0:b9ac53c439ed 436 bool is_upper_bound(int i) { return alpha_status[i] == UPPER_BOUND; }
peccu 0:b9ac53c439ed 437 bool is_lower_bound(int i) { return alpha_status[i] == LOWER_BOUND; }
peccu 0:b9ac53c439ed 438 bool is_free(int i) { return alpha_status[i] == FREE; }
peccu 0:b9ac53c439ed 439 void swap_index(int i, int j);
peccu 0:b9ac53c439ed 440 void reconstruct_gradient();
peccu 0:b9ac53c439ed 441 virtual int select_working_set(int &i, int &j);
peccu 0:b9ac53c439ed 442 virtual double calculate_rho();
peccu 0:b9ac53c439ed 443 virtual void do_shrinking();
peccu 0:b9ac53c439ed 444 private:
peccu 0:b9ac53c439ed 445 bool be_shrunk(int i, double Gmax1, double Gmax2);
peccu 0:b9ac53c439ed 446 };
peccu 0:b9ac53c439ed 447
peccu 0:b9ac53c439ed 448 void Solver::swap_index(int i, int j)
peccu 0:b9ac53c439ed 449 {
peccu 0:b9ac53c439ed 450 Q->swap_index(i,j);
peccu 0:b9ac53c439ed 451 swap(y[i],y[j]);
peccu 0:b9ac53c439ed 452 swap(G[i],G[j]);
peccu 0:b9ac53c439ed 453 swap(alpha_status[i],alpha_status[j]);
peccu 0:b9ac53c439ed 454 swap(alpha[i],alpha[j]);
peccu 0:b9ac53c439ed 455 swap(p[i],p[j]);
peccu 0:b9ac53c439ed 456 swap(active_set[i],active_set[j]);
peccu 0:b9ac53c439ed 457 swap(G_bar[i],G_bar[j]);
peccu 0:b9ac53c439ed 458 }
peccu 0:b9ac53c439ed 459
peccu 0:b9ac53c439ed 460 void Solver::reconstruct_gradient()
peccu 0:b9ac53c439ed 461 {
peccu 0:b9ac53c439ed 462 // reconstruct inactive elements of G from G_bar and free variables
peccu 0:b9ac53c439ed 463
peccu 0:b9ac53c439ed 464 if(active_size == l) return;
peccu 0:b9ac53c439ed 465
peccu 0:b9ac53c439ed 466 int i,j;
peccu 0:b9ac53c439ed 467 int nr_free = 0;
peccu 0:b9ac53c439ed 468
peccu 0:b9ac53c439ed 469 for(j=active_size;j<l;j++)
peccu 0:b9ac53c439ed 470 G[j] = G_bar[j] + p[j];
peccu 0:b9ac53c439ed 471
peccu 0:b9ac53c439ed 472 for(j=0;j<active_size;j++)
peccu 0:b9ac53c439ed 473 if(is_free(j))
peccu 0:b9ac53c439ed 474 nr_free++;
peccu 0:b9ac53c439ed 475
peccu 0:b9ac53c439ed 476 if(2*nr_free < active_size)
peccu 0:b9ac53c439ed 477 info("\nWarning: using -h 0 may be faster\n");
peccu 0:b9ac53c439ed 478
peccu 0:b9ac53c439ed 479 if (nr_free*l > 2*active_size*(l-active_size))
peccu 0:b9ac53c439ed 480 {
peccu 0:b9ac53c439ed 481 for(i=active_size;i<l;i++)
peccu 0:b9ac53c439ed 482 {
peccu 0:b9ac53c439ed 483 const Qfloat *Q_i = Q->get_Q(i,active_size);
peccu 0:b9ac53c439ed 484 for(j=0;j<active_size;j++)
peccu 0:b9ac53c439ed 485 if(is_free(j))
peccu 0:b9ac53c439ed 486 G[i] += alpha[j] * Q_i[j];
peccu 0:b9ac53c439ed 487 }
peccu 0:b9ac53c439ed 488 }
peccu 0:b9ac53c439ed 489 else
peccu 0:b9ac53c439ed 490 {
peccu 0:b9ac53c439ed 491 for(i=0;i<active_size;i++)
peccu 0:b9ac53c439ed 492 if(is_free(i))
peccu 0:b9ac53c439ed 493 {
peccu 0:b9ac53c439ed 494 const Qfloat *Q_i = Q->get_Q(i,l);
peccu 0:b9ac53c439ed 495 double alpha_i = alpha[i];
peccu 0:b9ac53c439ed 496 for(j=active_size;j<l;j++)
peccu 0:b9ac53c439ed 497 G[j] += alpha_i * Q_i[j];
peccu 0:b9ac53c439ed 498 }
peccu 0:b9ac53c439ed 499 }
peccu 0:b9ac53c439ed 500 }
peccu 0:b9ac53c439ed 501
peccu 0:b9ac53c439ed 502 void Solver::Solve(int l, const QMatrix& Q, const double *p_, const schar *y_,
peccu 0:b9ac53c439ed 503 double *alpha_, double Cp, double Cn, double eps,
peccu 0:b9ac53c439ed 504 SolutionInfo* si, int shrinking)
peccu 0:b9ac53c439ed 505 {
peccu 0:b9ac53c439ed 506 this->l = l;
peccu 0:b9ac53c439ed 507 this->Q = &Q;
peccu 0:b9ac53c439ed 508 QD=Q.get_QD();
peccu 0:b9ac53c439ed 509 clone(p, p_,l);
peccu 0:b9ac53c439ed 510 clone(y, y_,l);
peccu 0:b9ac53c439ed 511 clone(alpha,alpha_,l);
peccu 0:b9ac53c439ed 512 this->Cp = Cp;
peccu 0:b9ac53c439ed 513 this->Cn = Cn;
peccu 0:b9ac53c439ed 514 this->eps = eps;
peccu 0:b9ac53c439ed 515 unshrink = false;
peccu 0:b9ac53c439ed 516
peccu 0:b9ac53c439ed 517 // initialize alpha_status
peccu 0:b9ac53c439ed 518 {
peccu 0:b9ac53c439ed 519 alpha_status = new char[l];
peccu 0:b9ac53c439ed 520 for(int i=0;i<l;i++)
peccu 0:b9ac53c439ed 521 update_alpha_status(i);
peccu 0:b9ac53c439ed 522 }
peccu 0:b9ac53c439ed 523
peccu 0:b9ac53c439ed 524 // initialize active set (for shrinking)
peccu 0:b9ac53c439ed 525 {
peccu 0:b9ac53c439ed 526 active_set = new int[l];
peccu 0:b9ac53c439ed 527 for(int i=0;i<l;i++)
peccu 0:b9ac53c439ed 528 active_set[i] = i;
peccu 0:b9ac53c439ed 529 active_size = l;
peccu 0:b9ac53c439ed 530 }
peccu 0:b9ac53c439ed 531
peccu 0:b9ac53c439ed 532 // initialize gradient
peccu 0:b9ac53c439ed 533 {
peccu 0:b9ac53c439ed 534 G = new double[l];
peccu 0:b9ac53c439ed 535 G_bar = new double[l];
peccu 0:b9ac53c439ed 536 int i;
peccu 0:b9ac53c439ed 537 for(i=0;i<l;i++)
peccu 0:b9ac53c439ed 538 {
peccu 0:b9ac53c439ed 539 G[i] = p[i];
peccu 0:b9ac53c439ed 540 G_bar[i] = 0;
peccu 0:b9ac53c439ed 541 }
peccu 0:b9ac53c439ed 542 for(i=0;i<l;i++)
peccu 0:b9ac53c439ed 543 if(!is_lower_bound(i))
peccu 0:b9ac53c439ed 544 {
peccu 0:b9ac53c439ed 545 const Qfloat *Q_i = Q.get_Q(i,l);
peccu 0:b9ac53c439ed 546 double alpha_i = alpha[i];
peccu 0:b9ac53c439ed 547 int j;
peccu 0:b9ac53c439ed 548 for(j=0;j<l;j++)
peccu 0:b9ac53c439ed 549 G[j] += alpha_i*Q_i[j];
peccu 0:b9ac53c439ed 550 if(is_upper_bound(i))
peccu 0:b9ac53c439ed 551 for(j=0;j<l;j++)
peccu 0:b9ac53c439ed 552 G_bar[j] += get_C(i) * Q_i[j];
peccu 0:b9ac53c439ed 553 }
peccu 0:b9ac53c439ed 554 }
peccu 0:b9ac53c439ed 555
peccu 0:b9ac53c439ed 556 // optimization step
peccu 0:b9ac53c439ed 557
peccu 0:b9ac53c439ed 558 int iter = 0;
peccu 0:b9ac53c439ed 559 int counter = min(l,1000)+1;
peccu 0:b9ac53c439ed 560
peccu 0:b9ac53c439ed 561 while(1)
peccu 0:b9ac53c439ed 562 {
peccu 0:b9ac53c439ed 563 // show progress and do shrinking
peccu 0:b9ac53c439ed 564
peccu 0:b9ac53c439ed 565 if(--counter == 0)
peccu 0:b9ac53c439ed 566 {
peccu 0:b9ac53c439ed 567 counter = min(l,1000);
peccu 0:b9ac53c439ed 568 if(shrinking) do_shrinking();
peccu 0:b9ac53c439ed 569 info(".");
peccu 0:b9ac53c439ed 570 }
peccu 0:b9ac53c439ed 571
peccu 0:b9ac53c439ed 572 int i,j;
peccu 0:b9ac53c439ed 573 if(select_working_set(i,j)!=0)
peccu 0:b9ac53c439ed 574 {
peccu 0:b9ac53c439ed 575 // reconstruct the whole gradient
peccu 0:b9ac53c439ed 576 reconstruct_gradient();
peccu 0:b9ac53c439ed 577 // reset active set size and check
peccu 0:b9ac53c439ed 578 active_size = l;
peccu 0:b9ac53c439ed 579 info("*");
peccu 0:b9ac53c439ed 580 if(select_working_set(i,j)!=0)
peccu 0:b9ac53c439ed 581 break;
peccu 0:b9ac53c439ed 582 else
peccu 0:b9ac53c439ed 583 counter = 1; // do shrinking next iteration
peccu 0:b9ac53c439ed 584 }
peccu 0:b9ac53c439ed 585
peccu 0:b9ac53c439ed 586 ++iter;
peccu 0:b9ac53c439ed 587
peccu 0:b9ac53c439ed 588 // update alpha[i] and alpha[j], handle bounds carefully
peccu 0:b9ac53c439ed 589
peccu 0:b9ac53c439ed 590 const Qfloat *Q_i = Q.get_Q(i,active_size);
peccu 0:b9ac53c439ed 591 const Qfloat *Q_j = Q.get_Q(j,active_size);
peccu 0:b9ac53c439ed 592
peccu 0:b9ac53c439ed 593 double C_i = get_C(i);
peccu 0:b9ac53c439ed 594 double C_j = get_C(j);
peccu 0:b9ac53c439ed 595
peccu 0:b9ac53c439ed 596 double old_alpha_i = alpha[i];
peccu 0:b9ac53c439ed 597 double old_alpha_j = alpha[j];
peccu 0:b9ac53c439ed 598
peccu 0:b9ac53c439ed 599 if(y[i]!=y[j])
peccu 0:b9ac53c439ed 600 {
peccu 0:b9ac53c439ed 601 double quad_coef = QD[i]+QD[j]+2*Q_i[j];
peccu 0:b9ac53c439ed 602 if (quad_coef <= 0)
peccu 0:b9ac53c439ed 603 quad_coef = TAU;
peccu 0:b9ac53c439ed 604 double delta = (-G[i]-G[j])/quad_coef;
peccu 0:b9ac53c439ed 605 double diff = alpha[i] - alpha[j];
peccu 0:b9ac53c439ed 606 alpha[i] += delta;
peccu 0:b9ac53c439ed 607 alpha[j] += delta;
peccu 0:b9ac53c439ed 608
peccu 0:b9ac53c439ed 609 if(diff > 0)
peccu 0:b9ac53c439ed 610 {
peccu 0:b9ac53c439ed 611 if(alpha[j] < 0)
peccu 0:b9ac53c439ed 612 {
peccu 0:b9ac53c439ed 613 alpha[j] = 0;
peccu 0:b9ac53c439ed 614 alpha[i] = diff;
peccu 0:b9ac53c439ed 615 }
peccu 0:b9ac53c439ed 616 }
peccu 0:b9ac53c439ed 617 else
peccu 0:b9ac53c439ed 618 {
peccu 0:b9ac53c439ed 619 if(alpha[i] < 0)
peccu 0:b9ac53c439ed 620 {
peccu 0:b9ac53c439ed 621 alpha[i] = 0;
peccu 0:b9ac53c439ed 622 alpha[j] = -diff;
peccu 0:b9ac53c439ed 623 }
peccu 0:b9ac53c439ed 624 }
peccu 0:b9ac53c439ed 625 if(diff > C_i - C_j)
peccu 0:b9ac53c439ed 626 {
peccu 0:b9ac53c439ed 627 if(alpha[i] > C_i)
peccu 0:b9ac53c439ed 628 {
peccu 0:b9ac53c439ed 629 alpha[i] = C_i;
peccu 0:b9ac53c439ed 630 alpha[j] = C_i - diff;
peccu 0:b9ac53c439ed 631 }
peccu 0:b9ac53c439ed 632 }
peccu 0:b9ac53c439ed 633 else
peccu 0:b9ac53c439ed 634 {
peccu 0:b9ac53c439ed 635 if(alpha[j] > C_j)
peccu 0:b9ac53c439ed 636 {
peccu 0:b9ac53c439ed 637 alpha[j] = C_j;
peccu 0:b9ac53c439ed 638 alpha[i] = C_j + diff;
peccu 0:b9ac53c439ed 639 }
peccu 0:b9ac53c439ed 640 }
peccu 0:b9ac53c439ed 641 }
peccu 0:b9ac53c439ed 642 else
peccu 0:b9ac53c439ed 643 {
peccu 0:b9ac53c439ed 644 double quad_coef = QD[i]+QD[j]-2*Q_i[j];
peccu 0:b9ac53c439ed 645 if (quad_coef <= 0)
peccu 0:b9ac53c439ed 646 quad_coef = TAU;
peccu 0:b9ac53c439ed 647 double delta = (G[i]-G[j])/quad_coef;
peccu 0:b9ac53c439ed 648 double sum = alpha[i] + alpha[j];
peccu 0:b9ac53c439ed 649 alpha[i] -= delta;
peccu 0:b9ac53c439ed 650 alpha[j] += delta;
peccu 0:b9ac53c439ed 651
peccu 0:b9ac53c439ed 652 if(sum > C_i)
peccu 0:b9ac53c439ed 653 {
peccu 0:b9ac53c439ed 654 if(alpha[i] > C_i)
peccu 0:b9ac53c439ed 655 {
peccu 0:b9ac53c439ed 656 alpha[i] = C_i;
peccu 0:b9ac53c439ed 657 alpha[j] = sum - C_i;
peccu 0:b9ac53c439ed 658 }
peccu 0:b9ac53c439ed 659 }
peccu 0:b9ac53c439ed 660 else
peccu 0:b9ac53c439ed 661 {
peccu 0:b9ac53c439ed 662 if(alpha[j] < 0)
peccu 0:b9ac53c439ed 663 {
peccu 0:b9ac53c439ed 664 alpha[j] = 0;
peccu 0:b9ac53c439ed 665 alpha[i] = sum;
peccu 0:b9ac53c439ed 666 }
peccu 0:b9ac53c439ed 667 }
peccu 0:b9ac53c439ed 668 if(sum > C_j)
peccu 0:b9ac53c439ed 669 {
peccu 0:b9ac53c439ed 670 if(alpha[j] > C_j)
peccu 0:b9ac53c439ed 671 {
peccu 0:b9ac53c439ed 672 alpha[j] = C_j;
peccu 0:b9ac53c439ed 673 alpha[i] = sum - C_j;
peccu 0:b9ac53c439ed 674 }
peccu 0:b9ac53c439ed 675 }
peccu 0:b9ac53c439ed 676 else
peccu 0:b9ac53c439ed 677 {
peccu 0:b9ac53c439ed 678 if(alpha[i] < 0)
peccu 0:b9ac53c439ed 679 {
peccu 0:b9ac53c439ed 680 alpha[i] = 0;
peccu 0:b9ac53c439ed 681 alpha[j] = sum;
peccu 0:b9ac53c439ed 682 }
peccu 0:b9ac53c439ed 683 }
peccu 0:b9ac53c439ed 684 }
peccu 0:b9ac53c439ed 685
peccu 0:b9ac53c439ed 686 // update G
peccu 0:b9ac53c439ed 687
peccu 0:b9ac53c439ed 688 double delta_alpha_i = alpha[i] - old_alpha_i;
peccu 0:b9ac53c439ed 689 double delta_alpha_j = alpha[j] - old_alpha_j;
peccu 0:b9ac53c439ed 690
peccu 0:b9ac53c439ed 691 for(int k=0;k<active_size;k++)
peccu 0:b9ac53c439ed 692 {
peccu 0:b9ac53c439ed 693 G[k] += Q_i[k]*delta_alpha_i + Q_j[k]*delta_alpha_j;
peccu 0:b9ac53c439ed 694 }
peccu 0:b9ac53c439ed 695
peccu 0:b9ac53c439ed 696 // update alpha_status and G_bar
peccu 0:b9ac53c439ed 697
peccu 0:b9ac53c439ed 698 {
peccu 0:b9ac53c439ed 699 bool ui = is_upper_bound(i);
peccu 0:b9ac53c439ed 700 bool uj = is_upper_bound(j);
peccu 0:b9ac53c439ed 701 update_alpha_status(i);
peccu 0:b9ac53c439ed 702 update_alpha_status(j);
peccu 0:b9ac53c439ed 703 int k;
peccu 0:b9ac53c439ed 704 if(ui != is_upper_bound(i))
peccu 0:b9ac53c439ed 705 {
peccu 0:b9ac53c439ed 706 Q_i = Q.get_Q(i,l);
peccu 0:b9ac53c439ed 707 if(ui)
peccu 0:b9ac53c439ed 708 for(k=0;k<l;k++)
peccu 0:b9ac53c439ed 709 G_bar[k] -= C_i * Q_i[k];
peccu 0:b9ac53c439ed 710 else
peccu 0:b9ac53c439ed 711 for(k=0;k<l;k++)
peccu 0:b9ac53c439ed 712 G_bar[k] += C_i * Q_i[k];
peccu 0:b9ac53c439ed 713 }
peccu 0:b9ac53c439ed 714
peccu 0:b9ac53c439ed 715 if(uj != is_upper_bound(j))
peccu 0:b9ac53c439ed 716 {
peccu 0:b9ac53c439ed 717 Q_j = Q.get_Q(j,l);
peccu 0:b9ac53c439ed 718 if(uj)
peccu 0:b9ac53c439ed 719 for(k=0;k<l;k++)
peccu 0:b9ac53c439ed 720 G_bar[k] -= C_j * Q_j[k];
peccu 0:b9ac53c439ed 721 else
peccu 0:b9ac53c439ed 722 for(k=0;k<l;k++)
peccu 0:b9ac53c439ed 723 G_bar[k] += C_j * Q_j[k];
peccu 0:b9ac53c439ed 724 }
peccu 0:b9ac53c439ed 725 }
peccu 0:b9ac53c439ed 726 }
peccu 0:b9ac53c439ed 727
peccu 0:b9ac53c439ed 728 // calculate rho
peccu 0:b9ac53c439ed 729
peccu 0:b9ac53c439ed 730 si->rho = calculate_rho();
peccu 0:b9ac53c439ed 731
peccu 0:b9ac53c439ed 732 // calculate objective value
peccu 0:b9ac53c439ed 733 {
peccu 0:b9ac53c439ed 734 double v = 0;
peccu 0:b9ac53c439ed 735 int i;
peccu 0:b9ac53c439ed 736 for(i=0;i<l;i++)
peccu 0:b9ac53c439ed 737 v += alpha[i] * (G[i] + p[i]);
peccu 0:b9ac53c439ed 738
peccu 0:b9ac53c439ed 739 si->obj = v/2;
peccu 0:b9ac53c439ed 740 }
peccu 0:b9ac53c439ed 741
peccu 0:b9ac53c439ed 742 // put back the solution
peccu 0:b9ac53c439ed 743 {
peccu 0:b9ac53c439ed 744 for(int i=0;i<l;i++)
peccu 0:b9ac53c439ed 745 alpha_[active_set[i]] = alpha[i];
peccu 0:b9ac53c439ed 746 }
peccu 0:b9ac53c439ed 747
peccu 0:b9ac53c439ed 748 // juggle everything back
peccu 0:b9ac53c439ed 749 /*{
peccu 0:b9ac53c439ed 750 for(int i=0;i<l;i++)
peccu 0:b9ac53c439ed 751 while(active_set[i] != i)
peccu 0:b9ac53c439ed 752 swap_index(i,active_set[i]);
peccu 0:b9ac53c439ed 753 // or Q.swap_index(i,active_set[i]);
peccu 0:b9ac53c439ed 754 }*/
peccu 0:b9ac53c439ed 755
peccu 0:b9ac53c439ed 756 si->upper_bound_p = Cp;
peccu 0:b9ac53c439ed 757 si->upper_bound_n = Cn;
peccu 0:b9ac53c439ed 758
peccu 0:b9ac53c439ed 759 info("\noptimization finished, #iter = %d\n",iter);
peccu 0:b9ac53c439ed 760
peccu 0:b9ac53c439ed 761 delete[] p;
peccu 0:b9ac53c439ed 762 delete[] y;
peccu 0:b9ac53c439ed 763 delete[] alpha;
peccu 0:b9ac53c439ed 764 delete[] alpha_status;
peccu 0:b9ac53c439ed 765 delete[] active_set;
peccu 0:b9ac53c439ed 766 delete[] G;
peccu 0:b9ac53c439ed 767 delete[] G_bar;
peccu 0:b9ac53c439ed 768 }
peccu 0:b9ac53c439ed 769
peccu 0:b9ac53c439ed 770 // return 1 if already optimal, return 0 otherwise
peccu 0:b9ac53c439ed 771 int Solver::select_working_set(int &out_i, int &out_j)
peccu 0:b9ac53c439ed 772 {
peccu 0:b9ac53c439ed 773 // return i,j such that
peccu 0:b9ac53c439ed 774 // i: maximizes -y_i * grad(f)_i, i in I_up(\alpha)
peccu 0:b9ac53c439ed 775 // j: minimizes the decrease of obj value
peccu 0:b9ac53c439ed 776 // (if quadratic coefficeint <= 0, replace it with tau)
peccu 0:b9ac53c439ed 777 // -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha)
peccu 0:b9ac53c439ed 778
peccu 0:b9ac53c439ed 779 double Gmax = -INF;
peccu 0:b9ac53c439ed 780 double Gmax2 = -INF;
peccu 0:b9ac53c439ed 781 int Gmax_idx = -1;
peccu 0:b9ac53c439ed 782 int Gmin_idx = -1;
peccu 0:b9ac53c439ed 783 double obj_diff_min = INF;
peccu 0:b9ac53c439ed 784
peccu 0:b9ac53c439ed 785 for(int t=0;t<active_size;t++)
peccu 0:b9ac53c439ed 786 if(y[t]==+1)
peccu 0:b9ac53c439ed 787 {
peccu 0:b9ac53c439ed 788 if(!is_upper_bound(t))
peccu 0:b9ac53c439ed 789 if(-G[t] >= Gmax)
peccu 0:b9ac53c439ed 790 {
peccu 0:b9ac53c439ed 791 Gmax = -G[t];
peccu 0:b9ac53c439ed 792 Gmax_idx = t;
peccu 0:b9ac53c439ed 793 }
peccu 0:b9ac53c439ed 794 }
peccu 0:b9ac53c439ed 795 else
peccu 0:b9ac53c439ed 796 {
peccu 0:b9ac53c439ed 797 if(!is_lower_bound(t))
peccu 0:b9ac53c439ed 798 if(G[t] >= Gmax)
peccu 0:b9ac53c439ed 799 {
peccu 0:b9ac53c439ed 800 Gmax = G[t];
peccu 0:b9ac53c439ed 801 Gmax_idx = t;
peccu 0:b9ac53c439ed 802 }
peccu 0:b9ac53c439ed 803 }
peccu 0:b9ac53c439ed 804
peccu 0:b9ac53c439ed 805 int i = Gmax_idx;
peccu 0:b9ac53c439ed 806 const Qfloat *Q_i = NULL;
peccu 0:b9ac53c439ed 807 if(i != -1) // NULL Q_i not accessed: Gmax=-INF if i=-1
peccu 0:b9ac53c439ed 808 Q_i = Q->get_Q(i,active_size);
peccu 0:b9ac53c439ed 809
peccu 0:b9ac53c439ed 810 for(int j=0;j<active_size;j++)
peccu 0:b9ac53c439ed 811 {
peccu 0:b9ac53c439ed 812 if(y[j]==+1)
peccu 0:b9ac53c439ed 813 {
peccu 0:b9ac53c439ed 814 if (!is_lower_bound(j))
peccu 0:b9ac53c439ed 815 {
peccu 0:b9ac53c439ed 816 double grad_diff=Gmax+G[j];
peccu 0:b9ac53c439ed 817 if (G[j] >= Gmax2)
peccu 0:b9ac53c439ed 818 Gmax2 = G[j];
peccu 0:b9ac53c439ed 819 if (grad_diff > 0)
peccu 0:b9ac53c439ed 820 {
peccu 0:b9ac53c439ed 821 double obj_diff;
peccu 0:b9ac53c439ed 822 double quad_coef = QD[i]+QD[j]-2.0*y[i]*Q_i[j];
peccu 0:b9ac53c439ed 823 if (quad_coef > 0)
peccu 0:b9ac53c439ed 824 obj_diff = -(grad_diff*grad_diff)/quad_coef;
peccu 0:b9ac53c439ed 825 else
peccu 0:b9ac53c439ed 826 obj_diff = -(grad_diff*grad_diff)/TAU;
peccu 0:b9ac53c439ed 827
peccu 0:b9ac53c439ed 828 if (obj_diff <= obj_diff_min)
peccu 0:b9ac53c439ed 829 {
peccu 0:b9ac53c439ed 830 Gmin_idx=j;
peccu 0:b9ac53c439ed 831 obj_diff_min = obj_diff;
peccu 0:b9ac53c439ed 832 }
peccu 0:b9ac53c439ed 833 }
peccu 0:b9ac53c439ed 834 }
peccu 0:b9ac53c439ed 835 }
peccu 0:b9ac53c439ed 836 else
peccu 0:b9ac53c439ed 837 {
peccu 0:b9ac53c439ed 838 if (!is_upper_bound(j))
peccu 0:b9ac53c439ed 839 {
peccu 0:b9ac53c439ed 840 double grad_diff= Gmax-G[j];
peccu 0:b9ac53c439ed 841 if (-G[j] >= Gmax2)
peccu 0:b9ac53c439ed 842 Gmax2 = -G[j];
peccu 0:b9ac53c439ed 843 if (grad_diff > 0)
peccu 0:b9ac53c439ed 844 {
peccu 0:b9ac53c439ed 845 double obj_diff;
peccu 0:b9ac53c439ed 846 double quad_coef = QD[i]+QD[j]+2.0*y[i]*Q_i[j];
peccu 0:b9ac53c439ed 847 if (quad_coef > 0)
peccu 0:b9ac53c439ed 848 obj_diff = -(grad_diff*grad_diff)/quad_coef;
peccu 0:b9ac53c439ed 849 else
peccu 0:b9ac53c439ed 850 obj_diff = -(grad_diff*grad_diff)/TAU;
peccu 0:b9ac53c439ed 851
peccu 0:b9ac53c439ed 852 if (obj_diff <= obj_diff_min)
peccu 0:b9ac53c439ed 853 {
peccu 0:b9ac53c439ed 854 Gmin_idx=j;
peccu 0:b9ac53c439ed 855 obj_diff_min = obj_diff;
peccu 0:b9ac53c439ed 856 }
peccu 0:b9ac53c439ed 857 }
peccu 0:b9ac53c439ed 858 }
peccu 0:b9ac53c439ed 859 }
peccu 0:b9ac53c439ed 860 }
peccu 0:b9ac53c439ed 861
peccu 0:b9ac53c439ed 862 if(Gmax+Gmax2 < eps)
peccu 0:b9ac53c439ed 863 return 1;
peccu 0:b9ac53c439ed 864
peccu 0:b9ac53c439ed 865 out_i = Gmax_idx;
peccu 0:b9ac53c439ed 866 out_j = Gmin_idx;
peccu 0:b9ac53c439ed 867 return 0;
peccu 0:b9ac53c439ed 868 }
peccu 0:b9ac53c439ed 869
peccu 0:b9ac53c439ed 870 bool Solver::be_shrunk(int i, double Gmax1, double Gmax2)
peccu 0:b9ac53c439ed 871 {
peccu 0:b9ac53c439ed 872 if(is_upper_bound(i))
peccu 0:b9ac53c439ed 873 {
peccu 0:b9ac53c439ed 874 if(y[i]==+1)
peccu 0:b9ac53c439ed 875 return(-G[i] > Gmax1);
peccu 0:b9ac53c439ed 876 else
peccu 0:b9ac53c439ed 877 return(-G[i] > Gmax2);
peccu 0:b9ac53c439ed 878 }
peccu 0:b9ac53c439ed 879 else if(is_lower_bound(i))
peccu 0:b9ac53c439ed 880 {
peccu 0:b9ac53c439ed 881 if(y[i]==+1)
peccu 0:b9ac53c439ed 882 return(G[i] > Gmax2);
peccu 0:b9ac53c439ed 883 else
peccu 0:b9ac53c439ed 884 return(G[i] > Gmax1);
peccu 0:b9ac53c439ed 885 }
peccu 0:b9ac53c439ed 886 else
peccu 0:b9ac53c439ed 887 return(false);
peccu 0:b9ac53c439ed 888 }
peccu 0:b9ac53c439ed 889
peccu 0:b9ac53c439ed 890 void Solver::do_shrinking()
peccu 0:b9ac53c439ed 891 {
peccu 0:b9ac53c439ed 892 int i;
peccu 0:b9ac53c439ed 893 double Gmax1 = -INF; // max { -y_i * grad(f)_i | i in I_up(\alpha) }
peccu 0:b9ac53c439ed 894 double Gmax2 = -INF; // max { y_i * grad(f)_i | i in I_low(\alpha) }
peccu 0:b9ac53c439ed 895
peccu 0:b9ac53c439ed 896 // find maximal violating pair first
peccu 0:b9ac53c439ed 897 for(i=0;i<active_size;i++)
peccu 0:b9ac53c439ed 898 {
peccu 0:b9ac53c439ed 899 if(y[i]==+1)
peccu 0:b9ac53c439ed 900 {
peccu 0:b9ac53c439ed 901 if(!is_upper_bound(i))
peccu 0:b9ac53c439ed 902 {
peccu 0:b9ac53c439ed 903 if(-G[i] >= Gmax1)
peccu 0:b9ac53c439ed 904 Gmax1 = -G[i];
peccu 0:b9ac53c439ed 905 }
peccu 0:b9ac53c439ed 906 if(!is_lower_bound(i))
peccu 0:b9ac53c439ed 907 {
peccu 0:b9ac53c439ed 908 if(G[i] >= Gmax2)
peccu 0:b9ac53c439ed 909 Gmax2 = G[i];
peccu 0:b9ac53c439ed 910 }
peccu 0:b9ac53c439ed 911 }
peccu 0:b9ac53c439ed 912 else
peccu 0:b9ac53c439ed 913 {
peccu 0:b9ac53c439ed 914 if(!is_upper_bound(i))
peccu 0:b9ac53c439ed 915 {
peccu 0:b9ac53c439ed 916 if(-G[i] >= Gmax2)
peccu 0:b9ac53c439ed 917 Gmax2 = -G[i];
peccu 0:b9ac53c439ed 918 }
peccu 0:b9ac53c439ed 919 if(!is_lower_bound(i))
peccu 0:b9ac53c439ed 920 {
peccu 0:b9ac53c439ed 921 if(G[i] >= Gmax1)
peccu 0:b9ac53c439ed 922 Gmax1 = G[i];
peccu 0:b9ac53c439ed 923 }
peccu 0:b9ac53c439ed 924 }
peccu 0:b9ac53c439ed 925 }
peccu 0:b9ac53c439ed 926
peccu 0:b9ac53c439ed 927 if(unshrink == false && Gmax1 + Gmax2 <= eps*10)
peccu 0:b9ac53c439ed 928 {
peccu 0:b9ac53c439ed 929 unshrink = true;
peccu 0:b9ac53c439ed 930 reconstruct_gradient();
peccu 0:b9ac53c439ed 931 active_size = l;
peccu 0:b9ac53c439ed 932 info("*");
peccu 0:b9ac53c439ed 933 }
peccu 0:b9ac53c439ed 934
peccu 0:b9ac53c439ed 935 for(i=0;i<active_size;i++)
peccu 0:b9ac53c439ed 936 if (be_shrunk(i, Gmax1, Gmax2))
peccu 0:b9ac53c439ed 937 {
peccu 0:b9ac53c439ed 938 active_size--;
peccu 0:b9ac53c439ed 939 while (active_size > i)
peccu 0:b9ac53c439ed 940 {
peccu 0:b9ac53c439ed 941 if (!be_shrunk(active_size, Gmax1, Gmax2))
peccu 0:b9ac53c439ed 942 {
peccu 0:b9ac53c439ed 943 swap_index(i,active_size);
peccu 0:b9ac53c439ed 944 break;
peccu 0:b9ac53c439ed 945 }
peccu 0:b9ac53c439ed 946 active_size--;
peccu 0:b9ac53c439ed 947 }
peccu 0:b9ac53c439ed 948 }
peccu 0:b9ac53c439ed 949 }
peccu 0:b9ac53c439ed 950
peccu 0:b9ac53c439ed 951 double Solver::calculate_rho()
peccu 0:b9ac53c439ed 952 {
peccu 0:b9ac53c439ed 953 double r;
peccu 0:b9ac53c439ed 954 int nr_free = 0;
peccu 0:b9ac53c439ed 955 double ub = INF, lb = -INF, sum_free = 0;
peccu 0:b9ac53c439ed 956 for(int i=0;i<active_size;i++)
peccu 0:b9ac53c439ed 957 {
peccu 0:b9ac53c439ed 958 double yG = y[i]*G[i];
peccu 0:b9ac53c439ed 959
peccu 0:b9ac53c439ed 960 if(is_upper_bound(i))
peccu 0:b9ac53c439ed 961 {
peccu 0:b9ac53c439ed 962 if(y[i]==-1)
peccu 0:b9ac53c439ed 963 ub = min(ub,yG);
peccu 0:b9ac53c439ed 964 else
peccu 0:b9ac53c439ed 965 lb = max(lb,yG);
peccu 0:b9ac53c439ed 966 }
peccu 0:b9ac53c439ed 967 else if(is_lower_bound(i))
peccu 0:b9ac53c439ed 968 {
peccu 0:b9ac53c439ed 969 if(y[i]==+1)
peccu 0:b9ac53c439ed 970 ub = min(ub,yG);
peccu 0:b9ac53c439ed 971 else
peccu 0:b9ac53c439ed 972 lb = max(lb,yG);
peccu 0:b9ac53c439ed 973 }
peccu 0:b9ac53c439ed 974 else
peccu 0:b9ac53c439ed 975 {
peccu 0:b9ac53c439ed 976 ++nr_free;
peccu 0:b9ac53c439ed 977 sum_free += yG;
peccu 0:b9ac53c439ed 978 }
peccu 0:b9ac53c439ed 979 }
peccu 0:b9ac53c439ed 980
peccu 0:b9ac53c439ed 981 if(nr_free>0)
peccu 0:b9ac53c439ed 982 r = sum_free/nr_free;
peccu 0:b9ac53c439ed 983 else
peccu 0:b9ac53c439ed 984 r = (ub+lb)/2;
peccu 0:b9ac53c439ed 985
peccu 0:b9ac53c439ed 986 return r;
peccu 0:b9ac53c439ed 987 }
peccu 0:b9ac53c439ed 988
peccu 0:b9ac53c439ed 989 //
peccu 0:b9ac53c439ed 990 // Solver for nu-svm classification and regression
peccu 0:b9ac53c439ed 991 //
peccu 0:b9ac53c439ed 992 // additional constraint: e^T \alpha = constant
peccu 0:b9ac53c439ed 993 //
peccu 0:b9ac53c439ed 994 class Solver_NU : public Solver
peccu 0:b9ac53c439ed 995 {
peccu 0:b9ac53c439ed 996 public:
peccu 0:b9ac53c439ed 997 Solver_NU() {}
peccu 0:b9ac53c439ed 998 void Solve(int l, const QMatrix& Q, const double *p, const schar *y,
peccu 0:b9ac53c439ed 999 double *alpha, double Cp, double Cn, double eps,
peccu 0:b9ac53c439ed 1000 SolutionInfo* si, int shrinking)
peccu 0:b9ac53c439ed 1001 {
peccu 0:b9ac53c439ed 1002 this->si = si;
peccu 0:b9ac53c439ed 1003 Solver::Solve(l,Q,p,y,alpha,Cp,Cn,eps,si,shrinking);
peccu 0:b9ac53c439ed 1004 }
peccu 0:b9ac53c439ed 1005 private:
peccu 0:b9ac53c439ed 1006 SolutionInfo *si;
peccu 0:b9ac53c439ed 1007 int select_working_set(int &i, int &j);
peccu 0:b9ac53c439ed 1008 double calculate_rho();
peccu 0:b9ac53c439ed 1009 bool be_shrunk(int i, double Gmax1, double Gmax2, double Gmax3, double Gmax4);
peccu 0:b9ac53c439ed 1010 void do_shrinking();
peccu 0:b9ac53c439ed 1011 };
peccu 0:b9ac53c439ed 1012
peccu 0:b9ac53c439ed 1013 // return 1 if already optimal, return 0 otherwise
peccu 0:b9ac53c439ed 1014 int Solver_NU::select_working_set(int &out_i, int &out_j)
peccu 0:b9ac53c439ed 1015 {
peccu 0:b9ac53c439ed 1016 // return i,j such that y_i = y_j and
peccu 0:b9ac53c439ed 1017 // i: maximizes -y_i * grad(f)_i, i in I_up(\alpha)
peccu 0:b9ac53c439ed 1018 // j: minimizes the decrease of obj value
peccu 0:b9ac53c439ed 1019 // (if quadratic coefficeint <= 0, replace it with tau)
peccu 0:b9ac53c439ed 1020 // -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha)
peccu 0:b9ac53c439ed 1021
peccu 0:b9ac53c439ed 1022 double Gmaxp = -INF;
peccu 0:b9ac53c439ed 1023 double Gmaxp2 = -INF;
peccu 0:b9ac53c439ed 1024 int Gmaxp_idx = -1;
peccu 0:b9ac53c439ed 1025
peccu 0:b9ac53c439ed 1026 double Gmaxn = -INF;
peccu 0:b9ac53c439ed 1027 double Gmaxn2 = -INF;
peccu 0:b9ac53c439ed 1028 int Gmaxn_idx = -1;
peccu 0:b9ac53c439ed 1029
peccu 0:b9ac53c439ed 1030 int Gmin_idx = -1;
peccu 0:b9ac53c439ed 1031 double obj_diff_min = INF;
peccu 0:b9ac53c439ed 1032
peccu 0:b9ac53c439ed 1033 for(int t=0;t<active_size;t++)
peccu 0:b9ac53c439ed 1034 if(y[t]==+1)
peccu 0:b9ac53c439ed 1035 {
peccu 0:b9ac53c439ed 1036 if(!is_upper_bound(t))
peccu 0:b9ac53c439ed 1037 if(-G[t] >= Gmaxp)
peccu 0:b9ac53c439ed 1038 {
peccu 0:b9ac53c439ed 1039 Gmaxp = -G[t];
peccu 0:b9ac53c439ed 1040 Gmaxp_idx = t;
peccu 0:b9ac53c439ed 1041 }
peccu 0:b9ac53c439ed 1042 }
peccu 0:b9ac53c439ed 1043 else
peccu 0:b9ac53c439ed 1044 {
peccu 0:b9ac53c439ed 1045 if(!is_lower_bound(t))
peccu 0:b9ac53c439ed 1046 if(G[t] >= Gmaxn)
peccu 0:b9ac53c439ed 1047 {
peccu 0:b9ac53c439ed 1048 Gmaxn = G[t];
peccu 0:b9ac53c439ed 1049 Gmaxn_idx = t;
peccu 0:b9ac53c439ed 1050 }
peccu 0:b9ac53c439ed 1051 }
peccu 0:b9ac53c439ed 1052
peccu 0:b9ac53c439ed 1053 int ip = Gmaxp_idx;
peccu 0:b9ac53c439ed 1054 int in = Gmaxn_idx;
peccu 0:b9ac53c439ed 1055 const Qfloat *Q_ip = NULL;
peccu 0:b9ac53c439ed 1056 const Qfloat *Q_in = NULL;
peccu 0:b9ac53c439ed 1057 if(ip != -1) // NULL Q_ip not accessed: Gmaxp=-INF if ip=-1
peccu 0:b9ac53c439ed 1058 Q_ip = Q->get_Q(ip,active_size);
peccu 0:b9ac53c439ed 1059 if(in != -1)
peccu 0:b9ac53c439ed 1060 Q_in = Q->get_Q(in,active_size);
peccu 0:b9ac53c439ed 1061
peccu 0:b9ac53c439ed 1062 for(int j=0;j<active_size;j++)
peccu 0:b9ac53c439ed 1063 {
peccu 0:b9ac53c439ed 1064 if(y[j]==+1)
peccu 0:b9ac53c439ed 1065 {
peccu 0:b9ac53c439ed 1066 if (!is_lower_bound(j))
peccu 0:b9ac53c439ed 1067 {
peccu 0:b9ac53c439ed 1068 double grad_diff=Gmaxp+G[j];
peccu 0:b9ac53c439ed 1069 if (G[j] >= Gmaxp2)
peccu 0:b9ac53c439ed 1070 Gmaxp2 = G[j];
peccu 0:b9ac53c439ed 1071 if (grad_diff > 0)
peccu 0:b9ac53c439ed 1072 {
peccu 0:b9ac53c439ed 1073 double obj_diff;
peccu 0:b9ac53c439ed 1074 double quad_coef = QD[ip]+QD[j]-2*Q_ip[j];
peccu 0:b9ac53c439ed 1075 if (quad_coef > 0)
peccu 0:b9ac53c439ed 1076 obj_diff = -(grad_diff*grad_diff)/quad_coef;
peccu 0:b9ac53c439ed 1077 else
peccu 0:b9ac53c439ed 1078 obj_diff = -(grad_diff*grad_diff)/TAU;
peccu 0:b9ac53c439ed 1079
peccu 0:b9ac53c439ed 1080 if (obj_diff <= obj_diff_min)
peccu 0:b9ac53c439ed 1081 {
peccu 0:b9ac53c439ed 1082 Gmin_idx=j;
peccu 0:b9ac53c439ed 1083 obj_diff_min = obj_diff;
peccu 0:b9ac53c439ed 1084 }
peccu 0:b9ac53c439ed 1085 }
peccu 0:b9ac53c439ed 1086 }
peccu 0:b9ac53c439ed 1087 }
peccu 0:b9ac53c439ed 1088 else
peccu 0:b9ac53c439ed 1089 {
peccu 0:b9ac53c439ed 1090 if (!is_upper_bound(j))
peccu 0:b9ac53c439ed 1091 {
peccu 0:b9ac53c439ed 1092 double grad_diff=Gmaxn-G[j];
peccu 0:b9ac53c439ed 1093 if (-G[j] >= Gmaxn2)
peccu 0:b9ac53c439ed 1094 Gmaxn2 = -G[j];
peccu 0:b9ac53c439ed 1095 if (grad_diff > 0)
peccu 0:b9ac53c439ed 1096 {
peccu 0:b9ac53c439ed 1097 double obj_diff;
peccu 0:b9ac53c439ed 1098 double quad_coef = QD[in]+QD[j]-2*Q_in[j];
peccu 0:b9ac53c439ed 1099 if (quad_coef > 0)
peccu 0:b9ac53c439ed 1100 obj_diff = -(grad_diff*grad_diff)/quad_coef;
peccu 0:b9ac53c439ed 1101 else
peccu 0:b9ac53c439ed 1102 obj_diff = -(grad_diff*grad_diff)/TAU;
peccu 0:b9ac53c439ed 1103
peccu 0:b9ac53c439ed 1104 if (obj_diff <= obj_diff_min)
peccu 0:b9ac53c439ed 1105 {
peccu 0:b9ac53c439ed 1106 Gmin_idx=j;
peccu 0:b9ac53c439ed 1107 obj_diff_min = obj_diff;
peccu 0:b9ac53c439ed 1108 }
peccu 0:b9ac53c439ed 1109 }
peccu 0:b9ac53c439ed 1110 }
peccu 0:b9ac53c439ed 1111 }
peccu 0:b9ac53c439ed 1112 }
peccu 0:b9ac53c439ed 1113
peccu 0:b9ac53c439ed 1114 if(max(Gmaxp+Gmaxp2,Gmaxn+Gmaxn2) < eps)
peccu 0:b9ac53c439ed 1115 return 1;
peccu 0:b9ac53c439ed 1116
peccu 0:b9ac53c439ed 1117 if (y[Gmin_idx] == +1)
peccu 0:b9ac53c439ed 1118 out_i = Gmaxp_idx;
peccu 0:b9ac53c439ed 1119 else
peccu 0:b9ac53c439ed 1120 out_i = Gmaxn_idx;
peccu 0:b9ac53c439ed 1121 out_j = Gmin_idx;
peccu 0:b9ac53c439ed 1122
peccu 0:b9ac53c439ed 1123 return 0;
peccu 0:b9ac53c439ed 1124 }
peccu 0:b9ac53c439ed 1125
peccu 0:b9ac53c439ed 1126 bool Solver_NU::be_shrunk(int i, double Gmax1, double Gmax2, double Gmax3, double Gmax4)
peccu 0:b9ac53c439ed 1127 {
peccu 0:b9ac53c439ed 1128 if(is_upper_bound(i))
peccu 0:b9ac53c439ed 1129 {
peccu 0:b9ac53c439ed 1130 if(y[i]==+1)
peccu 0:b9ac53c439ed 1131 return(-G[i] > Gmax1);
peccu 0:b9ac53c439ed 1132 else
peccu 0:b9ac53c439ed 1133 return(-G[i] > Gmax4);
peccu 0:b9ac53c439ed 1134 }
peccu 0:b9ac53c439ed 1135 else if(is_lower_bound(i))
peccu 0:b9ac53c439ed 1136 {
peccu 0:b9ac53c439ed 1137 if(y[i]==+1)
peccu 0:b9ac53c439ed 1138 return(G[i] > Gmax2);
peccu 0:b9ac53c439ed 1139 else
peccu 0:b9ac53c439ed 1140 return(G[i] > Gmax3);
peccu 0:b9ac53c439ed 1141 }
peccu 0:b9ac53c439ed 1142 else
peccu 0:b9ac53c439ed 1143 return(false);
peccu 0:b9ac53c439ed 1144 }
peccu 0:b9ac53c439ed 1145
peccu 0:b9ac53c439ed 1146 void Solver_NU::do_shrinking()
peccu 0:b9ac53c439ed 1147 {
peccu 0:b9ac53c439ed 1148 double Gmax1 = -INF; // max { -y_i * grad(f)_i | y_i = +1, i in I_up(\alpha) }
peccu 0:b9ac53c439ed 1149 double Gmax2 = -INF; // max { y_i * grad(f)_i | y_i = +1, i in I_low(\alpha) }
peccu 0:b9ac53c439ed 1150 double Gmax3 = -INF; // max { -y_i * grad(f)_i | y_i = -1, i in I_up(\alpha) }
peccu 0:b9ac53c439ed 1151 double Gmax4 = -INF; // max { y_i * grad(f)_i | y_i = -1, i in I_low(\alpha) }
peccu 0:b9ac53c439ed 1152
peccu 0:b9ac53c439ed 1153 // find maximal violating pair first
peccu 0:b9ac53c439ed 1154 int i;
peccu 0:b9ac53c439ed 1155 for(i=0;i<active_size;i++)
peccu 0:b9ac53c439ed 1156 {
peccu 0:b9ac53c439ed 1157 if(!is_upper_bound(i))
peccu 0:b9ac53c439ed 1158 {
peccu 0:b9ac53c439ed 1159 if(y[i]==+1)
peccu 0:b9ac53c439ed 1160 {
peccu 0:b9ac53c439ed 1161 if(-G[i] > Gmax1) Gmax1 = -G[i];
peccu 0:b9ac53c439ed 1162 }
peccu 0:b9ac53c439ed 1163 else if(-G[i] > Gmax4) Gmax4 = -G[i];
peccu 0:b9ac53c439ed 1164 }
peccu 0:b9ac53c439ed 1165 if(!is_lower_bound(i))
peccu 0:b9ac53c439ed 1166 {
peccu 0:b9ac53c439ed 1167 if(y[i]==+1)
peccu 0:b9ac53c439ed 1168 {
peccu 0:b9ac53c439ed 1169 if(G[i] > Gmax2) Gmax2 = G[i];
peccu 0:b9ac53c439ed 1170 }
peccu 0:b9ac53c439ed 1171 else if(G[i] > Gmax3) Gmax3 = G[i];
peccu 0:b9ac53c439ed 1172 }
peccu 0:b9ac53c439ed 1173 }
peccu 0:b9ac53c439ed 1174
peccu 0:b9ac53c439ed 1175 if(unshrink == false && max(Gmax1+Gmax2,Gmax3+Gmax4) <= eps*10)
peccu 0:b9ac53c439ed 1176 {
peccu 0:b9ac53c439ed 1177 unshrink = true;
peccu 0:b9ac53c439ed 1178 reconstruct_gradient();
peccu 0:b9ac53c439ed 1179 active_size = l;
peccu 0:b9ac53c439ed 1180 }
peccu 0:b9ac53c439ed 1181
peccu 0:b9ac53c439ed 1182 for(i=0;i<active_size;i++)
peccu 0:b9ac53c439ed 1183 if (be_shrunk(i, Gmax1, Gmax2, Gmax3, Gmax4))
peccu 0:b9ac53c439ed 1184 {
peccu 0:b9ac53c439ed 1185 active_size--;
peccu 0:b9ac53c439ed 1186 while (active_size > i)
peccu 0:b9ac53c439ed 1187 {
peccu 0:b9ac53c439ed 1188 if (!be_shrunk(active_size, Gmax1, Gmax2, Gmax3, Gmax4))
peccu 0:b9ac53c439ed 1189 {
peccu 0:b9ac53c439ed 1190 swap_index(i,active_size);
peccu 0:b9ac53c439ed 1191 break;
peccu 0:b9ac53c439ed 1192 }
peccu 0:b9ac53c439ed 1193 active_size--;
peccu 0:b9ac53c439ed 1194 }
peccu 0:b9ac53c439ed 1195 }
peccu 0:b9ac53c439ed 1196 }
peccu 0:b9ac53c439ed 1197
peccu 0:b9ac53c439ed 1198 double Solver_NU::calculate_rho()
peccu 0:b9ac53c439ed 1199 {
peccu 0:b9ac53c439ed 1200 int nr_free1 = 0,nr_free2 = 0;
peccu 0:b9ac53c439ed 1201 double ub1 = INF, ub2 = INF;
peccu 0:b9ac53c439ed 1202 double lb1 = -INF, lb2 = -INF;
peccu 0:b9ac53c439ed 1203 double sum_free1 = 0, sum_free2 = 0;
peccu 0:b9ac53c439ed 1204
peccu 0:b9ac53c439ed 1205 for(int i=0;i<active_size;i++)
peccu 0:b9ac53c439ed 1206 {
peccu 0:b9ac53c439ed 1207 if(y[i]==+1)
peccu 0:b9ac53c439ed 1208 {
peccu 0:b9ac53c439ed 1209 if(is_upper_bound(i))
peccu 0:b9ac53c439ed 1210 lb1 = max(lb1,G[i]);
peccu 0:b9ac53c439ed 1211 else if(is_lower_bound(i))
peccu 0:b9ac53c439ed 1212 ub1 = min(ub1,G[i]);
peccu 0:b9ac53c439ed 1213 else
peccu 0:b9ac53c439ed 1214 {
peccu 0:b9ac53c439ed 1215 ++nr_free1;
peccu 0:b9ac53c439ed 1216 sum_free1 += G[i];
peccu 0:b9ac53c439ed 1217 }
peccu 0:b9ac53c439ed 1218 }
peccu 0:b9ac53c439ed 1219 else
peccu 0:b9ac53c439ed 1220 {
peccu 0:b9ac53c439ed 1221 if(is_upper_bound(i))
peccu 0:b9ac53c439ed 1222 lb2 = max(lb2,G[i]);
peccu 0:b9ac53c439ed 1223 else if(is_lower_bound(i))
peccu 0:b9ac53c439ed 1224 ub2 = min(ub2,G[i]);
peccu 0:b9ac53c439ed 1225 else
peccu 0:b9ac53c439ed 1226 {
peccu 0:b9ac53c439ed 1227 ++nr_free2;
peccu 0:b9ac53c439ed 1228 sum_free2 += G[i];
peccu 0:b9ac53c439ed 1229 }
peccu 0:b9ac53c439ed 1230 }
peccu 0:b9ac53c439ed 1231 }
peccu 0:b9ac53c439ed 1232
peccu 0:b9ac53c439ed 1233 double r1,r2;
peccu 0:b9ac53c439ed 1234 if(nr_free1 > 0)
peccu 0:b9ac53c439ed 1235 r1 = sum_free1/nr_free1;
peccu 0:b9ac53c439ed 1236 else
peccu 0:b9ac53c439ed 1237 r1 = (ub1+lb1)/2;
peccu 0:b9ac53c439ed 1238
peccu 0:b9ac53c439ed 1239 if(nr_free2 > 0)
peccu 0:b9ac53c439ed 1240 r2 = sum_free2/nr_free2;
peccu 0:b9ac53c439ed 1241 else
peccu 0:b9ac53c439ed 1242 r2 = (ub2+lb2)/2;
peccu 0:b9ac53c439ed 1243
peccu 0:b9ac53c439ed 1244 si->r = (r1+r2)/2;
peccu 0:b9ac53c439ed 1245 return (r1-r2)/2;
peccu 0:b9ac53c439ed 1246 }
peccu 0:b9ac53c439ed 1247
peccu 0:b9ac53c439ed 1248 //
peccu 0:b9ac53c439ed 1249 // Q matrices for various formulations
peccu 0:b9ac53c439ed 1250 //
peccu 0:b9ac53c439ed 1251 class SVC_Q: public Kernel
peccu 0:b9ac53c439ed 1252 {
peccu 0:b9ac53c439ed 1253 public:
peccu 0:b9ac53c439ed 1254 SVC_Q(const svm_problem& prob, const svm_parameter& param, const schar *y_)
peccu 0:b9ac53c439ed 1255 :Kernel(prob.l, prob.x, param)
peccu 0:b9ac53c439ed 1256 {
peccu 0:b9ac53c439ed 1257 clone(y,y_,prob.l);
peccu 0:b9ac53c439ed 1258 cache = new Cache(prob.l,(long int)(param.cache_size*(1<<20)));
peccu 0:b9ac53c439ed 1259 QD = new double[prob.l];
peccu 0:b9ac53c439ed 1260 for(int i=0;i<prob.l;i++)
peccu 0:b9ac53c439ed 1261 QD[i] = (this->*kernel_function)(i,i);
peccu 0:b9ac53c439ed 1262 }
peccu 0:b9ac53c439ed 1263
peccu 0:b9ac53c439ed 1264 Qfloat *get_Q(int i, int len) const
peccu 0:b9ac53c439ed 1265 {
peccu 0:b9ac53c439ed 1266 Qfloat *data;
peccu 0:b9ac53c439ed 1267 int start, j;
peccu 0:b9ac53c439ed 1268 if((start = cache->get_data(i,&data,len)) < len)
peccu 0:b9ac53c439ed 1269 {
peccu 0:b9ac53c439ed 1270 for(j=start;j<len;j++)
peccu 0:b9ac53c439ed 1271 data[j] = (Qfloat)(y[i]*y[j]*(this->*kernel_function)(i,j));
peccu 0:b9ac53c439ed 1272 }
peccu 0:b9ac53c439ed 1273 return data;
peccu 0:b9ac53c439ed 1274 }
peccu 0:b9ac53c439ed 1275
peccu 0:b9ac53c439ed 1276 double *get_QD() const
peccu 0:b9ac53c439ed 1277 {
peccu 0:b9ac53c439ed 1278 return QD;
peccu 0:b9ac53c439ed 1279 }
peccu 0:b9ac53c439ed 1280
peccu 0:b9ac53c439ed 1281 void swap_index(int i, int j) const
peccu 0:b9ac53c439ed 1282 {
peccu 0:b9ac53c439ed 1283 cache->swap_index(i,j);
peccu 0:b9ac53c439ed 1284 Kernel::swap_index(i,j);
peccu 0:b9ac53c439ed 1285 swap(y[i],y[j]);
peccu 0:b9ac53c439ed 1286 swap(QD[i],QD[j]);
peccu 0:b9ac53c439ed 1287 }
peccu 0:b9ac53c439ed 1288
peccu 0:b9ac53c439ed 1289 ~SVC_Q()
peccu 0:b9ac53c439ed 1290 {
peccu 0:b9ac53c439ed 1291 delete[] y;
peccu 0:b9ac53c439ed 1292 delete cache;
peccu 0:b9ac53c439ed 1293 delete[] QD;
peccu 0:b9ac53c439ed 1294 }
peccu 0:b9ac53c439ed 1295 private:
peccu 0:b9ac53c439ed 1296 schar *y;
peccu 0:b9ac53c439ed 1297 Cache *cache;
peccu 0:b9ac53c439ed 1298 double *QD;
peccu 0:b9ac53c439ed 1299 };
peccu 0:b9ac53c439ed 1300
peccu 0:b9ac53c439ed 1301 class ONE_CLASS_Q: public Kernel
peccu 0:b9ac53c439ed 1302 {
peccu 0:b9ac53c439ed 1303 public:
peccu 0:b9ac53c439ed 1304 ONE_CLASS_Q(const svm_problem& prob, const svm_parameter& param)
peccu 0:b9ac53c439ed 1305 :Kernel(prob.l, prob.x, param)
peccu 0:b9ac53c439ed 1306 {
peccu 0:b9ac53c439ed 1307 cache = new Cache(prob.l,(long int)(param.cache_size*(1<<20)));
peccu 0:b9ac53c439ed 1308 QD = new double[prob.l];
peccu 0:b9ac53c439ed 1309 for(int i=0;i<prob.l;i++)
peccu 0:b9ac53c439ed 1310 QD[i] = (this->*kernel_function)(i,i);
peccu 0:b9ac53c439ed 1311 }
peccu 0:b9ac53c439ed 1312
peccu 0:b9ac53c439ed 1313 Qfloat *get_Q(int i, int len) const
peccu 0:b9ac53c439ed 1314 {
peccu 0:b9ac53c439ed 1315 Qfloat *data;
peccu 0:b9ac53c439ed 1316 int start, j;
peccu 0:b9ac53c439ed 1317 if((start = cache->get_data(i,&data,len)) < len)
peccu 0:b9ac53c439ed 1318 {
peccu 0:b9ac53c439ed 1319 for(j=start;j<len;j++)
peccu 0:b9ac53c439ed 1320 data[j] = (Qfloat)(this->*kernel_function)(i,j);
peccu 0:b9ac53c439ed 1321 }
peccu 0:b9ac53c439ed 1322 return data;
peccu 0:b9ac53c439ed 1323 }
peccu 0:b9ac53c439ed 1324
peccu 0:b9ac53c439ed 1325 double *get_QD() const
peccu 0:b9ac53c439ed 1326 {
peccu 0:b9ac53c439ed 1327 return QD;
peccu 0:b9ac53c439ed 1328 }
peccu 0:b9ac53c439ed 1329
peccu 0:b9ac53c439ed 1330 void swap_index(int i, int j) const
peccu 0:b9ac53c439ed 1331 {
peccu 0:b9ac53c439ed 1332 cache->swap_index(i,j);
peccu 0:b9ac53c439ed 1333 Kernel::swap_index(i,j);
peccu 0:b9ac53c439ed 1334 swap(QD[i],QD[j]);
peccu 0:b9ac53c439ed 1335 }
peccu 0:b9ac53c439ed 1336
peccu 0:b9ac53c439ed 1337 ~ONE_CLASS_Q()
peccu 0:b9ac53c439ed 1338 {
peccu 0:b9ac53c439ed 1339 delete cache;
peccu 0:b9ac53c439ed 1340 delete[] QD;
peccu 0:b9ac53c439ed 1341 }
peccu 0:b9ac53c439ed 1342 private:
peccu 0:b9ac53c439ed 1343 Cache *cache;
peccu 0:b9ac53c439ed 1344 double *QD;
peccu 0:b9ac53c439ed 1345 };
peccu 0:b9ac53c439ed 1346
peccu 0:b9ac53c439ed 1347 class SVR_Q: public Kernel
peccu 0:b9ac53c439ed 1348 {
peccu 0:b9ac53c439ed 1349 public:
peccu 0:b9ac53c439ed 1350 SVR_Q(const svm_problem& prob, const svm_parameter& param)
peccu 0:b9ac53c439ed 1351 :Kernel(prob.l, prob.x, param)
peccu 0:b9ac53c439ed 1352 {
peccu 0:b9ac53c439ed 1353 l = prob.l;
peccu 0:b9ac53c439ed 1354 cache = new Cache(l,(long int)(param.cache_size*(1<<20)));
peccu 0:b9ac53c439ed 1355 QD = new double[2*l];
peccu 0:b9ac53c439ed 1356 sign = new schar[2*l];
peccu 0:b9ac53c439ed 1357 index = new int[2*l];
peccu 0:b9ac53c439ed 1358 for(int k=0;k<l;k++)
peccu 0:b9ac53c439ed 1359 {
peccu 0:b9ac53c439ed 1360 sign[k] = 1;
peccu 0:b9ac53c439ed 1361 sign[k+l] = -1;
peccu 0:b9ac53c439ed 1362 index[k] = k;
peccu 0:b9ac53c439ed 1363 index[k+l] = k;
peccu 0:b9ac53c439ed 1364 QD[k] = (this->*kernel_function)(k,k);
peccu 0:b9ac53c439ed 1365 QD[k+l] = QD[k];
peccu 0:b9ac53c439ed 1366 }
peccu 0:b9ac53c439ed 1367 buffer[0] = new Qfloat[2*l];
peccu 0:b9ac53c439ed 1368 buffer[1] = new Qfloat[2*l];
peccu 0:b9ac53c439ed 1369 next_buffer = 0;
peccu 0:b9ac53c439ed 1370 }
peccu 0:b9ac53c439ed 1371
peccu 0:b9ac53c439ed 1372 void swap_index(int i, int j) const
peccu 0:b9ac53c439ed 1373 {
peccu 0:b9ac53c439ed 1374 swap(sign[i],sign[j]);
peccu 0:b9ac53c439ed 1375 swap(index[i],index[j]);
peccu 0:b9ac53c439ed 1376 swap(QD[i],QD[j]);
peccu 0:b9ac53c439ed 1377 }
peccu 0:b9ac53c439ed 1378
peccu 0:b9ac53c439ed 1379 Qfloat *get_Q(int i, int len) const
peccu 0:b9ac53c439ed 1380 {
peccu 0:b9ac53c439ed 1381 Qfloat *data;
peccu 0:b9ac53c439ed 1382 int j, real_i = index[i];
peccu 0:b9ac53c439ed 1383 if(cache->get_data(real_i,&data,l) < l)
peccu 0:b9ac53c439ed 1384 {
peccu 0:b9ac53c439ed 1385 for(j=0;j<l;j++)
peccu 0:b9ac53c439ed 1386 data[j] = (Qfloat)(this->*kernel_function)(real_i,j);
peccu 0:b9ac53c439ed 1387 }
peccu 0:b9ac53c439ed 1388
peccu 0:b9ac53c439ed 1389 // reorder and copy
peccu 0:b9ac53c439ed 1390 Qfloat *buf = buffer[next_buffer];
peccu 0:b9ac53c439ed 1391 next_buffer = 1 - next_buffer;
peccu 0:b9ac53c439ed 1392 schar si = sign[i];
peccu 0:b9ac53c439ed 1393 for(j=0;j<len;j++)
peccu 0:b9ac53c439ed 1394 buf[j] = (Qfloat) si * (Qfloat) sign[j] * data[index[j]];
peccu 0:b9ac53c439ed 1395 return buf;
peccu 0:b9ac53c439ed 1396 }
peccu 0:b9ac53c439ed 1397
peccu 0:b9ac53c439ed 1398 double *get_QD() const
peccu 0:b9ac53c439ed 1399 {
peccu 0:b9ac53c439ed 1400 return QD;
peccu 0:b9ac53c439ed 1401 }
peccu 0:b9ac53c439ed 1402
peccu 0:b9ac53c439ed 1403 ~SVR_Q()
peccu 0:b9ac53c439ed 1404 {
peccu 0:b9ac53c439ed 1405 delete cache;
peccu 0:b9ac53c439ed 1406 delete[] sign;
peccu 0:b9ac53c439ed 1407 delete[] index;
peccu 0:b9ac53c439ed 1408 delete[] buffer[0];
peccu 0:b9ac53c439ed 1409 delete[] buffer[1];
peccu 0:b9ac53c439ed 1410 delete[] QD;
peccu 0:b9ac53c439ed 1411 }
peccu 0:b9ac53c439ed 1412 private:
peccu 0:b9ac53c439ed 1413 int l;
peccu 0:b9ac53c439ed 1414 Cache *cache;
peccu 0:b9ac53c439ed 1415 schar *sign;
peccu 0:b9ac53c439ed 1416 int *index;
peccu 0:b9ac53c439ed 1417 mutable int next_buffer;
peccu 0:b9ac53c439ed 1418 Qfloat *buffer[2];
peccu 0:b9ac53c439ed 1419 double *QD;
peccu 0:b9ac53c439ed 1420 };
peccu 0:b9ac53c439ed 1421
peccu 0:b9ac53c439ed 1422 //
peccu 0:b9ac53c439ed 1423 // construct and solve various formulations
peccu 0:b9ac53c439ed 1424 //
peccu 0:b9ac53c439ed 1425 static void solve_c_svc(
peccu 0:b9ac53c439ed 1426 const svm_problem *prob, const svm_parameter* param,
peccu 0:b9ac53c439ed 1427 double *alpha, Solver::SolutionInfo* si, double Cp, double Cn)
peccu 0:b9ac53c439ed 1428 {
peccu 0:b9ac53c439ed 1429 int l = prob->l;
peccu 0:b9ac53c439ed 1430 double *minus_ones = new double[l];
peccu 0:b9ac53c439ed 1431 schar *y = new schar[l];
peccu 0:b9ac53c439ed 1432
peccu 0:b9ac53c439ed 1433 int i;
peccu 0:b9ac53c439ed 1434
peccu 0:b9ac53c439ed 1435 for(i=0;i<l;i++)
peccu 0:b9ac53c439ed 1436 {
peccu 0:b9ac53c439ed 1437 alpha[i] = 0;
peccu 0:b9ac53c439ed 1438 minus_ones[i] = -1;
peccu 0:b9ac53c439ed 1439 if(prob->y[i] > 0) y[i] = +1; else y[i] = -1;
peccu 0:b9ac53c439ed 1440 }
peccu 0:b9ac53c439ed 1441
peccu 0:b9ac53c439ed 1442 Solver s;
peccu 0:b9ac53c439ed 1443 s.Solve(l, SVC_Q(*prob,*param,y), minus_ones, y,
peccu 0:b9ac53c439ed 1444 alpha, Cp, Cn, param->eps, si, param->shrinking);
peccu 0:b9ac53c439ed 1445
peccu 0:b9ac53c439ed 1446 double sum_alpha=0;
peccu 0:b9ac53c439ed 1447 for(i=0;i<l;i++)
peccu 0:b9ac53c439ed 1448 sum_alpha += alpha[i];
peccu 0:b9ac53c439ed 1449
peccu 0:b9ac53c439ed 1450 if (Cp==Cn)
peccu 0:b9ac53c439ed 1451 info("nu = %f\n", sum_alpha/(Cp*prob->l));
peccu 0:b9ac53c439ed 1452
peccu 0:b9ac53c439ed 1453 for(i=0;i<l;i++)
peccu 0:b9ac53c439ed 1454 alpha[i] *= y[i];
peccu 0:b9ac53c439ed 1455
peccu 0:b9ac53c439ed 1456 delete[] minus_ones;
peccu 0:b9ac53c439ed 1457 delete[] y;
peccu 0:b9ac53c439ed 1458 }
peccu 0:b9ac53c439ed 1459
peccu 0:b9ac53c439ed 1460 static void solve_nu_svc(
peccu 0:b9ac53c439ed 1461 const svm_problem *prob, const svm_parameter *param,
peccu 0:b9ac53c439ed 1462 double *alpha, Solver::SolutionInfo* si)
peccu 0:b9ac53c439ed 1463 {
peccu 0:b9ac53c439ed 1464 int i;
peccu 0:b9ac53c439ed 1465 int l = prob->l;
peccu 0:b9ac53c439ed 1466 double nu = param->nu;
peccu 0:b9ac53c439ed 1467
peccu 0:b9ac53c439ed 1468 schar *y = new schar[l];
peccu 0:b9ac53c439ed 1469
peccu 0:b9ac53c439ed 1470 for(i=0;i<l;i++)
peccu 0:b9ac53c439ed 1471 if(prob->y[i]>0)
peccu 0:b9ac53c439ed 1472 y[i] = +1;
peccu 0:b9ac53c439ed 1473 else
peccu 0:b9ac53c439ed 1474 y[i] = -1;
peccu 0:b9ac53c439ed 1475
peccu 0:b9ac53c439ed 1476 double sum_pos = nu*l/2;
peccu 0:b9ac53c439ed 1477 double sum_neg = nu*l/2;
peccu 0:b9ac53c439ed 1478
peccu 0:b9ac53c439ed 1479 for(i=0;i<l;i++)
peccu 0:b9ac53c439ed 1480 if(y[i] == +1)
peccu 0:b9ac53c439ed 1481 {
peccu 0:b9ac53c439ed 1482 alpha[i] = min(1.0,sum_pos);
peccu 0:b9ac53c439ed 1483 sum_pos -= alpha[i];
peccu 0:b9ac53c439ed 1484 }
peccu 0:b9ac53c439ed 1485 else
peccu 0:b9ac53c439ed 1486 {
peccu 0:b9ac53c439ed 1487 alpha[i] = min(1.0,sum_neg);
peccu 0:b9ac53c439ed 1488 sum_neg -= alpha[i];
peccu 0:b9ac53c439ed 1489 }
peccu 0:b9ac53c439ed 1490
peccu 0:b9ac53c439ed 1491 double *zeros = new double[l];
peccu 0:b9ac53c439ed 1492
peccu 0:b9ac53c439ed 1493 for(i=0;i<l;i++)
peccu 0:b9ac53c439ed 1494 zeros[i] = 0;
peccu 0:b9ac53c439ed 1495
peccu 0:b9ac53c439ed 1496 Solver_NU s;
peccu 0:b9ac53c439ed 1497 s.Solve(l, SVC_Q(*prob,*param,y), zeros, y,
peccu 0:b9ac53c439ed 1498 alpha, 1.0, 1.0, param->eps, si, param->shrinking);
peccu 0:b9ac53c439ed 1499 double r = si->r;
peccu 0:b9ac53c439ed 1500
peccu 0:b9ac53c439ed 1501 info("C = %f\n",1/r);
peccu 0:b9ac53c439ed 1502
peccu 0:b9ac53c439ed 1503 for(i=0;i<l;i++)
peccu 0:b9ac53c439ed 1504 alpha[i] *= y[i]/r;
peccu 0:b9ac53c439ed 1505
peccu 0:b9ac53c439ed 1506 si->rho /= r;
peccu 0:b9ac53c439ed 1507 si->obj /= (r*r);
peccu 0:b9ac53c439ed 1508 si->upper_bound_p = 1/r;
peccu 0:b9ac53c439ed 1509 si->upper_bound_n = 1/r;
peccu 0:b9ac53c439ed 1510
peccu 0:b9ac53c439ed 1511 delete[] y;
peccu 0:b9ac53c439ed 1512 delete[] zeros;
peccu 0:b9ac53c439ed 1513 }
peccu 0:b9ac53c439ed 1514
peccu 0:b9ac53c439ed 1515 static void solve_one_class(
peccu 0:b9ac53c439ed 1516 const svm_problem *prob, const svm_parameter *param,
peccu 0:b9ac53c439ed 1517 double *alpha, Solver::SolutionInfo* si)
peccu 0:b9ac53c439ed 1518 {
peccu 0:b9ac53c439ed 1519 int l = prob->l;
peccu 0:b9ac53c439ed 1520 double *zeros = new double[l];
peccu 0:b9ac53c439ed 1521 schar *ones = new schar[l];
peccu 0:b9ac53c439ed 1522 int i;
peccu 0:b9ac53c439ed 1523
peccu 0:b9ac53c439ed 1524 int n = (int)(param->nu*prob->l); // # of alpha's at upper bound
peccu 0:b9ac53c439ed 1525
peccu 0:b9ac53c439ed 1526 for(i=0;i<n;i++)
peccu 0:b9ac53c439ed 1527 alpha[i] = 1;
peccu 0:b9ac53c439ed 1528 if(n<prob->l)
peccu 0:b9ac53c439ed 1529 alpha[n] = param->nu * prob->l - n;
peccu 0:b9ac53c439ed 1530 for(i=n+1;i<l;i++)
peccu 0:b9ac53c439ed 1531 alpha[i] = 0;
peccu 0:b9ac53c439ed 1532
peccu 0:b9ac53c439ed 1533 for(i=0;i<l;i++)
peccu 0:b9ac53c439ed 1534 {
peccu 0:b9ac53c439ed 1535 zeros[i] = 0;
peccu 0:b9ac53c439ed 1536 ones[i] = 1;
peccu 0:b9ac53c439ed 1537 }
peccu 0:b9ac53c439ed 1538
peccu 0:b9ac53c439ed 1539 Solver s;
peccu 0:b9ac53c439ed 1540 s.Solve(l, ONE_CLASS_Q(*prob,*param), zeros, ones,
peccu 0:b9ac53c439ed 1541 alpha, 1.0, 1.0, param->eps, si, param->shrinking);
peccu 0:b9ac53c439ed 1542
peccu 0:b9ac53c439ed 1543 delete[] zeros;
peccu 0:b9ac53c439ed 1544 delete[] ones;
peccu 0:b9ac53c439ed 1545 }
peccu 0:b9ac53c439ed 1546
peccu 0:b9ac53c439ed 1547 static void solve_epsilon_svr(
peccu 0:b9ac53c439ed 1548 const svm_problem *prob, const svm_parameter *param,
peccu 0:b9ac53c439ed 1549 double *alpha, Solver::SolutionInfo* si)
peccu 0:b9ac53c439ed 1550 {
peccu 0:b9ac53c439ed 1551 int l = prob->l;
peccu 0:b9ac53c439ed 1552 double *alpha2 = new double[2*l];
peccu 0:b9ac53c439ed 1553 double *linear_term = new double[2*l];
peccu 0:b9ac53c439ed 1554 schar *y = new schar[2*l];
peccu 0:b9ac53c439ed 1555 int i;
peccu 0:b9ac53c439ed 1556
peccu 0:b9ac53c439ed 1557 for(i=0;i<l;i++)
peccu 0:b9ac53c439ed 1558 {
peccu 0:b9ac53c439ed 1559 alpha2[i] = 0;
peccu 0:b9ac53c439ed 1560 linear_term[i] = param->p - prob->y[i];
peccu 0:b9ac53c439ed 1561 y[i] = 1;
peccu 0:b9ac53c439ed 1562
peccu 0:b9ac53c439ed 1563 alpha2[i+l] = 0;
peccu 0:b9ac53c439ed 1564 linear_term[i+l] = param->p + prob->y[i];
peccu 0:b9ac53c439ed 1565 y[i+l] = -1;
peccu 0:b9ac53c439ed 1566 }
peccu 0:b9ac53c439ed 1567
peccu 0:b9ac53c439ed 1568 Solver s;
peccu 0:b9ac53c439ed 1569 s.Solve(2*l, SVR_Q(*prob,*param), linear_term, y,
peccu 0:b9ac53c439ed 1570 alpha2, param->C, param->C, param->eps, si, param->shrinking);
peccu 0:b9ac53c439ed 1571
peccu 0:b9ac53c439ed 1572 double sum_alpha = 0;
peccu 0:b9ac53c439ed 1573 for(i=0;i<l;i++)
peccu 0:b9ac53c439ed 1574 {
peccu 0:b9ac53c439ed 1575 alpha[i] = alpha2[i] - alpha2[i+l];
peccu 0:b9ac53c439ed 1576 sum_alpha += fabs(alpha[i]);
peccu 0:b9ac53c439ed 1577 }
peccu 0:b9ac53c439ed 1578 info("nu = %f\n",sum_alpha/(param->C*l));
peccu 0:b9ac53c439ed 1579
peccu 0:b9ac53c439ed 1580 delete[] alpha2;
peccu 0:b9ac53c439ed 1581 delete[] linear_term;
peccu 0:b9ac53c439ed 1582 delete[] y;
peccu 0:b9ac53c439ed 1583 }
peccu 0:b9ac53c439ed 1584
peccu 0:b9ac53c439ed 1585 static void solve_nu_svr(
peccu 0:b9ac53c439ed 1586 const svm_problem *prob, const svm_parameter *param,
peccu 0:b9ac53c439ed 1587 double *alpha, Solver::SolutionInfo* si)
peccu 0:b9ac53c439ed 1588 {
peccu 0:b9ac53c439ed 1589 int l = prob->l;
peccu 0:b9ac53c439ed 1590 double C = param->C;
peccu 0:b9ac53c439ed 1591 double *alpha2 = new double[2*l];
peccu 0:b9ac53c439ed 1592 double *linear_term = new double[2*l];
peccu 0:b9ac53c439ed 1593 schar *y = new schar[2*l];
peccu 0:b9ac53c439ed 1594 int i;
peccu 0:b9ac53c439ed 1595
peccu 0:b9ac53c439ed 1596 double sum = C * param->nu * l / 2;
peccu 0:b9ac53c439ed 1597 for(i=0;i<l;i++)
peccu 0:b9ac53c439ed 1598 {
peccu 0:b9ac53c439ed 1599 alpha2[i] = alpha2[i+l] = min(sum,C);
peccu 0:b9ac53c439ed 1600 sum -= alpha2[i];
peccu 0:b9ac53c439ed 1601
peccu 0:b9ac53c439ed 1602 linear_term[i] = - prob->y[i];
peccu 0:b9ac53c439ed 1603 y[i] = 1;
peccu 0:b9ac53c439ed 1604
peccu 0:b9ac53c439ed 1605 linear_term[i+l] = prob->y[i];
peccu 0:b9ac53c439ed 1606 y[i+l] = -1;
peccu 0:b9ac53c439ed 1607 }
peccu 0:b9ac53c439ed 1608
peccu 0:b9ac53c439ed 1609 Solver_NU s;
peccu 0:b9ac53c439ed 1610 s.Solve(2*l, SVR_Q(*prob,*param), linear_term, y,
peccu 0:b9ac53c439ed 1611 alpha2, C, C, param->eps, si, param->shrinking);
peccu 0:b9ac53c439ed 1612
peccu 0:b9ac53c439ed 1613 info("epsilon = %f\n",-si->r);
peccu 0:b9ac53c439ed 1614
peccu 0:b9ac53c439ed 1615 for(i=0;i<l;i++)
peccu 0:b9ac53c439ed 1616 alpha[i] = alpha2[i] - alpha2[i+l];
peccu 0:b9ac53c439ed 1617
peccu 0:b9ac53c439ed 1618 delete[] alpha2;
peccu 0:b9ac53c439ed 1619 delete[] linear_term;
peccu 0:b9ac53c439ed 1620 delete[] y;
peccu 0:b9ac53c439ed 1621 }
peccu 0:b9ac53c439ed 1622
peccu 0:b9ac53c439ed 1623 //
peccu 0:b9ac53c439ed 1624 // decision_function
peccu 0:b9ac53c439ed 1625 //
peccu 0:b9ac53c439ed 1626 struct decision_function
peccu 0:b9ac53c439ed 1627 {
peccu 0:b9ac53c439ed 1628 double *alpha;
peccu 0:b9ac53c439ed 1629 double rho;
peccu 0:b9ac53c439ed 1630 };
peccu 0:b9ac53c439ed 1631
peccu 0:b9ac53c439ed 1632 static decision_function svm_train_one(
peccu 0:b9ac53c439ed 1633 const svm_problem *prob, const svm_parameter *param,
peccu 0:b9ac53c439ed 1634 double Cp, double Cn)
peccu 0:b9ac53c439ed 1635 {
peccu 0:b9ac53c439ed 1636 double *alpha = Malloc(double,prob->l);
peccu 0:b9ac53c439ed 1637 Solver::SolutionInfo si;
peccu 0:b9ac53c439ed 1638 switch(param->svm_type)
peccu 0:b9ac53c439ed 1639 {
peccu 0:b9ac53c439ed 1640 case C_SVC:
peccu 0:b9ac53c439ed 1641 solve_c_svc(prob,param,alpha,&si,Cp,Cn);
peccu 0:b9ac53c439ed 1642 break;
peccu 0:b9ac53c439ed 1643 case NU_SVC:
peccu 0:b9ac53c439ed 1644 solve_nu_svc(prob,param,alpha,&si);
peccu 0:b9ac53c439ed 1645 break;
peccu 0:b9ac53c439ed 1646 case ONE_CLASS:
peccu 0:b9ac53c439ed 1647 solve_one_class(prob,param,alpha,&si);
peccu 0:b9ac53c439ed 1648 break;
peccu 0:b9ac53c439ed 1649 case EPSILON_SVR:
peccu 0:b9ac53c439ed 1650 solve_epsilon_svr(prob,param,alpha,&si);
peccu 0:b9ac53c439ed 1651 break;
peccu 0:b9ac53c439ed 1652 case NU_SVR:
peccu 0:b9ac53c439ed 1653 solve_nu_svr(prob,param,alpha,&si);
peccu 0:b9ac53c439ed 1654 break;
peccu 0:b9ac53c439ed 1655 }
peccu 0:b9ac53c439ed 1656
peccu 0:b9ac53c439ed 1657 info("obj = %f, rho = %f\n",si.obj,si.rho);
peccu 0:b9ac53c439ed 1658
peccu 0:b9ac53c439ed 1659 // output SVs
peccu 0:b9ac53c439ed 1660
peccu 0:b9ac53c439ed 1661 int nSV = 0;
peccu 0:b9ac53c439ed 1662 int nBSV = 0;
peccu 0:b9ac53c439ed 1663 for(int i=0;i<prob->l;i++)
peccu 0:b9ac53c439ed 1664 {
peccu 0:b9ac53c439ed 1665 if(fabs(alpha[i]) > 0)
peccu 0:b9ac53c439ed 1666 {
peccu 0:b9ac53c439ed 1667 ++nSV;
peccu 0:b9ac53c439ed 1668 if(prob->y[i] > 0)
peccu 0:b9ac53c439ed 1669 {
peccu 0:b9ac53c439ed 1670 if(fabs(alpha[i]) >= si.upper_bound_p)
peccu 0:b9ac53c439ed 1671 ++nBSV;
peccu 0:b9ac53c439ed 1672 }
peccu 0:b9ac53c439ed 1673 else
peccu 0:b9ac53c439ed 1674 {
peccu 0:b9ac53c439ed 1675 if(fabs(alpha[i]) >= si.upper_bound_n)
peccu 0:b9ac53c439ed 1676 ++nBSV;
peccu 0:b9ac53c439ed 1677 }
peccu 0:b9ac53c439ed 1678 }
peccu 0:b9ac53c439ed 1679 }
peccu 0:b9ac53c439ed 1680
peccu 0:b9ac53c439ed 1681 info("nSV = %d, nBSV = %d\n",nSV,nBSV);
peccu 0:b9ac53c439ed 1682
peccu 0:b9ac53c439ed 1683 decision_function f;
peccu 0:b9ac53c439ed 1684 f.alpha = alpha;
peccu 0:b9ac53c439ed 1685 f.rho = si.rho;
peccu 0:b9ac53c439ed 1686 return f;
peccu 0:b9ac53c439ed 1687 }
peccu 0:b9ac53c439ed 1688
peccu 0:b9ac53c439ed 1689 // Platt's binary SVM Probablistic Output: an improvement from Lin et al.
peccu 0:b9ac53c439ed 1690 static void sigmoid_train(
peccu 0:b9ac53c439ed 1691 int l, const double *dec_values, const double *labels,
peccu 0:b9ac53c439ed 1692 double& A, double& B)
peccu 0:b9ac53c439ed 1693 {
peccu 0:b9ac53c439ed 1694 double prior1=0, prior0 = 0;
peccu 0:b9ac53c439ed 1695 int i;
peccu 0:b9ac53c439ed 1696
peccu 0:b9ac53c439ed 1697 for (i=0;i<l;i++)
peccu 0:b9ac53c439ed 1698 if (labels[i] > 0) prior1+=1;
peccu 0:b9ac53c439ed 1699 else prior0+=1;
peccu 0:b9ac53c439ed 1700
peccu 0:b9ac53c439ed 1701 int max_iter=100; // Maximal number of iterations
peccu 0:b9ac53c439ed 1702 double min_step=1e-10; // Minimal step taken in line search
peccu 0:b9ac53c439ed 1703 double sigma=1e-12; // For numerically strict PD of Hessian
peccu 0:b9ac53c439ed 1704 double eps=1e-5;
peccu 0:b9ac53c439ed 1705 double hiTarget=(prior1+1.0)/(prior1+2.0);
peccu 0:b9ac53c439ed 1706 double loTarget=1/(prior0+2.0);
peccu 0:b9ac53c439ed 1707 double *t=Malloc(double,l);
peccu 0:b9ac53c439ed 1708 double fApB,p,q,h11,h22,h21,g1,g2,det,dA,dB,gd,stepsize;
peccu 0:b9ac53c439ed 1709 double newA,newB,newf,d1,d2;
peccu 0:b9ac53c439ed 1710 int iter;
peccu 0:b9ac53c439ed 1711
peccu 0:b9ac53c439ed 1712 // Initial Point and Initial Fun Value
peccu 0:b9ac53c439ed 1713 A=0.0; B=log((prior0+1.0)/(prior1+1.0));
peccu 0:b9ac53c439ed 1714 double fval = 0.0;
peccu 0:b9ac53c439ed 1715
peccu 0:b9ac53c439ed 1716 for (i=0;i<l;i++)
peccu 0:b9ac53c439ed 1717 {
peccu 0:b9ac53c439ed 1718 if (labels[i]>0) t[i]=hiTarget;
peccu 0:b9ac53c439ed 1719 else t[i]=loTarget;
peccu 0:b9ac53c439ed 1720 fApB = dec_values[i]*A+B;
peccu 0:b9ac53c439ed 1721 if (fApB>=0)
peccu 0:b9ac53c439ed 1722 fval += t[i]*fApB + log(1+exp(-fApB));
peccu 0:b9ac53c439ed 1723 else
peccu 0:b9ac53c439ed 1724 fval += (t[i] - 1)*fApB +log(1+exp(fApB));
peccu 0:b9ac53c439ed 1725 }
peccu 0:b9ac53c439ed 1726 for (iter=0;iter<max_iter;iter++)
peccu 0:b9ac53c439ed 1727 {
peccu 0:b9ac53c439ed 1728 // Update Gradient and Hessian (use H' = H + sigma I)
peccu 0:b9ac53c439ed 1729 h11=sigma; // numerically ensures strict PD
peccu 0:b9ac53c439ed 1730 h22=sigma;
peccu 0:b9ac53c439ed 1731 h21=0.0;g1=0.0;g2=0.0;
peccu 0:b9ac53c439ed 1732 for (i=0;i<l;i++)
peccu 0:b9ac53c439ed 1733 {
peccu 0:b9ac53c439ed 1734 fApB = dec_values[i]*A+B;
peccu 0:b9ac53c439ed 1735 if (fApB >= 0)
peccu 0:b9ac53c439ed 1736 {
peccu 0:b9ac53c439ed 1737 p=exp(-fApB)/(1.0+exp(-fApB));
peccu 0:b9ac53c439ed 1738 q=1.0/(1.0+exp(-fApB));
peccu 0:b9ac53c439ed 1739 }
peccu 0:b9ac53c439ed 1740 else
peccu 0:b9ac53c439ed 1741 {
peccu 0:b9ac53c439ed 1742 p=1.0/(1.0+exp(fApB));
peccu 0:b9ac53c439ed 1743 q=exp(fApB)/(1.0+exp(fApB));
peccu 0:b9ac53c439ed 1744 }
peccu 0:b9ac53c439ed 1745 d2=p*q;
peccu 0:b9ac53c439ed 1746 h11+=dec_values[i]*dec_values[i]*d2;
peccu 0:b9ac53c439ed 1747 h22+=d2;
peccu 0:b9ac53c439ed 1748 h21+=dec_values[i]*d2;
peccu 0:b9ac53c439ed 1749 d1=t[i]-p;
peccu 0:b9ac53c439ed 1750 g1+=dec_values[i]*d1;
peccu 0:b9ac53c439ed 1751 g2+=d1;
peccu 0:b9ac53c439ed 1752 }
peccu 0:b9ac53c439ed 1753
peccu 0:b9ac53c439ed 1754 // Stopping Criteria
peccu 0:b9ac53c439ed 1755 if (fabs(g1)<eps && fabs(g2)<eps)
peccu 0:b9ac53c439ed 1756 break;
peccu 0:b9ac53c439ed 1757
peccu 0:b9ac53c439ed 1758 // Finding Newton direction: -inv(H') * g
peccu 0:b9ac53c439ed 1759 det=h11*h22-h21*h21;
peccu 0:b9ac53c439ed 1760 dA=-(h22*g1 - h21 * g2) / det;
peccu 0:b9ac53c439ed 1761 dB=-(-h21*g1+ h11 * g2) / det;
peccu 0:b9ac53c439ed 1762 gd=g1*dA+g2*dB;
peccu 0:b9ac53c439ed 1763
peccu 0:b9ac53c439ed 1764
peccu 0:b9ac53c439ed 1765 stepsize = 1; // Line Search
peccu 0:b9ac53c439ed 1766 while (stepsize >= min_step)
peccu 0:b9ac53c439ed 1767 {
peccu 0:b9ac53c439ed 1768 newA = A + stepsize * dA;
peccu 0:b9ac53c439ed 1769 newB = B + stepsize * dB;
peccu 0:b9ac53c439ed 1770
peccu 0:b9ac53c439ed 1771 // New function value
peccu 0:b9ac53c439ed 1772 newf = 0.0;
peccu 0:b9ac53c439ed 1773 for (i=0;i<l;i++)
peccu 0:b9ac53c439ed 1774 {
peccu 0:b9ac53c439ed 1775 fApB = dec_values[i]*newA+newB;
peccu 0:b9ac53c439ed 1776 if (fApB >= 0)
peccu 0:b9ac53c439ed 1777 newf += t[i]*fApB + log(1+exp(-fApB));
peccu 0:b9ac53c439ed 1778 else
peccu 0:b9ac53c439ed 1779 newf += (t[i] - 1)*fApB +log(1+exp(fApB));
peccu 0:b9ac53c439ed 1780 }
peccu 0:b9ac53c439ed 1781 // Check sufficient decrease
peccu 0:b9ac53c439ed 1782 if (newf<fval+0.0001*stepsize*gd)
peccu 0:b9ac53c439ed 1783 {
peccu 0:b9ac53c439ed 1784 A=newA;B=newB;fval=newf;
peccu 0:b9ac53c439ed 1785 break;
peccu 0:b9ac53c439ed 1786 }
peccu 0:b9ac53c439ed 1787 else
peccu 0:b9ac53c439ed 1788 stepsize = stepsize / 2.0;
peccu 0:b9ac53c439ed 1789 }
peccu 0:b9ac53c439ed 1790
peccu 0:b9ac53c439ed 1791 if (stepsize < min_step)
peccu 0:b9ac53c439ed 1792 {
peccu 0:b9ac53c439ed 1793 info("Line search fails in two-class probability estimates\n");
peccu 0:b9ac53c439ed 1794 break;
peccu 0:b9ac53c439ed 1795 }
peccu 0:b9ac53c439ed 1796 }
peccu 0:b9ac53c439ed 1797
peccu 0:b9ac53c439ed 1798 if (iter>=max_iter)
peccu 0:b9ac53c439ed 1799 info("Reaching maximal iterations in two-class probability estimates\n");
peccu 0:b9ac53c439ed 1800 free(t);
peccu 0:b9ac53c439ed 1801 }
peccu 0:b9ac53c439ed 1802
peccu 0:b9ac53c439ed 1803 static double sigmoid_predict(double decision_value, double A, double B)
peccu 0:b9ac53c439ed 1804 {
peccu 0:b9ac53c439ed 1805 double fApB = decision_value*A+B;
peccu 0:b9ac53c439ed 1806 if (fApB >= 0)
peccu 0:b9ac53c439ed 1807 return exp(-fApB)/(1.0+exp(-fApB));
peccu 0:b9ac53c439ed 1808 else
peccu 0:b9ac53c439ed 1809 return 1.0/(1+exp(fApB)) ;
peccu 0:b9ac53c439ed 1810 }
peccu 0:b9ac53c439ed 1811
peccu 0:b9ac53c439ed 1812 // Method 2 from the multiclass_prob paper by Wu, Lin, and Weng
peccu 0:b9ac53c439ed 1813 static void multiclass_probability(int k, double **r, double *p)
peccu 0:b9ac53c439ed 1814 {
peccu 0:b9ac53c439ed 1815 int t,j;
peccu 0:b9ac53c439ed 1816 int iter = 0, max_iter=max(100,k);
peccu 0:b9ac53c439ed 1817 double **Q=Malloc(double *,k);
peccu 0:b9ac53c439ed 1818 double *Qp=Malloc(double,k);
peccu 0:b9ac53c439ed 1819 double pQp, eps=0.005/k;
peccu 0:b9ac53c439ed 1820
peccu 0:b9ac53c439ed 1821 for (t=0;t<k;t++)
peccu 0:b9ac53c439ed 1822 {
peccu 0:b9ac53c439ed 1823 p[t]=1.0/k; // Valid if k = 1
peccu 0:b9ac53c439ed 1824 Q[t]=Malloc(double,k);
peccu 0:b9ac53c439ed 1825 Q[t][t]=0;
peccu 0:b9ac53c439ed 1826 for (j=0;j<t;j++)
peccu 0:b9ac53c439ed 1827 {
peccu 0:b9ac53c439ed 1828 Q[t][t]+=r[j][t]*r[j][t];
peccu 0:b9ac53c439ed 1829 Q[t][j]=Q[j][t];
peccu 0:b9ac53c439ed 1830 }
peccu 0:b9ac53c439ed 1831 for (j=t+1;j<k;j++)
peccu 0:b9ac53c439ed 1832 {
peccu 0:b9ac53c439ed 1833 Q[t][t]+=r[j][t]*r[j][t];
peccu 0:b9ac53c439ed 1834 Q[t][j]=-r[j][t]*r[t][j];
peccu 0:b9ac53c439ed 1835 }
peccu 0:b9ac53c439ed 1836 }
peccu 0:b9ac53c439ed 1837 for (iter=0;iter<max_iter;iter++)
peccu 0:b9ac53c439ed 1838 {
peccu 0:b9ac53c439ed 1839 // stopping condition, recalculate QP,pQP for numerical accuracy
peccu 0:b9ac53c439ed 1840 pQp=0;
peccu 0:b9ac53c439ed 1841 for (t=0;t<k;t++)
peccu 0:b9ac53c439ed 1842 {
peccu 0:b9ac53c439ed 1843 Qp[t]=0;
peccu 0:b9ac53c439ed 1844 for (j=0;j<k;j++)
peccu 0:b9ac53c439ed 1845 Qp[t]+=Q[t][j]*p[j];
peccu 0:b9ac53c439ed 1846 pQp+=p[t]*Qp[t];
peccu 0:b9ac53c439ed 1847 }
peccu 0:b9ac53c439ed 1848 double max_error=0;
peccu 0:b9ac53c439ed 1849 for (t=0;t<k;t++)
peccu 0:b9ac53c439ed 1850 {
peccu 0:b9ac53c439ed 1851 double error=fabs(Qp[t]-pQp);
peccu 0:b9ac53c439ed 1852 if (error>max_error)
peccu 0:b9ac53c439ed 1853 max_error=error;
peccu 0:b9ac53c439ed 1854 }
peccu 0:b9ac53c439ed 1855 if (max_error<eps) break;
peccu 0:b9ac53c439ed 1856
peccu 0:b9ac53c439ed 1857 for (t=0;t<k;t++)
peccu 0:b9ac53c439ed 1858 {
peccu 0:b9ac53c439ed 1859 double diff=(-Qp[t]+pQp)/Q[t][t];
peccu 0:b9ac53c439ed 1860 p[t]+=diff;
peccu 0:b9ac53c439ed 1861 pQp=(pQp+diff*(diff*Q[t][t]+2*Qp[t]))/(1+diff)/(1+diff);
peccu 0:b9ac53c439ed 1862 for (j=0;j<k;j++)
peccu 0:b9ac53c439ed 1863 {
peccu 0:b9ac53c439ed 1864 Qp[j]=(Qp[j]+diff*Q[t][j])/(1+diff);
peccu 0:b9ac53c439ed 1865 p[j]/=(1+diff);
peccu 0:b9ac53c439ed 1866 }
peccu 0:b9ac53c439ed 1867 }
peccu 0:b9ac53c439ed 1868 }
peccu 0:b9ac53c439ed 1869 if (iter>=max_iter)
peccu 0:b9ac53c439ed 1870 info("Exceeds max_iter in multiclass_prob\n");
peccu 0:b9ac53c439ed 1871 for(t=0;t<k;t++) free(Q[t]);
peccu 0:b9ac53c439ed 1872 free(Q);
peccu 0:b9ac53c439ed 1873 free(Qp);
peccu 0:b9ac53c439ed 1874 }
peccu 0:b9ac53c439ed 1875
peccu 0:b9ac53c439ed 1876 // Cross-validation decision values for probability estimates
peccu 0:b9ac53c439ed 1877 static void svm_binary_svc_probability(
peccu 0:b9ac53c439ed 1878 const svm_problem *prob, const svm_parameter *param,
peccu 0:b9ac53c439ed 1879 double Cp, double Cn, double& probA, double& probB)
peccu 0:b9ac53c439ed 1880 {
peccu 0:b9ac53c439ed 1881 int i;
peccu 0:b9ac53c439ed 1882 int nr_fold = 5;
peccu 0:b9ac53c439ed 1883 int *perm = Malloc(int,prob->l);
peccu 0:b9ac53c439ed 1884 double *dec_values = Malloc(double,prob->l);
peccu 0:b9ac53c439ed 1885
peccu 0:b9ac53c439ed 1886 // random shuffle
peccu 0:b9ac53c439ed 1887 for(i=0;i<prob->l;i++) perm[i]=i;
peccu 0:b9ac53c439ed 1888 for(i=0;i<prob->l;i++)
peccu 0:b9ac53c439ed 1889 {
peccu 0:b9ac53c439ed 1890 int j = i+rand()%(prob->l-i);
peccu 0:b9ac53c439ed 1891 swap(perm[i],perm[j]);
peccu 0:b9ac53c439ed 1892 }
peccu 0:b9ac53c439ed 1893 for(i=0;i<nr_fold;i++)
peccu 0:b9ac53c439ed 1894 {
peccu 0:b9ac53c439ed 1895 int begin = i*prob->l/nr_fold;
peccu 0:b9ac53c439ed 1896 int end = (i+1)*prob->l/nr_fold;
peccu 0:b9ac53c439ed 1897 int j,k;
peccu 0:b9ac53c439ed 1898 struct svm_problem subprob;
peccu 0:b9ac53c439ed 1899
peccu 0:b9ac53c439ed 1900 subprob.l = prob->l-(end-begin);
peccu 0:b9ac53c439ed 1901 subprob.x = Malloc(struct svm_node*,subprob.l);
peccu 0:b9ac53c439ed 1902 subprob.y = Malloc(double,subprob.l);
peccu 0:b9ac53c439ed 1903
peccu 0:b9ac53c439ed 1904 k=0;
peccu 0:b9ac53c439ed 1905 for(j=0;j<begin;j++)
peccu 0:b9ac53c439ed 1906 {
peccu 0:b9ac53c439ed 1907 subprob.x[k] = prob->x[perm[j]];
peccu 0:b9ac53c439ed 1908 subprob.y[k] = prob->y[perm[j]];
peccu 0:b9ac53c439ed 1909 ++k;
peccu 0:b9ac53c439ed 1910 }
peccu 0:b9ac53c439ed 1911 for(j=end;j<prob->l;j++)
peccu 0:b9ac53c439ed 1912 {
peccu 0:b9ac53c439ed 1913 subprob.x[k] = prob->x[perm[j]];
peccu 0:b9ac53c439ed 1914 subprob.y[k] = prob->y[perm[j]];
peccu 0:b9ac53c439ed 1915 ++k;
peccu 0:b9ac53c439ed 1916 }
peccu 0:b9ac53c439ed 1917 int p_count=0,n_count=0;
peccu 0:b9ac53c439ed 1918 for(j=0;j<k;j++)
peccu 0:b9ac53c439ed 1919 if(subprob.y[j]>0)
peccu 0:b9ac53c439ed 1920 p_count++;
peccu 0:b9ac53c439ed 1921 else
peccu 0:b9ac53c439ed 1922 n_count++;
peccu 0:b9ac53c439ed 1923
peccu 0:b9ac53c439ed 1924 if(p_count==0 && n_count==0)
peccu 0:b9ac53c439ed 1925 for(j=begin;j<end;j++)
peccu 0:b9ac53c439ed 1926 dec_values[perm[j]] = 0;
peccu 0:b9ac53c439ed 1927 else if(p_count > 0 && n_count == 0)
peccu 0:b9ac53c439ed 1928 for(j=begin;j<end;j++)
peccu 0:b9ac53c439ed 1929 dec_values[perm[j]] = 1;
peccu 0:b9ac53c439ed 1930 else if(p_count == 0 && n_count > 0)
peccu 0:b9ac53c439ed 1931 for(j=begin;j<end;j++)
peccu 0:b9ac53c439ed 1932 dec_values[perm[j]] = -1;
peccu 0:b9ac53c439ed 1933 else
peccu 0:b9ac53c439ed 1934 {
peccu 0:b9ac53c439ed 1935 svm_parameter subparam = *param;
peccu 0:b9ac53c439ed 1936 subparam.probability=0;
peccu 0:b9ac53c439ed 1937 subparam.C=1.0;
peccu 0:b9ac53c439ed 1938 subparam.nr_weight=2;
peccu 0:b9ac53c439ed 1939 subparam.weight_label = Malloc(int,2);
peccu 0:b9ac53c439ed 1940 subparam.weight = Malloc(double,2);
peccu 0:b9ac53c439ed 1941 subparam.weight_label[0]=+1;
peccu 0:b9ac53c439ed 1942 subparam.weight_label[1]=-1;
peccu 0:b9ac53c439ed 1943 subparam.weight[0]=Cp;
peccu 0:b9ac53c439ed 1944 subparam.weight[1]=Cn;
peccu 0:b9ac53c439ed 1945 struct svm_model *submodel = svm_train(&subprob,&subparam);
peccu 0:b9ac53c439ed 1946 for(j=begin;j<end;j++)
peccu 0:b9ac53c439ed 1947 {
peccu 0:b9ac53c439ed 1948 svm_predict_values(submodel,prob->x[perm[j]],&(dec_values[perm[j]]));
peccu 0:b9ac53c439ed 1949 // ensure +1 -1 order; reason not using CV subroutine
peccu 0:b9ac53c439ed 1950 dec_values[perm[j]] *= submodel->label[0];
peccu 0:b9ac53c439ed 1951 }
peccu 0:b9ac53c439ed 1952 svm_free_and_destroy_model(&submodel);
peccu 0:b9ac53c439ed 1953 svm_destroy_param(&subparam);
peccu 0:b9ac53c439ed 1954 }
peccu 0:b9ac53c439ed 1955 free(subprob.x);
peccu 0:b9ac53c439ed 1956 free(subprob.y);
peccu 0:b9ac53c439ed 1957 }
peccu 0:b9ac53c439ed 1958 sigmoid_train(prob->l,dec_values,prob->y,probA,probB);
peccu 0:b9ac53c439ed 1959 free(dec_values);
peccu 0:b9ac53c439ed 1960 free(perm);
peccu 0:b9ac53c439ed 1961 }
peccu 0:b9ac53c439ed 1962
peccu 0:b9ac53c439ed 1963 // Return parameter of a Laplace distribution
peccu 0:b9ac53c439ed 1964 static double svm_svr_probability(
peccu 0:b9ac53c439ed 1965 const svm_problem *prob, const svm_parameter *param)
peccu 0:b9ac53c439ed 1966 {
peccu 0:b9ac53c439ed 1967 int i;
peccu 0:b9ac53c439ed 1968 int nr_fold = 5;
peccu 0:b9ac53c439ed 1969 double *ymv = Malloc(double,prob->l);
peccu 0:b9ac53c439ed 1970 double mae = 0;
peccu 0:b9ac53c439ed 1971
peccu 0:b9ac53c439ed 1972 svm_parameter newparam = *param;
peccu 0:b9ac53c439ed 1973 newparam.probability = 0;
peccu 0:b9ac53c439ed 1974 svm_cross_validation(prob,&newparam,nr_fold,ymv);
peccu 0:b9ac53c439ed 1975 for(i=0;i<prob->l;i++)
peccu 0:b9ac53c439ed 1976 {
peccu 0:b9ac53c439ed 1977 ymv[i]=prob->y[i]-ymv[i];
peccu 0:b9ac53c439ed 1978 mae += fabs(ymv[i]);
peccu 0:b9ac53c439ed 1979 }
peccu 0:b9ac53c439ed 1980 mae /= prob->l;
peccu 0:b9ac53c439ed 1981 double std=sqrt(2*mae*mae);
peccu 0:b9ac53c439ed 1982 int count=0;
peccu 0:b9ac53c439ed 1983 mae=0;
peccu 0:b9ac53c439ed 1984 for(i=0;i<prob->l;i++)
peccu 0:b9ac53c439ed 1985 if (fabs(ymv[i]) > 5*std)
peccu 0:b9ac53c439ed 1986 count=count+1;
peccu 0:b9ac53c439ed 1987 else
peccu 0:b9ac53c439ed 1988 mae+=fabs(ymv[i]);
peccu 0:b9ac53c439ed 1989 mae /= (prob->l-count);
peccu 0:b9ac53c439ed 1990 info("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma= %g\n",mae);
peccu 0:b9ac53c439ed 1991 free(ymv);
peccu 0:b9ac53c439ed 1992 return mae;
peccu 0:b9ac53c439ed 1993 }
peccu 0:b9ac53c439ed 1994
peccu 0:b9ac53c439ed 1995
peccu 0:b9ac53c439ed 1996 // label: label name, start: begin of each class, count: #data of classes, perm: indices to the original data
peccu 0:b9ac53c439ed 1997 // perm, length l, must be allocated before calling this subroutine
peccu 0:b9ac53c439ed 1998 static void svm_group_classes(const svm_problem *prob, int *nr_class_ret, int **label_ret, int **start_ret, int **count_ret, int *perm)
peccu 0:b9ac53c439ed 1999 {
peccu 0:b9ac53c439ed 2000 int l = prob->l;
peccu 0:b9ac53c439ed 2001 int max_nr_class = 16;
peccu 0:b9ac53c439ed 2002 int nr_class = 0;
peccu 0:b9ac53c439ed 2003 int *label = Malloc(int,max_nr_class);
peccu 0:b9ac53c439ed 2004 int *count = Malloc(int,max_nr_class);
peccu 0:b9ac53c439ed 2005 int *data_label = Malloc(int,l);
peccu 0:b9ac53c439ed 2006 int i;
peccu 0:b9ac53c439ed 2007
peccu 0:b9ac53c439ed 2008 for(i=0;i<l;i++)
peccu 0:b9ac53c439ed 2009 {
peccu 0:b9ac53c439ed 2010 int this_label = (int)prob->y[i];
peccu 0:b9ac53c439ed 2011 int j;
peccu 0:b9ac53c439ed 2012 for(j=0;j<nr_class;j++)
peccu 0:b9ac53c439ed 2013 {
peccu 0:b9ac53c439ed 2014 if(this_label == label[j])
peccu 0:b9ac53c439ed 2015 {
peccu 0:b9ac53c439ed 2016 ++count[j];
peccu 0:b9ac53c439ed 2017 break;
peccu 0:b9ac53c439ed 2018 }
peccu 0:b9ac53c439ed 2019 }
peccu 0:b9ac53c439ed 2020 data_label[i] = j;
peccu 0:b9ac53c439ed 2021 if(j == nr_class)
peccu 0:b9ac53c439ed 2022 {
peccu 0:b9ac53c439ed 2023 if(nr_class == max_nr_class)
peccu 0:b9ac53c439ed 2024 {
peccu 0:b9ac53c439ed 2025 max_nr_class *= 2;
peccu 0:b9ac53c439ed 2026 label = (int *)realloc(label,max_nr_class*sizeof(int));
peccu 0:b9ac53c439ed 2027 count = (int *)realloc(count,max_nr_class*sizeof(int));
peccu 0:b9ac53c439ed 2028 }
peccu 0:b9ac53c439ed 2029 label[nr_class] = this_label;
peccu 0:b9ac53c439ed 2030 count[nr_class] = 1;
peccu 0:b9ac53c439ed 2031 ++nr_class;
peccu 0:b9ac53c439ed 2032 }
peccu 0:b9ac53c439ed 2033 }
peccu 0:b9ac53c439ed 2034
peccu 0:b9ac53c439ed 2035 int *start = Malloc(int,nr_class);
peccu 0:b9ac53c439ed 2036 start[0] = 0;
peccu 0:b9ac53c439ed 2037 for(i=1;i<nr_class;i++)
peccu 0:b9ac53c439ed 2038 start[i] = start[i-1]+count[i-1];
peccu 0:b9ac53c439ed 2039 for(i=0;i<l;i++)
peccu 0:b9ac53c439ed 2040 {
peccu 0:b9ac53c439ed 2041 perm[start[data_label[i]]] = i;
peccu 0:b9ac53c439ed 2042 ++start[data_label[i]];
peccu 0:b9ac53c439ed 2043 }
peccu 0:b9ac53c439ed 2044 start[0] = 0;
peccu 0:b9ac53c439ed 2045 for(i=1;i<nr_class;i++)
peccu 0:b9ac53c439ed 2046 start[i] = start[i-1]+count[i-1];
peccu 0:b9ac53c439ed 2047
peccu 0:b9ac53c439ed 2048 *nr_class_ret = nr_class;
peccu 0:b9ac53c439ed 2049 *label_ret = label;
peccu 0:b9ac53c439ed 2050 *start_ret = start;
peccu 0:b9ac53c439ed 2051 *count_ret = count;
peccu 0:b9ac53c439ed 2052 free(data_label);
peccu 0:b9ac53c439ed 2053 }
peccu 0:b9ac53c439ed 2054
peccu 0:b9ac53c439ed 2055 //
peccu 0:b9ac53c439ed 2056 // Interface functions
peccu 0:b9ac53c439ed 2057 //
peccu 0:b9ac53c439ed 2058 svm_model *svm_train(const svm_problem *prob, const svm_parameter *param)
peccu 0:b9ac53c439ed 2059 {
peccu 0:b9ac53c439ed 2060 svm_model *model = Malloc(svm_model,1);
peccu 0:b9ac53c439ed 2061 model->param = *param;
peccu 0:b9ac53c439ed 2062 model->free_sv = 0; // XXX
peccu 0:b9ac53c439ed 2063
peccu 0:b9ac53c439ed 2064 if(param->svm_type == ONE_CLASS ||
peccu 0:b9ac53c439ed 2065 param->svm_type == EPSILON_SVR ||
peccu 0:b9ac53c439ed 2066 param->svm_type == NU_SVR)
peccu 0:b9ac53c439ed 2067 {
peccu 0:b9ac53c439ed 2068 // regression or one-class-svm
peccu 0:b9ac53c439ed 2069 model->nr_class = 2;
peccu 0:b9ac53c439ed 2070 model->label = NULL;
peccu 0:b9ac53c439ed 2071 model->nSV = NULL;
peccu 0:b9ac53c439ed 2072 model->probA = NULL; model->probB = NULL;
peccu 0:b9ac53c439ed 2073 model->sv_coef = Malloc(double *,1);
peccu 0:b9ac53c439ed 2074
peccu 0:b9ac53c439ed 2075 if(param->probability &&
peccu 0:b9ac53c439ed 2076 (param->svm_type == EPSILON_SVR ||
peccu 0:b9ac53c439ed 2077 param->svm_type == NU_SVR))
peccu 0:b9ac53c439ed 2078 {
peccu 0:b9ac53c439ed 2079 model->probA = Malloc(double,1);
peccu 0:b9ac53c439ed 2080 model->probA[0] = svm_svr_probability(prob,param);
peccu 0:b9ac53c439ed 2081 }
peccu 0:b9ac53c439ed 2082
peccu 0:b9ac53c439ed 2083 decision_function f = svm_train_one(prob,param,0,0);
peccu 0:b9ac53c439ed 2084 model->rho = Malloc(double,1);
peccu 0:b9ac53c439ed 2085 model->rho[0] = f.rho;
peccu 0:b9ac53c439ed 2086
peccu 0:b9ac53c439ed 2087 int nSV = 0;
peccu 0:b9ac53c439ed 2088 int i;
peccu 0:b9ac53c439ed 2089 for(i=0;i<prob->l;i++)
peccu 0:b9ac53c439ed 2090 if(fabs(f.alpha[i]) > 0) ++nSV;
peccu 0:b9ac53c439ed 2091 model->l = nSV;
peccu 0:b9ac53c439ed 2092 model->SV = Malloc(svm_node *,nSV);
peccu 0:b9ac53c439ed 2093 model->sv_coef[0] = Malloc(double,nSV);
peccu 0:b9ac53c439ed 2094 int j = 0;
peccu 0:b9ac53c439ed 2095 for(i=0;i<prob->l;i++)
peccu 0:b9ac53c439ed 2096 if(fabs(f.alpha[i]) > 0)
peccu 0:b9ac53c439ed 2097 {
peccu 0:b9ac53c439ed 2098 model->SV[j] = prob->x[i];
peccu 0:b9ac53c439ed 2099 model->sv_coef[0][j] = f.alpha[i];
peccu 0:b9ac53c439ed 2100 ++j;
peccu 0:b9ac53c439ed 2101 }
peccu 0:b9ac53c439ed 2102
peccu 0:b9ac53c439ed 2103 free(f.alpha);
peccu 0:b9ac53c439ed 2104 }
peccu 0:b9ac53c439ed 2105 else
peccu 0:b9ac53c439ed 2106 {
peccu 0:b9ac53c439ed 2107 // classification
peccu 0:b9ac53c439ed 2108 int l = prob->l;
peccu 0:b9ac53c439ed 2109 int nr_class;
peccu 0:b9ac53c439ed 2110 int *label = NULL;
peccu 0:b9ac53c439ed 2111 int *start = NULL;
peccu 0:b9ac53c439ed 2112 int *count = NULL;
peccu 0:b9ac53c439ed 2113 int *perm = Malloc(int,l);
peccu 0:b9ac53c439ed 2114
peccu 0:b9ac53c439ed 2115 // group training data of the same class
peccu 0:b9ac53c439ed 2116 svm_group_classes(prob,&nr_class,&label,&start,&count,perm);
peccu 0:b9ac53c439ed 2117 svm_node **x = Malloc(svm_node *,l);
peccu 0:b9ac53c439ed 2118 int i;
peccu 0:b9ac53c439ed 2119 for(i=0;i<l;i++)
peccu 0:b9ac53c439ed 2120 x[i] = prob->x[perm[i]];
peccu 0:b9ac53c439ed 2121
peccu 0:b9ac53c439ed 2122 // calculate weighted C
peccu 0:b9ac53c439ed 2123
peccu 0:b9ac53c439ed 2124 double *weighted_C = Malloc(double, nr_class);
peccu 0:b9ac53c439ed 2125 for(i=0;i<nr_class;i++)
peccu 0:b9ac53c439ed 2126 weighted_C[i] = param->C;
peccu 0:b9ac53c439ed 2127 for(i=0;i<param->nr_weight;i++)
peccu 0:b9ac53c439ed 2128 {
peccu 0:b9ac53c439ed 2129 int j;
peccu 0:b9ac53c439ed 2130 for(j=0;j<nr_class;j++)
peccu 0:b9ac53c439ed 2131 if(param->weight_label[i] == label[j])
peccu 0:b9ac53c439ed 2132 break;
peccu 0:b9ac53c439ed 2133 if(j == nr_class)
peccu 0:b9ac53c439ed 2134 fprintf(stderr,"warning: class label %d specified in weight is not found\n", param->weight_label[i]);
peccu 0:b9ac53c439ed 2135 else
peccu 0:b9ac53c439ed 2136 weighted_C[j] *= param->weight[i];
peccu 0:b9ac53c439ed 2137 }
peccu 0:b9ac53c439ed 2138
peccu 0:b9ac53c439ed 2139 // train k*(k-1)/2 models
peccu 0:b9ac53c439ed 2140
peccu 0:b9ac53c439ed 2141 bool *nonzero = Malloc(bool,l);
peccu 0:b9ac53c439ed 2142 for(i=0;i<l;i++)
peccu 0:b9ac53c439ed 2143 nonzero[i] = false;
peccu 0:b9ac53c439ed 2144 decision_function *f = Malloc(decision_function,nr_class*(nr_class-1)/2);
peccu 0:b9ac53c439ed 2145
peccu 0:b9ac53c439ed 2146 double *probA=NULL,*probB=NULL;
peccu 0:b9ac53c439ed 2147 if (param->probability)
peccu 0:b9ac53c439ed 2148 {
peccu 0:b9ac53c439ed 2149 probA=Malloc(double,nr_class*(nr_class-1)/2);
peccu 0:b9ac53c439ed 2150 probB=Malloc(double,nr_class*(nr_class-1)/2);
peccu 0:b9ac53c439ed 2151 }
peccu 0:b9ac53c439ed 2152
peccu 0:b9ac53c439ed 2153 int p = 0;
peccu 0:b9ac53c439ed 2154 for(i=0;i<nr_class;i++)
peccu 0:b9ac53c439ed 2155 for(int j=i+1;j<nr_class;j++)
peccu 0:b9ac53c439ed 2156 {
peccu 0:b9ac53c439ed 2157 svm_problem sub_prob;
peccu 0:b9ac53c439ed 2158 int si = start[i], sj = start[j];
peccu 0:b9ac53c439ed 2159 int ci = count[i], cj = count[j];
peccu 0:b9ac53c439ed 2160 sub_prob.l = ci+cj;
peccu 0:b9ac53c439ed 2161 sub_prob.x = Malloc(svm_node *,sub_prob.l);
peccu 0:b9ac53c439ed 2162 sub_prob.y = Malloc(double,sub_prob.l);
peccu 0:b9ac53c439ed 2163 int k;
peccu 0:b9ac53c439ed 2164 for(k=0;k<ci;k++)
peccu 0:b9ac53c439ed 2165 {
peccu 0:b9ac53c439ed 2166 sub_prob.x[k] = x[si+k];
peccu 0:b9ac53c439ed 2167 sub_prob.y[k] = +1;
peccu 0:b9ac53c439ed 2168 }
peccu 0:b9ac53c439ed 2169 for(k=0;k<cj;k++)
peccu 0:b9ac53c439ed 2170 {
peccu 0:b9ac53c439ed 2171 sub_prob.x[ci+k] = x[sj+k];
peccu 0:b9ac53c439ed 2172 sub_prob.y[ci+k] = -1;
peccu 0:b9ac53c439ed 2173 }
peccu 0:b9ac53c439ed 2174
peccu 0:b9ac53c439ed 2175 if(param->probability)
peccu 0:b9ac53c439ed 2176 svm_binary_svc_probability(&sub_prob,param,weighted_C[i],weighted_C[j],probA[p],probB[p]);
peccu 0:b9ac53c439ed 2177
peccu 0:b9ac53c439ed 2178 f[p] = svm_train_one(&sub_prob,param,weighted_C[i],weighted_C[j]);
peccu 0:b9ac53c439ed 2179 for(k=0;k<ci;k++)
peccu 0:b9ac53c439ed 2180 if(!nonzero[si+k] && fabs(f[p].alpha[k]) > 0)
peccu 0:b9ac53c439ed 2181 nonzero[si+k] = true;
peccu 0:b9ac53c439ed 2182 for(k=0;k<cj;k++)
peccu 0:b9ac53c439ed 2183 if(!nonzero[sj+k] && fabs(f[p].alpha[ci+k]) > 0)
peccu 0:b9ac53c439ed 2184 nonzero[sj+k] = true;
peccu 0:b9ac53c439ed 2185 free(sub_prob.x);
peccu 0:b9ac53c439ed 2186 free(sub_prob.y);
peccu 0:b9ac53c439ed 2187 ++p;
peccu 0:b9ac53c439ed 2188 }
peccu 0:b9ac53c439ed 2189
peccu 0:b9ac53c439ed 2190 // build output
peccu 0:b9ac53c439ed 2191
peccu 0:b9ac53c439ed 2192 model->nr_class = nr_class;
peccu 0:b9ac53c439ed 2193
peccu 0:b9ac53c439ed 2194 model->label = Malloc(int,nr_class);
peccu 0:b9ac53c439ed 2195 for(i=0;i<nr_class;i++)
peccu 0:b9ac53c439ed 2196 model->label[i] = label[i];
peccu 0:b9ac53c439ed 2197
peccu 0:b9ac53c439ed 2198 model->rho = Malloc(double,nr_class*(nr_class-1)/2);
peccu 0:b9ac53c439ed 2199 for(i=0;i<nr_class*(nr_class-1)/2;i++)
peccu 0:b9ac53c439ed 2200 model->rho[i] = f[i].rho;
peccu 0:b9ac53c439ed 2201
peccu 0:b9ac53c439ed 2202 if(param->probability)
peccu 0:b9ac53c439ed 2203 {
peccu 0:b9ac53c439ed 2204 model->probA = Malloc(double,nr_class*(nr_class-1)/2);
peccu 0:b9ac53c439ed 2205 model->probB = Malloc(double,nr_class*(nr_class-1)/2);
peccu 0:b9ac53c439ed 2206 for(i=0;i<nr_class*(nr_class-1)/2;i++)
peccu 0:b9ac53c439ed 2207 {
peccu 0:b9ac53c439ed 2208 model->probA[i] = probA[i];
peccu 0:b9ac53c439ed 2209 model->probB[i] = probB[i];
peccu 0:b9ac53c439ed 2210 }
peccu 0:b9ac53c439ed 2211 }
peccu 0:b9ac53c439ed 2212 else
peccu 0:b9ac53c439ed 2213 {
peccu 0:b9ac53c439ed 2214 model->probA=NULL;
peccu 0:b9ac53c439ed 2215 model->probB=NULL;
peccu 0:b9ac53c439ed 2216 }
peccu 0:b9ac53c439ed 2217
peccu 0:b9ac53c439ed 2218 int total_sv = 0;
peccu 0:b9ac53c439ed 2219 int *nz_count = Malloc(int,nr_class);
peccu 0:b9ac53c439ed 2220 model->nSV = Malloc(int,nr_class);
peccu 0:b9ac53c439ed 2221 for(i=0;i<nr_class;i++)
peccu 0:b9ac53c439ed 2222 {
peccu 0:b9ac53c439ed 2223 int nSV = 0;
peccu 0:b9ac53c439ed 2224 for(int j=0;j<count[i];j++)
peccu 0:b9ac53c439ed 2225 if(nonzero[start[i]+j])
peccu 0:b9ac53c439ed 2226 {
peccu 0:b9ac53c439ed 2227 ++nSV;
peccu 0:b9ac53c439ed 2228 ++total_sv;
peccu 0:b9ac53c439ed 2229 }
peccu 0:b9ac53c439ed 2230 model->nSV[i] = nSV;
peccu 0:b9ac53c439ed 2231 nz_count[i] = nSV;
peccu 0:b9ac53c439ed 2232 }
peccu 0:b9ac53c439ed 2233
peccu 0:b9ac53c439ed 2234 info("Total nSV = %d\n",total_sv);
peccu 0:b9ac53c439ed 2235
peccu 0:b9ac53c439ed 2236 model->l = total_sv;
peccu 0:b9ac53c439ed 2237 model->SV = Malloc(svm_node *,total_sv);
peccu 0:b9ac53c439ed 2238 p = 0;
peccu 0:b9ac53c439ed 2239 for(i=0;i<l;i++)
peccu 0:b9ac53c439ed 2240 if(nonzero[i]) model->SV[p++] = x[i];
peccu 0:b9ac53c439ed 2241
peccu 0:b9ac53c439ed 2242 int *nz_start = Malloc(int,nr_class);
peccu 0:b9ac53c439ed 2243 nz_start[0] = 0;
peccu 0:b9ac53c439ed 2244 for(i=1;i<nr_class;i++)
peccu 0:b9ac53c439ed 2245 nz_start[i] = nz_start[i-1]+nz_count[i-1];
peccu 0:b9ac53c439ed 2246
peccu 0:b9ac53c439ed 2247 model->sv_coef = Malloc(double *,nr_class-1);
peccu 0:b9ac53c439ed 2248 for(i=0;i<nr_class-1;i++)
peccu 0:b9ac53c439ed 2249 model->sv_coef[i] = Malloc(double,total_sv);
peccu 0:b9ac53c439ed 2250
peccu 0:b9ac53c439ed 2251 p = 0;
peccu 0:b9ac53c439ed 2252 for(i=0;i<nr_class;i++)
peccu 0:b9ac53c439ed 2253 for(int j=i+1;j<nr_class;j++)
peccu 0:b9ac53c439ed 2254 {
peccu 0:b9ac53c439ed 2255 // classifier (i,j): coefficients with
peccu 0:b9ac53c439ed 2256 // i are in sv_coef[j-1][nz_start[i]...],
peccu 0:b9ac53c439ed 2257 // j are in sv_coef[i][nz_start[j]...]
peccu 0:b9ac53c439ed 2258
peccu 0:b9ac53c439ed 2259 int si = start[i];
peccu 0:b9ac53c439ed 2260 int sj = start[j];
peccu 0:b9ac53c439ed 2261 int ci = count[i];
peccu 0:b9ac53c439ed 2262 int cj = count[j];
peccu 0:b9ac53c439ed 2263
peccu 0:b9ac53c439ed 2264 int q = nz_start[i];
peccu 0:b9ac53c439ed 2265 int k;
peccu 0:b9ac53c439ed 2266 for(k=0;k<ci;k++)
peccu 0:b9ac53c439ed 2267 if(nonzero[si+k])
peccu 0:b9ac53c439ed 2268 model->sv_coef[j-1][q++] = f[p].alpha[k];
peccu 0:b9ac53c439ed 2269 q = nz_start[j];
peccu 0:b9ac53c439ed 2270 for(k=0;k<cj;k++)
peccu 0:b9ac53c439ed 2271 if(nonzero[sj+k])
peccu 0:b9ac53c439ed 2272 model->sv_coef[i][q++] = f[p].alpha[ci+k];
peccu 0:b9ac53c439ed 2273 ++p;
peccu 0:b9ac53c439ed 2274 }
peccu 0:b9ac53c439ed 2275
peccu 0:b9ac53c439ed 2276 free(label);
peccu 0:b9ac53c439ed 2277 free(probA);
peccu 0:b9ac53c439ed 2278 free(probB);
peccu 0:b9ac53c439ed 2279 free(count);
peccu 0:b9ac53c439ed 2280 free(perm);
peccu 0:b9ac53c439ed 2281 free(start);
peccu 0:b9ac53c439ed 2282 free(x);
peccu 0:b9ac53c439ed 2283 free(weighted_C);
peccu 0:b9ac53c439ed 2284 free(nonzero);
peccu 0:b9ac53c439ed 2285 for(i=0;i<nr_class*(nr_class-1)/2;i++)
peccu 0:b9ac53c439ed 2286 free(f[i].alpha);
peccu 0:b9ac53c439ed 2287 free(f);
peccu 0:b9ac53c439ed 2288 free(nz_count);
peccu 0:b9ac53c439ed 2289 free(nz_start);
peccu 0:b9ac53c439ed 2290 }
peccu 0:b9ac53c439ed 2291 return model;
peccu 0:b9ac53c439ed 2292 }
peccu 0:b9ac53c439ed 2293
peccu 0:b9ac53c439ed 2294 // Stratified cross validation
peccu 0:b9ac53c439ed 2295 void svm_cross_validation(const svm_problem *prob, const svm_parameter *param, int nr_fold, double *target)
peccu 0:b9ac53c439ed 2296 {
peccu 0:b9ac53c439ed 2297 int i;
peccu 0:b9ac53c439ed 2298 int *fold_start = Malloc(int,nr_fold+1);
peccu 0:b9ac53c439ed 2299 int l = prob->l;
peccu 0:b9ac53c439ed 2300 int *perm = Malloc(int,l);
peccu 0:b9ac53c439ed 2301 int nr_class;
peccu 0:b9ac53c439ed 2302
peccu 0:b9ac53c439ed 2303 // stratified cv may not give leave-one-out rate
peccu 0:b9ac53c439ed 2304 // Each class to l folds -> some folds may have zero elements
peccu 0:b9ac53c439ed 2305 if((param->svm_type == C_SVC ||
peccu 0:b9ac53c439ed 2306 param->svm_type == NU_SVC) && nr_fold < l)
peccu 0:b9ac53c439ed 2307 {
peccu 0:b9ac53c439ed 2308 int *start = NULL;
peccu 0:b9ac53c439ed 2309 int *label = NULL;
peccu 0:b9ac53c439ed 2310 int *count = NULL;
peccu 0:b9ac53c439ed 2311 svm_group_classes(prob,&nr_class,&label,&start,&count,perm);
peccu 0:b9ac53c439ed 2312
peccu 0:b9ac53c439ed 2313 // random shuffle and then data grouped by fold using the array perm
peccu 0:b9ac53c439ed 2314 int *fold_count = Malloc(int,nr_fold);
peccu 0:b9ac53c439ed 2315 int c;
peccu 0:b9ac53c439ed 2316 int *index = Malloc(int,l);
peccu 0:b9ac53c439ed 2317 for(i=0;i<l;i++)
peccu 0:b9ac53c439ed 2318 index[i]=perm[i];
peccu 0:b9ac53c439ed 2319 for (c=0; c<nr_class; c++)
peccu 0:b9ac53c439ed 2320 for(i=0;i<count[c];i++)
peccu 0:b9ac53c439ed 2321 {
peccu 0:b9ac53c439ed 2322 int j = i+rand()%(count[c]-i);
peccu 0:b9ac53c439ed 2323 swap(index[start[c]+j],index[start[c]+i]);
peccu 0:b9ac53c439ed 2324 }
peccu 0:b9ac53c439ed 2325 for(i=0;i<nr_fold;i++)
peccu 0:b9ac53c439ed 2326 {
peccu 0:b9ac53c439ed 2327 fold_count[i] = 0;
peccu 0:b9ac53c439ed 2328 for (c=0; c<nr_class;c++)
peccu 0:b9ac53c439ed 2329 fold_count[i]+=(i+1)*count[c]/nr_fold-i*count[c]/nr_fold;
peccu 0:b9ac53c439ed 2330 }
peccu 0:b9ac53c439ed 2331 fold_start[0]=0;
peccu 0:b9ac53c439ed 2332 for (i=1;i<=nr_fold;i++)
peccu 0:b9ac53c439ed 2333 fold_start[i] = fold_start[i-1]+fold_count[i-1];
peccu 0:b9ac53c439ed 2334 for (c=0; c<nr_class;c++)
peccu 0:b9ac53c439ed 2335 for(i=0;i<nr_fold;i++)
peccu 0:b9ac53c439ed 2336 {
peccu 0:b9ac53c439ed 2337 int begin = start[c]+i*count[c]/nr_fold;
peccu 0:b9ac53c439ed 2338 int end = start[c]+(i+1)*count[c]/nr_fold;
peccu 0:b9ac53c439ed 2339 for(int j=begin;j<end;j++)
peccu 0:b9ac53c439ed 2340 {
peccu 0:b9ac53c439ed 2341 perm[fold_start[i]] = index[j];
peccu 0:b9ac53c439ed 2342 fold_start[i]++;
peccu 0:b9ac53c439ed 2343 }
peccu 0:b9ac53c439ed 2344 }
peccu 0:b9ac53c439ed 2345 fold_start[0]=0;
peccu 0:b9ac53c439ed 2346 for (i=1;i<=nr_fold;i++)
peccu 0:b9ac53c439ed 2347 fold_start[i] = fold_start[i-1]+fold_count[i-1];
peccu 0:b9ac53c439ed 2348 free(start);
peccu 0:b9ac53c439ed 2349 free(label);
peccu 0:b9ac53c439ed 2350 free(count);
peccu 0:b9ac53c439ed 2351 free(index);
peccu 0:b9ac53c439ed 2352 free(fold_count);
peccu 0:b9ac53c439ed 2353 }
peccu 0:b9ac53c439ed 2354 else
peccu 0:b9ac53c439ed 2355 {
peccu 0:b9ac53c439ed 2356 for(i=0;i<l;i++) perm[i]=i;
peccu 0:b9ac53c439ed 2357 for(i=0;i<l;i++)
peccu 0:b9ac53c439ed 2358 {
peccu 0:b9ac53c439ed 2359 int j = i+rand()%(l-i);
peccu 0:b9ac53c439ed 2360 swap(perm[i],perm[j]);
peccu 0:b9ac53c439ed 2361 }
peccu 0:b9ac53c439ed 2362 for(i=0;i<=nr_fold;i++)
peccu 0:b9ac53c439ed 2363 fold_start[i]=i*l/nr_fold;
peccu 0:b9ac53c439ed 2364 }
peccu 0:b9ac53c439ed 2365
peccu 0:b9ac53c439ed 2366 for(i=0;i<nr_fold;i++)
peccu 0:b9ac53c439ed 2367 {
peccu 0:b9ac53c439ed 2368 int begin = fold_start[i];
peccu 0:b9ac53c439ed 2369 int end = fold_start[i+1];
peccu 0:b9ac53c439ed 2370 int j,k;
peccu 0:b9ac53c439ed 2371 struct svm_problem subprob;
peccu 0:b9ac53c439ed 2372
peccu 0:b9ac53c439ed 2373 subprob.l = l-(end-begin);
peccu 0:b9ac53c439ed 2374 subprob.x = Malloc(struct svm_node*,subprob.l);
peccu 0:b9ac53c439ed 2375 subprob.y = Malloc(double,subprob.l);
peccu 0:b9ac53c439ed 2376
peccu 0:b9ac53c439ed 2377 k=0;
peccu 0:b9ac53c439ed 2378 for(j=0;j<begin;j++)
peccu 0:b9ac53c439ed 2379 {
peccu 0:b9ac53c439ed 2380 subprob.x[k] = prob->x[perm[j]];
peccu 0:b9ac53c439ed 2381 subprob.y[k] = prob->y[perm[j]];
peccu 0:b9ac53c439ed 2382 ++k;
peccu 0:b9ac53c439ed 2383 }
peccu 0:b9ac53c439ed 2384 for(j=end;j<l;j++)
peccu 0:b9ac53c439ed 2385 {
peccu 0:b9ac53c439ed 2386 subprob.x[k] = prob->x[perm[j]];
peccu 0:b9ac53c439ed 2387 subprob.y[k] = prob->y[perm[j]];
peccu 0:b9ac53c439ed 2388 ++k;
peccu 0:b9ac53c439ed 2389 }
peccu 0:b9ac53c439ed 2390 struct svm_model *submodel = svm_train(&subprob,param);
peccu 0:b9ac53c439ed 2391 if(param->probability &&
peccu 0:b9ac53c439ed 2392 (param->svm_type == C_SVC || param->svm_type == NU_SVC))
peccu 0:b9ac53c439ed 2393 {
peccu 0:b9ac53c439ed 2394 double *prob_estimates=Malloc(double,svm_get_nr_class(submodel));
peccu 0:b9ac53c439ed 2395 for(j=begin;j<end;j++)
peccu 0:b9ac53c439ed 2396 target[perm[j]] = svm_predict_probability(submodel,prob->x[perm[j]],prob_estimates);
peccu 0:b9ac53c439ed 2397 free(prob_estimates);
peccu 0:b9ac53c439ed 2398 }
peccu 0:b9ac53c439ed 2399 else
peccu 0:b9ac53c439ed 2400 for(j=begin;j<end;j++)
peccu 0:b9ac53c439ed 2401 target[perm[j]] = svm_predict(submodel,prob->x[perm[j]]);
peccu 0:b9ac53c439ed 2402 svm_free_and_destroy_model(&submodel);
peccu 0:b9ac53c439ed 2403 free(subprob.x);
peccu 0:b9ac53c439ed 2404 free(subprob.y);
peccu 0:b9ac53c439ed 2405 }
peccu 0:b9ac53c439ed 2406 free(fold_start);
peccu 0:b9ac53c439ed 2407 free(perm);
peccu 0:b9ac53c439ed 2408 }
peccu 0:b9ac53c439ed 2409
peccu 0:b9ac53c439ed 2410
peccu 0:b9ac53c439ed 2411 int svm_get_svm_type(const svm_model *model)
peccu 0:b9ac53c439ed 2412 {
peccu 0:b9ac53c439ed 2413 return model->param.svm_type;
peccu 0:b9ac53c439ed 2414 }
peccu 0:b9ac53c439ed 2415
peccu 0:b9ac53c439ed 2416 int svm_get_nr_class(const svm_model *model)
peccu 0:b9ac53c439ed 2417 {
peccu 0:b9ac53c439ed 2418 return model->nr_class;
peccu 0:b9ac53c439ed 2419 }
peccu 0:b9ac53c439ed 2420
peccu 0:b9ac53c439ed 2421 void svm_get_labels(const svm_model *model, int* label)
peccu 0:b9ac53c439ed 2422 {
peccu 0:b9ac53c439ed 2423 if (model->label != NULL)
peccu 0:b9ac53c439ed 2424 for(int i=0;i<model->nr_class;i++)
peccu 0:b9ac53c439ed 2425 label[i] = model->label[i];
peccu 0:b9ac53c439ed 2426 }
peccu 0:b9ac53c439ed 2427
peccu 0:b9ac53c439ed 2428 double svm_get_svr_probability(const svm_model *model)
peccu 0:b9ac53c439ed 2429 {
peccu 0:b9ac53c439ed 2430 if ((model->param.svm_type == EPSILON_SVR || model->param.svm_type == NU_SVR) &&
peccu 0:b9ac53c439ed 2431 model->probA!=NULL)
peccu 0:b9ac53c439ed 2432 return model->probA[0];
peccu 0:b9ac53c439ed 2433 else
peccu 0:b9ac53c439ed 2434 {
peccu 0:b9ac53c439ed 2435 fprintf(stderr,"Model doesn't contain information for SVR probability inference\n");
peccu 0:b9ac53c439ed 2436 return 0;
peccu 0:b9ac53c439ed 2437 }
peccu 0:b9ac53c439ed 2438 }
peccu 0:b9ac53c439ed 2439
peccu 0:b9ac53c439ed 2440 double svm_predict_values(const svm_model *model, const svm_node *x, double* dec_values)
peccu 0:b9ac53c439ed 2441 {
peccu 0:b9ac53c439ed 2442 if(model->param.svm_type == ONE_CLASS ||
peccu 0:b9ac53c439ed 2443 model->param.svm_type == EPSILON_SVR ||
peccu 0:b9ac53c439ed 2444 model->param.svm_type == NU_SVR)
peccu 0:b9ac53c439ed 2445 {
peccu 0:b9ac53c439ed 2446 double *sv_coef = model->sv_coef[0];
peccu 0:b9ac53c439ed 2447 double sum = 0;
peccu 0:b9ac53c439ed 2448 for(int i=0;i<model->l;i++)
peccu 0:b9ac53c439ed 2449 sum += sv_coef[i] * Kernel::k_function(x,model->SV[i],model->param);
peccu 0:b9ac53c439ed 2450 sum -= model->rho[0];
peccu 0:b9ac53c439ed 2451 *dec_values = sum;
peccu 0:b9ac53c439ed 2452
peccu 0:b9ac53c439ed 2453 if(model->param.svm_type == ONE_CLASS)
peccu 0:b9ac53c439ed 2454 return (sum>0)?1:-1;
peccu 0:b9ac53c439ed 2455 else
peccu 0:b9ac53c439ed 2456 return sum;
peccu 0:b9ac53c439ed 2457 }
peccu 0:b9ac53c439ed 2458 else
peccu 0:b9ac53c439ed 2459 {
peccu 0:b9ac53c439ed 2460 int i;
peccu 0:b9ac53c439ed 2461 int nr_class = model->nr_class;
peccu 0:b9ac53c439ed 2462 int l = model->l;
peccu 0:b9ac53c439ed 2463
peccu 0:b9ac53c439ed 2464 double *kvalue = Malloc(double,l);
peccu 0:b9ac53c439ed 2465 for(i=0;i<l;i++)
peccu 0:b9ac53c439ed 2466 kvalue[i] = Kernel::k_function(x,model->SV[i],model->param);
peccu 0:b9ac53c439ed 2467
peccu 0:b9ac53c439ed 2468 int *start = Malloc(int,nr_class);
peccu 0:b9ac53c439ed 2469 start[0] = 0;
peccu 0:b9ac53c439ed 2470 for(i=1;i<nr_class;i++)
peccu 0:b9ac53c439ed 2471 start[i] = start[i-1]+model->nSV[i-1];
peccu 0:b9ac53c439ed 2472
peccu 0:b9ac53c439ed 2473 int *vote = Malloc(int,nr_class);
peccu 0:b9ac53c439ed 2474 for(i=0;i<nr_class;i++)
peccu 0:b9ac53c439ed 2475 vote[i] = 0;
peccu 0:b9ac53c439ed 2476
peccu 0:b9ac53c439ed 2477 int p=0;
peccu 0:b9ac53c439ed 2478 for(i=0;i<nr_class;i++)
peccu 0:b9ac53c439ed 2479 for(int j=i+1;j<nr_class;j++)
peccu 0:b9ac53c439ed 2480 {
peccu 0:b9ac53c439ed 2481 double sum = 0;
peccu 0:b9ac53c439ed 2482 int si = start[i];
peccu 0:b9ac53c439ed 2483 int sj = start[j];
peccu 0:b9ac53c439ed 2484 int ci = model->nSV[i];
peccu 0:b9ac53c439ed 2485 int cj = model->nSV[j];
peccu 0:b9ac53c439ed 2486
peccu 0:b9ac53c439ed 2487 int k;
peccu 0:b9ac53c439ed 2488 double *coef1 = model->sv_coef[j-1];
peccu 0:b9ac53c439ed 2489 double *coef2 = model->sv_coef[i];
peccu 0:b9ac53c439ed 2490 for(k=0;k<ci;k++)
peccu 0:b9ac53c439ed 2491 sum += coef1[si+k] * kvalue[si+k];
peccu 0:b9ac53c439ed 2492 for(k=0;k<cj;k++)
peccu 0:b9ac53c439ed 2493 sum += coef2[sj+k] * kvalue[sj+k];
peccu 0:b9ac53c439ed 2494 sum -= model->rho[p];
peccu 0:b9ac53c439ed 2495 dec_values[p] = sum;
peccu 0:b9ac53c439ed 2496
peccu 0:b9ac53c439ed 2497 if(dec_values[p] > 0)
peccu 0:b9ac53c439ed 2498 ++vote[i];
peccu 0:b9ac53c439ed 2499 else
peccu 0:b9ac53c439ed 2500 ++vote[j];
peccu 0:b9ac53c439ed 2501 p++;
peccu 0:b9ac53c439ed 2502 }
peccu 0:b9ac53c439ed 2503
peccu 0:b9ac53c439ed 2504 int vote_max_idx = 0;
peccu 0:b9ac53c439ed 2505 for(i=1;i<nr_class;i++)
peccu 0:b9ac53c439ed 2506 if(vote[i] > vote[vote_max_idx])
peccu 0:b9ac53c439ed 2507 vote_max_idx = i;
peccu 0:b9ac53c439ed 2508
peccu 0:b9ac53c439ed 2509 free(kvalue);
peccu 0:b9ac53c439ed 2510 free(start);
peccu 0:b9ac53c439ed 2511 free(vote);
peccu 0:b9ac53c439ed 2512 return model->label[vote_max_idx];
peccu 0:b9ac53c439ed 2513 }
peccu 0:b9ac53c439ed 2514 }
peccu 0:b9ac53c439ed 2515
peccu 0:b9ac53c439ed 2516 double svm_predict(const svm_model *model, const svm_node *x)
peccu 0:b9ac53c439ed 2517 {
peccu 0:b9ac53c439ed 2518 int nr_class = model->nr_class;
peccu 0:b9ac53c439ed 2519 double *dec_values;
peccu 0:b9ac53c439ed 2520 if(model->param.svm_type == ONE_CLASS ||
peccu 0:b9ac53c439ed 2521 model->param.svm_type == EPSILON_SVR ||
peccu 0:b9ac53c439ed 2522 model->param.svm_type == NU_SVR)
peccu 0:b9ac53c439ed 2523 dec_values = Malloc(double, 1);
peccu 0:b9ac53c439ed 2524 else
peccu 0:b9ac53c439ed 2525 dec_values = Malloc(double, nr_class*(nr_class-1)/2);
peccu 0:b9ac53c439ed 2526 double pred_result = svm_predict_values(model, x, dec_values);
peccu 0:b9ac53c439ed 2527 free(dec_values);
peccu 0:b9ac53c439ed 2528 return pred_result;
peccu 0:b9ac53c439ed 2529 }
peccu 0:b9ac53c439ed 2530
peccu 0:b9ac53c439ed 2531 double svm_predict_probability(
peccu 0:b9ac53c439ed 2532 const svm_model *model, const svm_node *x, double *prob_estimates)
peccu 0:b9ac53c439ed 2533 {
peccu 0:b9ac53c439ed 2534 if ((model->param.svm_type == C_SVC || model->param.svm_type == NU_SVC) &&
peccu 0:b9ac53c439ed 2535 model->probA!=NULL && model->probB!=NULL)
peccu 0:b9ac53c439ed 2536 {
peccu 0:b9ac53c439ed 2537 int i;
peccu 0:b9ac53c439ed 2538 int nr_class = model->nr_class;
peccu 0:b9ac53c439ed 2539 double *dec_values = Malloc(double, nr_class*(nr_class-1)/2);
peccu 0:b9ac53c439ed 2540 svm_predict_values(model, x, dec_values);
peccu 0:b9ac53c439ed 2541
peccu 0:b9ac53c439ed 2542 double min_prob=1e-7;
peccu 0:b9ac53c439ed 2543 double **pairwise_prob=Malloc(double *,nr_class);
peccu 0:b9ac53c439ed 2544 for(i=0;i<nr_class;i++)
peccu 0:b9ac53c439ed 2545 pairwise_prob[i]=Malloc(double,nr_class);
peccu 0:b9ac53c439ed 2546 int k=0;
peccu 0:b9ac53c439ed 2547 for(i=0;i<nr_class;i++)
peccu 0:b9ac53c439ed 2548 for(int j=i+1;j<nr_class;j++)
peccu 0:b9ac53c439ed 2549 {
peccu 0:b9ac53c439ed 2550 pairwise_prob[i][j]=min(max(sigmoid_predict(dec_values[k],model->probA[k],model->probB[k]),min_prob),1-min_prob);
peccu 0:b9ac53c439ed 2551 pairwise_prob[j][i]=1-pairwise_prob[i][j];
peccu 0:b9ac53c439ed 2552 k++;
peccu 0:b9ac53c439ed 2553 }
peccu 0:b9ac53c439ed 2554 multiclass_probability(nr_class,pairwise_prob,prob_estimates);
peccu 0:b9ac53c439ed 2555
peccu 0:b9ac53c439ed 2556 int prob_max_idx = 0;
peccu 0:b9ac53c439ed 2557 for(i=1;i<nr_class;i++)
peccu 0:b9ac53c439ed 2558 if(prob_estimates[i] > prob_estimates[prob_max_idx])
peccu 0:b9ac53c439ed 2559 prob_max_idx = i;
peccu 0:b9ac53c439ed 2560 for(i=0;i<nr_class;i++)
peccu 0:b9ac53c439ed 2561 free(pairwise_prob[i]);
peccu 0:b9ac53c439ed 2562 free(dec_values);
peccu 0:b9ac53c439ed 2563 free(pairwise_prob);
peccu 0:b9ac53c439ed 2564 return model->label[prob_max_idx];
peccu 0:b9ac53c439ed 2565 }
peccu 0:b9ac53c439ed 2566 else
peccu 0:b9ac53c439ed 2567 return svm_predict(model, x);
peccu 0:b9ac53c439ed 2568 }
peccu 0:b9ac53c439ed 2569
peccu 0:b9ac53c439ed 2570 static const char *svm_type_table[] =
peccu 0:b9ac53c439ed 2571 {
peccu 0:b9ac53c439ed 2572 "c_svc","nu_svc","one_class","epsilon_svr","nu_svr",NULL
peccu 0:b9ac53c439ed 2573 };
peccu 0:b9ac53c439ed 2574
peccu 0:b9ac53c439ed 2575 static const char *kernel_type_table[]=
peccu 0:b9ac53c439ed 2576 {
peccu 0:b9ac53c439ed 2577 "linear","polynomial","rbf","sigmoid","precomputed",NULL
peccu 0:b9ac53c439ed 2578 };
peccu 0:b9ac53c439ed 2579
peccu 0:b9ac53c439ed 2580 int svm_save_model(const char *model_file_name, const svm_model *model)
peccu 0:b9ac53c439ed 2581 {
peccu 0:b9ac53c439ed 2582 FILE *fp = fopen(model_file_name,"w");
peccu 0:b9ac53c439ed 2583 if(fp==NULL) return -1;
peccu 0:b9ac53c439ed 2584
peccu 0:b9ac53c439ed 2585 const svm_parameter& param = model->param;
peccu 0:b9ac53c439ed 2586
peccu 0:b9ac53c439ed 2587 fprintf(fp,"svm_type %s\n", svm_type_table[param.svm_type]);
peccu 0:b9ac53c439ed 2588 fprintf(fp,"kernel_type %s\n", kernel_type_table[param.kernel_type]);
peccu 0:b9ac53c439ed 2589
peccu 0:b9ac53c439ed 2590 if(param.kernel_type == POLY)
peccu 0:b9ac53c439ed 2591 fprintf(fp,"degree %d\n", param.degree);
peccu 0:b9ac53c439ed 2592
peccu 0:b9ac53c439ed 2593 if(param.kernel_type == POLY || param.kernel_type == RBF || param.kernel_type == SIGMOID)
peccu 0:b9ac53c439ed 2594 fprintf(fp,"gamma %g\n", param.gamma);
peccu 0:b9ac53c439ed 2595
peccu 0:b9ac53c439ed 2596 if(param.kernel_type == POLY || param.kernel_type == SIGMOID)
peccu 0:b9ac53c439ed 2597 fprintf(fp,"coef0 %g\n", param.coef0);
peccu 0:b9ac53c439ed 2598
peccu 0:b9ac53c439ed 2599 int nr_class = model->nr_class;
peccu 0:b9ac53c439ed 2600 int l = model->l;
peccu 0:b9ac53c439ed 2601 fprintf(fp, "nr_class %d\n", nr_class);
peccu 0:b9ac53c439ed 2602 fprintf(fp, "total_sv %d\n",l);
peccu 0:b9ac53c439ed 2603
peccu 0:b9ac53c439ed 2604 {
peccu 0:b9ac53c439ed 2605 fprintf(fp, "rho");
peccu 0:b9ac53c439ed 2606 for(int i=0;i<nr_class*(nr_class-1)/2;i++)
peccu 0:b9ac53c439ed 2607 fprintf(fp," %g",model->rho[i]);
peccu 0:b9ac53c439ed 2608 fprintf(fp, "\n");
peccu 0:b9ac53c439ed 2609 }
peccu 0:b9ac53c439ed 2610
peccu 0:b9ac53c439ed 2611 if(model->label)
peccu 0:b9ac53c439ed 2612 {
peccu 0:b9ac53c439ed 2613 fprintf(fp, "label");
peccu 0:b9ac53c439ed 2614 for(int i=0;i<nr_class;i++)
peccu 0:b9ac53c439ed 2615 fprintf(fp," %d",model->label[i]);
peccu 0:b9ac53c439ed 2616 fprintf(fp, "\n");
peccu 0:b9ac53c439ed 2617 }
peccu 0:b9ac53c439ed 2618
peccu 0:b9ac53c439ed 2619 if(model->probA) // regression has probA only
peccu 0:b9ac53c439ed 2620 {
peccu 0:b9ac53c439ed 2621 fprintf(fp, "probA");
peccu 0:b9ac53c439ed 2622 for(int i=0;i<nr_class*(nr_class-1)/2;i++)
peccu 0:b9ac53c439ed 2623 fprintf(fp," %g",model->probA[i]);
peccu 0:b9ac53c439ed 2624 fprintf(fp, "\n");
peccu 0:b9ac53c439ed 2625 }
peccu 0:b9ac53c439ed 2626 if(model->probB)
peccu 0:b9ac53c439ed 2627 {
peccu 0:b9ac53c439ed 2628 fprintf(fp, "probB");
peccu 0:b9ac53c439ed 2629 for(int i=0;i<nr_class*(nr_class-1)/2;i++)
peccu 0:b9ac53c439ed 2630 fprintf(fp," %g",model->probB[i]);
peccu 0:b9ac53c439ed 2631 fprintf(fp, "\n");
peccu 0:b9ac53c439ed 2632 }
peccu 0:b9ac53c439ed 2633
peccu 0:b9ac53c439ed 2634 if(model->nSV)
peccu 0:b9ac53c439ed 2635 {
peccu 0:b9ac53c439ed 2636 fprintf(fp, "nr_sv");
peccu 0:b9ac53c439ed 2637 for(int i=0;i<nr_class;i++)
peccu 0:b9ac53c439ed 2638 fprintf(fp," %d",model->nSV[i]);
peccu 0:b9ac53c439ed 2639 fprintf(fp, "\n");
peccu 0:b9ac53c439ed 2640 }
peccu 0:b9ac53c439ed 2641
peccu 0:b9ac53c439ed 2642 fprintf(fp, "SV\n");
peccu 0:b9ac53c439ed 2643 const double * const *sv_coef = model->sv_coef;
peccu 0:b9ac53c439ed 2644 const svm_node * const *SV = model->SV;
peccu 0:b9ac53c439ed 2645
peccu 0:b9ac53c439ed 2646 for(int i=0;i<l;i++)
peccu 0:b9ac53c439ed 2647 {
peccu 0:b9ac53c439ed 2648 for(int j=0;j<nr_class-1;j++)
peccu 0:b9ac53c439ed 2649 fprintf(fp, "%.16g ",sv_coef[j][i]);
peccu 0:b9ac53c439ed 2650
peccu 0:b9ac53c439ed 2651 const svm_node *p = SV[i];
peccu 0:b9ac53c439ed 2652
peccu 0:b9ac53c439ed 2653 if(param.kernel_type == PRECOMPUTED)
peccu 0:b9ac53c439ed 2654 fprintf(fp,"0:%d ",(int)(p->value));
peccu 0:b9ac53c439ed 2655 else
peccu 0:b9ac53c439ed 2656 while(p->index != -1)
peccu 0:b9ac53c439ed 2657 {
peccu 0:b9ac53c439ed 2658 fprintf(fp,"%d:%.8g ",p->index,p->value);
peccu 0:b9ac53c439ed 2659 p++;
peccu 0:b9ac53c439ed 2660 }
peccu 0:b9ac53c439ed 2661 fprintf(fp, "\n");
peccu 0:b9ac53c439ed 2662 }
peccu 0:b9ac53c439ed 2663 if (ferror(fp) != 0 || fclose(fp) != 0) return -1;
peccu 0:b9ac53c439ed 2664 else return 0;
peccu 0:b9ac53c439ed 2665 }
peccu 0:b9ac53c439ed 2666
peccu 0:b9ac53c439ed 2667 static char *line = NULL;
peccu 0:b9ac53c439ed 2668 static int max_line_len;
peccu 0:b9ac53c439ed 2669
peccu 0:b9ac53c439ed 2670 static char* readline(FILE *input)
peccu 0:b9ac53c439ed 2671 {
peccu 0:b9ac53c439ed 2672 int len;
peccu 0:b9ac53c439ed 2673
peccu 0:b9ac53c439ed 2674 if(fgets(line,max_line_len,input) == NULL)
peccu 0:b9ac53c439ed 2675 return NULL;
peccu 0:b9ac53c439ed 2676
peccu 0:b9ac53c439ed 2677 while(strrchr(line,'\n') == NULL)
peccu 0:b9ac53c439ed 2678 {
peccu 0:b9ac53c439ed 2679 max_line_len *= 2;
peccu 0:b9ac53c439ed 2680 line = (char *) realloc(line,max_line_len);
peccu 0:b9ac53c439ed 2681 len = (int) strlen(line);
peccu 0:b9ac53c439ed 2682 if(fgets(line+len,max_line_len-len,input) == NULL)
peccu 0:b9ac53c439ed 2683 break;
peccu 0:b9ac53c439ed 2684 }
peccu 0:b9ac53c439ed 2685 return line;
peccu 0:b9ac53c439ed 2686 }
peccu 0:b9ac53c439ed 2687
peccu 0:b9ac53c439ed 2688 svm_model *svm_load_model(const char *model_file_name)
peccu 0:b9ac53c439ed 2689 {
peccu 0:b9ac53c439ed 2690 FILE *fp = fopen(model_file_name,"rb");
peccu 0:b9ac53c439ed 2691 printf("load\r\n");
peccu 0:b9ac53c439ed 2692 if(fp==NULL) return NULL;
peccu 0:b9ac53c439ed 2693 printf("loaded\r\n");
peccu 0:b9ac53c439ed 2694 // read parameters
peccu 0:b9ac53c439ed 2695
peccu 0:b9ac53c439ed 2696 svm_model *model = Malloc(svm_model,1);
peccu 0:b9ac53c439ed 2697 svm_parameter& param = model->param;
peccu 0:b9ac53c439ed 2698 model->rho = NULL;
peccu 0:b9ac53c439ed 2699 model->probA = NULL;
peccu 0:b9ac53c439ed 2700 model->probB = NULL;
peccu 0:b9ac53c439ed 2701 model->label = NULL;
peccu 0:b9ac53c439ed 2702 model->nSV = NULL;
peccu 0:b9ac53c439ed 2703
peccu 0:b9ac53c439ed 2704 char cmd[81];
peccu 0:b9ac53c439ed 2705 while(1)
peccu 0:b9ac53c439ed 2706 {
peccu 0:b9ac53c439ed 2707 fscanf(fp,"%80s",cmd);
peccu 0:b9ac53c439ed 2708
peccu 0:b9ac53c439ed 2709 if(strcmp(cmd,"svm_type")==0)
peccu 0:b9ac53c439ed 2710 {
peccu 0:b9ac53c439ed 2711 fscanf(fp,"%80s",cmd);
peccu 0:b9ac53c439ed 2712 int i;
peccu 0:b9ac53c439ed 2713 for(i=0;svm_type_table[i];i++)
peccu 0:b9ac53c439ed 2714 {
peccu 0:b9ac53c439ed 2715 if(strcmp(svm_type_table[i],cmd)==0)
peccu 0:b9ac53c439ed 2716 {
peccu 0:b9ac53c439ed 2717 param.svm_type=i;
peccu 0:b9ac53c439ed 2718 break;
peccu 0:b9ac53c439ed 2719 }
peccu 0:b9ac53c439ed 2720 }
peccu 0:b9ac53c439ed 2721 if(svm_type_table[i] == NULL)
peccu 0:b9ac53c439ed 2722 {
peccu 0:b9ac53c439ed 2723 fprintf(stderr,"unknown svm type.\n");
peccu 0:b9ac53c439ed 2724 free(model->rho);
peccu 0:b9ac53c439ed 2725 free(model->label);
peccu 0:b9ac53c439ed 2726 free(model->nSV);
peccu 0:b9ac53c439ed 2727 free(model);
peccu 0:b9ac53c439ed 2728 return NULL;
peccu 0:b9ac53c439ed 2729 }
peccu 0:b9ac53c439ed 2730 }
peccu 0:b9ac53c439ed 2731 else if(strcmp(cmd,"kernel_type")==0)
peccu 0:b9ac53c439ed 2732 {
peccu 0:b9ac53c439ed 2733 fscanf(fp,"%80s",cmd);
peccu 0:b9ac53c439ed 2734 int i;
peccu 0:b9ac53c439ed 2735 for(i=0;kernel_type_table[i];i++)
peccu 0:b9ac53c439ed 2736 {
peccu 0:b9ac53c439ed 2737 if(strcmp(kernel_type_table[i],cmd)==0)
peccu 0:b9ac53c439ed 2738 {
peccu 0:b9ac53c439ed 2739 param.kernel_type=i;
peccu 0:b9ac53c439ed 2740 break;
peccu 0:b9ac53c439ed 2741 }
peccu 0:b9ac53c439ed 2742 }
peccu 0:b9ac53c439ed 2743 if(kernel_type_table[i] == NULL)
peccu 0:b9ac53c439ed 2744 {
peccu 0:b9ac53c439ed 2745 fprintf(stderr,"unknown kernel function.\n");
peccu 0:b9ac53c439ed 2746 free(model->rho);
peccu 0:b9ac53c439ed 2747 free(model->label);
peccu 0:b9ac53c439ed 2748 free(model->nSV);
peccu 0:b9ac53c439ed 2749 free(model);
peccu 0:b9ac53c439ed 2750 return NULL;
peccu 0:b9ac53c439ed 2751 }
peccu 0:b9ac53c439ed 2752 }
peccu 0:b9ac53c439ed 2753 else if(strcmp(cmd,"degree")==0)
peccu 0:b9ac53c439ed 2754 fscanf(fp,"%d",&param.degree);
peccu 0:b9ac53c439ed 2755 else if(strcmp(cmd,"gamma")==0)
peccu 0:b9ac53c439ed 2756 fscanf(fp,"%lf",&param.gamma);
peccu 0:b9ac53c439ed 2757 else if(strcmp(cmd,"coef0")==0)
peccu 0:b9ac53c439ed 2758 fscanf(fp,"%lf",&param.coef0);
peccu 0:b9ac53c439ed 2759 else if(strcmp(cmd,"nr_class")==0)
peccu 0:b9ac53c439ed 2760 fscanf(fp,"%d",&model->nr_class);
peccu 0:b9ac53c439ed 2761 else if(strcmp(cmd,"total_sv")==0)
peccu 0:b9ac53c439ed 2762 fscanf(fp,"%d",&model->l);
peccu 0:b9ac53c439ed 2763 else if(strcmp(cmd,"rho")==0)
peccu 0:b9ac53c439ed 2764 {
peccu 0:b9ac53c439ed 2765 int n = model->nr_class * (model->nr_class-1)/2;
peccu 0:b9ac53c439ed 2766 model->rho = Malloc(double,n);
peccu 0:b9ac53c439ed 2767 for(int i=0;i<n;i++)
peccu 0:b9ac53c439ed 2768 fscanf(fp,"%lf",&model->rho[i]);
peccu 0:b9ac53c439ed 2769 }
peccu 0:b9ac53c439ed 2770 else if(strcmp(cmd,"label")==0)
peccu 0:b9ac53c439ed 2771 {
peccu 0:b9ac53c439ed 2772 int n = model->nr_class;
peccu 0:b9ac53c439ed 2773 model->label = Malloc(int,n);
peccu 0:b9ac53c439ed 2774 for(int i=0;i<n;i++)
peccu 0:b9ac53c439ed 2775 fscanf(fp,"%d",&model->label[i]);
peccu 0:b9ac53c439ed 2776 }
peccu 0:b9ac53c439ed 2777 else if(strcmp(cmd,"probA")==0)
peccu 0:b9ac53c439ed 2778 {
peccu 0:b9ac53c439ed 2779 int n = model->nr_class * (model->nr_class-1)/2;
peccu 0:b9ac53c439ed 2780 model->probA = Malloc(double,n);
peccu 0:b9ac53c439ed 2781 for(int i=0;i<n;i++)
peccu 0:b9ac53c439ed 2782 fscanf(fp,"%lf",&model->probA[i]);
peccu 0:b9ac53c439ed 2783 }
peccu 0:b9ac53c439ed 2784 else if(strcmp(cmd,"probB")==0)
peccu 0:b9ac53c439ed 2785 {
peccu 0:b9ac53c439ed 2786 int n = model->nr_class * (model->nr_class-1)/2;
peccu 0:b9ac53c439ed 2787 model->probB = Malloc(double,n);
peccu 0:b9ac53c439ed 2788 for(int i=0;i<n;i++)
peccu 0:b9ac53c439ed 2789 fscanf(fp,"%lf",&model->probB[i]);
peccu 0:b9ac53c439ed 2790 }
peccu 0:b9ac53c439ed 2791 else if(strcmp(cmd,"nr_sv")==0)
peccu 0:b9ac53c439ed 2792 {
peccu 0:b9ac53c439ed 2793 int n = model->nr_class;
peccu 0:b9ac53c439ed 2794 model->nSV = Malloc(int,n);
peccu 0:b9ac53c439ed 2795 for(int i=0;i<n;i++)
peccu 0:b9ac53c439ed 2796 fscanf(fp,"%d",&model->nSV[i]);
peccu 0:b9ac53c439ed 2797 }
peccu 0:b9ac53c439ed 2798 else if(strcmp(cmd,"SV")==0)
peccu 0:b9ac53c439ed 2799 {
peccu 0:b9ac53c439ed 2800 while(1)
peccu 0:b9ac53c439ed 2801 {
peccu 0:b9ac53c439ed 2802 int c = getc(fp);
peccu 0:b9ac53c439ed 2803 if(c==EOF || c=='\n') break;
peccu 0:b9ac53c439ed 2804 }
peccu 0:b9ac53c439ed 2805 break;
peccu 0:b9ac53c439ed 2806 }
peccu 0:b9ac53c439ed 2807 else
peccu 0:b9ac53c439ed 2808 {
peccu 0:b9ac53c439ed 2809 fprintf(stderr,"unknown text in model file: [%s]\n",cmd);
peccu 0:b9ac53c439ed 2810 free(model->rho);
peccu 0:b9ac53c439ed 2811 free(model->label);
peccu 0:b9ac53c439ed 2812 free(model->nSV);
peccu 0:b9ac53c439ed 2813 free(model);
peccu 0:b9ac53c439ed 2814 return NULL;
peccu 0:b9ac53c439ed 2815 }
peccu 0:b9ac53c439ed 2816 }
peccu 0:b9ac53c439ed 2817
peccu 0:b9ac53c439ed 2818 // read sv_coef and SV
peccu 0:b9ac53c439ed 2819
peccu 0:b9ac53c439ed 2820 int elements = 0;
peccu 0:b9ac53c439ed 2821 long pos = ftell(fp);
peccu 0:b9ac53c439ed 2822
peccu 0:b9ac53c439ed 2823 max_line_len = 1024;
peccu 0:b9ac53c439ed 2824 line = Malloc(char,max_line_len);
peccu 0:b9ac53c439ed 2825 char *p,*endptr,*idx,*val;
peccu 0:b9ac53c439ed 2826
peccu 0:b9ac53c439ed 2827 while(readline(fp)!=NULL)
peccu 0:b9ac53c439ed 2828 {
peccu 0:b9ac53c439ed 2829 p = strtok(line,":");
peccu 0:b9ac53c439ed 2830 while(1)
peccu 0:b9ac53c439ed 2831 {
peccu 0:b9ac53c439ed 2832 p = strtok(NULL,":");
peccu 0:b9ac53c439ed 2833 if(p == NULL)
peccu 0:b9ac53c439ed 2834 break;
peccu 0:b9ac53c439ed 2835 ++elements;
peccu 0:b9ac53c439ed 2836 }
peccu 0:b9ac53c439ed 2837 }
peccu 0:b9ac53c439ed 2838 elements += model->l;
peccu 0:b9ac53c439ed 2839
peccu 0:b9ac53c439ed 2840 fseek(fp,pos,SEEK_SET);
peccu 0:b9ac53c439ed 2841
peccu 0:b9ac53c439ed 2842 int m = model->nr_class - 1;
peccu 0:b9ac53c439ed 2843 int l = model->l;
peccu 0:b9ac53c439ed 2844 model->sv_coef = Malloc(double *,m);
peccu 0:b9ac53c439ed 2845 int i;
peccu 0:b9ac53c439ed 2846 for(i=0;i<m;i++)
peccu 0:b9ac53c439ed 2847 model->sv_coef[i] = Malloc(double,l);
peccu 0:b9ac53c439ed 2848 model->SV = Malloc(svm_node*,l);
peccu 0:b9ac53c439ed 2849 svm_node *x_space = NULL;
peccu 0:b9ac53c439ed 2850 if(l>0) x_space = Malloc(svm_node,elements);
peccu 0:b9ac53c439ed 2851
peccu 0:b9ac53c439ed 2852 int j=0;
peccu 0:b9ac53c439ed 2853 for(i=0;i<l;i++)
peccu 0:b9ac53c439ed 2854 {
peccu 0:b9ac53c439ed 2855 readline(fp);
peccu 0:b9ac53c439ed 2856 model->SV[i] = &x_space[j];
peccu 0:b9ac53c439ed 2857
peccu 0:b9ac53c439ed 2858 p = strtok(line, " \t");
peccu 0:b9ac53c439ed 2859 model->sv_coef[0][i] = strtod(p,&endptr);
peccu 0:b9ac53c439ed 2860 for(int k=1;k<m;k++)
peccu 0:b9ac53c439ed 2861 {
peccu 0:b9ac53c439ed 2862 p = strtok(NULL, " \t");
peccu 0:b9ac53c439ed 2863 model->sv_coef[k][i] = strtod(p,&endptr);
peccu 0:b9ac53c439ed 2864 }
peccu 0:b9ac53c439ed 2865
peccu 0:b9ac53c439ed 2866 while(1)
peccu 0:b9ac53c439ed 2867 {
peccu 0:b9ac53c439ed 2868 idx = strtok(NULL, ":");
peccu 0:b9ac53c439ed 2869 val = strtok(NULL, " \t");
peccu 0:b9ac53c439ed 2870
peccu 0:b9ac53c439ed 2871 if(val == NULL)
peccu 0:b9ac53c439ed 2872 break;
peccu 0:b9ac53c439ed 2873 x_space[j].index = (int) strtol(idx,&endptr,10);
peccu 0:b9ac53c439ed 2874 x_space[j].value = strtod(val,&endptr);
peccu 0:b9ac53c439ed 2875
peccu 0:b9ac53c439ed 2876 ++j;
peccu 0:b9ac53c439ed 2877 }
peccu 0:b9ac53c439ed 2878 x_space[j++].index = -1;
peccu 0:b9ac53c439ed 2879 }
peccu 0:b9ac53c439ed 2880 free(line);
peccu 0:b9ac53c439ed 2881
peccu 0:b9ac53c439ed 2882 if (ferror(fp) != 0 || fclose(fp) != 0)
peccu 0:b9ac53c439ed 2883 return NULL;
peccu 0:b9ac53c439ed 2884
peccu 0:b9ac53c439ed 2885 model->free_sv = 1; // XXX
peccu 0:b9ac53c439ed 2886 return model;
peccu 0:b9ac53c439ed 2887 }
peccu 0:b9ac53c439ed 2888
peccu 0:b9ac53c439ed 2889 void svm_free_model_content(svm_model* model_ptr)
peccu 0:b9ac53c439ed 2890 {
peccu 0:b9ac53c439ed 2891 if(model_ptr->free_sv && model_ptr->l > 0)
peccu 0:b9ac53c439ed 2892 free((void *)(model_ptr->SV[0]));
peccu 0:b9ac53c439ed 2893 for(int i=0;i<model_ptr->nr_class-1;i++)
peccu 0:b9ac53c439ed 2894 free(model_ptr->sv_coef[i]);
peccu 0:b9ac53c439ed 2895 free(model_ptr->SV);
peccu 0:b9ac53c439ed 2896 free(model_ptr->sv_coef);
peccu 0:b9ac53c439ed 2897 free(model_ptr->rho);
peccu 0:b9ac53c439ed 2898 free(model_ptr->label);
peccu 0:b9ac53c439ed 2899 free(model_ptr->probA);
peccu 0:b9ac53c439ed 2900 free(model_ptr->probB);
peccu 0:b9ac53c439ed 2901 free(model_ptr->nSV);
peccu 0:b9ac53c439ed 2902 }
peccu 0:b9ac53c439ed 2903
peccu 0:b9ac53c439ed 2904 void svm_free_and_destroy_model(svm_model** model_ptr_ptr)
peccu 0:b9ac53c439ed 2905 {
peccu 0:b9ac53c439ed 2906 svm_model* model_ptr = *model_ptr_ptr;
peccu 0:b9ac53c439ed 2907 if(model_ptr != NULL)
peccu 0:b9ac53c439ed 2908 {
peccu 0:b9ac53c439ed 2909 svm_free_model_content(model_ptr);
peccu 0:b9ac53c439ed 2910 free(model_ptr);
peccu 0:b9ac53c439ed 2911 }
peccu 0:b9ac53c439ed 2912 }
peccu 0:b9ac53c439ed 2913
peccu 0:b9ac53c439ed 2914 void svm_destroy_model(svm_model* model_ptr)
peccu 0:b9ac53c439ed 2915 {
peccu 0:b9ac53c439ed 2916 fprintf(stderr,"warning: svm_destroy_model is deprecated and should not be used. Please use svm_free_and_destroy_model(svm_model **model_ptr_ptr)\n");
peccu 0:b9ac53c439ed 2917 svm_free_and_destroy_model(&model_ptr);
peccu 0:b9ac53c439ed 2918 }
peccu 0:b9ac53c439ed 2919
peccu 0:b9ac53c439ed 2920 void svm_destroy_param(svm_parameter* param)
peccu 0:b9ac53c439ed 2921 {
peccu 0:b9ac53c439ed 2922 free(param->weight_label);
peccu 0:b9ac53c439ed 2923 free(param->weight);
peccu 0:b9ac53c439ed 2924 }
peccu 0:b9ac53c439ed 2925
peccu 0:b9ac53c439ed 2926 const char *svm_check_parameter(const svm_problem *prob, const svm_parameter *param)
peccu 0:b9ac53c439ed 2927 {
peccu 0:b9ac53c439ed 2928 // svm_type
peccu 0:b9ac53c439ed 2929
peccu 0:b9ac53c439ed 2930 int svm_type = param->svm_type;
peccu 0:b9ac53c439ed 2931 if(svm_type != C_SVC &&
peccu 0:b9ac53c439ed 2932 svm_type != NU_SVC &&
peccu 0:b9ac53c439ed 2933 svm_type != ONE_CLASS &&
peccu 0:b9ac53c439ed 2934 svm_type != EPSILON_SVR &&
peccu 0:b9ac53c439ed 2935 svm_type != NU_SVR)
peccu 0:b9ac53c439ed 2936 return "unknown svm type";
peccu 0:b9ac53c439ed 2937
peccu 0:b9ac53c439ed 2938 // kernel_type, degree
peccu 0:b9ac53c439ed 2939
peccu 0:b9ac53c439ed 2940 int kernel_type = param->kernel_type;
peccu 0:b9ac53c439ed 2941 if(kernel_type != LINEAR &&
peccu 0:b9ac53c439ed 2942 kernel_type != POLY &&
peccu 0:b9ac53c439ed 2943 kernel_type != RBF &&
peccu 0:b9ac53c439ed 2944 kernel_type != SIGMOID &&
peccu 0:b9ac53c439ed 2945 kernel_type != PRECOMPUTED)
peccu 0:b9ac53c439ed 2946 return "unknown kernel type";
peccu 0:b9ac53c439ed 2947
peccu 0:b9ac53c439ed 2948 if(param->gamma < 0)
peccu 0:b9ac53c439ed 2949 return "gamma < 0";
peccu 0:b9ac53c439ed 2950
peccu 0:b9ac53c439ed 2951 if(param->degree < 0)
peccu 0:b9ac53c439ed 2952 return "degree of polynomial kernel < 0";
peccu 0:b9ac53c439ed 2953
peccu 0:b9ac53c439ed 2954 // cache_size,eps,C,nu,p,shrinking
peccu 0:b9ac53c439ed 2955
peccu 0:b9ac53c439ed 2956 if(param->cache_size <= 0)
peccu 0:b9ac53c439ed 2957 return "cache_size <= 0";
peccu 0:b9ac53c439ed 2958
peccu 0:b9ac53c439ed 2959 if(param->eps <= 0)
peccu 0:b9ac53c439ed 2960 return "eps <= 0";
peccu 0:b9ac53c439ed 2961
peccu 0:b9ac53c439ed 2962 if(svm_type == C_SVC ||
peccu 0:b9ac53c439ed 2963 svm_type == EPSILON_SVR ||
peccu 0:b9ac53c439ed 2964 svm_type == NU_SVR)
peccu 0:b9ac53c439ed 2965 if(param->C <= 0)
peccu 0:b9ac53c439ed 2966 return "C <= 0";
peccu 0:b9ac53c439ed 2967
peccu 0:b9ac53c439ed 2968 if(svm_type == NU_SVC ||
peccu 0:b9ac53c439ed 2969 svm_type == ONE_CLASS ||
peccu 0:b9ac53c439ed 2970 svm_type == NU_SVR)
peccu 0:b9ac53c439ed 2971 if(param->nu <= 0 || param->nu > 1)
peccu 0:b9ac53c439ed 2972 return "nu <= 0 or nu > 1";
peccu 0:b9ac53c439ed 2973
peccu 0:b9ac53c439ed 2974 if(svm_type == EPSILON_SVR)
peccu 0:b9ac53c439ed 2975 if(param->p < 0)
peccu 0:b9ac53c439ed 2976 return "p < 0";
peccu 0:b9ac53c439ed 2977
peccu 0:b9ac53c439ed 2978 if(param->shrinking != 0 &&
peccu 0:b9ac53c439ed 2979 param->shrinking != 1)
peccu 0:b9ac53c439ed 2980 return "shrinking != 0 and shrinking != 1";
peccu 0:b9ac53c439ed 2981
peccu 0:b9ac53c439ed 2982 if(param->probability != 0 &&
peccu 0:b9ac53c439ed 2983 param->probability != 1)
peccu 0:b9ac53c439ed 2984 return "probability != 0 and probability != 1";
peccu 0:b9ac53c439ed 2985
peccu 0:b9ac53c439ed 2986 if(param->probability == 1 &&
peccu 0:b9ac53c439ed 2987 svm_type == ONE_CLASS)
peccu 0:b9ac53c439ed 2988 return "one-class SVM probability output not supported yet";
peccu 0:b9ac53c439ed 2989
peccu 0:b9ac53c439ed 2990
peccu 0:b9ac53c439ed 2991 // check whether nu-svc is feasible
peccu 0:b9ac53c439ed 2992
peccu 0:b9ac53c439ed 2993 if(svm_type == NU_SVC)
peccu 0:b9ac53c439ed 2994 {
peccu 0:b9ac53c439ed 2995 int l = prob->l;
peccu 0:b9ac53c439ed 2996 int max_nr_class = 16;
peccu 0:b9ac53c439ed 2997 int nr_class = 0;
peccu 0:b9ac53c439ed 2998 int *label = Malloc(int,max_nr_class);
peccu 0:b9ac53c439ed 2999 int *count = Malloc(int,max_nr_class);
peccu 0:b9ac53c439ed 3000
peccu 0:b9ac53c439ed 3001 int i;
peccu 0:b9ac53c439ed 3002 for(i=0;i<l;i++)
peccu 0:b9ac53c439ed 3003 {
peccu 0:b9ac53c439ed 3004 int this_label = (int)prob->y[i];
peccu 0:b9ac53c439ed 3005 int j;
peccu 0:b9ac53c439ed 3006 for(j=0;j<nr_class;j++)
peccu 0:b9ac53c439ed 3007 if(this_label == label[j])
peccu 0:b9ac53c439ed 3008 {
peccu 0:b9ac53c439ed 3009 ++count[j];
peccu 0:b9ac53c439ed 3010 break;
peccu 0:b9ac53c439ed 3011 }
peccu 0:b9ac53c439ed 3012 if(j == nr_class)
peccu 0:b9ac53c439ed 3013 {
peccu 0:b9ac53c439ed 3014 if(nr_class == max_nr_class)
peccu 0:b9ac53c439ed 3015 {
peccu 0:b9ac53c439ed 3016 max_nr_class *= 2;
peccu 0:b9ac53c439ed 3017 label = (int *)realloc(label,max_nr_class*sizeof(int));
peccu 0:b9ac53c439ed 3018 count = (int *)realloc(count,max_nr_class*sizeof(int));
peccu 0:b9ac53c439ed 3019 }
peccu 0:b9ac53c439ed 3020 label[nr_class] = this_label;
peccu 0:b9ac53c439ed 3021 count[nr_class] = 1;
peccu 0:b9ac53c439ed 3022 ++nr_class;
peccu 0:b9ac53c439ed 3023 }
peccu 0:b9ac53c439ed 3024 }
peccu 0:b9ac53c439ed 3025
peccu 0:b9ac53c439ed 3026 for(i=0;i<nr_class;i++)
peccu 0:b9ac53c439ed 3027 {
peccu 0:b9ac53c439ed 3028 int n1 = count[i];
peccu 0:b9ac53c439ed 3029 for(int j=i+1;j<nr_class;j++)
peccu 0:b9ac53c439ed 3030 {
peccu 0:b9ac53c439ed 3031 int n2 = count[j];
peccu 0:b9ac53c439ed 3032 if(param->nu*(n1+n2)/2 > min(n1,n2))
peccu 0:b9ac53c439ed 3033 {
peccu 0:b9ac53c439ed 3034 free(label);
peccu 0:b9ac53c439ed 3035 free(count);
peccu 0:b9ac53c439ed 3036 return "specified nu is infeasible";
peccu 0:b9ac53c439ed 3037 }
peccu 0:b9ac53c439ed 3038 }
peccu 0:b9ac53c439ed 3039 }
peccu 0:b9ac53c439ed 3040 free(label);
peccu 0:b9ac53c439ed 3041 free(count);
peccu 0:b9ac53c439ed 3042 }
peccu 0:b9ac53c439ed 3043
peccu 0:b9ac53c439ed 3044 return NULL;
peccu 0:b9ac53c439ed 3045 }
peccu 0:b9ac53c439ed 3046
peccu 0:b9ac53c439ed 3047 int svm_check_probability_model(const svm_model *model)
peccu 0:b9ac53c439ed 3048 {
peccu 0:b9ac53c439ed 3049 return ((model->param.svm_type == C_SVC || model->param.svm_type == NU_SVC) &&
peccu 0:b9ac53c439ed 3050 model->probA!=NULL && model->probB!=NULL) ||
peccu 0:b9ac53c439ed 3051 ((model->param.svm_type == EPSILON_SVR || model->param.svm_type == NU_SVR) &&
peccu 0:b9ac53c439ed 3052 model->probA!=NULL);
peccu 0:b9ac53c439ed 3053 }
peccu 0:b9ac53c439ed 3054
peccu 0:b9ac53c439ed 3055 void svm_set_print_string_function(void (*print_func)(const char *))
peccu 0:b9ac53c439ed 3056 {
peccu 0:b9ac53c439ed 3057 if(print_func == NULL)
peccu 0:b9ac53c439ed 3058 svm_print_string = &print_string_stdout;
peccu 0:b9ac53c439ed 3059 else
peccu 0:b9ac53c439ed 3060 svm_print_string = print_func;
peccu 0:b9ac53c439ed 3061 }
peccu 0:b9ac53c439ed 3062
peccu 0:b9ac53c439ed 3063 // this function is copied by shimatani
peccu 0:b9ac53c439ed 3064 // for fopen on iPhone
peccu 0:b9ac53c439ed 3065 svm_model *svm_load_model_fp(FILE *fp)
peccu 0:b9ac53c439ed 3066 {
peccu 0:b9ac53c439ed 3067 if(fp==NULL) return NULL;
peccu 0:b9ac53c439ed 3068
peccu 0:b9ac53c439ed 3069 // read parameters
peccu 0:b9ac53c439ed 3070
peccu 0:b9ac53c439ed 3071 svm_model *model = Malloc(svm_model,1);
peccu 0:b9ac53c439ed 3072 svm_parameter& param = model->param;
peccu 0:b9ac53c439ed 3073 model->rho = NULL;
peccu 0:b9ac53c439ed 3074 model->probA = NULL;
peccu 0:b9ac53c439ed 3075 model->probB = NULL;
peccu 0:b9ac53c439ed 3076 model->label = NULL;
peccu 0:b9ac53c439ed 3077 model->nSV = NULL;
peccu 0:b9ac53c439ed 3078
peccu 0:b9ac53c439ed 3079 char cmd[81];
peccu 0:b9ac53c439ed 3080 while(1)
peccu 0:b9ac53c439ed 3081 {
peccu 0:b9ac53c439ed 3082 fscanf(fp,"%80s",cmd);
peccu 0:b9ac53c439ed 3083
peccu 0:b9ac53c439ed 3084 if(strcmp(cmd,"svm_type")==0)
peccu 0:b9ac53c439ed 3085 {
peccu 0:b9ac53c439ed 3086 fscanf(fp,"%80s",cmd);
peccu 0:b9ac53c439ed 3087 int i;
peccu 0:b9ac53c439ed 3088 for(i=0;svm_type_table[i];i++)
peccu 0:b9ac53c439ed 3089 {
peccu 0:b9ac53c439ed 3090 if(strcmp(svm_type_table[i],cmd)==0)
peccu 0:b9ac53c439ed 3091 {
peccu 0:b9ac53c439ed 3092 param.svm_type=i;
peccu 0:b9ac53c439ed 3093 break;
peccu 0:b9ac53c439ed 3094 }
peccu 0:b9ac53c439ed 3095 }
peccu 0:b9ac53c439ed 3096 if(svm_type_table[i] == NULL)
peccu 0:b9ac53c439ed 3097 {
peccu 0:b9ac53c439ed 3098 fprintf(stderr,"unknown svm type.\n");
peccu 0:b9ac53c439ed 3099 free(model->rho);
peccu 0:b9ac53c439ed 3100 free(model->label);
peccu 0:b9ac53c439ed 3101 free(model->nSV);
peccu 0:b9ac53c439ed 3102 free(model);
peccu 0:b9ac53c439ed 3103 return NULL;
peccu 0:b9ac53c439ed 3104 }
peccu 0:b9ac53c439ed 3105 }
peccu 0:b9ac53c439ed 3106 else if(strcmp(cmd,"kernel_type")==0)
peccu 0:b9ac53c439ed 3107 {
peccu 0:b9ac53c439ed 3108 fscanf(fp,"%80s",cmd);
peccu 0:b9ac53c439ed 3109 int i;
peccu 0:b9ac53c439ed 3110 for(i=0;kernel_type_table[i];i++)
peccu 0:b9ac53c439ed 3111 {
peccu 0:b9ac53c439ed 3112 if(strcmp(kernel_type_table[i],cmd)==0)
peccu 0:b9ac53c439ed 3113 {
peccu 0:b9ac53c439ed 3114 param.kernel_type=i;
peccu 0:b9ac53c439ed 3115 break;
peccu 0:b9ac53c439ed 3116 }
peccu 0:b9ac53c439ed 3117 }
peccu 0:b9ac53c439ed 3118 if(kernel_type_table[i] == NULL)
peccu 0:b9ac53c439ed 3119 {
peccu 0:b9ac53c439ed 3120 fprintf(stderr,"unknown kernel function.\n");
peccu 0:b9ac53c439ed 3121 free(model->rho);
peccu 0:b9ac53c439ed 3122 free(model->label);
peccu 0:b9ac53c439ed 3123 free(model->nSV);
peccu 0:b9ac53c439ed 3124 free(model);
peccu 0:b9ac53c439ed 3125 return NULL;
peccu 0:b9ac53c439ed 3126 }
peccu 0:b9ac53c439ed 3127 }
peccu 0:b9ac53c439ed 3128 else if(strcmp(cmd,"degree")==0)
peccu 0:b9ac53c439ed 3129 fscanf(fp,"%d",&param.degree);
peccu 0:b9ac53c439ed 3130 else if(strcmp(cmd,"gamma")==0)
peccu 0:b9ac53c439ed 3131 fscanf(fp,"%lf",&param.gamma);
peccu 0:b9ac53c439ed 3132 else if(strcmp(cmd,"coef0")==0)
peccu 0:b9ac53c439ed 3133 fscanf(fp,"%lf",&param.coef0);
peccu 0:b9ac53c439ed 3134 else if(strcmp(cmd,"nr_class")==0)
peccu 0:b9ac53c439ed 3135 fscanf(fp,"%d",&model->nr_class);
peccu 0:b9ac53c439ed 3136 else if(strcmp(cmd,"total_sv")==0)
peccu 0:b9ac53c439ed 3137 fscanf(fp,"%d",&model->l);
peccu 0:b9ac53c439ed 3138 else if(strcmp(cmd,"rho")==0)
peccu 0:b9ac53c439ed 3139 {
peccu 0:b9ac53c439ed 3140 int n = model->nr_class * (model->nr_class-1)/2;
peccu 0:b9ac53c439ed 3141 model->rho = Malloc(double,n);
peccu 0:b9ac53c439ed 3142 for(int i=0;i<n;i++)
peccu 0:b9ac53c439ed 3143 fscanf(fp,"%lf",&model->rho[i]);
peccu 0:b9ac53c439ed 3144 }
peccu 0:b9ac53c439ed 3145 else if(strcmp(cmd,"label")==0)
peccu 0:b9ac53c439ed 3146 {
peccu 0:b9ac53c439ed 3147 int n = model->nr_class;
peccu 0:b9ac53c439ed 3148 model->label = Malloc(int,n);
peccu 0:b9ac53c439ed 3149 for(int i=0;i<n;i++)
peccu 0:b9ac53c439ed 3150 fscanf(fp,"%d",&model->label[i]);
peccu 0:b9ac53c439ed 3151 }
peccu 0:b9ac53c439ed 3152 else if(strcmp(cmd,"probA")==0)
peccu 0:b9ac53c439ed 3153 {
peccu 0:b9ac53c439ed 3154 int n = model->nr_class * (model->nr_class-1)/2;
peccu 0:b9ac53c439ed 3155 model->probA = Malloc(double,n);
peccu 0:b9ac53c439ed 3156 for(int i=0;i<n;i++)
peccu 0:b9ac53c439ed 3157 fscanf(fp,"%lf",&model->probA[i]);
peccu 0:b9ac53c439ed 3158 }
peccu 0:b9ac53c439ed 3159 else if(strcmp(cmd,"probB")==0)
peccu 0:b9ac53c439ed 3160 {
peccu 0:b9ac53c439ed 3161 int n = model->nr_class * (model->nr_class-1)/2;
peccu 0:b9ac53c439ed 3162 model->probB = Malloc(double,n);
peccu 0:b9ac53c439ed 3163 for(int i=0;i<n;i++)
peccu 0:b9ac53c439ed 3164 fscanf(fp,"%lf",&model->probB[i]);
peccu 0:b9ac53c439ed 3165 }
peccu 0:b9ac53c439ed 3166 else if(strcmp(cmd,"nr_sv")==0)
peccu 0:b9ac53c439ed 3167 {
peccu 0:b9ac53c439ed 3168 int n = model->nr_class;
peccu 0:b9ac53c439ed 3169 model->nSV = Malloc(int,n);
peccu 0:b9ac53c439ed 3170 for(int i=0;i<n;i++)
peccu 0:b9ac53c439ed 3171 fscanf(fp,"%d",&model->nSV[i]);
peccu 0:b9ac53c439ed 3172 }
peccu 0:b9ac53c439ed 3173 else if(strcmp(cmd,"SV")==0)
peccu 0:b9ac53c439ed 3174 {
peccu 0:b9ac53c439ed 3175 while(1)
peccu 0:b9ac53c439ed 3176 {
peccu 0:b9ac53c439ed 3177 int c = getc(fp);
peccu 0:b9ac53c439ed 3178 if(c==EOF || c=='\n') break;
peccu 0:b9ac53c439ed 3179 }
peccu 0:b9ac53c439ed 3180 break;
peccu 0:b9ac53c439ed 3181 }
peccu 0:b9ac53c439ed 3182 else
peccu 0:b9ac53c439ed 3183 {
peccu 0:b9ac53c439ed 3184 fprintf(stderr,"unknown text in model file: [%s]\n",cmd);
peccu 0:b9ac53c439ed 3185 free(model->rho);
peccu 0:b9ac53c439ed 3186 free(model->label);
peccu 0:b9ac53c439ed 3187 free(model->nSV);
peccu 0:b9ac53c439ed 3188 free(model);
peccu 0:b9ac53c439ed 3189 return NULL;
peccu 0:b9ac53c439ed 3190 }
peccu 0:b9ac53c439ed 3191 }
peccu 0:b9ac53c439ed 3192
peccu 0:b9ac53c439ed 3193 // read sv_coef and SV
peccu 0:b9ac53c439ed 3194
peccu 0:b9ac53c439ed 3195 int elements = 0;
peccu 0:b9ac53c439ed 3196 long pos = ftell(fp);
peccu 0:b9ac53c439ed 3197
peccu 0:b9ac53c439ed 3198 max_line_len = 1024;
peccu 0:b9ac53c439ed 3199 line = Malloc(char,max_line_len);
peccu 0:b9ac53c439ed 3200 char *p,*endptr,*idx,*val;
peccu 0:b9ac53c439ed 3201
peccu 0:b9ac53c439ed 3202 while(readline(fp)!=NULL)
peccu 0:b9ac53c439ed 3203 {
peccu 0:b9ac53c439ed 3204 p = strtok(line,":");
peccu 0:b9ac53c439ed 3205 while(1)
peccu 0:b9ac53c439ed 3206 {
peccu 0:b9ac53c439ed 3207 p = strtok(NULL,":");
peccu 0:b9ac53c439ed 3208 if(p == NULL)
peccu 0:b9ac53c439ed 3209 break;
peccu 0:b9ac53c439ed 3210 ++elements;
peccu 0:b9ac53c439ed 3211 }
peccu 0:b9ac53c439ed 3212 }
peccu 0:b9ac53c439ed 3213 elements += model->l;
peccu 0:b9ac53c439ed 3214
peccu 0:b9ac53c439ed 3215 fseek(fp,pos,SEEK_SET);
peccu 0:b9ac53c439ed 3216
peccu 0:b9ac53c439ed 3217 int m = model->nr_class - 1;
peccu 0:b9ac53c439ed 3218 int l = model->l;
peccu 0:b9ac53c439ed 3219 model->sv_coef = Malloc(double *,m);
peccu 0:b9ac53c439ed 3220 int i;
peccu 0:b9ac53c439ed 3221 for(i=0;i<m;i++)
peccu 0:b9ac53c439ed 3222 model->sv_coef[i] = Malloc(double,l);
peccu 0:b9ac53c439ed 3223 model->SV = Malloc(svm_node*,l);
peccu 0:b9ac53c439ed 3224 svm_node *x_space = NULL;
peccu 0:b9ac53c439ed 3225 if(l>0) x_space = Malloc(svm_node,elements);
peccu 0:b9ac53c439ed 3226
peccu 0:b9ac53c439ed 3227 int j=0;
peccu 0:b9ac53c439ed 3228 for(i=0;i<l;i++)
peccu 0:b9ac53c439ed 3229 {
peccu 0:b9ac53c439ed 3230 readline(fp);
peccu 0:b9ac53c439ed 3231 model->SV[i] = &x_space[j];
peccu 0:b9ac53c439ed 3232
peccu 0:b9ac53c439ed 3233 p = strtok(line, " \t");
peccu 0:b9ac53c439ed 3234 model->sv_coef[0][i] = strtod(p,&endptr);
peccu 0:b9ac53c439ed 3235 for(int k=1;k<m;k++)
peccu 0:b9ac53c439ed 3236 {
peccu 0:b9ac53c439ed 3237 p = strtok(NULL, " \t");
peccu 0:b9ac53c439ed 3238 model->sv_coef[k][i] = strtod(p,&endptr);
peccu 0:b9ac53c439ed 3239 }
peccu 0:b9ac53c439ed 3240
peccu 0:b9ac53c439ed 3241 while(1)
peccu 0:b9ac53c439ed 3242 {
peccu 0:b9ac53c439ed 3243 idx = strtok(NULL, ":");
peccu 0:b9ac53c439ed 3244 val = strtok(NULL, " \t");
peccu 0:b9ac53c439ed 3245
peccu 0:b9ac53c439ed 3246 if(val == NULL)
peccu 0:b9ac53c439ed 3247 break;
peccu 0:b9ac53c439ed 3248 x_space[j].index = (int) strtol(idx,&endptr,10);
peccu 0:b9ac53c439ed 3249 x_space[j].value = strtod(val,&endptr);
peccu 0:b9ac53c439ed 3250
peccu 0:b9ac53c439ed 3251 ++j;
peccu 0:b9ac53c439ed 3252 }
peccu 0:b9ac53c439ed 3253 x_space[j++].index = -1;
peccu 0:b9ac53c439ed 3254 }
peccu 0:b9ac53c439ed 3255 free(line);
peccu 0:b9ac53c439ed 3256
peccu 0:b9ac53c439ed 3257 if (ferror(fp) != 0 || fclose(fp) != 0)
peccu 0:b9ac53c439ed 3258 return NULL;
peccu 0:b9ac53c439ed 3259
peccu 0:b9ac53c439ed 3260 model->free_sv = 1; // XXX
peccu 0:b9ac53c439ed 3261 return model;
peccu 0:b9ac53c439ed 3262 }