A simple library to support serving https.

Dependents:   oldheating gps motorhome heating

tls/tls-request.c

Committer:
andrewboyson
Date:
2019-09-05
Revision:
7:94ef5824c3c0
Parent:
6:819c17738dc2
Child:
8:5e66a6b4b38c

File content as of revision 7:94ef5824c3c0:

#include "tls.h"
#include "tls-defs.h"
#include "tls-connection.h"
#include "tls-session.h"
#include "tls-log.h"
#include "mstimer.h"
#include "log.h"
#include "pri-key.h"
#include "aes128.h"

static int handleClientHello(int length, uint8_t* pBuffer, struct TlsConnection* pConnection) //returns 0 on success; -1 on error
{   
    //Check things look ok
    uint8_t* p = pBuffer;
    if (length < 100)
    {
        LogF("TLS - %d byte client hello message is not at least 100 bytes long\r\n", length);
        return -1;
    }
    
    //Read in the parameters
    uint8_t versionH         = *p++;
    uint8_t versionL         = *p++;
    
    for (int i = 0; i < 32; i++) pConnection->clientRandom[i] = *p++;
    
    int sessionIdLength = *p++;
    uint8_t* pSessionId = p;
    
    //Handle the parameters
    pConnection->session = -1;
    if (sessionIdLength == 1) pConnection->session = *pSessionId;
    struct TlsSession* pSession = TlsSessionOrNull(pConnection->session);
    if (!pSession || pSession->state != TLS_SESSION_STATE_VALID)
    {
        pSession = TlsSessionGetOldest();
        pSession->state = TLS_SESSION_STATE_STARTED;
    }
    pConnection->session = TlsSessionGetIndex(pSession);

    pSession->lastUsed = MsTimerCount;
    
    //Log the parameters
    if (TlsTrace)
    {
        LogF("- client version HH:LL: %02x:%02x\r\n", versionH, versionL);
        Log ("- client random:\r\n");     LogBytesAsHex(pConnection->clientRandom, 32); Log("\r\n");
        Log ("- client session id:\r\n"); LogBytesAsHex(pSessionId, sessionIdLength); Log("\r\n");
        LogF("- session index: %d\r\n",  pConnection->session);
    }
    return 0;
}
static int handleClientKeyExchange(int length, uint8_t* pBuffer, struct TlsConnection* pConnection) //returns 0 on success; -1 on error
{
    struct TlsSession* pSession = TlsSessionOrNull(pConnection->session);
    if (!pSession)
    {
        LogTimeF("handleClientKeyExchange - invalid session %d\r\n", pConnection->session);
        return -1;
    }
    
    if (length != 130)
    {
        LogF("TLS - %d byte client key exchange message is not 130 bytes long\r\n", length);
        return -1;
    }
    int premasterLength = pBuffer[0] << 8 | pBuffer[1]; //Overall length 2 bytes
    if (premasterLength != 128)
    {
        LogF("TLS - %d byte encrypted pre master secret is not 128 bytes long\r\n", length);
        return -1;
    }
    uint8_t* pEncryptedPreMasterSecret = pBuffer + 2;
    pSession->slotPriKeyDecryption = PriKeyDecryptStart(pEncryptedPreMasterSecret);
    
    if (TlsTrace)
    {
        LogF("- encrypted pre master\r\n", premasterLength);
        LogBytesAsHex(pEncryptedPreMasterSecret, 128);
        Log("\r\n");
    }
    
    return 0;
}
static void changeCipher(int length, uint8_t* pBuffer, struct TlsConnection* pConnection)
{
    uint8_t message = pBuffer[0];
    if (TlsTrace)
    {
        LogF("- message: %d\r\n", message);
    }
    pConnection->clientEncrypted = true;
}
static void handleAlert(int length, uint8_t* pBuffer)
{
    uint8_t level       = pBuffer[0];
    uint8_t description = pBuffer[1];
    if (TlsTrace)
    {
        Log("- alert level:       "); TlsLogAlertLevel      (level);       Log("\r\n");
        Log("- alert description: "); TlsLogAlertDescription(description); Log("\r\n");
    }
}
static void handleApplication(int length, uint8_t* pBuffer)
{
    if (TlsTrace)
    {
        Log("- application data:\r\n");
        LogBytesAsHex(pBuffer, length);
        Log("\r\n");
    }    
}

