Andrew Boyson / net

Dependents:   oldheating gps motorhome heating

udp/dns/dnsserver.cpp

Committer:
andrewboyson
Date:
2017-10-19
Revision:
43:bc028d5a6424
Parent:
37:793b39683406
Child:
44:83ce5ace337b

File content as of revision 43:bc028d5a6424:

#include    "mbed.h"
#include  "dnshdr.h"
#include "dnsname.h"
#include     "net.h"
#include  "action.h"
#include     "dns.h"
#include     "log.h"
#include    "dhcp.h"
#include   "slaac.h"
#include      "io.h"
#include     "ip4.h"
#include     "ip6.h"

bool DnsServerTrace = false;

#define MAX_QUERIES 16

//Set by 'initialise'
static char* p;               //Position relative to DnsHdrData and is updated while both reading questions and writing answers
static char  myFullName[100]; //The name, adjusted to include the domain if needed by the protocol, used when reading and when writing
static int   myFullNameLength;

//Set by readQuestions and used by answerQuestions
static int   questionTypes[MAX_QUERIES];
static int   nodeNameTypes[MAX_QUERIES];
static bool  hadQueryForMe;
static bool  mdnsUnicastReply;

static int readQuestions()
{
    p = DnsHdrData;
    mdnsUnicastReply = false;
    hadQueryForMe    = false;
    
    //Get the questions
    for (int q = 0; q < DnsHdrQdcount; q++)
    {
        //Bomb out if we are tending to overrun the buffer
        if (q >= MAX_QUERIES)
        {
            if (DnsServerTrace) LogTimeF("DnsServer-readQuestions - exceeded MAX_QUERIES\r\n");
            break;
        }
        
        //Bomb out if we are tending to overrun the buffer
        if (p - DnsHdrPacket > 512)
        {
            if (DnsServerTrace) LogTimeF("DnsServer-readQuestions - overrunning the buffer\r\n");
            return -1;
        }
        
        //Node name
        int nameLength = DnsNameLength(p);
        if (!nameLength)
        {
            if (DnsServerTrace) LogTimeF("DnsServer-readQuestions namelength is zero\r\n");
            return -1;
        }
        nodeNameTypes[q] = DNS_RECORD_NONE;
        if (!nodeNameTypes[q] && DnsNameComparePtr(p, myFullName)      ) nodeNameTypes[q] = DNS_RECORD_PTR;  //rtc.local
        if (!nodeNameTypes[q] && DnsNameCompareIp4(p, DhcpLocalIp)     ) nodeNameTypes[q] = DNS_RECORD_A;    //3.1.168.192.inaddr.arpa
        if (!nodeNameTypes[q] && DnsNameCompareIp6(p, SlaacLinkLocalIp)) nodeNameTypes[q] = DNS_RECORD_AAAA; //5.8.c.c.0.1.e.f.f.f.2.3.1.1.2.0.8.7.d.0.9.0.f.1.0.7.4.0.1.0.0.2.ip6.arpa
        p += nameLength;                            //Skip past the name
        
        //Remember had a query for me
        if (nodeNameTypes[q]) hadQueryForMe = true;
        
        //Type
        p++ ;                                       //skip the first byte of the type
        questionTypes[q] = *p++;                    //read the record type
        
        //Class
        if (*p++ & 0x80) mdnsUnicastReply = true;   //check the class 15th bit (UNICAST-RESPONSE)
        p++;                                        //skip the class
    }
    return 0;
}
static int addAnswers(int dnsProtocol)
{
    //Strip the questions if this is MDNS
    //if (dnsProtocol == DNS_PROTOCOL_MDNS) p = DnsHdrData; 
    
    //Go through each question
    DnsHdrAncount = 0;
    for (int q = 0; q < DnsHdrQdcount; q++)
    {        
        //Skip unwanted record types
        switch (questionTypes[q])
        {
            case DNS_RECORD_A:
            case DNS_RECORD_AAAA:
            case DNS_RECORD_PTR: break;
            default: continue;
        }
        
        //Skip queries which are not addressed to me
        if (!nodeNameTypes[q]) continue; 
        
        //Bomb out if we are tending to overrun the buffer
        if (p - DnsHdrPacket > 512)
        {
            if (DnsServerTrace) LogTimeF("DnsServer-addAnswers - reply is getting too big\r\n");
            return -1;
        }
        
        //Count the number of answers
        DnsHdrAncount++;
                
        //Encode the node name
        switch (questionTypes[q])
        {
            case DNS_RECORD_A:
                DnsNameEncodePtr(myFullName, &p);
                break;
            case DNS_RECORD_AAAA:
                DnsNameEncodePtr(myFullName, &p);
                break;
            case DNS_RECORD_PTR:
                if (nodeNameTypes[q] == DNS_RECORD_A   ) DnsNameEncodeIp4(DhcpLocalIp,      &p);
                if (nodeNameTypes[q] == DNS_RECORD_AAAA) DnsNameEncodeIp6(SlaacLinkLocalIp, &p);
                break;
        }
        
        //Add the type
        *p++ =    0; *p++ = questionTypes[q];                    //16 bit Type
        
        //Add the class
        char mdns = dnsProtocol == DNS_PROTOCOL_MDNS ? 0x80 : 0; //Set the 15th bit (CACHE_FLUSH) of the class to 1 if MDNS
        *p++ = mdns; *p++ = 1;                                   //16 bit Class LSB QCLASS_IN = 1 - internet
        
        //Add the TTL
        *p++ =    0; *p++ = 0; *p++ = 4; *p++ = 0;               //32 bit TTL seconds - 1024
        
        
        //Add the 16 bit payload length
        switch (questionTypes[q])
        {
            case DNS_RECORD_A:    *p++ = 0; *p++ =  4;                   break;
            case DNS_RECORD_AAAA: *p++ = 0; *p++ = 16;                   break;
            case DNS_RECORD_PTR:  *p++ = 0; *p++ = myFullNameLength + 2; break; //name length plus one byte for initial length plus one byte for terminating zero
        }
        
        //Add the payload
        switch (questionTypes[q])
        {
            case DNS_RECORD_A:    memcpy(p, &DhcpLocalIp,      4); p +=  4; break;
            case DNS_RECORD_AAAA: memcpy(p, SlaacLinkLocalIp, 16); p += 16; break;
            case DNS_RECORD_PTR:  DnsNameEncodePtr(myFullName, &p);         break;
        }
    }
    return 0;
}

