A simple library to support serving https.

Dependents:   oldheating gps motorhome heating

tls/tls.c

Committer:
andrewboyson
Date:
2019-07-26
Revision:
1:9c66a551a67e
Child:
2:82268409e83f

File content as of revision 1:9c66a551a67e:

#include <stdbool.h>

#include "http.h"
#include "tcpbuf.h"
#include "action.h"
#include "net.h"
#include "log.h"
#include "led.h"
#include "restart.h"
#include "mstimer.h"
#include "random.h"
#include "pri-key.h"
#include "ser-cer.h"
#include "tls-prf.h"

#define TLS_CONTENT_TYPE_ChangeCipher      20
#define TLS_CONTENT_TYPE_Alert             21
#define TLS_CONTENT_TYPE_Handshake         22
#define TLS_CONTENT_TYPE_Application       23
#define TLS_CONTENT_TYPE_Heartbeat         24

#define TLS_HANDSHAKE_HelloRequest          0
#define TLS_HANDSHAKE_ClientHello           1
#define TLS_HANDSHAKE_ServerHello           2
#define TLS_HANDSHAKE_NewSessionTicket      4
#define TLS_HANDSHAKE_EncryptedExtensions   8
#define TLS_HANDSHAKE_Certificate          11
#define TLS_HANDSHAKE_ServerKeyExchange    12
#define TLS_HANDSHAKE_CertificateRequest   13
#define TLS_HANDSHAKE_ServerHelloDone      14
#define TLS_HANDSHAKE_CertificateVerify    15
#define TLS_HANDSHAKE_ClientKeyExchange    16
#define TLS_HANDSHAKE_Finished             20

#define DO_WAIT_CLIENT_HELLO                0
#define DO_SEND_SERVER_HELLO                1
#define DO_WAIT_CLIENT_CHANGE               2
#define DO_WAIT_DECRYPT_MASTER_SECRET       3
#define DO_SEND_SERVER_CHANGE               4
#define DO_APPLICATION                      5
#define DO_SEND_ALERT_ILLEGAL_PARAMETER     6
#define DO_SEND_ALERT_INTERNAL_ERROR        7

bool TlsTrace = true;

char paddedMasterSecret[128];
char clientHelloRandom[32];

void TlsInit()
{
    SerCerInit();
    PriKeyInit();
    TlsPrfTest();
}

