Andrew Boyson / net

Dependents:   oldheating gps motorhome heating

Revision:
37:793b39683406
Parent:
36:900e24b27bfb
Child:
43:bc028d5a6424
--- a/udp/dns/dnsserver.cpp	Mon Sep 25 07:09:32 2017 +0000
+++ b/udp/dns/dnsserver.cpp	Wed Oct 04 07:51:02 2017 +0000
@@ -2,13 +2,18 @@
 #include  "dnshdr.h"
 #include "dnsname.h"
 #include     "net.h"
+#include  "action.h"
 #include     "dns.h"
 #include     "log.h"
 #include    "dhcp.h"
 #include   "slaac.h"
 #include      "io.h"
+#include     "ip4.h"
+#include     "ip6.h"
 
-#define DEBUG true
+bool DnsServerTrace = false;
+
+#define MAX_QUERIES 16
 
 //Set by 'initialise'
 static char* p;               //Position relative to DnsHdrData and is updated while both reading questions and writing answers
@@ -16,67 +21,69 @@
 static int   myFullNameLength;
 
 //Set by readQuestions and used by answerQuestions
-static int   questionTypes[4];
-static int   nodeNameTypes[4];
+static int   questionTypes[MAX_QUERIES];
+static int   nodeNameTypes[MAX_QUERIES];
+static bool  hadQueryForMe;
 static bool  mdnsUnicastReply;
 
