#ifndef SIMPLE_SOCKET_H
#define SIMPLE_SOCKET_H

#include "mbed.h"
#include "EthernetInterface.h"
#include <stdarg.h>

extern Serial pc;

/**
 * client socket class for communication endpoint
 */
class ClientSocket {
    friend class ServerSocket;

public:
    ClientSocket(char *hostname, int port) : preread(-1), timeout(1.5) {
        if (EthernetInterface::getIPAddress() == NULL) {
            EthernetInterface::init();
            EthernetInterface::connect();
        }
        sock = new TCPSocketConnection();
        sock->connect(hostname, port);
        sock->set_blocking(false, (int) (timeout * 1000));
    }

    ClientSocket(TCPSocketConnection *sock) : sock(sock), preread(-1), timeout(1.5) {
        sock->set_blocking(false, (int) (timeout * 1000));
    }

    ~ClientSocket() {
        // do NOT close in the destructor.
    }

    void setTimeout(float timeout) {
        sock->set_blocking(false, (int) (timeout * 1000));
        this->timeout = timeout;
    }

    bool available() {
        if (preread == -1) {
            char c;
            sock->set_blocking(false, 0);
            if (sock->receive((char *) &c, 1) <= 0) {
                sock->set_blocking(false, (int) (timeout * 1000));
                return false;
            }
            preread = c & 255;
            sock->set_blocking(false, (int) (timeout * 1000));
        }
        
        return true;
    }

    int read() {
        char c;
        if (preread != -1) {
            c = (char) preread;
            preread = -1;
        } else {
            if (sock->receive((char *) &c, 1) <= 0)
                return -1;
        }
        
        return c & 255;
    }

    int read(char *buf, int size) {
        if (size <= 0)
            return 0;

        int nread = 0;
        if (preread != -1) {
            nread = 1;
            *buf++ = (char) preread;
            preread = -1;
            size--;
        }

        if (size > 0) {
            size = sock->receive_all(buf, size - 1);
            if (size > 0)
                nread += size;
        }

        return (nread > 0) ? nread : -1;
    }

    int scanf(const char *format, ...) {
        va_list argp;
        va_start(argp, format);
        char buf[256];

        int len = read(buf, sizeof(buf) - 1);
        if (len <= 0)
            return 0;
        buf[len] = '\0';

        int ret = vsscanf(buf, format, argp);
        va_end(argp);
        
        return ret;
    }

    int write(char c) {
        return sock->send_all(&c, 1);
    }

    int write(char *buf, int size) {
        return sock->send_all(buf, size);
    }

    int printf(const char *format, ...) {
        va_list argp;
        va_start(argp, format);
        char buf[256];

        int len = vsnprintf(buf, sizeof(buf), format, argp);
        va_end(argp);
        
        return write(buf, len);
    }

    void close() {
        if (sock) {
            sock->close();
            delete sock;
            sock = 0;
        }
    }

    bool connected() {
        return sock && sock->is_connected();
    }

    operator bool() {
        return connected();
    }
    
    char* get_address() {
        return sock->get_address();
    }

private:
    TCPSocketConnection *sock;
    int preread;
    float timeout;
};

/**
* server socket class for handling incoming communication requests
*/
class ServerSocket {
public:
    ServerSocket(int port) {
        // Assume that the interface is initialized.
        // For some reason this initialization code doesn't like being threaded.
        //if (EthernetInterface::getIPAddress() == NULL) {
        //    EthernetInterface::init();
        //    EthernetInterface::connect();
        //}
        ssock.bind(port);
        ssock.listen();
    }

    ClientSocket accept() {
        TCPSocketConnection *sock = new TCPSocketConnection();
        ssock.accept(*sock);
        
        return ClientSocket(sock);
    }

private:
    TCPSocketServer ssock;
};

/**
 * class for handling datagram
 */
class DatagramSocket {
public:
    DatagramSocket(int port = 0, int bufsize = 512) : bufsize(bufsize) {
        if (EthernetInterface::getIPAddress() == NULL) {
            EthernetInterface::init();
            EthernetInterface::connect();
        }
        buf = new char[bufsize + 1];
        usock.init();
        usock.bind(port);
        usock.set_blocking(false, 1500);
    }

    DatagramSocket(Endpoint& local, int bufsize = 512) : bufsize(bufsize) {
        if (EthernetInterface::getIPAddress() == NULL) {
            EthernetInterface::init();
            EthernetInterface::connect();
        }
        buf = new char[bufsize + 1];
        usock.init();
        usock.bind(local.get_port());
        usock.set_blocking(false, 1500);
    }

    ~DatagramSocket() {
        delete[] buf;
    }

    void setTimeout(float timeout = 1.5) {
        usock.set_blocking(false, (int) (timeout * 1000));
    }
    
    int write(char *buf, int length) {
        if (length > bufsize) length = bufsize;
        this->length = length;
        memcpy(this->buf, buf, length);
        
        return length;
    }

    int printf(const char* format, ...) {
        va_list argp;
        va_start(argp, format);
        int len = vsnprintf(buf, bufsize, format, argp);
        va_end(argp);
        if (len > 0)
            length = len;
            
        return len;
    }

    void send(const char *host, int port) {
        Endpoint remote;
        remote.reset_address();
        remote.set_address(host, port);
        usock.sendTo(remote, buf, length);
    }

    void send(Endpoint& remote) {
        usock.sendTo(remote, buf, length);
    }

    int receive(char *host = 0, int *port = 0) {
        Endpoint remote;
        length = usock.receiveFrom(remote, buf, bufsize);
        if (length > 0) {
            if (host) strcpy(host, remote.get_address());
            if (port) *port = remote.get_port();
        }
        
        return length;
    }

    int receive(Endpoint& remote) {
        return usock.receiveFrom(remote, buf, bufsize);
    }

    int read(char *buf, int size) {
        int len = length < size ? length : size;
        if (len > 0)
            memcpy(buf, this->buf, len);
            
        return len;
    }

    int scanf(const char* format, ...) {
        va_list argp;
        va_start(argp, format);
        buf[length] = '\0';
        int ret = vsscanf(buf, format, argp);
        va_end(argp);
        
        return ret;
    }

private:
    char *buf;
    int bufsize;
    int length;
    UDPSocket usock;
};

#endif