#include "NetworkStack.h"
#include "TCPSocket.h"
#include "UDPSocket.h"
#include <stdio.h>
#include "string.h"

#define NSAPI_MAX_DATA_SIZE 2048
#define NSAPI_TEST_HOST "mbed.org"
#define NSAPI_TEST_IP   "8.8.4.4"

uint8_t tx_test_buffer[NSAPI_MAX_DATA_SIZE];
uint8_t rx_test_buffer[NSAPI_MAX_DATA_SIZE];

int nsapi_networkstack_get_ip_address_test(NetworkStack *stack)
{
    if (!stack->get_ip_address()[0]) {
        printf("error: 'get_ip_address()' did not return an IP address\r\n");
        return -1;
    }

    return 0;
}

int nsapi_networkstack_gethostbyname_test(NetworkStack *stack, SocketAddress *address, const char *test_host)
{
    int ret = stack->gethostbyname(address, NSAPI_TEST_HOST);

    if (ret) {
        printf("error: 'gethostbyname(\"%s\")' failed with code %d\r\n", NSAPI_TEST_HOST, ret);
        return -1;
    } else if (!address) {
        printf("error: 'gethostbyname(\"%s\")' returned null IP address\r\n", NSAPI_TEST_HOST);
        return -2;
    }

    ret = stack->gethostbyname(address, NSAPI_TEST_IP);

    if (ret) {
        printf("error: 'gethostbyname(\"%s\")' failed with code %d\r\n", NSAPI_TEST_IP, ret);
        return -1;
    } else if (!address) {
        printf("error: 'gethostbyname(\"%s\")' returned null IP address\r\n", NSAPI_TEST_IP);
        return -2;
    }

    ret = stack->gethostbyname(address, test_host);

    if (ret) {
        printf("error: 'gethostbyname(\"%s\")' failed with code %d\r\n", test_host, ret);
        return -1;
    } else if (!address) {
        printf("error: 'gethostbyname(\"%s\")' returned null IP address\r\n", test_host);
        return -2;
    }

    return 0;
}

static void nsapi_buffer_init(uint8_t *buffer, unsigned size)
{
    for (unsigned i = 0; i < size; i++) {
        buffer[i] = rand() % 256;
    }
}

static bool nsapi_buffer_check(const uint8_t *buffer, unsigned size)
{
    for (unsigned i = 0; i < size; i++) {
        if (buffer[i] != tx_test_buffer[i]) {
            return false;
        }
    }

    return true;
}

static void nsapi_tcp_flush(TCPSocket *socket)
{
    socket->set_timeout(1000);
    socket->recv(rx_test_buffer, sizeof rx_test_buffer);
    socket->set_timeout(-1);
}

int nsapi_tcp_open_test(TCPSocket *udp, NetworkStack *stack)
{
    int ret = udp->open(stack);

    if (ret) {
        printf("error: 'open(%p)' failed with code %d\r\n", stack, ret);
        return -1;
    } else {
        return 0;
    }
}

int nsapi_tcp_connect_test(TCPSocket *tcp, SocketAddress *addr)
{
    int ret = tcp->connect(*addr);

    if (ret) {
        printf("error: 'connect(SocketAddress(%s, %d))' failed with code %d\r\n",
                addr->get_ip_address(), addr->get_port(), ret);
        return -1;
    } else {
        return 0;
    }
}

static int nsapi_tcp_blocking_test_helper(TCPSocket *tcp, unsigned size)
{
    unsigned total;
    nsapi_tcp_flush(tcp);
    nsapi_buffer_init(tx_test_buffer, size);

    for (total = 0; total < size;) {
        int ret = tcp->send(tx_test_buffer+total, size-total);

        if (ret < 0) {
            printf("error: 'send(buffer, %d)' failed during test with code %d\r\n", size, ret);
            return -1;
        }

        total += ret;
    }
    
    memset(rx_test_buffer, 0, size);    
    for (total = 0; total < size;) {
        int ret = tcp->recv(rx_test_buffer+total, (sizeof rx_test_buffer)-total);

        if (ret < 0) {
            printf("error: 'recv(buffer, %d)' failed during test with code %d\r\n", sizeof rx_test_buffer, ret);
            return -2;
        }

        total += ret;
    }

    if (total != size || !nsapi_buffer_check(rx_test_buffer, size)) {
        printf("error: 'recv(buffer, %d)' recieved incorrect data with length %d\r\n", sizeof rx_test_buffer, total);
        return -3;
    }

    return 0;
}

int nsapi_tcp_blocking_test(TCPSocket *tcp)
{
    for (unsigned size = 64; size <= NSAPI_MAX_DATA_SIZE; size *= 2) {
        printf("%s: size %d\r\n", __func__, size);

        int ret = nsapi_tcp_blocking_test_helper(tcp, size);
        if (ret) {
            return ret;
        }
    }

    return 0;
}

