Andrew Boyson / net

Dependents:   oldheating gps motorhome heating

Revision:
44:83ce5ace337b
Parent:
43:bc028d5a6424
Child:
57:e0fb648acf48
--- a/udp/dns/dnsserver.cpp	Thu Oct 19 20:56:58 2017 +0000
+++ b/udp/dns/dnsserver.cpp	Sun Oct 22 17:19:17 2017 +0000
@@ -13,32 +13,41 @@
 
 bool DnsServerTrace = false;
 
-#define MAX_QUERIES 16
+#define RECORD_NONE        0
+#define RECORD_PTR4        1
+#define RECORD_PTR6_LOCAL  2
+#define RECORD_PTR6_GLOBAL 3
+#define RECORD_A           4
+#define RECORD_AAAA_LOCAL  5
+#define RECORD_AAAA_GLOBAL 6
+
+#define MAX_ANSWERS 4
 
 //Set by 'initialise'
 static char* p;               //Position relative to DnsHdrData and is updated while both reading questions and writing answers
-static char  myFullName[100]; //The name, adjusted to include the domain if needed by the protocol, used when reading and when writing
-static int   myFullNameLength;
+static char  myFullName4[100]; //The name, adjusted to include the domain if needed by the protocol, used when reading and when writing
+static int   myFullName4Length;
+static char  myFullName6[100]; //The name, adjusted to include the domain if needed by the protocol, used when reading and when writing
+static int   myFullName6Length;
 
 //Set by readQuestions and used by answerQuestions
-static int   questionTypes[MAX_QUERIES];
-static int   nodeNameTypes[MAX_QUERIES];
-static bool  hadQueryForMe;
+static int   answers[MAX_ANSWERS];
+static int   answerCount = 0;
 static bool  mdnsUnicastReply;
 
 static int readQuestions()
 {
     p = DnsHdrData;
     mdnsUnicastReply = false;
-    hadQueryForMe    = false;
     
     //Get the questions
-    for (int q = 0; q < DnsHdrQdcount; q++)
+    answerCount = 0;
+    for (int i = 0; i < DnsHdrQdcount; i++)
     {
-        //Bomb out if we are tending to overrun the buffer
-        if (q >= MAX_QUERIES)
+        //Bomb out if there are too many answers
+        if (answerCount >= MAX_ANSWERS)
         {
-            if (DnsServerTrace) LogTimeF("DnsServer-readQuestions - exceeded MAX_QUERIES\r\n");
+            if (DnsServerTrace) LogTimeF("DnsServer-readQuestions - exceeded %d answers\r\n", MAX_ANSWERS);
             break;
         }
         
@@ -56,22 +65,27 @@
             if (DnsServerTrace) LogTimeF("DnsServer-readQuestions namelength is zero\r\n");
             return -1;
         }
-        nodeNameTypes[q] = DNS_RECORD_NONE;
-        if (!nodeNameTypes[q] && DnsNameComparePtr(p, myFullName)      ) nodeNameTypes[q] = DNS_RECORD_PTR;  //rtc.local
-        if (!nodeNameTypes[q] && DnsNameCompareIp4(p, DhcpLocalIp)     ) nodeNameTypes[q] = DNS_RECORD_A;    //3.1.168.192.inaddr.arpa
-        if (!nodeNameTypes[q] && DnsNameCompareIp6(p, SlaacLinkLocalIp)) nodeNameTypes[q] = DNS_RECORD_AAAA; //5.8.c.c.0.1.e.f.f.f.2.3.1.1.2.0.8.7.d.0.9.0.f.1.0.7.4.0.1.0.0.2.ip6.arpa
-        p += nameLength;                            //Skip past the name
-        
-        //Remember had a query for me
-        if (nodeNameTypes[q]) hadQueryForMe = true;
-        
+        bool nodeIsName4 = DnsNameComparePtr(p, myFullName4);
+        bool nodeIsName6 = DnsNameComparePtr(p, myFullName6);
+        bool nodeIsAddr4 = DnsNameCompareIp4(p, DhcpLocalIp);
+        bool nodeIsLocl6 = DnsNameCompareIp6(p, SlaacLinkLocalIp);
+        bool nodeIsGlob6 = DnsNameCompareIp6(p, SlaacGlobalIp);
+        p += nameLength;                          //Skip past the name
+                
         //Type
-        p++ ;                                       //skip the first byte of the type
-        questionTypes[q] = *p++;                    //read the record type
+        p++ ;                                     //skip the first byte of the type
+        char recordType = *p++;                   //read the record type
         
         //Class
-        if (*p++ & 0x80) mdnsUnicastReply = true;   //check the class 15th bit (UNICAST-RESPONSE)
-        p++;                                        //skip the class
+        if (*p++ & 0x80) mdnsUnicastReply = true; //check the class 15th bit (UNICAST-RESPONSE)
+        p++;                                      //skip the class
+        
+        //Handle the questions
+        if (nodeIsName4 && recordType == DNS_RECORD_A   )   answers[answerCount++] = RECORD_A;
+        if (nodeIsName6 && recordType == DNS_RECORD_AAAA) { answers[answerCount++] = RECORD_AAAA_LOCAL; answers[answerCount++] = RECORD_AAAA_GLOBAL; }
+        if (nodeIsAddr4 && recordType == DNS_RECORD_PTR )   answers[answerCount++] = RECORD_PTR4;
+        if (nodeIsLocl6 && recordType == DNS_RECORD_PTR )   answers[answerCount++] = RECORD_PTR6_LOCAL;
+        if (nodeIsGlob6 && recordType == DNS_RECORD_PTR )   answers[answerCount++] = RECORD_PTR6_GLOBAL;
     }
     return 0;
 }
