/*
 * mbed Tiny DNS Resolver
 * Copyright (c) 2011 Hiroshi Suga
 * Released under the MIT License: http://mbed.org/license/mit
 */

/** @file
 * @brief Tiny DNS Resolver
 */

#include "mbed.h"
#include "EthernetNetIf.h"
#include "UDPSocket.h"
#include "TinyResolver.h"

// host to network short
#define htons( x ) ( (( (x) << 8 ) & 0xFF00) | (( (x) >> 8 ) & 0x00FF) )
#define ntohs( x ) htons(x)
// host to network long
#define htonl( x ) ( (( (x) << 24 ) & 0xFF000000)  \
                   | (( (x) <<  8 ) & 0x00FF0000)  \
                   | (( (x) >>  8 ) & 0x0000FF00)  \
                   | (( (x) >> 24 ) & 0x000000FF)  )
#define ntohl( x ) htonl(x)

extern EthernetNetIf eth;
static UDPSocket *dns;
static volatile unsigned long dnsaddr;
static volatile unsigned long id;

int createDnsRequest (char *name, char *buf) {
    struct DnsHeader *dnsHeader;
    struct DnsQuestionEnd *dnsEnd;
    int len, num;

    id ++;
    dnsHeader = (struct DnsHeader*)buf;
    dnsHeader->id = htons(id);
    dnsHeader->flags = htons(0x100);
    dnsHeader->questions = htons(1);
    dnsHeader->answers = 0;
    dnsHeader->authorities = 0;
    dnsHeader->additional = 0;

    len = sizeof(struct DnsHeader);
    while ((num = (int)strchr(name, '.')) != NULL) {
        num = num - (int)name;
        buf[len] = num;
        len ++;
        strncpy(&buf[len], name, num); 
        name = name + num + 1;
        len = len + num;
    }

    if ((num = strlen(name)) != NULL) {
        buf[len] = num;
        len ++; 
        strncpy(&buf[len], name, num); 
        len = len + num;
    }
    buf[len] = 0;
    len ++; 

    dnsEnd = (struct DnsQuestionEnd *)&buf[len];
    dnsEnd->type = htons(DNS_QUERY_A);
    dnsEnd->clas = htons(DNS_CLASS_IN);

    return len + sizeof(struct DnsQuestionEnd);
}

int getDnsResponse (const char *buf, int len, uint32_t *addr) {
    int i;
    struct DnsHeader *dnsHeader;
    struct DnsAnswer *dnsAnswer;

    // header
    dnsHeader = (struct DnsHeader*)buf;
    if (ntohs(dnsHeader->id) != id || (ntohs(dnsHeader->flags) & 0x800f) != 0x8000) {
        return -1;
    }

    // skip question
    for (i = sizeof(struct DnsHeader); buf[i] && i < len; i ++);
    i = i + 1 + sizeof(struct DnsQuestionEnd);

    // answer
    while (i < len) {
        dnsAnswer = (struct DnsAnswer*)&buf[i];

        if (dnsAnswer->clas != htons(DNS_CLASS_IN)) {
            return -1;
        }

        i = i + sizeof(struct DnsAnswer);
        if (dnsAnswer->type == htons(DNS_QUERY_A)) {
            // A record
            *addr = ((uint32_t)buf[i] << 24) + ((uint32_t)buf[i + 1] << 16) + ((uint32_t)buf[i + 2] << 8) + (uint32_t)buf[i + 3];
            return 0;
        }
        // next answer
        i = i + dnsAnswer->length;
    }

    return -1;
}

void isr_dns (UDPSocketEvent e) {
    char buf[512];
    Host dsthost;
    int len;

    if (e == UDPSOCKET_READABLE) {
        // recv responce;
        len = dns->recvfrom(buf, sizeof(buf), &dsthost);
#ifdef DEBUG
        for (int i = 0; i < len; i ++) {
            printf(" %02x", (unsigned char)buf[i]);
        }
        puts("\r");
#endif
        if (len >= sizeof(struct DnsHeader)) {
            getDnsResponse(buf, len, (uint32_t*)&dnsaddr);
        }
    }
}

int getHostByName (IpAddr nameserver, const char *name, uint32_t *addr) {
    UDPSocketErr err;
    Host myhost, dnshost;
    char buf[100];
    int i, len;

    // localhost
    if (!strcmp(name, "localhost")) {
        *addr = 0x0f000001;
        return 0;
    }

    dnsaddr = 0;
    dns = new UDPSocket;
    dns->setOnEvent(isr_dns);

    // bind
    myhost.setIp(eth.getIp());
    myhost.setPort(DNS_SRC_PORT);
    err = dns->bind(myhost);
    if (err != UDPSOCKET_OK) goto exit;

    // send request
    dnshost.setIp(nameserver);
    dnshost.setPort(DNS_PORT);
    len = createDnsRequest((char*)name, buf);
#ifdef DEBUG
    for (int i = 0; i < len; i ++) {
        printf(" %02x", (unsigned char)buf[i]);
    }
    puts("\r");
#endif
    dns->sendto(buf, len, &dnshost);

    // wait responce
    for (i = 0; i < DNS_TIMEOUT / 10; i ++) {
        if (dnsaddr) {
            // responce
            *addr = dnsaddr;
            break;
        }
        if (i % 500 == 499) {
            // retry
            dns->sendto(buf, len, &dnshost);
        }
        Net::poll();
        wait_ms(10);
    }

exit:
    dns->resetOnEvent();
    delete dns;

    return dnsaddr ? 0 : -1;
}