int DnsServerHandleQuery(void (*traceback)(void), int dnsProtocol, int *pSize) //Received an mdns or llmnr query on port 5353 or 5355
{    
    if (*pSize > 512)
    {
        if (DnsServerTrace)
        {
            if (NetTraceNewLine) Log("\r\n");
            LogTimeF("DnsServer length %d too long\r\n",  *pSize );
            if (NetTraceStack) traceback();
        }
        return DO_NOTHING;
    }
    
    myFullNameLength = DnsMakeFullNameFromName(dnsProtocol, NetName, sizeof(myFullName), myFullName);
    
    if (readQuestions()) return DO_NOTHING;
    if (!hadQueryForMe ) return DO_NOTHING;
    
    if (DnsServerTrace)
    {
        if (NetTraceNewLine) Log("\r\n");
        LogTimeF("DnsServer received query\r\n");
        if (NetTraceStack) traceback();
        DnsHdrLog(dnsProtocol);
    }
    
    if (addAnswers(dnsProtocol)) return DO_NOTHING;
    if (!DnsHdrAncount)          return DO_NOTHING;
    
    DnsHdrIsReply         = true;
    DnsHdrIsAuthoritative = true;
    DnsHdrWrite();
    
    *pSize = p - DnsHdrPacket;
    
    if (DnsServerTrace) DnsHdrLog(dnsProtocol);

    int dest;
    if (dnsProtocol == DNS_PROTOCOL_MDNS && !mdnsUnicastReply) dest = MULTICAST_MDNS;
    else                                                       dest =   UNICAST;
    
    return ActionMakeFromDestAndTrace(dest, NetTraceStack && DnsServerTrace);
}