/*
 * $Id: Sockets.c 29 2011-06-11 14:53:08Z benoit $
 * $Author: benoit $
 * $Date: 2011-06-11 16:53:08 +0200 (sam., 11 juin 2011) $
 * $Rev: 29 $
 * 
 * 
 * 
 * 
 * 
 */
 
#include "NetIF.h"
#include "Sockets.h"
#include "IPv4.h"
#include "UDPv4.h"
#include "Debug.h"
#include "CQueue.h"
#include <string.h>
#include <stdlib.h>


#define    DEBUG_CURRENT_MODULE_NAME    "sockets"
#define    DEBUG_CURRENT_MODULE_ID        DEBUG_MODULE_SOCKETS


#define    UDPV4_DATA_OFFSET            8


struct DataBlock
{
    uint8_t        *dataPtr,
                *readPtr;
    int16_t        totalSize, 
                remainingSize;
};
typedef struct DataBlock DataBlock_t;


enum State
{
    State_Close = 0,
    State_Open,
    State_Bound,
};
typedef enum State State_t;


struct Socket_Entry
{
    Socket_Family_t		family;
    Socket_Protocol_t	protocol;
    int32_t				options;
    State_t				state;
    Socket_Addr_t		*localAddr,
						*remoteAddr;
    CQueue_t			*dataQueue;
    int32_t				index;
};
typedef struct Socket_Entry Socket_Entry_t;


static Socket_Entry_t    socketEntryTable[SOCKET_MAX_COUNT];
static Bool_t            socketAPIInitialized = False;


static void				Init(void);
static int32_t			Hook(NetIF_t *netIF, Protocol_ID_t protocolID, NetPacket_t *packet);
static void				Hook_UDPv4(NetIF_t *netIF, NetPacket_t *packet, Socket_Entry_t *entry);
static Socket_Entry_t	*GetSocketEntry(Socket_t socket);
static int32_t			BindUDPv4(Socket_Entry_t    *entry, Socket_AddrIn_t *addrIn);
static int32_t			Recv_Data(Socket_Entry_t *entry, uint8_t *data, int32_t length);
static int32_t			SendToUDPv4(Socket_Entry_t *entry, uint8_t *data, int32_t length, Socket_AddrIn_t *remoteAddr);


Net_API_t    sockets = 
{
    API_ID_Sockets,
    Init,
    Hook
};


static void Init(void)
{
    if (socketAPIInitialized) goto Exit;
    
    DEBUG_MODULE(DEBUG_LEVEL_INFO, ("Initializing"));
    memset(socketEntryTable, 0, sizeof(socketEntryTable));
    socketAPIInitialized = True;
    
Exit:
    return;
}


static int32_t Hook(NetIF_t *netIF, Protocol_ID_t protocolID, NetPacket_t *packet)
{
    int32_t                index = 0;
    Socket_Entry_t        *entry = NULL;

    DEBUG_MODULE(DEBUG_LEVEL_VERBOSE0, ("Hook(%s%d, %s, %d bytes)",
        netIF->name,
        netIF->index,
        protocol_IDNames[protocolID],
        packet->length
    ));
    
    for (index = 0; index < SOCKET_MAX_COUNT; index++)
    {
        entry = socketEntryTable + index;
        if (entry->state != State_Bound) continue;
        switch(protocolID)
        {
            case Protocol_ID_UDPv4:
                if (entry->protocol == SOCK_DGRAM) Hook_UDPv4(netIF, packet, entry);
                break;
                
            default:
                continue;
        }
    }
        
    return 0;
}

