Andrew Boyson / net

Dependents:   oldheating gps motorhome heating

Revision:
32:679654f2d023
Parent:
22:914b970356f0
Child:
33:714a0345e59b
diff -r eb98bc176c3d -r 679654f2d023 udp/dns/dnsserver.cpp
--- a/udp/dns/dnsserver.cpp	Fri Aug 11 17:41:52 2017 +0000
+++ b/udp/dns/dnsserver.cpp	Thu Aug 17 14:21:02 2017 +0000
@@ -10,27 +10,32 @@
 
 #define DEBUG false
 
-int DnsServerHandleQuery(int dnsProtocol, int *pSize) //Received an mdns or llmnr query on port 5353 or 5355
+//Set by 'initialise'
+char* p;             //Position relative to DnsHdrData and is updated while both reading questions and writing answers
+char  fullname[100]; //The name, adjusted to include the domain if needed by the protocol, used when reading and when writing
+int   fullnamelength;
+
+//Set by readQuestions and used by answerQuestions
+int   questionTypes[4];
+bool  mdnsUnicastReply;
+int   nodeNameType;
+
+static int initialise(int dnsProtocol, int size)
 {    
-    if (DEBUG) DnsHdrLog("server received query", dnsProtocol);
+    if (         size > 512) { if (DEBUG) LogTimeF("DnsServer-initialise length %d too long\r\n",  size         ); return -1; }
+    if (DnsHdrQdcount >   4) { if (DEBUG) LogTimeF("DnsServer-initialise too many queries %d\r\n", DnsHdrQdcount); return -1; }
     
-    if (       *pSize > 512)
-    {
-        if (DEBUG) LogTimeF("DnsServerHandleQuery length %d too long\r\n",  *pSize);
-        return DO_NOTHING;
-    }
-    if (DnsHdrQdcount >   4)
-    {
-        if (DEBUG) LogTimeF("DnsServerHandleQuery too many queries %d\r\n", *pSize);
-        return DO_NOTHING;
-    }
+    p = DnsHdrData;
+    
+    fullnamelength = DnsMakeFullNameFromName(dnsProtocol, NetName, sizeof(fullname), fullname);
     
-    char *p = DnsHdrData;
-    
+    return 0;
+}
+static int readQuestions()
+{
     char iEncodedName;
-    bool isMe = false;
-    int  types[4];
-    bool mdnsUnicastReply = false;
+    nodeNameType = DNS_RECORD_NONE;
+    mdnsUnicastReply = false;
     
     //Get the questions
     DnsHdrAncount = 0;
@@ -39,14 +44,19 @@
         iEncodedName = DnsNameIndexFromPointer(p);
         int nameLength = DnsNameLength(p);
         if (!nameLength) return DO_NOTHING;
-        if (!q) isMe = DnsNameCompare(p, NetName); //get the name: rtc.local; 3.1.168.192.inaddr.arpa; etc
-        p += nameLength;                           //Skip past the name
-        p++ ;                                      //skip the first byte of the type
-        char recordType = *p++;                    //read the record type
-        if (*p++ & 0x80) mdnsUnicastReply = true;  //Check the 15th bit (UNICAST-RESPONSE)
-        p += 1;                                    //skip the class
+        if (!q)
+        {
+            if (!nodeNameType && DnsNameCompare   (p, fullname)        ) nodeNameType = DNS_RECORD_PTR;  //rtc.local
+            if (!nodeNameType && DnsNameCompareIp4(p, DhcpLocalIp)     ) nodeNameType = DNS_RECORD_A;    //3.1.168.192.inaddr.arpa
+            if (!nodeNameType && DnsNameCompareIp6(p, SlaacLinkLocalIp)) nodeNameType = DNS_RECORD_AAAA; //5.8.c.c.0.1.e.f.f.f.2.3.1.1.2.0.8.7.d.0.9.0.f.1.0.7.4.0.1.0.0.2.ip6.arpa
+        }
+        p += nameLength;                            //Skip past the name
+        p++ ;                                       //skip the first byte of the type
+        char recordType = *p++;                     //read the record type
+        if (*p++ & 0x80) mdnsUnicastReply = true;   //Check the 15th bit (UNICAST-RESPONSE)
+        p += 1;                                     //skip the class
         
-        types[q] = recordType;
+        questionTypes[q] = recordType;
         
         if (DEBUG)
         {
@@ -62,81 +72,69 @@
             LogF("  Name     %s\r\n", text);
         }    
     }
