#include "Mat.h"

template<typename T>
class EKF
{
    private :
    
    T dt;   /* par default : = 0.005 */
    int nbr_state;
    int nbr_ctrl;
    int nbr_obs;
    
    Mat<T> dX;  /*desired state*/
    Mat<T>* _X; /*previous state*/
    Mat<T>* X;      /*states*/
    Mat<T>* X_; /*derivated states or next states... (continuous/discrete)*/
    Mat<T>* u;      /*controls*/
    Mat<T>* z;      /*observations/measurements*/
    
    Mat<T>* Ki;
    Mat<T>* Kp;
    Mat<T>* Kd;
    Mat<T>* Kdd;    
    
    Mat<T>* A;      /*linear relation matrix between states and derivated states.*/
    /*par default : X_ = A * X + B * u / x_i = x_i + dt * x_i_ + b_i * u_i... */
    Mat<T>* B;      /*linear relation matrix between derivated states and control.*/
    /*par default : B = 1 */
    Mat<T>* C;      /*linear relation matrix between states and observation.*/
    /*par default : C = [1 0], on observe les positions, non leurs dérivées. */

    /*Noise*/
    T std_noise;    /*par defaut : 0.0005*/
    Mat<T>* Sigma;  /*covariance matrix*/
    Mat<T>* Q;      /*process noise*/
    Mat<T>* R;      /*measurement noise*/
    
    /*Prediction*/
    Mat<T> K;       // Kalman Gain...
    Mat<T> Sigma_p;
    
    /*Others*/
    Mat<T>* Identity;
    
    
    /*Extended*/
    bool extended;
    Mat<T> (*ptrMotion)(Mat<T> state, Mat<T> command, T dt);
    Mat<T> (*ptrSensor)(Mat<T> state, Mat<T> command, Mat<T> d_state, T dt);
    Mat<T> G;
    Mat<T> H;
    Mat<T> (*ptrJMotion)(Mat<T> state, Mat<T> command, T dt);
    Mat<T> (*ptrJSensor)(Mat<T> state, Mat<T> command, Mat<T> d_state, T dt);
    
    
    public :
    
    EKF(int nbr_state_, int nbr_ctrl_, int nbr_obs_, T dt_, T std_noise_, Mat<T> currentState, bool ext = false)
    {
        /*extension*/
        extended = ext;
        ptrMotion = NULL;
        ptrSensor = NULL;
        ptrJMotion = NULL;
        ptrJSensor = NULL;
        G = Mat<T>((T)0, nbr_state_, nbr_state_);
        H = Mat<T>((T)0, nbr_obs_, nbr_state_);
        
        /*----------------*/
        
        dt = dt_;
        nbr_state = nbr_state_;
        nbr_ctrl = nbr_ctrl_;
        nbr_obs = nbr_obs_;
        
        _X = new Mat<T>((T)0, nbr_state, (int)1);       /*previous state*/
        X = new Mat<T>(currentState);                   /*states*/
        X_ = new Mat<T>((T)0, nbr_state, (int)1);       /*derivated states*/
        u = new Mat<T>((T)0, nbr_ctrl, (int)1);         /*controls*/
        z = new Mat<T>((T)0, nbr_obs, (int)1);          /*observations*/
        A = new Mat<T>((T)0, nbr_state, nbr_state);     /*linear relation or jacobian matrix between states and derivated states.*/
        B = new Mat<T>((T)0, nbr_state, nbr_ctrl);      /*linear relation matrix between derivated states and control.*/
        C = new Mat<T>((T)0, nbr_obs, nbr_state);       /*linear relation or jacobian matrix between states and observation.*/
    
        Ki = new Mat<T>((T)0.08, nbr_ctrl, nbr_state);
        Kp = new Mat<T>((T)0.08, nbr_ctrl, nbr_state);
        Kd = new Mat<T>((T)0.08, nbr_ctrl, nbr_state);
        Kdd = new Mat<T>((T)0.08, nbr_ctrl, nbr_state);
    
    
        std_noise = std_noise_;
        Sigma = new Mat<T>((T)0, nbr_state, nbr_state);
        Q = new Mat<T>((T)0, nbr_state, nbr_state);
        R = new Mat<T>((T)0, nbr_obs, nbr_obs/*1*/);
    
        /*Initialize Covariance matrix as the identity matrix.*/
        for(int i=1;i<=nbr_state;i++)
        {
            Sigma->set((T)1, i,i);
            R->set( (T)1, i,i);
            /*
            for(int j=1;j<=nbr_state;j++)
            {
                Sigma->set((T)1, i, j);
            
                //if(i<=nbr_obs && j==1)
                //  R->set(std_noise*std_noise, i, j);
                
            }
            */
        }
    
        Identity = new Mat<T>(*Sigma);
        *Q = (std_noise*std_noise)*(*Identity);
        *R = (std_noise*std_noise)*(*R);
    
    }

