A stack which works with or without an Mbed os library. Provides IPv4 or IPv6 with a full 1500 byte buffer.

Dependents:   oldheating gps motorhome heating

udp/dns/dnsreply.cpp

Committer:
andrewboyson
Date:
2017-10-19
Revision:
43:bc028d5a6424
Parent:
37:793b39683406
Child:
48:952dddb74b8b

File content as of revision 43:bc028d5a6424:

#include   "mbed.h"
#include    "log.h"
#include    "net.h"
#include "action.h"
#include    "ip4.h"
#include    "ip6.h"
#include     "ar.h"
#include     "nr.h"
#include    "dns.h"
#include "dnshdr.h"
#include "dnsname.h"
#include   "dhcp.h"

bool DnsReplyTrace = false;

char     DnsReplyRecordName[256];
uint32_t DnsReplyRecordNameAsIp4 = 0;
char     DnsReplyRecordNameAsIp6[16];
char     DnsReplyName[64];
uint32_t DnsReplyIp4 = 0;
char     DnsReplyIp6[16];

static char *p;
static int recordNameOffset;
static int recordType;
static int recordDataLength;
static char* pRecordData;

static int scanQuery()
{
    int recordNameLength = DnsNameLength(p);
    if (!recordNameLength) return -1; //failure
    
    recordNameOffset = DnsNameIndexFromPointer(p);
    p += recordNameLength;
    
    p++ ; //skip the first byte of the type
    recordType = *p++;
    
    p += 2; //skip the class
    
    return 0; //success
}

static int scanAnswer()
{
    int recordNameLength = DnsNameLength(p);
    if (!recordNameLength) return -1; //failure
    
    recordNameOffset = DnsNameIndexFromPointer(p);
    p += recordNameLength;
    
    p++ ; //skip the first byte of the type
    recordType = *p++;
    
    p += 6; //skip the class, TTL
    recordDataLength = 0;
    recordDataLength |= *p++ <<  8;
    recordDataLength |= *p++;
    
    pRecordData = p; //record the start of the data
    
    p += recordDataLength; //Move to the start of the next record
    
    return 0; //success
}
static void readAnswer()
{
    DnsReplyRecordName[0]      = 0;
    DnsReplyRecordNameAsIp4    = 0;
    DnsReplyRecordNameAsIp6[0] = 0;
    DnsReplyName[0]            = 0;
    DnsReplyIp4                = 0;
    DnsReplyIp6[0]             = 0;
    
    switch (recordType)
    {
        case DNS_RECORD_A:
        case DNS_RECORD_AAAA:
        case DNS_RECORD_PTR:
        case DNS_RECORD_SRV:
        case DNS_RECORD_TXT:
            break;
        default:
            return;
    }

    DnsNameDecodePtr(recordNameOffset, sizeof(DnsReplyRecordName), DnsReplyRecordName);
    DnsNameDecodeIp4(recordNameOffset, &DnsReplyRecordNameAsIp4);
    DnsNameDecodeIp6(recordNameOffset,  DnsReplyRecordNameAsIp6);

    switch (recordType)
    {
        case DNS_RECORD_A:
            if (recordDataLength != 4) return;
            memcpy(&DnsReplyIp4, pRecordData, 4);
            break;
        case DNS_RECORD_AAAA:
            if (recordDataLength != 16) return;
            memcpy(DnsReplyIp6, pRecordData, 16);
            break;
        case DNS_RECORD_PTR:
            if (recordDataLength > DNS_MAX_LABEL_LENGTH) return;
            DnsNameDecodePtr(DnsNameIndexFromPointer(pRecordData), sizeof(DnsReplyName), DnsReplyName);
            break;
    }
}

static void sendToDnsCache(int dnsProtocol)
{
    char strippedName[100];
    if (DnsReplyRecordName[0]) DnsStripNameFromFullName(dnsProtocol, DnsReplyRecordName, sizeof(strippedName), strippedName);
    if (DnsReplyName[0]      ) DnsStripNameFromFullName(dnsProtocol, DnsReplyName      , sizeof(strippedName), strippedName);
    
    if (DnsReplyIp4                && DnsReplyRecordName[0]) NrAddIp4Record(DnsReplyIp4,             strippedName, dnsProtocol);
    if (DnsReplyIp6[0]             && DnsReplyRecordName[0]) NrAddIp6Record(DnsReplyIp6,             strippedName, dnsProtocol);
    if (DnsReplyRecordNameAsIp4    && DnsReplyName[0]      ) NrAddIp4Record(DnsReplyRecordNameAsIp4, strippedName, dnsProtocol);
    if (DnsReplyRecordNameAsIp6[0] && DnsReplyName[0]      ) NrAddIp6Record(DnsReplyRecordNameAsIp6, strippedName, dnsProtocol);
}
int DnsReplyHandle(void (*traceback)(void), int dnsProtocol, int *pSize)
{        
    if (!DnsHdrAncount) return DO_NOTHING;
    if (DnsReplyTrace)
    {
        if (NetTraceNewLine) Log("\r\n");
        LogTimeF("DnsReply received\r\n");
        if (NetTraceStack) traceback();
        DnsHdrLog(dnsProtocol);
    }

    p = DnsHdrData;

    for (int q = 0; q < DnsHdrQdcount; q++)
    {
        if (scanQuery()) return DO_NOTHING;
    }
    for (int a = 0; a < DnsHdrAncount; a++)
    {
        if (scanAnswer()) return DO_NOTHING;
        readAnswer();
        sendToDnsCache(dnsProtocol);
    }
    
    return DO_NOTHING;
}