#include "Websocket.h"

#define MAX_TRY_WRITE 20
#define MAX_TRY_READ 10

//Debug is disabled by default
#ifdef DEBUG
#define DBG(x, ...) std::printf("[WebSocket : DBG]"x"\r\n", ##__VA_ARGS__); 
#define WARN(x, ...) std::printf("[WebSocket : WARN]"x"\r\n", ##__VA_ARGS__); 
#define ERR(x, ...) std::printf("[WebSocket : ERR]"x"\r\n", ##__VA_ARGS__); 
#define INFO(x, ...) printf("[WebSocket : INFO]"x"\r\n", ##__VA_ARGS__); 
#else
#define DBG(x, ...) 
#define WARN(x, ...)
#define ERR(x, ...) 
#define INFO(x, ...) 
#endif

Websocket::Websocket(const char * url) {
    fillFields(url);
    socket.set_blocking(false, 400);
}

Websocket::~Websocket () {
    close ();
}

void  Websocket::setFnPing (void (*fn)()) {
    fnPing = fn;
}

void Websocket::fillFields(const char * url) {
  int ret = parseURL(url, scheme, sizeof(scheme), host, sizeof(host), &port, path, sizeof(path));
  if(ret)
  {
    ERR("URL parsing failed; please use: \"ws://ip-or-domain[:port]/path\"");
    return;
  }

  if(port == 0) //TODO do handle WSS->443
  {
    port = 80;
  }
  
  if(strcmp(scheme, "ws"))
  {
    ERR("Wrong scheme, please use \"ws\" instead");
  }
}

int Websocket::parseURL(const char* url, char* scheme, size_t maxSchemeLen, char* host, size_t maxHostLen, uint16_t* port, char* path, size_t maxPathLen) //Parse URL
{
  char* schemePtr = (char*) url;
  char* hostPtr = (char*) strstr(url, "://");
  if(hostPtr == NULL)
  {
    WARN("Could not find host");
    return -1; //URL is invalid
  }

  if( maxSchemeLen < hostPtr - schemePtr + 1 ) //including NULL-terminating char
  {
    WARN("Scheme str is too small (%d >= %d)", maxSchemeLen, hostPtr - schemePtr + 1);
    return -1;
  }
  memcpy(scheme, schemePtr, hostPtr - schemePtr);
  scheme[hostPtr - schemePtr] = '\0';

  hostPtr+=3;

  size_t hostLen = 0;

  char* portPtr = strchr(hostPtr, ':');
  if( portPtr != NULL )
  {
    hostLen = portPtr - hostPtr;
    portPtr++;
    if( sscanf(portPtr, "%hu", port) != 1)
    {
      WARN("Could not find port");
      return -1;
    }
  }
  else
  {
    *port=0;
  }
  char* pathPtr = strchr(hostPtr, '/');
  if( hostLen == 0 )
  {
    hostLen = pathPtr - hostPtr;
  }

  if( maxHostLen < hostLen + 1 ) //including NULL-terminating char
  {
    WARN("Host str is too small (%d >= %d)", maxHostLen, hostLen + 1);
    return -1;
  }
  memcpy(host, hostPtr, hostLen);
  host[hostLen] = '\0';

  size_t pathLen;
  char* fragmentPtr = strchr(hostPtr, '#');
  if(fragmentPtr != NULL)
  {
    pathLen = fragmentPtr - pathPtr;
  }
  else
  {
    pathLen = strlen(pathPtr);
  }

  if( maxPathLen < pathLen + 1 ) //including NULL-terminating char
  {
    WARN("Path str is too small (%d >= %d)", maxPathLen, pathLen + 1);
    return -1;
  }
  memcpy(path, pathPtr, pathLen);
  path[pathLen] = '\0';

  return 0;
}