struct state
{
    int      toDo;
};
static void logContentType(char contentType)
{
    switch (contentType)
    {
        case TLS_CONTENT_TYPE_ChangeCipher: Log ("Change cipher"     ); break;
        case TLS_CONTENT_TYPE_Alert:        Log ("Alert"             ); break;
        case TLS_CONTENT_TYPE_Handshake:    Log ("Handshake"         ); break;
        case TLS_CONTENT_TYPE_Application:  Log ("Application"       ); break;
        case TLS_CONTENT_TYPE_Heartbeat:    Log ("Heartbeat"         ); break;
        default:                            LogF("%02hX", contentType); break;
    }
}
static void logHandshakeType(char handshakeType)
{
    switch (handshakeType)
    {
        case TLS_HANDSHAKE_HelloRequest:        Log ("Hello request"       ); break;
        case TLS_HANDSHAKE_ClientHello:         Log ("Client hello"        ); break;
        case TLS_HANDSHAKE_ServerHello:         Log ("Server hello"        ); break;
        case TLS_HANDSHAKE_NewSessionTicket:    Log ("New session ticket"  ); break;
        case TLS_HANDSHAKE_EncryptedExtensions: Log ("Encrypted extensions"); break;
        case TLS_HANDSHAKE_Certificate:         Log ("Certificate"         ); break;
        case TLS_HANDSHAKE_ServerKeyExchange:   Log ("Server key exchange" ); break;
        case TLS_HANDSHAKE_CertificateRequest:  Log ("Certificate request" ); break;
        case TLS_HANDSHAKE_ServerHelloDone:     Log ("Server hello done"   ); break;
        case TLS_HANDSHAKE_CertificateVerify:   Log ("Certificate verify"  ); break;
        case TLS_HANDSHAKE_ClientKeyExchange:   Log ("Client key exchange" ); break;
        case TLS_HANDSHAKE_Finished:            Log ("Finished"            ); break;
        default:                                LogF("%02hX", handshakeType); break;
    }
}
static void logAlertLevel(char level)
{
    switch (level)
    {
        case  1: Log ("Warning"  ); break;
        case  2: Log ("Fatal"    ); break;
        default: LogF("%d", level); break;
    }
}
static void logAlertDescription(char description)
{
    switch (description)
    {
        case   0: Log("Close notify"                   ); break;  
        case  10: Log("Unexpected message"             ); break;  
        case  20: Log("Bad record MAC"                 ); break;
        case  21: Log("Decryption failed"              ); break;
        case  22: Log("Record overflow"                ); break;
        case  30: Log("Decompression failure"          ); break; 
        case  40: Log("Handshake failure"              ); break;
        case  41: Log("No certificate"                 ); break;
        case  42: Log("Bad certificate"                ); break;
        case  43: Log("Unsupported certificate"        ); break;
        case  44: Log("Certificate revoked"            ); break;
        case  45: Log("Certificate expired"            ); break;
        case  46: Log("Certificate unknown"            ); break;
        case  47: Log("Illegal parameter"              ); break;
        case  48: Log("Unknown CA"                     ); break;
        case  49: Log("Access denied"                  ); break;
        case  50: Log("Decode error"                   ); break;
        case  51: Log("Decrypt error"                  ); break;
        case  60: Log("Export restriction"             ); break;
        case  70: Log("Protocol version"               ); break;
        case  71: Log("Insufficient security"          ); break;
        case  80: Log("Internal error"                 ); break;
        case  86: Log("Inappropriate Fallback"         ); break;
        case  90: Log("User cancelled"                 ); break;
        case 100: Log("No renegotiation"               ); break;
        case 110: Log("Unsupported extension"          ); break;
        case 111: Log("Certificate unobtainable"       ); break;
        case 112: Log("Unrecognized name"              ); break;
        case 113: Log("Bad certificate status response"); break;
        case 114: Log("Bad certificate hash value"     ); break;
        case 115: Log("Unknown PSK identity"           ); break;
        case 120: Log("No Application Protocol"        ); break;
        default:  LogF("%d", description               ); break;
    }
}
static int handleClientHello(int length, char* pBuffer) //returns 0 on success; -1 on error
{
    if (length != 32)
    {
        LogF("TLS - %d byte client hello message is not 32 bytes long\r\n", length);
        return -1;
    }
    for (int i = 0; i < 32; i++) clientHelloRandom[i] = pBuffer[i];
    if (TlsTrace)
    {
        Log("- random:\r\n");
        LogBytesAsHex(clientHelloRandom, 32);
        Log("\r\n");
    }
    return 0;
}
static int handleClientKeyExchange(int length, char* pBuffer) //returns 0 on success; -1 on error
{
    if (length != 130)
    {
        LogF("TLS - %d byte client key exchange message is not 130 bytes long\r\n", length);
        return -1;
    }
    int premasterLength = pBuffer[0] << 8 | pBuffer[1]; //Overall length 2 bytes
    if (premasterLength != 128)
    {
        LogF("TLS - %d byte encrypted pre master secret is not 128 bytes long\r\n", length);
        return -1;
    }
    char* pEncryptedPreMasterSecret = pBuffer + 2;
    PriKeyDecryptStart(pEncryptedPreMasterSecret, paddedMasterSecret);
    
    if (TlsTrace)
    {
        LogF("- encrypted premaster (%d bytes little endian)\r\n", premasterLength);
        LogBytesAsHex(pEncryptedPreMasterSecret, 128);
        Log("\r\n");
    }
    
    return 0;
}
static void handleHandshake(int length, char* pBuffer, int* pToDo)
{
    char* p = pBuffer;
    while (p < pBuffer + length)
    {
        char handshakeType   = *p++;
        int  handshakeLength = *p++ << 16 | *p++ << 8 | *p++; //Handshake length 3 bytes
        if (TlsTrace)
        {
            Log ("- handshake type: "); logHandshakeType(handshakeType); Log("\r\n");
            LogF("- handshake length: %d\r\n", handshakeLength);
        }
        int r = -1;
        switch (handshakeType)
        {
            case TLS_HANDSHAKE_ClientHello:
                r = handleClientHello(handshakeLength, p);
                *pToDo = r ? DO_SEND_ALERT_ILLEGAL_PARAMETER : DO_SEND_SERVER_HELLO;
                break;
                
            case TLS_HANDSHAKE_ClientKeyExchange:
                r = handleClientKeyExchange(handshakeLength, p);
                *pToDo = r ? DO_SEND_ALERT_ILLEGAL_PARAMETER : DO_WAIT_DECRYPT_MASTER_SECRET;
                break;
                
            default:
                LogF("TLS - ignoring untreated %d byte handshake type ", handshakeLength);
                logHandshakeType(handshakeType);
                Log("\r\n");
                break;
        }
        p += handshakeLength;
    }
}
static void handleAlert(int length, char* pBuffer)
{
    char level       = pBuffer[0];
    char description = pBuffer[1];
    if (TlsTrace)
    {
        Log("- alert level:       "); logAlertLevel      (level);       Log("\r\n");
        Log("- alert description: "); logAlertDescription(description); Log("\r\n");
    }
}
static void handleApplication(int length, char* pBuffer)
{
    if (TlsTrace)
    {
        Log("- application data:\r\n");
        LogBytesAsHex(pBuffer, length);
        Log("\r\n");
    }    
}
void TlsRequest(char* pTlsState, char* pWebState, int size, char* pRequestStream, uint32_t positionInRequestStream)
{
    struct state* pState = (struct state*)pTlsState;
    
    if (TlsTrace) LogF("TLS <<< %d (%u)\r\n", size, positionInRequestStream);

    if (size == 0) return;
    //if (positionInRequestStream != 0) return;
    char contentType = pRequestStream[0];
    char versionH    = pRequestStream[1];
    char versionL    = pRequestStream[2];
    int length       = pRequestStream[3] << 8 | pRequestStream[4]; //Length (2 bytes)
    if (TlsTrace)
    {
        Log ("- content type: "); logContentType(contentType); Log("\r\n");
        LogF("- legacy HH:LL: %02x:%02x\r\n", versionH, versionL);
        LogF("- length      : %d\r\n"       , length);
    }
    switch (contentType)
    {
        case TLS_CONTENT_TYPE_Handshake:
        {
            handleHandshake(length, pRequestStream + 5, &pState->toDo);
            return;
        }
        case TLS_CONTENT_TYPE_Alert:
        {
            handleAlert(length, pRequestStream + 5);
            return;
        }
        case TLS_CONTENT_TYPE_Application:
        {
            handleApplication(length, pRequestStream + 5);
            pState->toDo = DO_APPLICATION;
            return;
        }
        
        default:
            Log("TLS - untreated content type "); logContentType(contentType); Log("\r\n");
            pState->toDo = DO_WAIT_CLIENT_HELLO;
            return;
    }
}
char lengthH(int size) { return size >> 8;}
char lengthL(int size) { return size & 0xFF; }
void addSize(int size)
{
    TcpBufAddChar(size >> 8  );
    TcpBufAddChar(size & 0xFF);
}