static void Hook_UDPv4(NetIF_t *netIF, NetPacket_t *packet, Socket_Entry_t *entry)
{
    IPv4_Header_t		*ipv4Header;
    UDPv4_Header_t		*udpv4Header;
    Socket_AddrIn_t		*localAddrIn, *remoteAddrIn;
    int32_t				depth;
    DataBlock_t			*dataBlock;

    depth = packet->depth;
    ipv4Header = (IPv4_Header_t *)packet->headerPtrTable[depth];
    udpv4Header = (UDPv4_Header_t *)(ipv4Header + 1);
    localAddrIn = (Socket_AddrIn_t *)entry->localAddr;
    remoteAddrIn = (Socket_AddrIn_t *)entry->remoteAddr;
    DEBUG_MODULE(DEBUG_LEVEL_VERBOSE0, ("ports: %d.%d.%d.%d:%d to %d.%d.%d.%d:%d size:%d",
        ipv4Header->source.IP0,
        ipv4Header->source.IP1,
        ipv4Header->source.IP2,
        ipv4Header->source.IP3,
        ntohs(udpv4Header->destPort),
        localAddrIn->address.IP0,
        localAddrIn->address.IP1,
        localAddrIn->address.IP2,
        localAddrIn->address.IP3,
        ntohs(localAddrIn->port),
        ntohs(udpv4Header->length)
    ));
    if ((localAddrIn->port == udpv4Header->destPort) && ( (localAddrIn->address.addr == IPADDR_ANY) || (ipv4Header->dest.addr == localAddrIn->address.addr) ) )
    {
        if (!CQueue_IsFull(entry->dataQueue))
        {
            remoteAddrIn->address = ipv4Header->source;
            remoteAddrIn->port = udpv4Header->sourcePort;
            dataBlock = (DataBlock_t *)malloc(sizeof(DataBlock_t));
            if (dataBlock == NULL)
            {
                mbedNet_LastError = mbedNetResult_NotEnoughMemory;
                goto Exit;
            }
            dataBlock->totalSize = ntohs(udpv4Header->length) - sizeof(UDPv4_Header_t);
            dataBlock->remainingSize = dataBlock->totalSize;
            dataBlock->dataPtr = (uint8_t *)malloc(dataBlock->totalSize);
            if (dataBlock->dataPtr == NULL)
            {
                free(dataBlock);
                mbedNet_LastError = mbedNetResult_NotEnoughMemory;
                goto Exit;
            }
            dataBlock->readPtr = dataBlock->dataPtr;
            memcpy(dataBlock->dataPtr, packet->data + sizeof(UDPv4_Header_t), dataBlock->totalSize);
            CQueue_Push(entry->dataQueue, (void *)dataBlock);
            DEBUG_MODULE(DEBUG_LEVEL_VERBOSE0, ("Added block of %d bytes to socket %d", dataBlock->totalSize, entry->index));
        }
    }
    
Exit:
    return;
}

static int32_t BindUDPv4(Socket_Entry_t *entry, Socket_AddrIn_t *addrIn)
{
    int32_t					result = -1;
    Socket_AddrIn_t			*localAddrIn,
							*remoteAddrIn;

	/* Allocate local internet v4 addr */
    entry->localAddr = (Socket_Addr_t *)malloc(sizeof(Socket_AddrIn_t));
    if (entry->localAddr == NULL)
    {
        mbedNet_LastError = mbedNetResult_NotEnoughMemory;
        goto Exit;
    }

	/* Allocate remote internet v4 addr */
    entry->remoteAddr = (Socket_Addr_t *)malloc(sizeof(Socket_AddrIn_t));
    if (entry->remoteAddr == NULL)
    {
		free(entry->localAddr);
        mbedNet_LastError = mbedNetResult_NotEnoughMemory;
        goto Exit;
    }
	
	/* Setup local socket address */
    localAddrIn = (Socket_AddrIn_t *)entry->localAddr;
    memcpy(localAddrIn, addrIn, sizeof(Socket_AddrIn_t));

	/* Setup remote socket adress, copy from local address, set port & address to zero */
	remoteAddrIn = (Socket_AddrIn_t *)entry->remoteAddr;
	*remoteAddrIn = *localAddrIn;
	remoteAddrIn->port = 0;
	remoteAddrIn->address.addr = 0;
        
    DEBUG_MODULE(DEBUG_LEVEL_INFO, ("Binding socket %d to %d.%d.%d.%d:%d", 
        entry->index, 
        addrIn->address.IP0, 
        addrIn->address.IP1, 
        addrIn->address.IP2, 
        addrIn->address.IP3, 
        ntohs(addrIn->port)
    ));
    
Exit:
    return result;
}


