Andrew Boyson / net

Dependents:   oldheating gps motorhome heating

tcp/tls/tls.c

Committer:
andrewboyson
Date:
2019-06-27
Revision:
150:3366e4a0c60e
Parent:
149:39d1ba392f4b
Child:
151:bde6f7da1755

File content as of revision 150:3366e4a0c60e:

#include <stdbool.h>

#include "http.h"
#include "tcpbuf.h"
#include "action.h"
#include "net.h"
#include "log.h"
#include "led.h"
#include "restart.h"
#include "mstimer.h"
#include "random.h"
#include "pri-key.h"

#define TLS_CONTENT_TYPE_ChangeCipher      20
#define TLS_CONTENT_TYPE_Alert             21
#define TLS_CONTENT_TYPE_Handshake         22
#define TLS_CONTENT_TYPE_Application       23
#define TLS_CONTENT_TYPE_Heartbeat         24

#define TLS_HANDSHAKE_HelloRequest          0
#define TLS_HANDSHAKE_ClientHello           1
#define TLS_HANDSHAKE_ServerHello           2
#define TLS_HANDSHAKE_NewSessionTicket      4
#define TLS_HANDSHAKE_EncryptedExtensions   8
#define TLS_HANDSHAKE_Certificate          11
#define TLS_HANDSHAKE_ServerKeyExchange    12
#define TLS_HANDSHAKE_CertificateRequest   13
#define TLS_HANDSHAKE_ServerHelloDone      14
#define TLS_HANDSHAKE_CertificateVerify    15
#define TLS_HANDSHAKE_ClientKeyExchange    16
#define TLS_HANDSHAKE_Finished             20

#define DO_WAIT_CLIENT_HELLO  0
#define DO_SEND_SERVER_HELLO  1
#define DO_WAIT_CLIENT_CHANGE 2
#define DO_SEND_SERVER_CHANGE 3
#define DO_APPLICATION        4

bool TlsTrace = true;

static const char certificate[] = {
#include "certificate.inc"
};

void TlsInit()
{
    PriKeyInit();
}


struct state
{
    int      toDo;
};
static void logContentType(char contentType)
{
    switch (contentType)
    {
        case TLS_CONTENT_TYPE_ChangeCipher: Log ("Change cipher");      break;
        case TLS_CONTENT_TYPE_Alert:        Log ("Alert");              break;
        case TLS_CONTENT_TYPE_Handshake:    Log ("Handshake");          break;
        case TLS_CONTENT_TYPE_Application:  Log ("Application");        break;
        case TLS_CONTENT_TYPE_Heartbeat:    Log ("Heartbeat");          break;
        default:                            LogF("%02hX", contentType); break;
    }
}
static void logHandshakeType(char handshakeType)
{
    switch (handshakeType)
    {
        case TLS_HANDSHAKE_HelloRequest:        Log ("Hello request");        break;
        case TLS_HANDSHAKE_ClientHello:         Log ("Client hello");         break;
        case TLS_HANDSHAKE_ServerHello:         Log ("Server hello");         break;
        case TLS_HANDSHAKE_NewSessionTicket:    Log ("New session ticket");   break;
        case TLS_HANDSHAKE_EncryptedExtensions: Log ("Encrypted extensions"); break;
        case TLS_HANDSHAKE_Certificate:         Log ("Certificate");          break;
        case TLS_HANDSHAKE_ServerKeyExchange:   Log ("Server key exchange");  break;
        case TLS_HANDSHAKE_CertificateRequest:  Log ("Certificate request");  break;
        case TLS_HANDSHAKE_ServerHelloDone:     Log ("Server hello done");    break;
        case TLS_HANDSHAKE_CertificateVerify:   Log ("Certificate verify");   break;
        case TLS_HANDSHAKE_ClientKeyExchange:   Log ("Client key exchange");  break;
        case TLS_HANDSHAKE_Finished:            Log ("Finished");             break;
        default:                                LogF("%02hX", handshakeType); break;
    }
}
void TlsRequest(char* pTlsState, char* pWebState, int size, char* pRequestStream, uint32_t positionInRequestStream)
{
    struct state* pState = (struct state*)pTlsState;
    
    if (TlsTrace) LogF("TLS <<< %d (%u)\r\n", size, positionInRequestStream);

    if (size == 0) return;
    if (positionInRequestStream != 0) return;
    char contentType = pRequestStream[0];
    char versionH    = pRequestStream[1];
    char versionL    = pRequestStream[2];
    int length       = pRequestStream[3] << 8 | pRequestStream[4]; //Length (2 bytes)
    if (TlsTrace)
    {
        Log ("   content type: "); logContentType(contentType); Log("\r\n");
        LogF("   legacy HH:LL: %02x:%02x\r\n", versionH, versionL);
        LogF("   length      : %d\r\n"       , length);
    }
    switch (contentType)
    {
        case TLS_CONTENT_TYPE_Handshake:
        {
            char handshakeType = pRequestStream[5];
            if (TlsTrace) { Log("      handshake type: "); logHandshakeType(handshakeType); Log("\r\n"); }
            pState->toDo = DO_SEND_SERVER_HELLO;
            return;
        }
        case TLS_CONTENT_TYPE_Application:
        {
            pState->toDo = DO_APPLICATION;
            return;
        }
        
        default:
            Log("TLS - ignoring untreated content type\r\n");
            pState->toDo = DO_WAIT_CLIENT_HELLO;
            return;
    }
}
char lengthH(int size) { return size >> 8;}
char lengthL(int size) { return size & 0xFF; }
void addSize(int size)
{
    TcpBufAddChar(size >> 8  );
    TcpBufAddChar(size & 0xFF);
}