-    if (!isMe) return DO_NOTHING;
-    
-    //Respond to the questions
+    return 0;
+}
+static int addAnswers(int dnsProtocol)
+{
     for (int q = 0; q < DnsHdrQdcount; q++)
     {
-        switch (types[q])
+        //Skip unwanted record types
+        switch (questionTypes[q])
         {
             case DNS_RECORD_A:
-                if (p - DnsHdrPacket > 500)
-                {
-                    LogTimeF("DNS server ip4 query reply is getting too big\r\n");
-                    return DO_NOTHING;
-                }
-                DnsHdrAncount++;
-                //Name
-                DnsNameEncode(NetName, &p);
-                
-                //16 bit Type
-                *p++ = 0; //MSB type
-                *p++ = DNS_RECORD_A;
-                
-                //Class
-                *p++ = dnsProtocol == DNS_PROTOCOL_MDNS ? 0x80 : 0; //Set the 15th bit (CACHE_FLUSH) to 1 if MDNS
-                *p++ = 1;                                           //QCLASS_IN = 1 - internet
-                
-                 //32 bit TTL seconds
-                *p++ = 0;
-                *p++ = 0;
-                *p++ = 4; //1024 seconds
-                *p++ = 0;
-                
-                //16bit length in bytes
-                *p++ = 0;
-                *p++ = 4;
-                
-                //Value
-                memcpy(p, &DhcpLocalIp, 4);
-                p += 4;
+            case DNS_RECORD_AAAA:
+            case DNS_RECORD_PTR: continue;
+        }
+        if (p - DnsHdrPacket > 500)
+        {
+            LogTimeF("DnsServer-addAnswers Ip4 query reply is getting too big\r\n");
+            return -1;
+        }
+        DnsHdrAncount++;
+        int lenPayload = 0;
+        char* pPayload = 0;
+        switch (questionTypes[q])
+        {
+            case DNS_RECORD_A:
+                DnsNameEncode(fullname, &p);
+                lenPayload = 4;
+                pPayload = (char*)&DhcpLocalIp;
                 break;
                 
             case DNS_RECORD_AAAA:
-                if (p - DnsHdrPacket > 500)
-                {
-                    LogTimeF("DNS server Ip6 query reply is getting too big\r\n");
-                    return DO_NOTHING;
-                }
-                DnsHdrAncount++;
-                //Name
-                DnsNameEncode(NetName, &p);
-                
-                //16 bit Type
-                *p++ = 0; //MSB type
-                *p++ = DNS_RECORD_AAAA;
+                DnsNameEncode(fullname, &p);
+                lenPayload = 16;
+                pPayload = SlaacLinkLocalIp;
+                break;
                 
-                //Class
-                *p++ = dnsProtocol == DNS_PROTOCOL_MDNS ? 0x80 : 0; //Set the 15th bit (CACHE_FLUSH) to 1 if MDNS
-                *p++ = 1;                                           //QCLASS_IN = 1 - internet
-                
-                 //32 bit TTL seconds
-                *p++ = 0;
-                *p++ = 0;
-                *p++ = 4; //1024 seconds
-                *p++ = 0;
-                
-                //16bit length in bytes
-                *p++ = 0;
-                *p++ = 16;
-                
-                //Value
-                memcpy(p, SlaacLinkLocalIp, 16);
-                p += 16;
+            case DNS_RECORD_PTR:
+                if (nodeNameType == DNS_RECORD_A   ) DnsNameEncodeIp4(DhcpLocalIp,      &p);
+                if (nodeNameType == DNS_RECORD_AAAA) DnsNameEncodeIp6(SlaacLinkLocalIp, &p);
+                lenPayload = fullnamelength;
+                pPayload = fullname;
                 break;
         }
+        char mdns = dnsProtocol == DNS_PROTOCOL_MDNS ? 0x80 : 0; //Set the 15th bit (CACHE_FLUSH) of the class to 1 if MDNS
+        *p++ =    0; *p++ = questionTypes[q];                    //16 bit Type
+        *p++ = mdns; *p++ = 1;                                   //16 bit Class LSB QCLASS_IN = 1 - internet
+        *p++ =    0; *p++ = 0; *p++ = 4; *p++ = 0;               //32 bit TTL seconds - 1024
+        *p++ =    0; *p++ = lenPayload;                          //16 bit length in bytes
+        memcpy(p, pPayload, lenPayload);                         //Copy the payload
+        p += lenPayload;                                         //Adjust the pointer to the next character afetr the payload
+
     }
-    if (!DnsHdrAncount) return DO_NOTHING;
+    return 0;
+}
+
+int DnsServerHandleQuery(int dnsProtocol, int *pSize) //Received an mdns or llmnr query on port 5353 or 5355
+{
+    if (DEBUG) DnsHdrLog("DnsServer received query", dnsProtocol);
+    
+    if (initialise(dnsProtocol, *pSize)) return DO_NOTHING;
+    if (readQuestions()                ) return DO_NOTHING;
+    if (!nodeNameType)                   return DO_NOTHING;
+    if (addAnswers(dnsProtocol)        ) return DO_NOTHING;
+    if (!DnsHdrAncount)                  return DO_NOTHING;
     
     DnsHdrIsReply         = true;
     DnsHdrIsAuthoritative = true;
@@ -144,7 +142,7 @@
     
     *pSize = p - DnsHdrPacket;
     
-    if (DEBUG) DnsHdrLog("server sending reply", dnsProtocol);
+    if (DEBUG) DnsHdrLog("DnsServer sending reply", dnsProtocol);
 
     if (dnsProtocol == DNS_PROTOCOL_MDNS && !mdnsUnicastReply) return MULTICAST_MDNS;
     return UNICAST;