static int32_t Recv_Data(Socket_Entry_t *entry, uint8_t *data, int32_t length)
{
    int32_t            count = 0;
    DataBlock_t        *dataBlock = NULL;

    CQueue_Peek(entry->dataQueue, (void **)&dataBlock);
    if (dataBlock->remainingSize <= length)
    {
        count = dataBlock->remainingSize;
        CQueue_Pop(entry->dataQueue, (void **)&dataBlock);
        memcpy(data, dataBlock->readPtr, count);
        free(dataBlock->dataPtr);
        free(dataBlock);
    }
    else
    {
        count = length;
        memcpy(data, dataBlock->readPtr, count);
        dataBlock->readPtr += count;
        dataBlock->remainingSize -= count;
    }
    return count;
}


static int32_t SendToUDPv4(Socket_Entry_t *entry, uint8_t *data, int32_t length, Socket_AddrIn_t *remoteAddrIn)
{
    int32_t                count = -1,
                        totalLength;
    IPv4_Header_t        *ipv4Header;
    UDPv4_Header_t        *udpv4Header;
    Socket_AddrIn_t        *localAddrIn;
    
    localAddrIn = (Socket_AddrIn_t *)entry->localAddr;
    totalLength = length + sizeof(UDPv4_Header_t) + sizeof(IPv4_Header_t);
    ipv4Header = (IPv4_Header_t *)malloc(totalLength);
    if (ipv4Header == NULL)
    {
        DEBUG_SOURCE(DEBUG_LEVEL_ERROR, ("Not enough memory (needed %d bytes)", totalLength));
        mbedNet_LastError = mbedNetResult_NotEnoughMemory;
        goto Exit;
    }
    
    memset(ipv4Header, 0, totalLength);
    
    DEBUG_SOURCE(DEBUG_LEVEL_VERBOSE0, ("UDPv4 sending %d bytes to %d.%d.%d.%d:%d",
        length,
        remoteAddrIn->address.IP0,
        remoteAddrIn->address.IP1,
        remoteAddrIn->address.IP2,
        remoteAddrIn->address.IP3,
        ntohs(remoteAddrIn->port)
    ));
    
    udpv4Header = (UDPv4_Header_t *)(ipv4Header + 1);
    
    ipv4Header->ihl = 5;
    ipv4Header->version = IPV4_VERSION;
    ipv4Header->tos = 0;
    ipv4Header->totalLength = htons(5 * 4 + totalLength);
    ipv4Header->id = 0;
    ipv4Header->fragmentFlags = 0;
    ipv4Header->ttl = NET_DEFAULT_TTL;
    ipv4Header->protocol = IPV4_PROTO_UDPV4;
    ipv4Header->dest = remoteAddrIn->address;
    
    udpv4Header->sourcePort = localAddrIn->port;
    udpv4Header->destPort = remoteAddrIn->port;
    udpv4Header->length = htons(length + sizeof(UDPv4_Header_t));
    
    memcpy(udpv4Header + 1, data, length);
    
    DEBUG_BLOCK(DEBUG_LEVEL_VERBOSE0)
    {
        IPv4_DumpIPv4Header("Sockets:", ipv4Header);
    }
    
    count = NetIF_SendIPv4Packet(ipv4Header);
    free(ipv4Header);

Exit:
    return count;
}


Socket_t Sockets_Open(Socket_Family_t family, Socket_Protocol_t protocol, int32_t options)
{
    int32_t            result = 0,
                    index = 0;
    Socket_Entry_t    *entry = NULL;

    if (family != AF_INET)
    {
        DEBUG_MODULE(DEBUG_LEVEL_ERROR, ("Protocol family not supported"));
        mbedNet_LastError = mbedNetResult_NotIplemented;
        result = -1;
        goto Exit;
    }

    for (index = 0; index < SOCKET_MAX_COUNT; index++)
    {
        if (socketEntryTable[index].state != State_Close) continue;
        entry = socketEntryTable + index;
        break;
    }

    if (entry == NULL)
    {
        DEBUG_MODULE(DEBUG_LEVEL_WARNING, ("Too many open sockets"));
        mbedNet_LastError = mbedNetResult_TooManyOpenSockets;
        result = -1;
        goto Exit;
    }
    
    entry->family = family;
    entry->protocol = protocol;
    entry->options = options;
    entry->state = State_Open;
    entry->dataQueue = NULL;
    entry->index = index;
    result = index;

Exit:
    DEBUG_MODULE(DEBUG_LEVEL_INFO, ("opened socket %d", index));
    return result;
}