bool Websocket::connect() {
//    char cmd[200];

    while (socket.connect(host, port) < 0) {
        ERR("Unable to connect to (%s) on port (%d)", host, port);
        wait(0.2);
        return false;
    }

    // sent http header to upgrade to the ws protocol
    membuf.size = sprintf(membuf.ptr, "GET %s HTTP/1.1\r\n", path);
    write(membuf.ptr, membuf.size);
    
    membuf.size = sprintf(membuf.ptr, "Host: %s:%d\r\n", host, port);
    write(membuf.ptr, membuf.size);

    membuf.size = sprintf(membuf.ptr, "Upgrade: WebSocket\r\n");
    write(membuf.ptr, membuf.size);

    membuf.size = sprintf(membuf.ptr, "Connection: Upgrade\r\n");
    write(membuf.ptr, membuf.size);

    membuf.size = sprintf(membuf.ptr, "Sec-WebSocket-Key: L159VM0TWUzyDxwJEIEzjw==\r\n");
    write(membuf.ptr, membuf.size);

    membuf.size = sprintf(membuf.ptr, "Sec-WebSocket-Version: 13\r\n\r\n");
    int ret = write(membuf.ptr, membuf.size);
    if (ret != membuf.size) {
        close();
        ERR("Could not send request");
        return false;
    }

    membuf.size = read(membuf.ptr, 200, 100);
    if (membuf.size < 0) {
        close();
        ERR("Could not receive answer\r\n");
        return false;
    }

    membuf.ptr[membuf.size] = '\0';
    DBG("recv: %s\r\n", membuf.ptr);

    char*  p = strstr (membuf.ptr, "DdLWT/1JcX+nQFHebYP+rqEx5xI=");
    if (p == NULL) {
        ERR("Wrong answer from server, got \"%s\" instead\r\n", membuf.ptr);
        do {
            membuf.size = read(membuf.ptr, 200, 100);
            if (membuf.size < 0) {
                ERR("Could not receive answer\r\n");
                return false;
            }
            membuf.ptr[membuf.size] = '\0';
        } while (membuf.size > 0);
        close();
        return false;
    }

    INFO("\r\nhost: %s\r\npath: %s\r\nport: %d\r\n\r\n", host, path, port);
    p = strstr (p, "\r\n\r\n");
    if (p) {
	membuf.skipTo (p + 4);
    }
    return true;
}

int Websocket::sendLength(uint32_t len, char * msg) {

    if (len < 126) {
        msg[0] = len | (1<<7);
        return 1;
    } else if (len < 65535) {
        msg[0] = 126 | (1<<7);
        msg[1] = (len >> 8) & 0xff;
        msg[2] = len & 0xff;
        return 3;
    } else {
        msg[0] = 127 | (1<<7);
        for (int i = 0; i < 8; i++) {
            msg[i+1] = (len >> i*8) & 0xff;
        }
        return 9;
    }
}

int Websocket::readChar(char * pC, bool block) {
    return read(pC, 1, 1);
}

int Websocket::sendOpcode(uint8_t opcode, char * msg) {
    msg[0] = 0x80 | (opcode & 0x0f);
    return 1;
}

int Websocket::sendMask(char * msg) {
    for (int i = 0; i < 4; i++) {
        msg[i] = 0;
    }
    return 4;
}

int Websocket::sendOp(uint8_t opcode, size_t len, char* str) {
    char msg[strlen(str) + 15];
    int idx = 0;
    idx = sendOpcode(opcode, msg);
    idx += sendLength(len, msg + idx);
    idx += sendMask(msg + idx);
    memcpy(msg + idx, str, len);
    int res = write(msg, idx + len);
    return res;
}

int Websocket::send(char * str) {
    return sendOp(0x01, strlen(str), str);
}

int Websocket::sendBin(size_t len, char* str) {
    return sendOp(0x02, len, str);
}

