Websocket.cpp

Committer:
donatien
Date:
2012-05-31
Revision:
0:87e52bb764c5

File content as of revision 0:87e52bb764c5:

#define __DEBUG__ 4 //Maximum verbosity
#ifndef __MODULE__
#define __MODULE__ "Websocket.cpp"
#endif

#include "core/fwk.h"

#include "Websocket.h"
#include <string>


Websocket::Websocket(char * url) : m_sockHandle(-1) {
    //server_ip = NULL;

    std::memset(&m_sockAddr, 0, sizeof(struct sockaddr_in));

    fillFields(url);
}


void Websocket::fillFields(char * url) {
    char *res = NULL;
    char *res1 = NULL;

    char buf[50];
    strcpy(buf, url);

    res = strtok(buf, ":");
    if (strcmp(res, "ws")) {
#ifdef DEBUG
        printf("\r\nFormat error: please use: \"ws://ip-or-domain[:port]/path\"\r\n\r\n");
#endif
    } else {
        //ip_domain and port
        res = strtok(NULL, "/");

        //path
        res1 = strtok(NULL, " ");
        if (res1 != NULL) {
            path = res1;
        }

        //ip_domain
        res = strtok(res, ":");

        //port
        res1 = strtok(NULL, " ");
        //port
        if (res1 != NULL) {
            port = res1;
        } else {
            port = "80";
        }

        if (res != NULL) {
            ip_domain = res;

            //if we use ethernet, we must decode ip address or use dnsresolver
            strcpy(buf, res);

            //we try to decode the ip address
            if (buf[0] >= '0' && buf[0] <= '9') {
                res = strtok(buf, ".");
                int i = 0;
                int ip[4];
                while (res != NULL) {
                    ip[i] = atoi(res);
                    res = strtok(NULL, ".");
                    i++;
                }
                //server_ip = new IpAddr(ip[0], ip[1], ip[2], ip[3]);
                m_sockAddr.sin_addr.s_addr = (ip[0] << 24) | (ip[1] << 16) | (ip[2] << 8) | ip[3]; //Could do this easier if using inet_pton()

            }
        }
    }
}


bool Websocket::connect() {
    char cmd[192];

    //Resolve DNS if needed
    if(m_sockAddr.sin_addr.s_addr == 0)
    {
      DBG("Resolving DNS socket");
      struct hostent *server = socket::gethostbyname(ip_domain.c_str());
      if(server == NULL)
      {
        return false;
      }
      memcpy((char*)&m_sockAddr.sin_addr.s_addr, (char*)server->h_addr_list[0], server->h_length);
    }

    m_sockAddr.sin_family = AF_INET;
    m_sockAddr.sin_port = htons(atoi(port.c_str()));

    //Create socket
    DBG("Creating socket");
    m_sockHandle = socket::socket(AF_INET, SOCK_STREAM, 0);
    if (m_sockHandle < 0)
    {
      ERR("Could not create socket");
      return false;
    }
    DBG("Handle is %d",m_sockHandle);

    //Connect
    DBG("Connecting socket to %s:%d", inet_ntoa(m_sockAddr.sin_addr), ntohs(m_sockAddr.sin_port));
    int ret = socket::connect(m_sockHandle, (const struct sockaddr *)&m_sockAddr, sizeof(m_sockAddr));
    if (ret < 0)
    {
      socket::close(m_sockHandle);
      ERR("Could not connect");
      return false;
    }

    m_connected = true;

    DBG("Sending HTTP request");
    //send websocket HTTP header
    sprintf(cmd, "GET /%s HTTP/1.1\r\n", path.c_str());
    write((uint8_t*)cmd, strlen(cmd));

    sprintf(cmd, "Host: %s:%s\r\n", ip_domain.c_str(), port.c_str());
    write((uint8_t*)cmd, strlen(cmd));

    sprintf(cmd, "Upgrade: WebSocket\r\n");
    write((uint8_t*)cmd, strlen(cmd));

    sprintf(cmd, "Connection: Upgrade\r\n");
    write((uint8_t*)cmd, strlen(cmd));

//    socket::send(m_sockHandle, "Origin: null\r\n", strlen("Origin: null\r\n"), 0);

    sprintf(cmd, "Sec-WebSocket-Key: L159VM0TWUzyDxwJEIEzjw==\r\n");
    write((uint8_t*)cmd, strlen(cmd));

    sprintf(cmd, "Sec-WebSocket-Version: 13\r\n\r\n");
    ret = write((uint8_t*)cmd, strlen(cmd));
    if(ret < 0)
    {
      close();
      ERR("Could not send request");
      m_connected = false;
      return false;
    }

    DBG("Waiting for answer");
    ret = read((uint8_t*)cmd, 0, 192);
    if(ret < 0)
    {
      close();
      ERR("Could not receive answer");
      m_connected = false;
      return false;
    }
    cmd[ret] = '\0';

    DBG("Comparing answer");
    if( strstr(cmd, "Sec-WebSocket-Accept: DdLWT/1JcX+nQFHebYP+rqEx5xI=") == NULL )
    {
      ERR("Wrong answer from server, got \"%s\" instead", cmd);
      do{
        ret = read((uint8_t*)cmd, 0, 192);
        if(ret < 0)
        {
          ERR("Could not receive answer");
          return false;
        }
        cmd[ret] = '\0';
        printf("%s",cmd);
      } while(ret > 0);
      close();
      m_connected = false;
      return false;
    }
    DBG("\r\nip_domain: %s\r\npath: /%s\r\nport: %s\r\n\r\n",this->ip_domain.c_str(), this->path.c_str(), this->port.c_str());
    return true;
}

