A stack which works with or without an Mbed os library. Provides IPv4 or IPv6 with a full 1500 byte buffer.

Dependents:   oldheating gps motorhome heating

udp/dns/dnsreply.cpp

Committer:
andrewboyson
Date:
2017-10-04
Revision:
37:793b39683406
Parent:
35:93c39d260a83
Child:
43:bc028d5a6424

File content as of revision 37:793b39683406:

#include   "mbed.h"
#include    "log.h"
#include    "net.h"
#include "action.h"
#include    "ip4.h"
#include    "ip6.h"
#include     "ar.h"
#include     "nr.h"
#include    "dns.h"
#include "dnshdr.h"
#include "dnsname.h"
#include   "dhcp.h"

bool DnsReplyTrace = false;

char     DnsReplyRecordName[256];
uint32_t DnsReplyRecordNameAsIp4 = 0;
char     DnsReplyRecordNameAsIp6[16];
char     DnsReplyName[64];
uint32_t DnsReplyIp4 = 0;
char     DnsReplyIp6[16];

static char *p;
static int recordNameOffset;
static int recordType;
static int recordDataLength;
static char* pRecordData;

static int scanQuery()
{
    int recordNameLength = DnsNameLength(p);
    if (!recordNameLength)
    {
        if (DnsReplyTrace) LogTimeF("DnsReply scanRecord name length of zero\r\n");
        return -1; //failure
    }
    
    recordNameOffset = DnsNameIndexFromPointer(p);
    p += recordNameLength;
    
    p++ ; //skip the first byte of the type
    recordType = *p++;
    
    p += 2; //skip the class
    
    return 0; //success
}

static int scanAnswer()
{
    int recordNameLength = DnsNameLength(p);
    if (!recordNameLength)
    {
        if (DnsReplyTrace) LogTimeF("DnsReply scanRecord name length of zero\r\n");
        return -1; //failure
    }
    
    recordNameOffset = DnsNameIndexFromPointer(p);
    p += recordNameLength;
    
    p++ ; //skip the first byte of the type
    recordType = *p++;
    
    p += 6; //skip the class, TTL
    recordDataLength = 0;
    recordDataLength |= *p++ <<  8;
    recordDataLength |= *p++;
    
    pRecordData = p; //record the start of the data
    
    p += recordDataLength; //Move to the start of the next record
    
    return 0; //success
}
static void readAnswer()
{
    DnsReplyRecordName[0]      = 0;
    DnsReplyRecordNameAsIp4    = 0;
    DnsReplyRecordNameAsIp6[0] = 0;
    DnsReplyName[0]            = 0;
    DnsReplyIp4                = 0;
    DnsReplyIp6[0]             = 0;
    
    switch (recordType)
    {
        case DNS_RECORD_A:
        case DNS_RECORD_AAAA:
        case DNS_RECORD_PTR:
        case DNS_RECORD_SRV:
        case DNS_RECORD_TXT:
            break;
        default:    
            LogTimeF("DnsReply readAnswer unrecognised record type %d\r\n", recordType);
            return;
    }

    DnsNameDecodePtr(recordNameOffset, sizeof(DnsReplyRecordName), DnsReplyRecordName);
    DnsNameDecodeIp4(recordNameOffset, &DnsReplyRecordNameAsIp4);
    DnsNameDecodeIp6(recordNameOffset,  DnsReplyRecordNameAsIp6);

    switch (recordType)
    {
        case DNS_RECORD_A:
            if (recordDataLength != 4)
            {
                LogTimeF("DnsReply A type length of %d\r\n", recordDataLength);
                return;
            }
            memcpy(&DnsReplyIp4, pRecordData, 4);
            break;
        case DNS_RECORD_AAAA:
            if (recordDataLength != 16)
            {
                LogTimeF("DnsReply AAAA type length of %d\r\n", recordDataLength);
                return;
            }
            memcpy(DnsReplyIp6, pRecordData, 16);
            break;
        case DNS_RECORD_PTR:
            if (recordDataLength > DNS_MAX_LABEL_LENGTH)
            {
                LogTimeF("DnsReply PTR type length %d is greater than max DNS label length of %d\r\n", recordDataLength, DNS_MAX_LABEL_LENGTH);
                return;
            }
            DnsNameDecodePtr(DnsNameIndexFromPointer(pRecordData), sizeof(DnsReplyName), DnsReplyName);
            break;
    }

    if (DnsReplyTrace)
    {
        LogF("  answer: %s",  DnsReplyRecordName);
        char text[100];
        if (DnsReplyRecordNameAsIp4)
        {
            Ip4AddressToString(DnsReplyRecordNameAsIp4, sizeof(text), text);
            LogF(" (%s)",  text);
        }
        if (DnsReplyRecordNameAsIp6[0])
        {
            Ip6AddressToString(DnsReplyRecordNameAsIp6, sizeof(text), text);
            LogF(" (%s)",  text);
        }
        LogF(" == ");
        switch (recordType)
        {
            case DNS_RECORD_A:
                Ip4AddressToString(DnsReplyIp4, sizeof(text), text);
                LogF("%s\r\n",  text);
                break;
            case DNS_RECORD_AAAA:
                Ip6AddressToString(DnsReplyIp6, sizeof(text), text);
                LogF("%s\r\n",  text);
                break;
            case DNS_RECORD_PTR:
                LogF("%s\r\n",  DnsReplyName);
                break;
            default:
                DnsRecordTypeToString(recordType, sizeof(text), text);
                LogF("%d bytes of %s\r\n", recordDataLength, text);
                break;
        }
    }
}

static void sendToDnsCache(int dnsProtocol)
{
    char strippedName[100];
    if (DnsReplyRecordName[0]) DnsStripNameFromFullName(dnsProtocol, DnsReplyRecordName, sizeof(strippedName), strippedName);
    if (DnsReplyName[0]      ) DnsStripNameFromFullName(dnsProtocol, DnsReplyName      , sizeof(strippedName), strippedName);
    
    if (DnsReplyIp4                && DnsReplyRecordName[0]) NrAddIp4Record(DnsReplyIp4,             strippedName, dnsProtocol);
    if (DnsReplyIp6[0]             && DnsReplyRecordName[0]) NrAddIp6Record(DnsReplyIp6,             strippedName, dnsProtocol);
    if (DnsReplyRecordNameAsIp4    && DnsReplyName[0]      ) NrAddIp4Record(DnsReplyRecordNameAsIp4, strippedName, dnsProtocol);
    if (DnsReplyRecordNameAsIp6[0] && DnsReplyName[0]      ) NrAddIp6Record(DnsReplyRecordNameAsIp6, strippedName, dnsProtocol);
}
int DnsReplyHandle(void (*traceback)(void), int dnsProtocol, int *pSize)
{        
    if (!DnsHdrAncount) return DO_NOTHING;
    if (DnsReplyTrace)
    {
        LogTimeF("DNS received reply\r\n");
        if (NetTraceBack) traceback();
        DnsHdrLog(dnsProtocol);
    }

    p = DnsHdrData;

    for (int q = 0; q < DnsHdrQdcount; q++)
    {
        if (scanQuery()) return DO_NOTHING;
    }
    for (int a = 0; a < DnsHdrAncount; a++)
    {
        if (scanAnswer()) return DO_NOTHING;
        readAnswer();
        sendToDnsCache(dnsProtocol);
    }
    
    return DO_NOTHING;
}