static int nsapi_tcp_non_blocking_test_helper(TCPSocket *tcp, unsigned size)
{
    unsigned total;
    nsapi_tcp_flush(tcp);
    nsapi_buffer_init(tx_test_buffer, size);

    // First check to make sure `recv` will not block and return 0 for bytes received.
    tcp->set_blocking(false);
    int ret = tcp->recv(rx_test_buffer, sizeof rx_test_buffer);

    if (ret != NSAPI_ERROR_WOULD_BLOCK) {
        if (ret < 0) {
            printf("error: 'recv(buffer, %d)' failed during test with code %d\r\n", sizeof rx_test_buffer, ret);
            return -4;
        } else {
            printf("error: 'recv(buffer, %d)' returned %d when no data was expected\r\n", sizeof rx_test_buffer, ret);
            return -5;
        }
    }

    for (total = 0; total < size;) {
        int ret = tcp->send(tx_test_buffer+total, size-total);

        if (ret < 0) {
            printf("error: 'send(buffer, %d)' failed during test with code %d\r\n", size, ret);
            return -1;
        }

        total += ret;
    }

    memset (rx_test_buffer, 0, size);
    for (total = 0; total < size;) {
        ret = tcp->recv(rx_test_buffer+total, (sizeof rx_test_buffer)-total);

        if (ret == NSAPI_ERROR_WOULD_BLOCK) {
            continue;
        } else if (ret < 0) {
            printf("error: 'recv(buffer, %d)' failed during test with code %d\r\n", sizeof rx_test_buffer, ret);
            return -2;
        } else {
            total += ret;
        }
    }

    if (total != size || !nsapi_buffer_check(rx_test_buffer, size)) {
        printf("error: 'recv(buffer, %d)' recieved incorrect data with length %d\r\n", sizeof rx_test_buffer, total);
        return -3;
    }
    
    return 0;
}

int nsapi_tcp_non_blocking_test(TCPSocket *tcp)
{
    for (unsigned size = 64; size <= NSAPI_MAX_DATA_SIZE; size *= 2) {
        printf("%s: size %d\r\n", __func__, size);

        int ret = nsapi_tcp_non_blocking_test_helper(tcp, size);
        if (ret) {
            return ret;
        }
    }

    return 0;
}

int nsapi_tcp_close_test(TCPSocket *tcp)
{
    int ret = tcp->close();

    if (ret) {
        printf("error 'close()' failed with code %d\r\n", ret);
        return -1;
    } else {
        return 0;
    }
}

static void nsapi_udp_flush(UDPSocket *udp)
{
    udp->set_timeout(1000);
    udp->recvfrom(0, rx_test_buffer, sizeof rx_test_buffer);
    udp->set_timeout(-1);
}

int nsapi_udp_open_test(UDPSocket *udp, NetworkStack *stack)
{
    int ret = udp->open(stack);

    if (ret) {
        printf("error: 'open(%p)' failed with code %d\r\n", stack, ret);
        return -1;
    } else {
        return 0;
    }
}

static int nsapi_udp_blocking_test_helper(UDPSocket *udp, SocketAddress *addr, unsigned size)
{
    unsigned total;
    nsapi_udp_flush(udp);
    nsapi_buffer_init(tx_test_buffer, size);

    for (total = 0; total < size;) {
        int ret = udp->sendto(*addr, tx_test_buffer+total, size-total);

        if (ret < 0) {
            printf("error: 'sendto(SocketAddress(%s, %d), buffer, %d)' failed during test with code %d\r\n", 
                    addr->get_ip_address(), addr->get_port(), size, ret);
            return -1;
        }

        total += ret;
    }
    
    memset(rx_test_buffer, 0, size);
    for (total = 0; total < size;) {
        int ret = udp->recvfrom(0, rx_test_buffer+total, (sizeof rx_test_buffer)-total);

        if (ret < 0) {
            printf("error: 'recvfrom(0, buffer, %d)' failed during test with code %d\r\n", sizeof rx_test_buffer, ret);
            return -2;
        }

        total += ret;
    }

    if (total != size || !nsapi_buffer_check(rx_test_buffer, size)) {
        printf("error: 'recvfrom(0, buffer, %d)' recieved incorrect data with length %d\r\n", sizeof rx_test_buffer, total);
        return -3;
    }

    return 0;
}

int nsapi_udp_blocking_test(UDPSocket *udp, SocketAddress *addr)
{
    for (unsigned size = 64; size <= NSAPI_MAX_DATA_SIZE; size *= 2) {
        printf("%s: size %d\r\n", __func__, size);

        int ret = nsapi_udp_blocking_test_helper(udp, addr, size);
        if (ret) {
            return ret;
        }
    }

    return 0;
}