void Websocket::sendLength(uint32_t len) {
    if (len < 126) {
        sendChar(len | (1<<7));
    } else if (len < 65535) {
        sendChar(126 | (1<<7));
        sendChar(len & 0xff);
        sendChar((len >> 8) & 0xff);
    } else {
        sendChar(127 | (1<<7));
        for (int i = 0; i < 8; i++) {
            sendChar((len >> i*8) & 0xff);
        }
    }
}

void Websocket::sendChar(uint8_t c) {
    write(&c, 1);
}

bool Websocket::readChar(uint8_t* pC, bool block)
{
    int ret = read(pC, 1, block?36000000:3000);
    if(ret < 0)
    {
      return false;
    }
    return true;
}

void Websocket::sendOpcode(uint8_t opcode) {
    sendChar(0x80 | (opcode & 0x0f));
}

void Websocket::sendMask() {
    for (int i = 0; i < 4; i++) {
        sendChar(0);
    }
}

void Websocket::send(char * str) {
    sendOpcode(0x01);
    sendLength(strlen(str));
    sendMask();

    write((uint8_t*)str, strlen(str));
}



bool Websocket::read(char * message) {
    int i = 0;
    //int length_buffer = 0;
    uint32_t len_msg;
    char opcode = 0;
    uint8_t c;
    char mask[4] = {0, 0, 0, 0};
    Timer tmr;

    //length_buffer = wifi->readable();

    // read the opcode
    tmr.start();
    while (true) {
        if (tmr.read() > 3) {
            return false;
        }
        if(!readChar((uint8_t*)&c, false))
        {
          return false;
        }

        opcode = c;
        if (opcode == 0x81) {
            break;
        }
    }
#ifdef DEBUG
    printf("opcode: 0x%X\r\n", opcode);
#endif
    readChar((uint8_t*)&c);
    len_msg = c & 0x7f;
    if (len_msg == 126) {
        readChar((uint8_t*)&c);
        len_msg = c;
        readChar((uint8_t*)&c);
        len_msg += c << 8;
    } else if (len_msg == 127) {
        len_msg = 0;
        for (i = 0; i < 8; i++) {
            readChar((uint8_t*)&c);
            len_msg += c << i*8;
        }
    }
    if(len_msg == 0) {
        return false;
    }
#ifdef DEBUG
    printf("length: %d\r\n", len_msg);
#endif
    if ((len_msg & 0x80)) {
        for (i = 0; i < 4; i++)
            readChar((uint8_t*)&c);
            mask[i] = c;
    }



    for (i = 0; i < len_msg; i++) {
        readChar((uint8_t*)&c);
        message[i] = c ^ mask[i % 4];
    }

    message[len_msg] = 0;
    return true;
}