@@ -80,49 +94,39 @@
     //Strip the questions if this is MDNS
     //if (dnsProtocol == DNS_PROTOCOL_MDNS) p = DnsHdrData; 
     
-    //Go through each question
+    //Go through each answer
     DnsHdrAncount = 0;
-    for (int q = 0; q < DnsHdrQdcount; q++)
-    {        
-        //Skip unwanted record types
-        switch (questionTypes[q])
-        {
-            case DNS_RECORD_A:
-            case DNS_RECORD_AAAA:
-            case DNS_RECORD_PTR: break;
-            default: continue;
-        }
-        
-        //Skip queries which are not addressed to me
-        if (!nodeNameTypes[q]) continue; 
-        
+    for (int i = 0; i < answerCount; i++)
+    {                
         //Bomb out if we are tending to overrun the buffer
         if (p - DnsHdrPacket > 512)
         {
             if (DnsServerTrace) LogTimeF("DnsServer-addAnswers - reply is getting too big\r\n");
             return -1;
         }
-        
-        //Count the number of answers
-        DnsHdrAncount++;
                 
         //Encode the node name
-        switch (questionTypes[q])
+        switch (answers[i])
         {
-            case DNS_RECORD_A:
-                DnsNameEncodePtr(myFullName, &p);
-                break;
-            case DNS_RECORD_AAAA:
-                DnsNameEncodePtr(myFullName, &p);
-                break;
-            case DNS_RECORD_PTR:
-                if (nodeNameTypes[q] == DNS_RECORD_A   ) DnsNameEncodeIp4(DhcpLocalIp,      &p);
-                if (nodeNameTypes[q] == DNS_RECORD_AAAA) DnsNameEncodeIp6(SlaacLinkLocalIp, &p);
-                break;
+            case RECORD_A:           DnsNameEncodePtr(myFullName4,      &p); break;
+            case RECORD_AAAA_LOCAL:  DnsNameEncodePtr(myFullName6,      &p); break;  
+            case RECORD_AAAA_GLOBAL: DnsNameEncodePtr(myFullName6,      &p); break;
+            case RECORD_PTR4:        DnsNameEncodeIp4(DhcpLocalIp,      &p); break;
+            case RECORD_PTR6_LOCAL:  DnsNameEncodeIp6(SlaacLinkLocalIp, &p); break;
+            case RECORD_PTR6_GLOBAL: DnsNameEncodeIp6(SlaacGlobalIp,    &p); break;
         }
         
-        //Add the type
-        *p++ =    0; *p++ = questionTypes[q];                    //16 bit Type
+        //Add the 16 bit type
+        *p++ = 0;
+        switch (answers[i])
+        {
+            case RECORD_A:           *p++ = DNS_RECORD_A;    break;
+            case RECORD_AAAA_LOCAL:  *p++ = DNS_RECORD_AAAA; break;
+            case RECORD_AAAA_GLOBAL: *p++ = DNS_RECORD_AAAA; break;
+            case RECORD_PTR4:        *p++ = DNS_RECORD_PTR;  break;
+            case RECORD_PTR6_LOCAL:  *p++ = DNS_RECORD_PTR;  break;
+            case RECORD_PTR6_GLOBAL: *p++ = DNS_RECORD_PTR;  break;
+        }
         
         //Add the class
         char mdns = dnsProtocol == DNS_PROTOCOL_MDNS ? 0x80 : 0; //Set the 15th bit (CACHE_FLUSH) of the class to 1 if MDNS
@@ -133,20 +137,29 @@
         
         
         //Add the 16 bit payload length
-        switch (questionTypes[q])
+        *p++ = 0;
+        switch (answers[i])
         {
-            case DNS_RECORD_A:    *p++ = 0; *p++ =  4;                   break;
-            case DNS_RECORD_AAAA: *p++ = 0; *p++ = 16;                   break;
-            case DNS_RECORD_PTR:  *p++ = 0; *p++ = myFullNameLength + 2; break; //name length plus one byte for initial length plus one byte for terminating zero
+            case RECORD_A:           *p++ =  4;                    break;
+            case RECORD_AAAA_LOCAL:  *p++ = 16;                    break;
+            case RECORD_AAAA_GLOBAL: *p++ = 16;                    break;
+            case RECORD_PTR4:        *p++ = myFullName4Length + 2; break; //add a byte for the initial length and another for the terminating zero length
+            case RECORD_PTR6_LOCAL:  *p++ = myFullName6Length + 2; break;
+            case RECORD_PTR6_GLOBAL: *p++ = myFullName6Length + 2; break;
         }
         
         //Add the payload
-        switch (questionTypes[q])
+        switch (answers[i])
         {
-            case DNS_RECORD_A:    memcpy(p, &DhcpLocalIp,      4); p +=  4; break;
-            case DNS_RECORD_AAAA: memcpy(p, SlaacLinkLocalIp, 16); p += 16; break;
-            case DNS_RECORD_PTR:  DnsNameEncodePtr(myFullName, &p);         break;
+            case RECORD_A:           memcpy(p, &DhcpLocalIp,      4); p +=  4; break;
+            case RECORD_AAAA_LOCAL:  memcpy(p, SlaacLinkLocalIp, 16); p += 16; break;
+            case RECORD_AAAA_GLOBAL: memcpy(p, SlaacGlobalIp,    16); p += 16; break;
+            case RECORD_PTR4:        DnsNameEncodePtr(myFullName4, &p);        break;
+            case RECORD_PTR6_LOCAL:  DnsNameEncodePtr(myFullName6, &p);        break;
+            case RECORD_PTR6_GLOBAL: DnsNameEncodePtr(myFullName6, &p);        break;
         }
+        //Increment the number of good answers to send
+        DnsHdrAncount++;
     }
     return 0;
 }
@@ -164,10 +177,11 @@
         return DO_NOTHING;
     }
     
-    myFullNameLength = DnsMakeFullNameFromName(dnsProtocol, NetName, sizeof(myFullName), myFullName);
+    myFullName4Length = DnsMakeFullNameFromName(dnsProtocol, NetName4, sizeof(myFullName4), myFullName4);
+    myFullName6Length = DnsMakeFullNameFromName(dnsProtocol, NetName6, sizeof(myFullName6), myFullName6);
     
     if (readQuestions()) return DO_NOTHING;
-    if (!hadQueryForMe ) return DO_NOTHING;
+    if (!answerCount) return DO_NOTHING;
     
     if (DnsServerTrace)
     {
@@ -178,7 +192,6 @@
     }
     
     if (addAnswers(dnsProtocol)) return DO_NOTHING;
-    if (!DnsHdrAncount)          return DO_NOTHING;
     
     DnsHdrIsReply         = true;
     DnsHdrIsAuthoritative = true;