int nsapi_udp_non_blocking_test_helper(UDPSocket *udp, SocketAddress *addr, unsigned size)
{
    unsigned total;
    nsapi_udp_flush(udp);
    nsapi_buffer_init(tx_test_buffer, size);

    // First check to make sure `recv` will not block and return 0 for bytes received.
    udp->set_blocking(false);
    int ret = udp->recvfrom(0, rx_test_buffer, sizeof rx_test_buffer);

    if (ret != NSAPI_ERROR_WOULD_BLOCK) {
        if (ret < 0) {
            printf("error: 'recvfrom(0, buffer, %d)' failed during test with code %d\r\n", sizeof rx_test_buffer, ret);
            return -4;
        } else {
            printf("error: 'recvfrom(0, buffer, %d)' returned %d when no data was expected\r\n", sizeof rx_test_buffer, ret);
            return -5;
        }
    }

    for (total = 0; total < size;) {
        int ret = udp->sendto(*addr, tx_test_buffer+total, size-total);

        if (ret < 0) {
            printf("error: 'sendto(SocketAddress(%s, %d), buffer, %d)' failed during test with code %d\r\n", 
                    addr->get_ip_address(), addr->get_port(), size, ret);
            return -1;
        }

        total += ret;
    }
 
    memset(rx_test_buffer, 0, size);
    for (total = 0; total < size;) {
        ret = udp->recvfrom(0, rx_test_buffer+total, (sizeof rx_test_buffer)-total);

        if (ret == NSAPI_ERROR_WOULD_BLOCK) {
            continue;
        } else if (ret < 0) {
            printf("error: 'recv(buffer, %d)' failed during test with code %d\r\n", sizeof rx_test_buffer, ret);
            return -2;
        } else {
            total += ret;
        }
    }

    if (total != size || !nsapi_buffer_check(rx_test_buffer, size)) {
        printf("error: 'recv(buffer, %d)' recieved incorrect data with length %d\r\n", sizeof rx_test_buffer, total);
        return -3;
    }
    
    return 0;
}

int nsapi_udp_non_blocking_test(UDPSocket *udp, SocketAddress *addr)
{
    for (unsigned size = 64; size <= NSAPI_MAX_DATA_SIZE; size *= 2) {
        printf("%s: size %d\r\n", __func__, size);

        int ret = nsapi_udp_non_blocking_test_helper(udp, addr, size);
        if (ret) {
            return ret;
        }
    }

    return 0;
}

int nsapi_udp_close_test(UDPSocket *udp)
{
    int ret = udp->close();

    if (ret) {
        printf("error 'close()' failed with code %d\r\n", ret);
        return -1;
    } else {
        return 0;
    }
}

int nsapi_tests(const char *name, NetworkStack *stack, const char *test_host, uint16_t test_port)
{
    SocketAddress address(0, test_port);
    TCPSocket tcp;
    UDPSocket udp;

    int result = 0;

#define NSAPI_MARK_TESTS(tests)                                 \
    printf("\r\n\r\nRunning %s Tests\r\n\r\n", tests)

#define NSAPI_RUN_TEST(test, ...)                               \
    do {                                                        \
        printf("---------------------\r\n");                    \
        printf("%s: running...\r\n", #test);                    \
        int test##_result = test(__VA_ARGS__);                  \
        if (!test##_result) {                                   \
            printf("%s: PASS\r\n", #test);                      \
        } else {                                                \
            printf("%s: FAIL (%d)\r\n", #test, test##_result);  \
        }                                                       \
        result |= test##_result;                                \
    } while (0)

    NSAPI_MARK_TESTS("NetworkStack");
    NSAPI_RUN_TEST(nsapi_networkstack_get_ip_address_test, stack);
    NSAPI_RUN_TEST(nsapi_networkstack_gethostbyname_test, stack, &address, test_host);

    NSAPI_MARK_TESTS("UDPSocket");
    NSAPI_RUN_TEST(nsapi_udp_open_test, &udp, stack);
    NSAPI_RUN_TEST(nsapi_udp_blocking_test, &udp, &address);
    NSAPI_RUN_TEST(nsapi_udp_non_blocking_test, &udp, &address);
    NSAPI_RUN_TEST(nsapi_udp_close_test, &udp);

    NSAPI_MARK_TESTS("TCPSocket");
    NSAPI_RUN_TEST(nsapi_tcp_open_test, &tcp, stack);
    NSAPI_RUN_TEST(nsapi_tcp_connect_test, &tcp, &address);
    NSAPI_RUN_TEST(nsapi_tcp_blocking_test, &tcp);
    NSAPI_RUN_TEST(nsapi_tcp_non_blocking_test, &tcp);
    NSAPI_RUN_TEST(nsapi_tcp_close_test, &tcp);

    if (result == 0) {
        printf("\r\n\r\n--- ALL TESTS PASSING ---\r\n");
    } else {
        printf("\r\n\r\n--- TEST FAILURES OCCURRED ---\r\n");
    }

    return result;
}