int32_t Sockets_Bind(Socket_t socket, Socket_Addr_t *addr, int32_t addrLen)
{
    int32_t            result = -1;
    Socket_Entry_t    *entry;

    if ((entry = GetSocketEntry(socket)) == NULL) goto Exit;
    
    if (entry == NULL)
    {
        DEBUG_MODULE(DEBUG_LEVEL_ERROR, ("Socket %d not found", socket));
        mbedNet_LastError = mbedNetResult_NotEnoughMemory;
        result = -1;
        goto Exit;
    }

    /* Allocate address entry */
    switch(entry->family)
    {
        case AF_INET:
            switch(entry->protocol)
            {
                case SOCK_DGRAM:
                    if (addrLen != sizeof(Socket_AddrIn_t))
                    {
                        mbedNet_LastError = mbedNetResult_InvalidParameter;
                        result = -1;
                        goto Exit;
                    }
                    result = BindUDPv4(entry, (Socket_AddrIn_t *)addr);
                    break;
                    
                case SOCK_STREAM:
                    DEBUG_MODULE(DEBUG_LEVEL_ERROR, ("Protocol not supported"));
                    mbedNet_LastError = mbedNetResult_NotIplemented;
                    result = -1;
                    goto Exit;
                    
                case SOCK_RAW:
                    DEBUG_MODULE(DEBUG_LEVEL_ERROR, ("Protocol not supported"));
                    mbedNet_LastError = mbedNetResult_NotIplemented;
                    result = -1;
                    goto Exit;
                    
                default:
                    DEBUG_MODULE(DEBUG_LEVEL_ERROR, ("Unknown socket protocol"));
                    mbedNet_LastError = mbedNetResult_InvalidParameter;
                    result = -1;
                    goto Exit;
            }
            break;
            
        default:
            DEBUG_MODULE(DEBUG_LEVEL_ERROR, ("Protocol family not supported"));
            mbedNet_LastError = mbedNetResult_NotIplemented;
            result = -1;
            goto Exit;
    }
    
    entry->dataQueue = CQueue_Alloc(SOCKET_DATAQUEUE_ENTRY_COUNT);
    
    if (entry == NULL)
    {
        if (entry->localAddr) free(entry->localAddr);
        if (entry->remoteAddr) free(entry->remoteAddr);
        DEBUG_MODULE(DEBUG_LEVEL_ERROR, ("Not enough memory to allocate data queue"));
        mbedNet_LastError = mbedNetResult_NotEnoughMemory;
        result = -1;
        goto Exit;
    }
        
    entry->state = State_Bound;
    
    result = 0;
    
Exit:
    return result;
}


int32_t Sockets_Send(Socket_t socket, uint8_t *data, int32_t length, int32_t flags)
{
    int32_t            count = -1;
    Socket_Entry_t    *entry;

    entry = GetSocketEntry(socket);
    if (entry == NULL) goto Exit;

    if (entry->protocol == SOCK_DGRAM)
    {
        mbedNet_LastError = mbedNetResult_DestinationAddressRequired;
        goto Exit;
    }

    mbedNet_LastError = mbedNetResult_NotIplemented;

Exit:
    return count;
}