static void handleHandshake(int length, uint8_t* pBuffer, struct TlsConnection* pConnection)
{
    Sha256Add(&pConnection->handshakeHash, pBuffer, length);
    
    uint8_t* p = pBuffer;
    while (p < pBuffer + length)
    {
        uint8_t handshakeType    = *p++;
        int     handshakeLength  = *p++ << 16;
                handshakeLength |= *p++ <<  8;
                handshakeLength |= *p++      ; //Handshake length 3 bytes
             
        if (TlsTrace)
        {
            Log ("- handshake type: "); TlsLogHandshakeType(handshakeType); Log("\r\n");
            LogF("- handshake length: %d\r\n", handshakeLength);
        }
        
        int r = -1;
        switch (handshakeType)
        {
            case TLS_HANDSHAKE_ClientHello:
                r = handleClientHello(handshakeLength, p, pConnection);
                pConnection->toDo = r ? DO_SEND_ALERT_ILLEGAL_PARAMETER : DO_SEND_SERVER_HELLO;
                break;
                
            case TLS_HANDSHAKE_ClientKeyExchange:
                r = handleClientKeyExchange(handshakeLength, p, pConnection);
                pConnection->toDo = r ? DO_SEND_ALERT_ILLEGAL_PARAMETER : DO_WAIT_DECRYPT_MASTER_SECRET;
                break;
                
            default:
                LogF("TLS - ignoring handshake type ");
                TlsLogHandshakeType(handshakeType);
                LogF(" and skipping %d bytes\r\n", handshakeLength);
                break;
        }
        p += handshakeLength;
    }
}
static int handleContent(struct TlsConnection* pConnection, uint8_t* pBuffer)
{
    uint8_t contentType = *pBuffer++;
    uint8_t versionH    = *pBuffer++;
    uint8_t versionL    = *pBuffer++;
    int     length      = *pBuffer++ << 8;
            length     |= *pBuffer++;
    int     overallLen  = length + 5;
            
    if (TlsTrace)
    {
        Log ("- content type: "); TlsLogContentType(contentType); Log("\r\n");
        LogF("- legacy HH:LL: %02x:%02x\r\n", versionH, versionL);
        LogF("- length      : %d\r\n"       , length);
    }
    switch (contentType)
    {
        case TLS_CONTENT_TYPE_Handshake:
            if (pConnection->clientEncrypted)
            {
                Log("- encrypted bytes\r\n");
                LogBytesAsHex(pBuffer, length);
                Log("\r\n");
                if (length != 64)
                {
                    LogF("- verify length is %d not 64\r\n", length);
                }
                else
                {
                    for (int i = 0; i < 64; i++) pConnection->clientVerify[i] = *pBuffer++;
                }
            }
            else
            {
                handleHandshake(length, pBuffer, pConnection);
            }
            break;

        case TLS_CONTENT_TYPE_CHANGE_CIPHER:
            changeCipher(length, pBuffer, pConnection);
            break;

        case TLS_CONTENT_TYPE_ALERT:
            handleAlert(length, pBuffer);
            break;

        case TLS_CONTENT_TYPE_Application:
            handleApplication(length, pBuffer);
            pConnection->toDo = DO_APPLICATION;
            break;
        
        default:
            Log("TLS - ignoring content type ");
            TlsLogContentType(contentType);
            LogF(" and skipping %d bytes\r\n", overallLen);
            pConnection->toDo = DO_WAIT_CLIENT_HELLO;
            break;
    }
    return overallLen;
}
void TlsRequest(int connectionId, int size, uint8_t* pRequestStream, uint32_t positionInRequestStream)
{   
    //Log what we are doing
    if (TlsTrace) LogF("TLS %d <<< %d (%u)\r\n", connectionId, size, positionInRequestStream);
    
    //Get new or existing connection information
    struct TlsConnection* pConnection;
    if (!positionInRequestStream)
    {
        //If this is the start of the request then open a new connection
        pConnection = TlsConnectionNew(connectionId);
        Sha256Start(&pConnection->handshakeHash);
    }
    else
    {
        //If this is in the middle of a request then open an existing connection
        pConnection = TlsConnectionOrNull(connectionId);
        if (!pConnection)
        {
            LogTimeF("TlsRequest - no connection corresponds to id %d\r\n", connectionId);
            return;
        }
    }
        
    //Handle each item of coalesced content
    uint8_t* pNext = pRequestStream;
    while (pNext < pRequestStream + size) pNext += handleContent(pConnection, pNext);
}
void TlsReset(int connectionId)
{
    TlsConnectionReset(connectionId);
}