bool Websocket::read(char * message, size_t& size) {
    uint8_t opcode = 0;
    Timer tmr;

    // read the opcode
    tmr.start();
    while (true) {
        if (tmr.read() > 3) {
            DBG("timeout ws\r\n");
            return false;
        }
        
        if(!socket.is_connected())
        {
            WARN("Connection was closed by server");
            return false;
        }

	if (membuf.size > 0) {
	    opcode = (uint8_t)*membuf.ptr;
	    membuf.skip (1);
	} else {
	    socket.set_blocking(false, 1);
	    if (socket.receive((char*)&opcode, 1) != 1) {
		socket.set_blocking(false, 2000);
		return false;
	    }
	}

        socket.set_blocking(false, 2000);
	
	DBG("opcode: 0x%X", opcode);

	switch (opcode) {
	case 0x81:
	case 0x82:
	    DBG ("readBinFrame ()");
	    return readBinFrame (message, size);
	case 0x89:
	    DBG ("readPing ()");
	    return readPing ();
	default:;
	}
    }

    return false;
}

bool Websocket::close() {
    if (!is_connected())
        return false;

    int ret = socket.close();
    if (ret < 0) {
        ERR("Could not disconnect");
        return false;
    }
    return true;
}

bool Websocket::is_connected() {
    return socket.is_connected();
}

char* Websocket::getPath() {
    return path;
}

int Websocket::write(char * str, int len) {
    int res = 0, idx = 0;
    
    for (int j = 0; j < MAX_TRY_WRITE; j++) {
    
        if(!socket.is_connected())
        {
            WARN("Connection was closed by server");
            break;
        }

        if ((res = socket.send_all(str + idx, len - idx)) == -1)
            continue;

        idx += res;
        
        if (idx == len)
            return len;
    }
    
    return (idx == 0) ? -1 : idx;
}

int Websocket::read(char * str, int len, int min_len) {
    if (len == 0)
	return 0;
    if (membuf.size > 0) {
	int s;
        s = len < membuf.size ? len : membuf.size;
        memcpy (str, membuf.ptr, s);
	membuf.skip (s);
	if (min_len != -1 && s >= min_len)
	    return s;
	str += s;
	len -= s;
	min_len -= s;
    }
    int res = 0, idx = 0;

    for (int j = 0; j < MAX_TRY_WRITE; j++) {

        if ((res = socket.receive_all(str + idx, len - idx)) == -1)
            continue;

        idx += res;
        
        if (idx == len || (min_len != -1 && idx > min_len))
            return idx;
    }
    
    return (idx == 0) ? -1 : idx;
}

uint32_t  Websocket::readLenFrame (char* mask) {
    uint32_t  len_msg;
    char  c;
    bool  is_masked;
    int  i;
    size_t  s;

    s = read (&c, 1, 1);
    if (s < 1) return 0;
    len_msg = c & 0x7f;
    is_masked = c & 0x80;
    if (len_msg == 126) {
	s = read (&c, 1, 1);
	if (s < 1) return 0;
        len_msg = c << 8;
	s = read (&c, 1, 1);
	if (s < 1) return 0;
        len_msg += c;
    } else if (len_msg == 127) {
        len_msg = 0;
        for (int i = 0; i < 8; i++) {
	    s = read (&c, 1, 1);
            len_msg += (c << (7 - i) * 8);
        }
    }
    if (is_masked) {
        DBG ("is_masked");
        for (i = 0; i < 4; i++) {
		    s = read (&c, 1, 1);
		    if (s < 1) return 0;
	        if (mask)
	            mask[i] = c;
		}
    }

    return len_msg;
}

bool  Websocket::readBinFrame (char* message, size_t& size) {
    int i;
    uint32_t len_msg;
    char mask[4] = {0, 0, 0, 0};

    len_msg = readLenFrame (mask);
    DBG("length: %d", len_msg);
    if (len_msg == 0) {
        return false;
    }
    if (size < len_msg)
        len_msg = size;
    
    int nb = read(message, len_msg, len_msg);
    if (nb != len_msg)
        return false;

    for (i = 0; i < len_msg; i++) {
        message[i] = message[i] ^ mask[i % 4];
    }
    size = len_msg;

    return true;
}

bool  Websocket::readPing () {
    uint32_t len_msg = readLenFrame (NULL);
    char  data[len_msg];
    int nb = read(data, len_msg, len_msg);
    DBG ("ping");
    sendOp (0x0A, 0, NULL);
    if (fnPing) {
	fnPing ();
    }
    return false;
}