-static int initialise(int dnsProtocol, int size)
-{    
-    if (         size > 512) { if (DEBUG) LogTimeF("DnsServer-initialise length %d too long\r\n",  size         ); return -1; }
-    if (DnsHdrQdcount >   4) { if (DEBUG) LogTimeF("DnsServer-initialise too many queries %d\r\n", DnsHdrQdcount); return -1; }
-    
-    p = DnsHdrData;
-    
-    myFullNameLength = DnsMakeFullNameFromName(dnsProtocol, NetName, sizeof(myFullName), myFullName);
-    
-    return 0;
-}
 static int readQuestions()
 {
-    char iEncodedName;
+    p = DnsHdrData;
     mdnsUnicastReply = false;
+    hadQueryForMe    = false;
     
     //Get the questions
     for (int q = 0; q < DnsHdrQdcount; q++)
     {
-        iEncodedName = DnsNameIndexFromPointer(p);
+        //Bomb out if we are tending to overrun the buffer
+        if (q >= MAX_QUERIES)
+        {
+            if (DnsServerTrace) LogTimeF("DnsServer-readQuestions - exceeded MAX_QUERIES\r\n");
+            break;
+        }
+        
+        //Bomb out if we are tending to overrun the buffer
+        if (p - DnsHdrPacket > 512)
+        {
+            if (DnsServerTrace) LogTimeF("DnsServer-readQuestions - overrunning the buffer\r\n");
+            return -1;
+        }
+        
+        //Node name
         int nameLength = DnsNameLength(p);
-        if (!nameLength) { if (DEBUG) LogTimeF("DnsServer-readQuestions namelength is zero\r\n"); return -1; }
+        if (!nameLength)
+        {
+            if (DnsServerTrace) LogTimeF("DnsServer-readQuestions namelength is zero\r\n");
+            return -1;
+        }
         nodeNameTypes[q] = DNS_RECORD_NONE;
-        if (!nodeNameTypes[q] && DnsNameCompare   (p, myFullName)      ) nodeNameTypes[q] = DNS_RECORD_PTR;  //rtc.local
+        if (!nodeNameTypes[q] && DnsNameComparePtr(p, myFullName)      ) nodeNameTypes[q] = DNS_RECORD_PTR;  //rtc.local
         if (!nodeNameTypes[q] && DnsNameCompareIp4(p, DhcpLocalIp)     ) nodeNameTypes[q] = DNS_RECORD_A;    //3.1.168.192.inaddr.arpa
         if (!nodeNameTypes[q] && DnsNameCompareIp6(p, SlaacLinkLocalIp)) nodeNameTypes[q] = DNS_RECORD_AAAA; //5.8.c.c.0.1.e.f.f.f.2.3.1.1.2.0.8.7.d.0.9.0.f.1.0.7.4.0.1.0.0.2.ip6.arpa
         p += nameLength;                            //Skip past the name
-        p++ ;                                       //skip the first byte of the type
-        char recordType = *p++;                     //read the record type
-        if (*p++ & 0x80) mdnsUnicastReply = true;   //Check the 15th bit (UNICAST-RESPONSE)
-        p += 1;                                     //skip the class
         
-        questionTypes[q] = recordType;
+        //Remember had a query for me
+        if (nodeNameTypes[q]) hadQueryForMe = true;
         
-        if (DEBUG)
-        {
-            switch (recordType)
-            {
-                case DNS_RECORD_A:    LogF("  for IP4 address of "); break;
-                case DNS_RECORD_PTR:  LogF("  for name of ");        break;
-                case DNS_RECORD_AAAA: LogF("  for IP6 address of "); break;
-                default:              LogF("  for unrecognised record type %d of ", recordType); break;
-            }
-            char text[256];
-            DnsNameDecode(iEncodedName, sizeof(text), text);
-            LogF("%s\r\n", text);
-        }    
+        //Type
+        p++ ;                                       //skip the first byte of the type
+        questionTypes[q] = *p++;                    //read the record type
+        
+        //Class
+        if (*p++ & 0x80) mdnsUnicastReply = true;   //check the class 15th bit (UNICAST-RESPONSE)
+        p++;                                        //skip the class
     }
     return 0;
 }
 static int addAnswers(int dnsProtocol)
 {
+    //Strip the questions if this is MDNS
+    //if (dnsProtocol == DNS_PROTOCOL_MDNS) p = DnsHdrData; 
+    
+    //Go through each question
     DnsHdrAncount = 0;
     for (int q = 0; q < DnsHdrQdcount; q++)
-    {
-        if (DEBUG) LogF("  deal with question %d, answer %d, question %d, node %d\r\n", q, DnsHdrAncount, questionTypes[q], nodeNameTypes[q]);
-        
+    {        
         //Skip unwanted record types
         switch (questionTypes[q])
         {
@@ -85,59 +92,98 @@
             case DNS_RECORD_PTR: break;
             default: continue;
         }
-        if (!nodeNameTypes[q]) continue; //Skip queries which are not addressed to me
-        if (p - DnsHdrPacket > 500)
+        
+        //Skip queries which are not addressed to me
+        if (!nodeNameTypes[q]) continue; 
+        
+        //Bomb out if we are tending to overrun the buffer
+        if (p - DnsHdrPacket > 512)
         {
-            LogTimeF("DnsServer-addAnswers Ip4 query reply is getting too big\r\n");
+            if (DnsServerTrace) LogTimeF("DnsServer-addAnswers - reply is getting too big\r\n");
             return -1;
         }
+        
+        //Count the number of answers
         DnsHdrAncount++;
-        int lenPayload = 0;
-        char* pPayload = 0;
+        
+        //Log what we are doing
+        if (DnsServerTrace)
+        {
+            switch (questionTypes[q])
+            {
+                case DNS_RECORD_A:    Log("  replied with my IPv4 address\r\n"); break;
+                case DNS_RECORD_AAAA: Log("  replied with my IPv6 address\r\n"); break;
+                case DNS_RECORD_PTR:  Log("  replied with my name\r\n");         break;
+            }
+        }
+        
+        //Encode the node name
         switch (questionTypes[q])
         {
             case DNS_RECORD_A:
-                if (DEBUG) Log("  replied with my IPv4 address\r\n");
-                DnsNameEncode(myFullName, &p);
-                  pPayload = (char*)&DhcpLocalIp;
-                lenPayload = 4;
+                DnsNameEncodePtr(myFullName, &p);
                 break;
-                
             case DNS_RECORD_AAAA:
-                if (DEBUG) Log("  replied with my IPv6 address\r\n");
-                DnsNameEncode(myFullName, &p);
-                  pPayload = SlaacLinkLocalIp;
-                lenPayload = 16;
+                DnsNameEncodePtr(myFullName, &p);
                 break;
-                
             case DNS_RECORD_PTR:
-                if (DEBUG) Log("  replied with my name\r\n");
                 if (nodeNameTypes[q] == DNS_RECORD_A   ) DnsNameEncodeIp4(DhcpLocalIp,      &p);
                 if (nodeNameTypes[q] == DNS_RECORD_AAAA) DnsNameEncodeIp6(SlaacLinkLocalIp, &p);
-                  pPayload = myFullName;
-                lenPayload = myFullNameLength;
                 break;
         }
-        char mdns = dnsProtocol == DNS_PROTOCOL_MDNS ? 0x80 : 0; //Set the 15th bit (CACHE_FLUSH) of the class to 1 if MDNS
+        
+        //Add the type
         *p++ =    0; *p++ = questionTypes[q];                    //16 bit Type
+        
+        //Add the class
+        char mdns = dnsProtocol == DNS_PROTOCOL_MDNS ? 0x80 : 0; //Set the 15th bit (CACHE_FLUSH) of the class to 1 if MDNS
         *p++ = mdns; *p++ = 1;                                   //16 bit Class LSB QCLASS_IN = 1 - internet
+        
+        //Add the TTL
         *p++ =    0; *p++ = 0; *p++ = 4; *p++ = 0;               //32 bit TTL seconds - 1024
-        *p++ =    0; *p++ = lenPayload;                          //16 bit length in bytes
-        memcpy(p, pPayload, lenPayload);                         //Copy the payload
-        p += lenPayload;                                         //Adjust the pointer to the next character after the payload
-
+        
+        
+        //Add the 16 bit payload length
+        switch (questionTypes[q])
+        {
+            case DNS_RECORD_A:    *p++ = 0; *p++ =  4;                   break;
+            case DNS_RECORD_AAAA: *p++ = 0; *p++ = 16;                   break;
+            case DNS_RECORD_PTR:  *p++ = 0; *p++ = myFullNameLength + 2; break; //name length plus one byte for initial length plus one byte for terminating zero
+        }
+        
+        //Add the payload
+        switch (questionTypes[q])
+        {
+            case DNS_RECORD_A:    memcpy(p, &DhcpLocalIp,      4); p +=  4; break;
+            case DNS_RECORD_AAAA: memcpy(p, SlaacLinkLocalIp, 16); p += 16; break;
+            case DNS_RECORD_PTR:  DnsNameEncodePtr(myFullName, &p);         break;
+        }
     }
     return 0;
 }
 
-int DnsServerHandleQuery(int dnsProtocol, int *pSize) //Received an mdns or llmnr query on port 5353 or 5355
-{
-    if (DEBUG) DnsHdrLog("DnsServer received query", dnsProtocol);
+int DnsServerHandleQuery(void (*traceback)(void), int dnsProtocol, int *pSize) //Received an mdns or llmnr query on port 5353 or 5355
+{    
+    if (*pSize > 512)
+    {
+        if (DnsServerTrace) LogTimeF("DnsServerHandleQuery length %d too long\r\n",  *pSize );
+        return DO_NOTHING;
+    }
+    
+    myFullNameLength = DnsMakeFullNameFromName(dnsProtocol, NetName, sizeof(myFullName), myFullName);
     
-    if (initialise(dnsProtocol, *pSize)) return DO_NOTHING;
-    if (readQuestions()                ) return DO_NOTHING;
-    if (addAnswers(dnsProtocol)        ) return DO_NOTHING;
-    if (!DnsHdrAncount)                  return DO_NOTHING;
+    if (readQuestions()) return DO_NOTHING;
+    if (!hadQueryForMe ) return DO_NOTHING;
+    
+    if (DnsServerTrace)
+    {
+        LogTimeF("DnsServer received query\r\n");
+        if (NetTraceBack) traceback();
+        DnsHdrLog(dnsProtocol);
+    }
+    
+    if (addAnswers(dnsProtocol)) return DO_NOTHING;
+    if (!DnsHdrAncount)          return DO_NOTHING;
     
     DnsHdrIsReply         = true;
     DnsHdrIsAuthoritative = true;
@@ -145,8 +191,11 @@
     
     *pSize = p - DnsHdrPacket;
     
-    if (DEBUG) DnsHdrLog("DnsServer sending reply", dnsProtocol);
+    if (DnsServerTrace) DnsHdrLog(dnsProtocol);
 
-    if (dnsProtocol == DNS_PROTOCOL_MDNS && !mdnsUnicastReply) return MULTICAST_MDNS;
-    return UNICAST;
+    int dest;
+    if (dnsProtocol == DNS_PROTOCOL_MDNS && !mdnsUnicastReply) dest = MULTICAST_MDNS;
+    else                                                       dest =   UNICAST;
+    
+    return ActionMakeFromDestAndTrace(dest, NetTraceForward && DnsServerTrace);
 }