int32_t Sockets_SendTo(Socket_t socket, uint8_t *data, int32_t length, int32_t flags, const Socket_Addr_t *remoteAddr, int32_t addrLen)
{
    int32_t            count = -1;
    Socket_Entry_t    *entry;


    entry = GetSocketEntry(socket);
    if (entry == NULL) 
    {
        DEBUG_SOURCE(DEBUG_LEVEL_ERROR, ("socket not found!"));
        goto Exit;
    }

    switch(entry->family)
    {
        case AF_INET:
            switch(entry->protocol)
            {
                case SOCK_DGRAM:
                    if (addrLen != sizeof(Socket_AddrIn_t))
                    {
                        DEBUG_SOURCE(DEBUG_LEVEL_ERROR, ("Invalid socket address length"));
                        mbedNet_LastError = mbedNetResult_InvalidParameter;
                        goto Exit;
                    }
                    count = SendToUDPv4(entry, data, length, (Socket_AddrIn_t *)remoteAddr);
                    break;
                
                default:
                    DEBUG_SOURCE(DEBUG_LEVEL_ERROR, ("Protocol not implemented"));
                    mbedNet_LastError = mbedNetResult_NotIplemented;
                    goto Exit;
            }
            break;
            
        default:
            DEBUG_SOURCE(DEBUG_LEVEL_ERROR, ("Protocol family not implemented"));
            mbedNet_LastError = mbedNetResult_NotIplemented;
            goto Exit;
    }

Exit:
    return count;
}


int32_t Sockets_Recv(Socket_t socket, uint8_t *data, int32_t length, int32_t flags)
{
    int32_t                count = -1;
    Socket_Entry_t        *entry;
    
    entry = GetSocketEntry(socket);
    if (entry == NULL) goto Exit;

    if (entry->protocol == SOCK_DGRAM)
    {
        mbedNet_LastError = mbedNetResult_DestinationAddressRequired;
        goto Exit;
    }

    if (CQueue_IsEmpty(entry->dataQueue))
    {
        mbedNet_LastError = mbedNetResult_WouldBlock;
        goto Exit;
    }

    count = Recv_Data(entry, data, length);
        
Exit:
    return count;
}


int32_t Sockets_RecvFrom(Socket_t socket, uint8_t *data, int32_t length, int32_t flags, Socket_Addr_t *remoteAddr, int32_t *addrLen)
{
    int32_t            count = -1;
    Socket_Entry_t    *entry;
    
    entry = GetSocketEntry(socket);
    if (entry == NULL) goto Exit;

    if (CQueue_IsEmpty(entry->dataQueue))
    {
        mbedNet_LastError = mbedNetResult_WouldBlock;
        goto Exit;
    }
    
    if (remoteAddr != NULL)
    {
        if (entry->localAddr->len > *addrLen)
        {
            mbedNet_LastError = mbedNetResult_BufferTooSmall;
            goto Exit;
        }
        memcpy(remoteAddr, entry->remoteAddr, entry->remoteAddr->len);
    }
        
    count = Recv_Data(entry, data, length);

Exit:
    return count;
}


int32_t Sockets_Close(Socket_t socket)
{
    int32_t            result = -1;
    Socket_Entry_t    *entry;
    void            *ptr;

    if ((entry = GetSocketEntry(socket)) == NULL) goto Exit;
    
    entry->state = State_Close;
    if (entry->localAddr) free(entry->localAddr);
    entry->localAddr = NULL;
    if (entry->remoteAddr) free(entry->remoteAddr);
    entry->remoteAddr = NULL;
    /* Free pending data blocks */
    while(CQueue_Peek(entry->dataQueue, &ptr) != -1)
    {
        free(ptr);
    }
    CQueue_Free(entry->dataQueue);
    entry->dataQueue = NULL;
    result = 0;
    
Exit:
    DEBUG_MODULE(DEBUG_LEVEL_INFO, ("closed socket %d", socket));
    return result;
}



static Socket_Entry_t *GetSocketEntry(Socket_t socket)
{
    Socket_Entry_t    *entry = NULL;

    if ((socket < 0) || (socket >= SOCKET_MAX_COUNT))
    {
        DEBUG_MODULE(DEBUG_LEVEL_ERROR, ("Invalid socket handle"));
        mbedNet_LastError = mbedNetResult_InvalidSocketHandle;
        goto Exit;
    }
    entry = socketEntryTable + socket;

    if (entry->state == State_Close)
    {
        DEBUG_MODULE(DEBUG_LEVEL_WARNING, ("Socket already closed"));
        mbedNet_LastError = mbedNetResult_SocketAlreadyClosed;
        entry = NULL;
        goto Exit;
    }
Exit:
    return entry;
}
