// MyNetDnsRequest.cpp 2012/4/19
#include "mbed.h"
#include "MyNetDnsRequest.h"
#include "UDPSocket.h"
#include <string>
#include "dnsname.h"
#include "W5200NetIf.h"
//#define __DEBUG
#include "dbg/dbg.h"

#ifdef __DEBUG
#define DBG2(...) do{ DebugStream::debug("%p %d %s ", this,__LINE__,__PRETTY_FUNCTION__); DebugStream::debug(__VA_ARGS__); } while(0);
#else
#define DBG2(...) while(0);
#endif //__DEBUG

//#define DEBUG

#ifdef DEBUG
#include "Utils.h"
#define PRINT_FUNC() printf("%p %d:%s\n", this,__LINE__,__PRETTY_FUNCTION__)
#else //DEBUG
#define PRINT_FUNC()
#endif //DEBUG


MyNetDnsRequest::MyNetDnsRequest(const char* hostname) : NetDnsRequest(hostname), 
    m_state(MYNETDNS_START), m_cbFired(false), m_closing(false), m_udp(NULL) {
    PRINT_FUNC();
}

MyNetDnsRequest::MyNetDnsRequest(Host* pHost) : NetDnsRequest(pHost), 
    m_state(MYNETDNS_START), m_cbFired(false), m_closing(false), m_udp(NULL) {
    PRINT_FUNC();
}

MyNetDnsRequest::~MyNetDnsRequest() {
    PRINT_FUNC();
    if (m_udp) {
        delete m_udp;
    }
}

void MyNetDnsRequest::callback(UDPSocketEvent e)
{
    PRINT_FUNC();
    DBG2("m_id[]=%02x:%02x\n", m_id[0], m_id[1]);
    uint8_t buf[512];
    Host host;
    int len = m_udp->recvfrom((char*)buf, sizeof(buf), &host);
    if (memcmp(buf+0, m_id, 2) != 0) { //verify
        return;
    }
    int rcode = response(buf, len);
    if (rcode == 0) {
        m_state = MYNETDNS_OK;
    } else {
        m_state = MYNETDNS_NOTFOUND;
    }
}

int MyNetDnsRequest::response(uint8_t buf[], int size) {
    PRINT_FUNC();
#ifdef DEBUG
    printHex(buf, size);
#endif //DEBUG
    int rcode = buf[3] & 0x0f;
    if (rcode != 0) {
        return rcode;
    }
    int qdcount = buf[4]<<8|buf[5];
    int ancount = buf[6]<<8|buf[7];
    int pos = 12;
    while(qdcount-- > 0) {
        dnsname qname(buf);
        pos = qname.decode(pos); // qname
        pos += 4; // qtype qclass
    }
    while(ancount-- > 0) {
        dnsname name(buf);
        pos = name.decode(pos); // name
        int type = buf[pos]<<8|buf[pos+1];
        pos += 8; // type class TTL  
        int rdlength = buf[pos]<<8|buf[pos+1]; pos += 2;
        int rdata_pos = pos;
        pos += rdlength;
        if (type == 1) { // A record
            m_ip = IpAddr(buf[rdata_pos],buf[rdata_pos+1],buf[rdata_pos+2],buf[rdata_pos+3]);
        }
#ifdef DEBUG
        printf("%s", name.str.c_str());
        if (type == 1) {
            printf(" A %d.%d.%d.%d\n", 
                buf[rdata_pos],buf[rdata_pos+1],buf[rdata_pos+2],buf[rdata_pos+3]);
        } else if (type == 5) {
            dnsname rdname(buf);
            rdname.decode(rdata_pos);
            printf(" CNAME %s\n", rdname.str.c_str());
        } else {
            printf(" TYPE:%d", type);
            printfBytes(" RDATA:", &buf[rdata_pos], rdlength);
        }
#endif //DEBUG
    }
    return rcode;
}

int MyNetDnsRequest::query(uint8_t buf[], int size, const char* hostname) {
    PRINT_FUNC();
    const uint8_t header[] = {
        0x00,0x00,0x01,0x00, // id=0x0000 QR=0 rd=1 opcode=0 rcode=0
        0x00,0x01,0x00,0x00, // qdcount=1 ancount=0
        0x00,0x00,0x00,0x00};// nscount=0 arcount=0 
    const uint8_t tail[] = {0x00,0x01,0x00,0x01}; // qtype=A qclass=IN
    memcpy(buf, header, sizeof(header));
    int t = clock();
    m_id[0] = t>>8;
    m_id[1] = t;
    memcpy(buf, m_id, 2); 
    dnsname qname(buf);
    int pos = qname.encode(sizeof(header), (char*)hostname);
    memcpy(buf+pos, tail, sizeof(tail));
    pos += sizeof(tail);
    return pos;
}

void MyNetDnsRequest::resolve(const char* hostname) {
    PRINT_FUNC();
    if (m_udp == NULL) {
        m_udp = new UDPSocket;
    }
    m_udp->setOnEvent(this, &MyNetDnsRequest::callback);
    Host local(IpAddr(0,0,0,0), 1024 + rand()&0x7fff);
    IpAddr dns(8,8,8,8);
    NetIf* pIf = Net::getDefaultIf();
    if (pIf) {
        dns = ((W5200NetIf*)pIf)->m_dns;
    }
    Host server(dns, 53); // DNS
    m_udp->bind(local);
    uint8_t buf[256];                
    int size = query(buf, sizeof(buf), hostname);
#ifdef DEBUG
    printf("hostname:[%s]\n", hostname);
    printHex(buf, size);
#endif
    m_udp->sendto((char*)buf, size, &server);
    m_interval.reset();
    m_interval.start();
}

void MyNetDnsRequest::poll() {
    PRINT_FUNC();
#ifdef DEBUG
    printf("%p m_state: %d, m_udp: %p\n", this, m_state, m_udp);
    wait_ms(400);
#endif //DEBUG
    switch(m_state) {
        case MYNETDNS_START:
            m_retry = 0;
            resolve(m_hostname);
            m_state = MYNETDNS_PROCESSING;
            break;
        case MYNETDNS_PROCESSING: 
            break;
        case MYNETDNS_NOTFOUND: 
            onReply(NETDNS_FOUND); 
            break;
        case MYNETDNS_ERROR: 
            onReply(NETDNS_ERROR);
            break;
        case MYNETDNS_OK:
            DBG2("m_retry=%d, m_interval=%d\n", m_retry, m_interval.read_ms());
            onReply(NETDNS_FOUND); 
            break;
    }
    if (m_interval.read_ms() > 1000) {
        m_interval.stop();
        DBG2("timeout m_retry=%d\n", m_retry);
        if (++m_retry > 1) {
            m_state = MYNETDNS_ERROR;
        } else {
            resolve(m_hostname);
            m_state = MYNETDNS_PROCESSING;
        }
    }
    if(m_closing && (m_state!=MYNETDNS_PROCESSING)) {
        NetDnsRequest::close();
    }
}

void MyNetDnsRequest::close() {
    PRINT_FUNC();
    if(m_state != MYNETDNS_PROCESSING) {
        NetDnsRequest::close();
    } else {
        m_closing = true;
    }
}
