#include "omni_wheel.h"

OmniWheel::OmniWheel()
{
    wheel = new Wheel[4];
    wheelNumber = 4;
}

OmniWheel::OmniWheel(int wheelNumber)
    : wheelNumber(wheelNumber)
{
    wheel = new Wheel[wheelNumber];
}

void OmniWheel::computeXY(double X, double Y, double gX, double gY, double moment)
{
    computeCircular(hypot(X, Y), atan2(Y, X), gX, gY, moment);
}

void OmniWheel::computeXY(double X, double Y, double moment)
{
    computeCircular(hypot(X, Y), atan2(Y, X), 0, 0, moment);
}

void OmniWheel::computeCircular(double r, double theta, double gX, double gY, double moment)
{
    if(wheelNumber <= 0) return;
    double *shiftOut = new double[wheelNumber];
    double *rotateOut= new double[wheelNumber];
    double shiftMax = -1.0;
    double rotateMax = -1.0;
    double shiftMin = 1.0;
    double rotateMin = 1.0;

    for(int i = 0; i < wheelNumber; i++) {
        shiftOut[i] = wheel[i].calculateShift(r, theta);
        rotateOut[i] = wheel[i].calculateRotate(gX, gY, moment);
        if(shiftOut[i] > shiftMax) shiftMax = shiftOut[i];
        if(shiftOut[i] < shiftMin) shiftMin = shiftOut[i];
        if(rotateOut[i] > rotateMax) rotateMax = rotateOut[i];
        if(rotateOut[i] < rotateMin) rotateMin = rotateOut[i];
    }
    if (shiftMax + rotateMax > 1.0) {
        for (int i = 0; i < wheelNumber; i++) {
            wheel[i].setOutput((shiftOut[i] + rotateOut[i]) / fabs(shiftMax + rotateMax));
        }
    } else if(shiftMin + rotateMin < -1.0) {
        for (int i = 0; i < wheelNumber; i++) {
            wheel[i].setOutput((shiftOut[i] + rotateOut[i]) / fabs(shiftMin + rotateMin));
        }
    } else {
        for(int i = 0; i < wheelNumber; i++) {
            wheel[i].setOutput(shiftOut[i] + rotateOut[i]);
        }
    }
    delete[ ] shiftOut;
    delete[ ] rotateOut;
}

void OmniWheel::computeCircular(double r, double theta, double moment) {
    computeCircular(r, theta, 0, 0, moment);
}
