Andrew Boyson / net

Dependents:   oldheating gps motorhome heating

Revision:
61:aad055f1b0d1
Parent:
59:e0e556c8bd46
Child:
128:79052cb4a41c
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/udp/dns/dnsserver.c	Thu Jan 11 17:38:21 2018 +0000
@@ -0,0 +1,223 @@
+#include <stdbool.h>
+#include <string.h>
+
+#include  "dnshdr.h"
+#include "dnsname.h"
+#include     "net.h"
+#include  "action.h"
+#include     "dns.h"
+#include     "log.h"
+#include    "dhcp.h"
+#include   "slaac.h"
+#include     "ip4.h"
+#include     "ip6.h"
+
+bool DnsServerTrace = false;
+
+#define RECORD_NONE        0
+#define RECORD_PTR4        1
+#define RECORD_PTR6_LOCAL  2
+#define RECORD_PTR6_GLOBAL 3
+#define RECORD_A           4
+#define RECORD_AAAA_LOCAL  5
+#define RECORD_AAAA_GLOBAL 6
+
+#define MAX_ANSWERS 4
+
+//Set by 'initialise'
+static char* p;               //Position relative to DnsHdrData and is updated while both reading questions and writing answers
+static char  myFullName4[100]; //The name, adjusted to include the domain if needed by the protocol, used when reading and when writing
+static int   myFullName4Length;
+static char  myFullName6[100]; //The name, adjusted to include the domain if needed by the protocol, used when reading and when writing
+static int   myFullName6Length;
+
+//Set by readQuestions and used by answerQuestions
+static int   answers[MAX_ANSWERS];
+static int   answerCount = 0;
+static bool  mdnsUnicastReply;
+
+static int readQuestions()
+{
+    p = DnsHdrData;
+    mdnsUnicastReply = false;
+    
+    //Get the questions
+    answerCount = 0;
+    for (int i = 0; i < DnsHdrQdcount; i++)
+    {
+        //Bomb out if there are too many answers
+        if (answerCount >= MAX_ANSWERS)
+        {
+            if (DnsServerTrace) LogTimeF("DnsServer-readQuestions - exceeded %d answers\r\n", MAX_ANSWERS);
+            break;
+        }
+        
+        //Bomb out if we are tending to overrun the buffer
+        if (p - DnsHdrData > DnsHdrDataLength)
+        {
+            if (DnsServerTrace) LogTimeF("DnsServer-readQuestions - overrunning the buffer of %d bytes\r\n", DnsHdrDataLength);
+            return -1;
+        }
+        
+        //Node name
+        int nameLength = DnsNameLength(p);
+        if (!nameLength)
+        {
+            if (DnsServerTrace) LogTimeF("DnsServer-readQuestions namelength is zero\r\n");
+            return -1;
+        }
+        bool nodeIsName4 = DnsNameComparePtr(p, myFullName4);
+        bool nodeIsName6 = DnsNameComparePtr(p, myFullName6);
+        bool nodeIsAddr4 = DnsNameCompareIp4(p, DhcpLocalIp);
+        bool nodeIsLocl6 = DnsNameCompareIp6(p, SlaacLinkLocalIp);
+        bool nodeIsGlob6 = DnsNameCompareIp6(p, SlaacGlobalIp);
+        p += nameLength;                          //Skip past the name
+                
+        //Type
+        p++ ;                                     //skip the first byte of the type
+        char recordType = *p++;                   //read the record type
+        
+        //Class
+        if (*p++ & 0x80) mdnsUnicastReply = true; //check the class 15th bit (UNICAST-RESPONSE)
+        p++;                                      //skip the class
+        
+        //Handle the questions
+        if (nodeIsName4 && recordType == DNS_RECORD_A   )   answers[answerCount++] = RECORD_A;
+        if (nodeIsName6 && recordType == DNS_RECORD_AAAA) { answers[answerCount++] = RECORD_AAAA_LOCAL; answers[answerCount++] = RECORD_AAAA_GLOBAL; }
+        if (nodeIsAddr4 && recordType == DNS_RECORD_PTR )   answers[answerCount++] = RECORD_PTR4;
+        if (nodeIsLocl6 && recordType == DNS_RECORD_PTR )   answers[answerCount++] = RECORD_PTR6_LOCAL;
+        if (nodeIsGlob6 && recordType == DNS_RECORD_PTR )   answers[answerCount++] = RECORD_PTR6_GLOBAL;
+    }
+    return 0;
+}
+static int addAnswers(int dnsProtocol)
+{    
+    //Go through each answer
+    DnsHdrAncount = 0;
+    for (int i = 0; i < answerCount; i++)
+    {                
+        //Bomb out if we are tending to overrun the buffer
+        if (p - DnsHdrData > DnsHdrDataLength)
+        {
+            if (DnsServerTrace) LogTimeF("DnsServer-addAnswers - reply is getting too big\r\n");
+            return -1;
+        }
+                
+        //Encode the node name
+        switch (answers[i])
+        {
+            case RECORD_A:           DnsNameEncodePtr(myFullName4,      &p); break;
+            case RECORD_AAAA_LOCAL:  DnsNameEncodePtr(myFullName6,      &p); break;  
+            case RECORD_AAAA_GLOBAL: DnsNameEncodePtr(myFullName6,      &p); break;
+            case RECORD_PTR4:        DnsNameEncodeIp4(DhcpLocalIp,      &p); break;
+            case RECORD_PTR6_LOCAL:  DnsNameEncodeIp6(SlaacLinkLocalIp, &p); break;
+            case RECORD_PTR6_GLOBAL: DnsNameEncodeIp6(SlaacGlobalIp,    &p); break;
+        }
+        
+        //Add the 16 bit type
+        *p++ = 0;
+        switch (answers[i])
+        {
+            case RECORD_A:           *p++ = DNS_RECORD_A;    break;
+            case RECORD_AAAA_LOCAL:  *p++ = DNS_RECORD_AAAA; break;
+            case RECORD_AAAA_GLOBAL: *p++ = DNS_RECORD_AAAA; break;
+            case RECORD_PTR4:        *p++ = DNS_RECORD_PTR;  break;
+            case RECORD_PTR6_LOCAL:  *p++ = DNS_RECORD_PTR;  break;
+            case RECORD_PTR6_GLOBAL: *p++ = DNS_RECORD_PTR;  break;
+        }
+        
+        //Add the class
+        char mdns = dnsProtocol == DNS_PROTOCOL_MDNS ? 0x80 : 0; //Set the 15th bit (CACHE_FLUSH) of the class to 1 if MDNS
+        *p++ = mdns; *p++ = 1;                                   //16 bit Class LSB QCLASS_IN = 1 - internet
+        
+        //Add the TTL
+        *p++ =    0; *p++ = 0; *p++ = 4; *p++ = 0;               //32 bit TTL seconds - 1024
+        
+        
+        //Add the 16 bit payload length
+        *p++ = 0;
+        switch (answers[i])
+        {
+            case RECORD_A:           *p++ =  4;                    break;
+            case RECORD_AAAA_LOCAL:  *p++ = 16;                    break;
+            case RECORD_AAAA_GLOBAL: *p++ = 16;                    break;
+            case RECORD_PTR4:        *p++ = myFullName4Length + 2; break; //add a byte for the initial length and another for the terminating zero length
+            case RECORD_PTR6_LOCAL:  *p++ = myFullName6Length + 2; break;
+            case RECORD_PTR6_GLOBAL: *p++ = myFullName6Length + 2; break;
+        }
+        
+        //Add the payload
+        switch (answers[i])
+        {
+            case RECORD_A:           memcpy(p, &DhcpLocalIp,      4); p +=  4; break;
+            case RECORD_AAAA_LOCAL:  memcpy(p, SlaacLinkLocalIp, 16); p += 16; break;
+            case RECORD_AAAA_GLOBAL: memcpy(p, SlaacGlobalIp,    16); p += 16; break;
+            case RECORD_PTR4:        DnsNameEncodePtr(myFullName4, &p);        break;
+            case RECORD_PTR6_LOCAL:  DnsNameEncodePtr(myFullName6, &p);        break;
+            case RECORD_PTR6_GLOBAL: DnsNameEncodePtr(myFullName6, &p);        break;
+        }
+        //Increment the number of good answers to send
+        DnsHdrAncount++;
+    }
+    return 0;
+}
+
+int DnsServerHandleQuery(void (*traceback)(void), int dnsProtocol, void* pPacketTx, int *pSizeTx) //Received an mdns or llmnr query on port 5353 or 5355
+{            
+    myFullName4Length = DnsMakeFullNameFromName(dnsProtocol, NetName4, sizeof(myFullName4), myFullName4);
+    myFullName6Length = DnsMakeFullNameFromName(dnsProtocol, NetName6, sizeof(myFullName6), myFullName6);
+    
+    if (readQuestions()) return DO_NOTHING;
+    if (!answerCount) return DO_NOTHING;
+    
+    if (DnsServerTrace || NetTraceHostGetMatched())
+    {
+        if (NetTraceNewLine) Log("\r\n");
+        LogTimeF("DnsServer received query\r\n");
+        if (NetTraceStack) traceback();
+        DnsHdrLog(dnsProtocol);
+    }
+    
+    char* pRx = DnsHdrData;
+    char* pEndRx = p;
+    int qdcount = DnsHdrQdcount;
+    int nscount = DnsHdrNscount;
+    int arcount = DnsHdrArcount;
+    
+    DnsHdrSetup(pPacketTx, *pSizeTx);
+    p = DnsHdrData; 
+    
+    //Add the questions if this is not MDNS
+    if (dnsProtocol == DNS_PROTOCOL_MDNS)
+    {
+        DnsHdrQdcount = 0;
+        DnsHdrNscount = 0;
+        DnsHdrArcount = 0;
+    }
+    else
+    {
+        DnsHdrQdcount = qdcount;
+        DnsHdrNscount = nscount;
+        DnsHdrArcount = arcount;
+        while (pRx < pEndRx) *p++ = *pRx++;
+    }
+    
+
+    if (addAnswers(dnsProtocol)) return DO_NOTHING;
+    
+    DnsHdrIsReply          = true;
+    DnsHdrIsAuthoritative  = false;
+    DnsHdrIsRecursiveQuery = false;
+    
+    DnsHdrWrite();
+    
+    *pSizeTx = p - DnsHdrPacket;
+    
+    if (DnsServerTrace || NetTraceHostGetMatched()) DnsHdrLog(dnsProtocol);
+
+    int dest;
+    if (dnsProtocol == DNS_PROTOCOL_MDNS && !mdnsUnicastReply) dest = MULTICAST_MDNS;
+    else                                                       dest =   UNICAST;
+    
+    return ActionMakeFromDestAndTrace(dest, NetTraceStack && DnsServerTrace || NetTraceHostGetMatched());
+}