static void sendServerHello()
{
    Log("     sending server hello\r\n");
    TcpBufAddChar(TLS_CONTENT_TYPE_Handshake);                                   //Content is handshakes
    TcpBufAddChar(0x03); TcpBufAddChar(0x03);                                    //Legacy TLS version
    addSize((45 + 4) + (SerCerSize + 6 + 4) + (0 + 4));                          //Handshakes Length (2 bytes)
    
    TcpBufAddChar(TLS_HANDSHAKE_ServerHello); TcpBufAddChar(0x00);               //Handshake type server hello
    addSize(45);                                                                 //Size of this handshake
    TcpBufAddChar(0x03); TcpBufAddChar(0x03);                                    //TLS version 1.2
    for (int i = 0; i < 32; i++) TcpBufAddChar(RandomGetByte());                 //32 bit random number
    TcpBufAddChar(0x00);                                                         //SessionId length 0
    TcpBufAddChar(0x00); TcpBufAddChar(0x2f);                                    //Cipher Suite: TLS_RSA_WITH_AES_128_CBC_SHA (0x002f)
    TcpBufAddChar(0x00);                                                         //Compression method none
    TcpBufAddChar(0x00); TcpBufAddChar(0x05);                                    //Extensions length (2 bytes) 5 bytes
    TcpBufAddChar(0xff); TcpBufAddChar(0x01);                                    //Extension Renegotiation Info
    TcpBufAddChar(0x00); TcpBufAddChar(0x01);                                    //1 bytes of "Renegotiation Info" extension data follows
    TcpBufAddChar(0x00);                                                         //length is zero, because this is a new connection 
    
    TcpBufAddChar(TLS_HANDSHAKE_Certificate); TcpBufAddChar(0x00);               //Handshake type certificate
    addSize(SerCerSize + 6); TcpBufAddChar(0x00);                                //Size of this handshake
    addSize(SerCerSize + 3); TcpBufAddChar(0x00);                                //Size of all certificates
    addSize(SerCerSize    );                                                     //Size of first certificate
    for (int i = 0; i < SerCerSize; i++) TcpBufAddChar(SerCerData[i]);           //Certificate
    
    TcpBufAddChar(TLS_HANDSHAKE_ServerHelloDone); TcpBufAddChar(0x00);           //Handshake type server hello done
    addSize(0);                                                                  //Size of this handshake
}
static void sendServerChange()
{
    Log("     sending server change\r\n");
}
static void sendFatal(char description)
{
    Log("     sending fatal alert: ");
    logAlertDescription(description);
    Log("\r\n");
    TcpBufAddChar(TLS_CONTENT_TYPE_Alert);                                       //Content is alert
    TcpBufAddChar(0x03); TcpBufAddChar(0x03);                                    //Legacy TLS version
    addSize(2);                                                                  //Alert Length (2 bytes)
    
    TcpBufAddChar(2);                                                            //Fatal (level = 2)
    TcpBufAddChar(description);                                                  //Description
    
}
int TlsPoll(char* pTlsState, char* pWebState, bool clientFinished)
{
    struct state* pState = (struct state*)pTlsState;
    
    switch (pState->toDo)
    {
        case DO_WAIT_CLIENT_HELLO:
        case DO_WAIT_CLIENT_CHANGE:
            if (clientFinished) return -1; //The client hasn't made a request and never will so finish
            else                return  0; //The client hasn't made a request yet but it could.
        
        case DO_WAIT_DECRYPT_MASTER_SECRET:
            if (PriKeyDecryptFinished())
            {
                Log("Master secret\r\n");
                LogBytesAsHex(paddedMasterSecret, sizeof(paddedMasterSecret));
                Log("\r\n");
                pState->toDo = DO_SEND_SERVER_CHANGE;
                return 1;                  //Call TlsReply to do the send
            }
            else
            {
                if (clientFinished) return -1; //The client hasn't made a request and never will so finish
                else                return  0; //The client hasn't made a request yet but it could.
            }
            
        case DO_APPLICATION:
            return HttpPollFunction(pWebState, clientFinished); //Return whatever HTTP would be
            
        case DO_SEND_SERVER_HELLO:
        case DO_SEND_ALERT_ILLEGAL_PARAMETER:
        case DO_SEND_ALERT_INTERNAL_ERROR:
            return 1;
            
        default:
            LogTimeF("TlsPoll - unspecified TLS state %d\r\n", pState->toDo);
            return -1; //Finish
    }
}
bool TlsReply(char* pTlsState, char* pWebState)
{
    struct state* pState = (struct state*)pTlsState;
    
    switch(pState->toDo)
    {
        case DO_SEND_SERVER_HELLO:
            sendServerHello();
            pState->toDo = DO_WAIT_CLIENT_CHANGE;
            return false; //Not finished
            
        case DO_SEND_SERVER_CHANGE:
            sendServerChange();
            pState->toDo = DO_APPLICATION;
            return false; //Not finished
            
        case DO_APPLICATION:
            return HttpReplyFunction(pWebState); //Return whatever HTTP would be
            
        case DO_SEND_ALERT_ILLEGAL_PARAMETER:
            sendFatal(47);
            pState->toDo = DO_WAIT_CLIENT_HELLO;
            return true; //Finished
            
        case DO_SEND_ALERT_INTERNAL_ERROR:
            sendFatal(80);
            pState->toDo = DO_WAIT_CLIENT_HELLO;
            return true; //Finished
        
        default:
            LogTimeF("TlsReply - unspecified TLS state %d\r\n", pState->toDo);
            return true; //Finished
    }
}
static char encrypt(char c)
{
    return c; //Implement encryption
}
void TlsAddChar(char c)
{
    char e = encrypt(c);
    TcpBufAddChar(e);
}