    ~EKF()
    {
        delete _X;
        delete X;
        delete X_;
        delete u;
        delete z;
        delete A;
        delete B;
        delete C;
    
        delete Ki;
        delete Kp;
        delete Kd;
        delete Kdd;
    
        delete Sigma;
        delete Q;
        delete R;
    
        delete Identity;
    }

/*-------------------------------------------------------------*/
/*-------------------------------------------------------------*/
/*-------------------------------------------------------------*/


    int initA( Mat<T> A_)
    {
        if(A_ == *Identity)
        {
            for(int i=1;i<=(int)(nbr_state/2);i++)
            {
                A->set( dt, i, i+(int)(nbr_state/2));
            }
        
            return 1;
        }
        else
        {
            if(A_.getColumn() == nbr_state && A_.getLine() == nbr_state)
            {
                *A = A_;
                return 1;
            }
            else
            {
                cout << "ERREUR : mauvais format de matrice d'initialisation de A." << endl;
                return 0;
            }
        }
    }


    int initB( Mat<T> B_)
    {   
        if(B_.getColumn() == nbr_ctrl && B_.getLine() == nbr_state)
        {
            *B = B_;
            return 1;
        }
        else
        {
            cout << "ERREUR : mauvais format de matrice d'initialisation de B." << endl;
            return 0;
        }
    
    }
    
    
    
    int initC( Mat<T> C_)
    {   
        if(C_.getColumn() == nbr_state && C_.getLine() == nbr_obs)
        {
            *C = C_;
            return 1;
        }
        else
        {
            cout << "ERREUR : mauvais format de matrice d'initialisation de C." << endl;
            return 0;
        }
    
    }
    
    /*extension*/
    void initMotion( Mat<T> motion(Mat<T>, Mat<T>, T) )
    {
        ptrMotion = motion;
    }
    
    
    
    void initSensor( Mat<T> sensor(Mat<T>, Mat<T>, Mat<T>, T) )
    {   
        ptrSensor = sensor; 
    }
    
    void initJMotion( Mat<T> jmotion(Mat<T>, Mat<T>, T) )
    {
        ptrJMotion = jmotion;
    }
    
    
    
    void initJSensor( Mat<T> jsensor(Mat<T>, Mat<T>, Mat<T>, T) )
    {   
        ptrJSensor = jsensor;   
    }
    
    
/*-------------------------------------------------------------*/
/*-------------------------------------------------------------*/
/*-------------------------------------------------------------*/


    int setKi( Mat<T> Ki_)
    {
        if(Ki_.getColumn() == nbr_state && Ki_.getLine() == nbr_ctrl)
        {
            *Ki = Ki_;
            return 1;
        }
        else
        {
            cout << "ERREUR : mauvais format de vecteur d'initialisation de Ki." << endl;
            return 0;
        }
    }
    
    int setKp( Mat<T> Kp_)
    {
        if(Kp_.getColumn() == nbr_state && Kp_.getLine() == nbr_ctrl)
        {
            *Kp = Kp_;
            return 1;
        }
        else
        {
            cout << "ERREUR : mauvais format de vecteur d'initialisation de Kp." << endl;
            return 0;
        }
    }
    
    
    
    int setKd( Mat<T> Kd_)
    {
        if(Kd_.getColumn() == nbr_state && Kd_.getLine() == nbr_ctrl)
        {
            *Kd = Kd_;
            return 1;
        }
        else
        {
            cout << "ERREUR : mauvais format de vecteur d'initialisation de Kd." << endl;
            return 0;
        }

    }
    
    int setKdd( Mat<T> Kdd_)
    {
        if(Kdd_.getColumn() == nbr_state && Kdd_.getLine() == nbr_ctrl)
        {
            *Kdd = Kdd_;
            return 1;
        }
        else
        {
            cout << "ERREUR : mauvais format de vecteur d'initialisation de Kdd." << endl;
            return 0;
        }

    }
    
    
    void setdt( float dt_)
    {
        dt = dt_;
    }
    
/*-------------------------------------------------------------*/
/*-------------------------------------------------------------*/
/*-------------------------------------------------------------*/   
    
    
    Mat<T> getCommand()
    {
        return *u;
    }   
    
    Mat<T> getX()
    {
        return *X;
    }