bool Websocket::close() {
    if((m_sockHandle < 0) || !m_connected)
    {
      return false;
    }
    m_connected = false;
    int ret = socket::close(m_sockHandle);
    if (ret < 0)
    {
      ERR("Could not disconnect");
      return false;
    }
    return true;
}



bool Websocket::connected() {
    return m_connected;
}

std::string Websocket::getPath() {
    return path;
}

int Websocket::waitReadable(uint32_t timeout)
{
  //Creating FS set
  fd_set socksSet;
  FD_ZERO(&socksSet);
  FD_SET(m_sockHandle, &socksSet);
  struct timeval t_val;
  t_val.tv_sec = timeout / 1000;
  t_val.tv_usec = (timeout - (t_val.tv_sec * 1000)) * 1000;
  int ret = socket::select(FD_SETSIZE, &socksSet, NULL, NULL, &t_val);
  if(ret <= 0 || !FD_ISSET(m_sockHandle, &socksSet))
  {
    return -1; //Timeout
  }
  return 0;
}

int Websocket::waitWriteable(uint32_t timeout)
{
  //Creating FS set
  fd_set socksSet;
  FD_ZERO(&socksSet);
  FD_SET(m_sockHandle, &socksSet);
  struct timeval t_val;
  t_val.tv_sec = timeout / 1000;
  t_val.tv_usec = (timeout - (t_val.tv_sec * 1000)) * 1000;
  int ret = socket::select(FD_SETSIZE, NULL, &socksSet, NULL, &t_val);
  if(ret <= 0 || !FD_ISSET(m_sockHandle, &socksSet))
  {
    return -1; //Timeout
  }
  return 0;
}

int Websocket::read(uint8_t* buf, int minLen, int maxLen, uint32_t timeout)
{
  if(!m_connected)
  {
    return -1;
  }
  int readLen = 0;
  do
  {
    int ret = waitReadable(timeout);
    if(ret == -1)
    {
      WARN("Wait readable returned %d",ret);
      close();
      return -1;
    }
    ret = socket::recv(m_sockHandle, buf + readLen, maxLen - readLen, 0/*MSG_DONTWAIT*/);
    if(ret > 0)
    {
      readLen += ret;
    }
    else if(ret==0) //Connection closed
    {
      WARN("Recv returned %d",ret);
      return readLen;
    }
    else
    {
      WARN("Recv returned %d",ret);
      close();
      return -1;
    }
  } while(readLen < minLen);
  return readLen;
}

int Websocket::write(uint8_t* buf, int len, uint32_t timeout)
{
  if(!m_connected)
  {
    return -1;
  }
  int writtenLen = 0;
  do
  {
    int ret = waitWriteable(timeout);
    if(ret == -1)
    {
      WARN("Wait writeable returned %d",ret);
      close();
      return -1;
    }
    ret = socket::send(m_sockHandle, buf + writtenLen, len - writtenLen, 0/*MSG_DONTWAIT*/); //FIXME Probably DO WAIT to avoid overflow
    if(ret > 0)
    {
      writtenLen += ret;
    }
    else if(ret==0) //Connection closed
    {
      WARN("Send returned %d",ret);
      return writtenLen;
    }
    else
    {
      WARN("Send returned %d",ret);
      close(); //Must reset
      return -1;
    }
  } while(writtenLen < len);
  return writtenLen;
}

void Websocket::timeoutHandler()
{

}