A stack which works with or without an Mbed os library. Provides IPv4 or IPv6 with a full 1500 byte buffer.

Dependents:   oldheating gps motorhome heating

udp/dns/dns.c

Committer:
andrewboyson
Date:
2018-01-11
Revision:
61:aad055f1b0d1
Parent:
udp/dns/dns.cpp@ 59:e0e556c8bd46
Child:
83:08c983006a6e

File content as of revision 61:aad055f1b0d1:

#include <stdint.h>
#include <stdbool.h>
#include <string.h>
#include <stdio.h>

#include "dns.h"
#include "dnshdr.h"
#include "dnsquery.h"
#include "dnsreply.h"
#include "dnsserver.h"
#include "log.h"
#include "dhcp.h"

void DnsProtocolString(uint8_t protocol, int size, char* text)
{
    switch (protocol)
    {
        case DNS_PROTOCOL_UDNS:  strncpy (text, "DNS",   size);        break;
        case DNS_PROTOCOL_MDNS:  strncpy (text, "MDNS",  size);        break;
        case DNS_PROTOCOL_LLMNR: strncpy (text, "LLMNR", size);        break;
        default:                 snprintf(text, size, "%d", protocol); break;
    }
}

void DnsRecordTypeString(uint8_t recordtype, int size, char* text)
{
    switch (recordtype)
    {
        case DNS_RECORD_A:    strncpy (text, "A",    size);      break;
        case DNS_RECORD_AAAA: strncpy (text, "AAAA", size);      break;
        case DNS_RECORD_PTR:  strncpy (text, "PTR",  size);      break;
        case DNS_RECORD_TXT:  strncpy (text, "TXT",  size);      break;
        case DNS_RECORD_SRV:  strncpy (text, "SRV",  size);      break;
        default:              snprintf(text, size, "%d", recordtype); break;
    }
}
void DnsProtocolLog(uint8_t protocol)
{
    switch (protocol)
    {
        case DNS_PROTOCOL_UDNS:  Log ("DNS  ");        break;
        case DNS_PROTOCOL_MDNS:  Log ("MDNS ");        break;
        case DNS_PROTOCOL_LLMNR: Log ("LLMNR");        break;
        default:                 LogF("%d", protocol); break;
    }
}

void DnsRecordTypeLog(uint8_t recordtype)
{
    switch (recordtype)
    {
        case DNS_RECORD_A:    Log ("A"   );           break;
        case DNS_RECORD_AAAA: Log ("AAAA");           break;
        case DNS_RECORD_PTR:  Log ("PTR" );           break;
        case DNS_RECORD_TXT:  Log ("TXT" );           break;
        case DNS_RECORD_SRV:  Log ("SRV" );           break;
        default:              LogF("%d", recordtype); break;
    }
}
int DnsGetNextProtocol4(int protocol)
{
    switch(protocol)
    {
        case DNS_PROTOCOL_NONE:  return DNS_PROTOCOL_UDNS;
        case DNS_PROTOCOL_UDNS:  return DNS_PROTOCOL_MDNS;
        case DNS_PROTOCOL_MDNS:  return DNS_PROTOCOL_LLMNR;
        case DNS_PROTOCOL_LLMNR: return DNS_PROTOCOL_NONE;
        default: LogTimeF("DNS invalid protocol %d\r\n", protocol); return DNS_PROTOCOL_NONE;
    }
}
int DnsGetNextProtocol6(int protocol)
{
    switch(protocol)
    {
        case DNS_PROTOCOL_NONE:  return DNS_PROTOCOL_MDNS;
        case DNS_PROTOCOL_MDNS:  return DNS_PROTOCOL_LLMNR;
        case DNS_PROTOCOL_LLMNR: return DNS_PROTOCOL_UDNS;
        case DNS_PROTOCOL_UDNS:  return DNS_PROTOCOL_NONE;
        default: LogTimeF("DNS invalid protocol %d\r\n", protocol); return DNS_PROTOCOL_NONE;
    }
}
bool DnsHostNamesEquate(char* pA, char* pB)
{
    while(true)
    {
        char a = *pA++;
        char b = *pB++;
        if (a >= 'A' && a <= 'Z') a |= 0x20; //Make lower case
        if (b >= 'A' && b <= 'Z') b |= 0x20; //Make lower case
        if (a != b) return false;            //If different then stop and return the fact
        if (!a) break;                       //No need to check 'b' too as it will necessarily be equal to 'a' at this point.
    }
    return true;                             //If we get here the strings must equate.
}

int DnsMakeFullNameFromName(int protocol, const char* p, int size, char* result)
{
    int i = 0;
    char c;
    
    while (i < size - 1)
    {
        c = *p++;
        if (!c) break;
        *result++ = c;
        i++;
    }
    if (protocol == DNS_PROTOCOL_MDNS)
    {
        p = ".local";
        while (i < size - 1)
        {
            c = *p++;
            if (!c) break;
            *result++ = c;
            i++;
        }
    }
    if (protocol == DNS_PROTOCOL_UDNS && DhcpDomainName[0]) //Shouldn't do this in IPv6 as DHCP is IPv4 only
    {
        if (i < size - 1)
        {
            *result++ = '.';
            i++;
        }
        p = DhcpDomainName;
        while (i < size - 1)
        {
            c = *p++;
            if (!c) break;
            *result++ = c;
            i++;
        }
    }
    *result = 0; //Terminate the resulting string
    return i;
}
int DnsStripNameFromFullName(int protocol, char* p, int size, char* result)
{
    int i = 0;
    char c;
    
    while (i < size - 1)
    {
        c = *p++;
        if (c == 0)   break;                               //End of the fqdn so stop
        if (c == '.')
        {
            if (protocol == DNS_PROTOCOL_UDNS)
            {
                if (strcmp(p, DhcpDomainName) == 0) break; //Strip the domain from a UDNS fqdn if, and only if, it matches the domain given in DHCP. IPv4 only.
            }
            else
            {
                break;                                     //Strip the domain from an LLMNR (there shouldn't be one) or MDNS (it should always be '.local') fqdn
            }
        }
        *result++ = c;
        i++;
    }
    *result = 0;           //Terminate the copied string
    return i;
}

void DnsMain()
{
    DnsQueryMain();
}

int DnsHandlePacketReceived(void (*traceback)(void), int dnsProtocol, int sizeRx, void* pPacketRx, int* pSizeTx, void* pPacketTx)
{
    DnsHdrSetup(pPacketRx, sizeRx);
    DnsHdrRead();
    
    int action;
    if (DnsHdrIsReply)
    {
        action = DnsReplyHandle(traceback, dnsProtocol);
    }
    else
    {
        action = DnsServerHandleQuery(traceback, dnsProtocol, pPacketTx, pSizeTx);
    }
    return action;
}

int DnsPollForPacketToSend(void* pPacketTx, int* pSize)
{
    return DnsQueryPoll(pPacketTx, pSize);
}