    Mat<T> getSigma()
    {
        return *Sigma;
    }   


/*-------------------------------------------------------------*/
/*-------------------------------------------------------------*/
/*-------------------------------------------------------------*/       
    
    
    Mat<T> predictState()       /*return the computed predicted state*/
    {
        return (!extended ? (*A)*(*X)+(*B)*(*u) : ptrMotion(*X, *u, dt) );
    }
    
    
    Mat<T> predictCovariance()  /*return the predicted covariance matrix.*/
    {
        if(extended)
            G = ptrJMotion(*X, *u, dt);
        
        return ( (!extended ? ((*A)*(*Sigma))* transpose(*A) : G*(*Sigma)*transpose(G) ) + *Q);
    }
    
        
    Mat<T> calculateKalmanGain()    /*return the Kalman Gain K = C*Sigma_p * (C*Sigma_p*C.T +R).inv */
    {
        Sigma_p = predictCovariance();  
    
        if(extended)
            H = ptrJSensor(*X, *u, dX,dt);
    
        return ( !extended ? Sigma_p * transpose(*C) * invGJ( (*C) * Sigma_p * transpose(*C) + *R)  : Sigma_p * transpose(H) * invGJ( H * Sigma_p * transpose(H) + *R) );
    }
    
        
    Mat<T> correctState()       /*update X */
    {
        Mat<T> X_p( predictState());        
    
        *_X = *X;
        *X = X_p + K*( (*z) - (!extended ? (*C)*X_p  : ptrSensor(X_p, *u, dX, dt) ) );
    
        return *X;
    }
    
        
    Mat<T> correctCovariance()  /*update Sigma*/
    {
        *Sigma = (*Identity - K* (!extended ? (*C) : H) ) *Sigma_p;
    
        return *Sigma;
    }
    
    
    void state_Callback()       /* Update all... */
    {
        if( extended && (ptrMotion == NULL || ptrSensor == NULL || ptrJMotion == NULL || ptrJSensor == NULL) )
        {
            //~EKF();
            throw("ERREUR : les fonctions ne sont pas initialisees...");
        }       
        
        K = calculateKalmanGain();          
        correctState();
        correctCovariance();
    }
    
    void measurement_Callback(Mat<T> measurements, Mat<T> dX_)
    {
        if( extended && (ptrMotion == NULL || ptrSensor == NULL || ptrJMotion == NULL || ptrJSensor == NULL) )
        {
            //~EKF();
            throw("ERREUR : les fonctions ne sont pas initialisees...");
        }
        
        dX = dX_;
        
        *z = (!extended ? measurements : ptrSensor(*X,*u, dX, dt) );    
    }
    
    void measurement_Callback(Mat<T> measurements)
    {
        if( extended && (ptrMotion == NULL || ptrSensor == NULL || ptrJMotion == NULL || ptrJSensor == NULL) )
        {
            //~EKF();
            throw("ERREUR : les fonctions ne sont pas initialisees...");
        }
        
        
        *z = (!extended ? measurements : ptrSensor(*X,*u, dX, dt) );    
    }
    
    
    void computeCommand( Mat<T> desiredX, T dt_, int mode)
    {
        if(dt_ != (T)0 && mode != -1)
        {
            *u = (*Kp)*(desiredX - (*X));
            if(mode >= 1)
                *u = *u + (T)(dt_)*(*Ki)*(desiredX - (*X));
            if(mode >= 2)
                *u = *u + (T)((double)1.0/dt_)*(*Kd)*(desiredX - (*X));
            if(mode >= 3)
                *u = *u + (T)((double)1.0/(dt_*dt_))*(*Kdd)*(desiredX - (*X));                      
        }       
        else if(mode !=-1)
        {
            *u = (*Kp)*(desiredX - (*X));       
        }
            
        if(mode <= -1)
        {
            
            Mat<double> bicycle((double)0, 3,1);
            bicycle.set((double)70, 1,1);  /*radius*/
            bicycle.set((double)70, 2,1);
            bicycle.set((double)260, 3,1);   /*entre-roue*/
            
            double rho = (double)sqrt(z->get(1,1)*z->get(1,1) + z->get(2,1)*z->get(2,1));
            double alpha = atan21( dX.get(2,1)-X->get(2,1), dX.get(1,1) - X->get(1,1) );
            double beta = alpha + X->get(3,1) + dX.get(3,1);
            
            cout << "alpha = " << alpha << endl;
            
            double krho = (rho > 1 ? rho*3 : 1)/100;
            double kalpha = rho*300/100;
            double kbeta = 4*beta/(rho+1)/100;
            
            double vd = krho*rho;
            double wd = kalpha*alpha +  kbeta*beta;
            
            double phi_max = 100;
            
            double phi_r = bicycle.get(3,1)/bicycle.get(1,1)*(wd+vd/bicycle.get(3,1));
            double phi_l = bicycle.get(3,1)/bicycle.get(2,1)*(wd+vd/bicycle.get(3,1));
        
            phi_r = (phi_r <= phi_max ? phi_r : phi_max);
            phi_l = (phi_l <= phi_max ? phi_l : phi_max);
            
            vd = bicycle.get(1,1)/2*(phi_r+phi_l);
            wd = bicycle.get(1,1)/(2*bicycle.get(3,1))*(phi_r-phi_l);
            
            u->set( vd, 1,1);
            u->set( wd, 2,1);
            
        }
    }   
    

};