static void sendServerHello()
{
    Log("     sending server hello\r\n");
    TcpBufAddChar(TLS_CONTENT_TYPE_Handshake);                                   //Content is handshakes
    TcpBufAddChar(0x03); TcpBufAddChar(0x03);                                    //Legacy TLS version
    addSize((45 + 4) + (sizeof(certificate) + 6 + 4) + (0 + 4));                 //Handshakes Length (2 bytes)
    
    TcpBufAddChar(TLS_HANDSHAKE_ServerHello); TcpBufAddChar(0x00);               //Handshake type server hello
    addSize(45);                                                                 //Size of this handshake
    TcpBufAddChar(0x03); TcpBufAddChar(0x03);                                    //TLS version 1.2
    for (int i = 0; i < 32; i++) TcpBufAddChar(RandomGetByte());                 //32 bit random number
    TcpBufAddChar(0x00);                                                         //SessionId length 0
    TcpBufAddChar(0x00); TcpBufAddChar(0x2f);                                    //Cipher Suite: TLS_RSA_WITH_AES_128_CBC_SHA (0x002f)
    TcpBufAddChar(0x00);                                                         //Compression method none
    TcpBufAddChar(0x00); TcpBufAddChar(0x05);                                    //Extensions length (2 bytes) 5 bytes
    TcpBufAddChar(0xff); TcpBufAddChar(0x01);                                    //Extension Renegotiation Info
    TcpBufAddChar(0x00); TcpBufAddChar(0x01);                                    //1 bytes of "Renegotiation Info" extension data follows
    TcpBufAddChar(0x00);                                                         //length is zero, because this is a new connection 
    
    TcpBufAddChar(TLS_HANDSHAKE_Certificate); TcpBufAddChar(0x00);               //Handshake type certificate
    addSize(sizeof(certificate) + 6); TcpBufAddChar(0x00);                       //Size of this handshake
    addSize(sizeof(certificate) + 3); TcpBufAddChar(0x00);                       //Size of all certificates
    addSize(sizeof(certificate)    );                                            //Size of first certificate
    for (int i = 0; i < sizeof(certificate); i++) TcpBufAddChar(certificate[i]); //Certificate
    
    TcpBufAddChar(TLS_HANDSHAKE_ServerHelloDone); TcpBufAddChar(0x00);           //Handshake type server hello done
    addSize(0);                                                                  //Size of this handshake
}

int TlsPoll(char* pTlsState, char* pWebState, bool clientFinished)
{
    struct state* pState = (struct state*)pTlsState;
    
    switch (pState->toDo)
    {
        case DO_WAIT_CLIENT_HELLO:
            if (clientFinished) return -1; //The client hasn't made a request and never will so finish
            else                return  0; //The client hasn't made a request yet but it could.
        case DO_APPLICATION:    return HttpPollFunction(pWebState, clientFinished); //Return whatever HTTP would be
        default:                return  1; //The client has made a request so do it
    }
}
bool TlsReply(char* pTlsState, char* pWebState)
{
    struct state* pState = (struct state*)pTlsState;
    
    switch(pState->toDo)
    {
        case DO_SEND_SERVER_HELLO: sendServerHello(); return true;
        default:                                      return true; //Finished
    }
}
static char encrypt(char c)
{
    return c; //Implement encryption
}
void TlsAddChar(char c)
{
    char e = encrypt(c);
    TcpBufAddChar(e);
}