A super trimmed down TLS stack, GPL licensed

Dependents:   MiniTLS-HTTPS-Example

MiniTLS - A super trimmed down TLS/SSL Library for embedded devices Author: Donatien Garnier Copyright (C) 2013-2014 AppNearMe Ltd

This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version.

This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.

You should have received a copy of the GNU General Public License along with this program; if not, write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.

tls/tls_record.c

Committer:
MiniTLS
Date:
2014-06-10
Revision:
3:eb324ffffd2b
Parent:
2:527a66d0a1a9

File content as of revision 3:eb324ffffd2b:

/*
MiniTLS - A super trimmed down TLS/SSL Library for embedded devices
Author: Donatien Garnier
Copyright (C) 2013-2014 AppNearMe Ltd

This program is free software; you can redistribute it and/or
modify it under the terms of the GNU General Public License
as published by the Free Software Foundation; either version 2
of the License, or (at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
*//**
 * \file tls_record.c
 * \copyright Copyright (c) AppNearMe Ltd 2013
 * \author Donatien Garnier
 */

#define __DEBUG__ 4
#ifndef __MODULE__
#define __MODULE__ "tls_record.c"
#endif

#include "core/fwk.h"
#include "inc/minitls_config.h"
#include "inc/minitls_errors.h"
#include "tls_record.h"
#include "tls_alert.h"

#include "tls_handshake.h"
#include "tls_socket.h"

#include "socket/socket.h"

#include "crypto/crypto_aes_128_cbc.h"
#include "crypto/crypto_hmac_sha1.h"

static minitls_err_t record_wait_readable(tls_record_t* record);
static minitls_err_t record_wait_writeable(tls_record_t* record);

static minitls_err_t record_socket_read(tls_record_t* record, size_t size);
static minitls_err_t record_socket_write(tls_record_t* record, buffer_t* data);

static minitls_err_t tls_mac_append( const uint8_t* key, tls_content_type_t content_type, tls_protocol_version_t version,
    uint64_t sequence_number, buffer_t* buffer );
static minitls_err_t tls_mac_check( const uint8_t* key, tls_content_type_t content_type, tls_protocol_version_t version,
    uint64_t sequence_number, buffer_t* buffer );

typedef struct __tls_fragment_header
{
  tls_content_type_t type;
  tls_protocol_version_t version;
  uint16_t length; //(MAX 2^14 + 2048 = 18432)
} tls_fragment_header_t;

#define FRAGMENT_HEADER_SIZE 5

#define DEFAULT_READ_TIMEOUT 20000
#define DEFAULT_WRITE_TIMEOUT 20000

minitls_err_t tls_record_init(tls_record_t* record, tls_socket_t* sock, uint8_t* buf, size_t buf_size)
{

  record->handshake_done = false;

  //Open BSD socket
  record->socket_fd = socket_socket();
  if(record->socket_fd < 0)
  {
    ERR("Could not create socket descriptor");
    return MINITLS_ERR_SOCKET_ERROR;
  }

  record->read_timeout = DEFAULT_READ_TIMEOUT;
  record->write_timeout = DEFAULT_WRITE_TIMEOUT;

  if(buf_size >= TLS_DEFAULT_MAX_FRAGMENT_SIZE)
  {
    record->max_fragment_size = TLS_DEFAULT_MAX_FRAGMENT_SIZE;
  }
  else if( buf_size >= 4096 + TLS_ENCRYPTION_MAX_OVERHEAD )
  {
    record->max_fragment_size = 4096;
  }
  else if( buf_size >= 2048 + TLS_ENCRYPTION_MAX_OVERHEAD )
  {
    record->max_fragment_size = 2048;
  }
  else if( buf_size >= 1024 + TLS_ENCRYPTION_MAX_OVERHEAD )
  {
    record->max_fragment_size = 1024;
  }
  else if( buf_size >= 512 + TLS_ENCRYPTION_MAX_OVERHEAD )
  {
    record->max_fragment_size = 512;
  }
  else
  {
    ERR("Buffer is too small");
    return MINITLS_ERR_BUFFER_TOO_SMALL;
  }

  DBG("Max fragment size: %d bytes", record->max_fragment_size);

  if( (buf_size != TLS_DEFAULT_MAX_FRAGMENT_SIZE)
      && (buf_size != (record->max_fragment_size + TLS_ENCRYPTION_MAX_OVERHEAD)) )
  {
    WARN("Buffer size is not optimum");
  }

  //Initialize with oldest protocol version by default (as recommended by RFC 5246's Annex E)
#if MINITLS_CFG_PROTOCOL_SSL_3
  record->version.major = SSL_3_VERSION_MAJOR;
  record->version.minor = SSL_3_VERSION_MINOR;
#elif MINITLS_CFG_PROTOCOL_TLS_1_0
  record->version.major = TLS_1_0_VERSION_MAJOR;
  record->version.minor = TLS_1_0_VERSION_MINOR;
#elif MINITLS_CFG_PROTOCOL_TLS_1_1
  record->version.major = TLS_1_1_VERSION_MAJOR;
  record->version.minor = TLS_1_1_VERSION_MINOR;
#elif MINITLS_CFG_PROTOCOL_TLS_1_2
  record->version.major = TLS_1_2_VERSION_MAJOR;
  record->version.minor = TLS_1_2_VERSION_MINOR;
#else
#error No SSL/TLS protocol version enabled
#endif

  buffer_init(&record->buffer, buf, buf_size);

  record->tls_socket = sock;

  //Init security
  record->security_rx_state = TLS_SECURITY_NONE;
  record->security_tx_state = TLS_SECURITY_NONE;
  record->security_type = TLS_SECURITY_TYPE_NULL_NULL_NULL;

  //Memset keys
  memset(&record->client_write_mac_key, 0, TLS_HMAC_SHA1_KEY_SIZE);
  memset(&record->server_write_mac_key, 0, TLS_HMAC_SHA1_KEY_SIZE);
  memset(&record->client_write_cipher_key, 0, AES_128_KEY_SIZE);
  memset(&record->server_write_cipher_key, 0, AES_128_KEY_SIZE);

  return MINITLS_OK;
}

void tls_record_set_protocol_version(tls_record_t* record, uint8_t major, uint8_t minor)
{
  record->version.major = major;
  record->version.minor = minor;
}

void tls_record_get_protocol_version(tls_record_t* record, uint8_t* major, uint8_t* minor)
{
  *major = record->version.major;
  *minor = record->version.minor;
}

minitls_err_t tls_record_change_cipher_spec(tls_record_t* record, bool tx_nrx)
{
  if(tx_nrx)
  {
    if(record->security_tx_state == TLS_SECURITY_INTIALIZED)
    {
      record->security_tx_state = TLS_SECURITY_ACTIVE;
      return MINITLS_OK;
    }
    else
    {
      return MINITLS_ERR_PROTOCOL_NON_CONFORMANT;
    }
  }
  else
  {
    if(record->security_rx_state == TLS_SECURITY_INTIALIZED)
    {
      record->security_rx_state = TLS_SECURITY_ACTIVE;
      return MINITLS_OK;
    }
    else
    {
      return MINITLS_ERR_PROTOCOL_NON_CONFORMANT;
    }
  }
}

bool tls_record_is_secure(tls_record_t* record)
{
  if( record->security_tx_state != TLS_SECURITY_ACTIVE )
  {
    return false;
  }
  if( record->security_rx_state != TLS_SECURITY_ACTIVE )
  {
    return false;
  }
  return true;
}

minitls_err_t tls_record_connect(tls_record_t* record, const char* hostname, uint16_t port)
{
  DBG("Trying to connect to %s:%d", hostname, port);

  int r = socket_connect(record->socket_fd, hostname, port);
  if(r < 0)
  {
    socket_close(record->socket_fd);
    record->socket_fd = -1;
    return MINITLS_ERR_SOCKET_ERROR;
  }

  return MINITLS_OK;
}

minitls_err_t tls_record_process(tls_record_t* record)
{
  //Reset buffer length
  buffer_reset(&record->buffer);

  //Read header
  minitls_err_t ret = record_socket_read(record, FRAGMENT_HEADER_SIZE);
  if(ret == MINITLS_ERR_SOCKET_CLOSED)
  {
    return MINITLS_ERR_SOCKET_CLOSED;
  }
  else if(ret)
  {
    ERR("Socket err %d", ret);
    tls_alert_send( record, TLS_ALERT_FATAL, INTERNAL_ERROR, &record->buffer);
    return ret;
  }

  //Read version
  tls_fragment_header_t header;

  header.type = buffer_nu8_read(&record->buffer);
  header.version.major = buffer_nu8_read(&record->buffer);
  header.version.minor = buffer_nu8_read(&record->buffer);
  header.length = buffer_nu16_read(&record->buffer);

#if 1 //TODO how to relax this?
  if( (header.version.major != record->version.major) || (header.version.minor != record->version.minor) )
  {
    ERR("Version mismatch");
    tls_alert_send( record, TLS_ALERT_FATAL, PROTOCOL_VERSION, &record->buffer);
    return MINITLS_ERR_PROTOCOL_VERSION;
  }
#endif
  //Check content type
  //Check that encryption level is OK for this content type
  switch( header.type )
  {
  //All of these are OK in plain mode
  case TLS_CHANGE_CIPHER_SPEC:
  case TLS_ALERT:
  case TLS_HANDSHAKE:
    break;
  //This is only acceptable in ciphered mode:
  case TLS_APPLICATION_DATA:
    if( (!tls_record_is_secure(record)) || (!record->handshake_done) )
    {
      tls_alert_send( record, TLS_ALERT_FATAL, INSUFFICIENT_SECURITY, &record->buffer);
      return MINITLS_ERR_PROTOCOL_NON_CONFORMANT;
    }
    break;
  default:
    tls_alert_send( record, TLS_ALERT_FATAL, ILLEGAL_PARAMETER, &record->buffer);
    return MINITLS_ERR_PROTOCOL_NON_CONFORMANT;
  }

  //Reset buffer
  buffer_reset(&record->buffer);

  //Read payload
  ret = record_socket_read(record, header.length);
  if(ret)
  {
    ERR("Socket err %d", ret);
    tls_alert_send( record, TLS_ALERT_FATAL, INTERNAL_ERROR, &record->buffer);
    return ret;
  }

  if( record->security_rx_state == TLS_SECURITY_ACTIVE )
  {

#if CRYPTO_AES_128
    if(record->security_type == TLS_SECURITY_TYPE_AES_128_CBC_SHA)
    {
      DBG("IV + Ciphertext");
      DBG_BLOCK(buffer_dump(&record->buffer);)

      buffer_t buffer_iv_header;
      if( (buffer_length(&record->buffer) < 2*AES_128_BLOCK_SIZE) || ( (buffer_length(&record->buffer) % AES_128_BLOCK_SIZE) != 0 ) )
      {
        tls_alert_send( record, TLS_ALERT_FATAL, UNEXPECTED_MESSAGE, &record->buffer );
        return MINITLS_ERR_PROTOCOL_NON_CONFORMANT;
      }
      buffer_byref(&buffer_iv_header, buffer_current_read_position(&record->buffer), AES_128_BLOCK_SIZE); //Extract IV vector
      buffer_n_discard(&record->buffer, AES_128_BLOCK_SIZE);

      //Decrypt message
      ret = crypto_aes_128_cbc_decrypt( &record->cipher_rx.aes_128, &buffer_iv_header, &record->buffer );
      if(ret)
      {
        ERR("Failed to decipher, ret %d", ret);
        tls_alert_send( record, TLS_ALERT_FATAL, DECRYPT_ERROR, &record->buffer );
        return ret;
      }

      DBG("Plaintext + MAC + padding + padding length");
      DBG_BLOCK(buffer_dump(&record->buffer);)

      //Check and remove padding
      size_t padding_length = *(buffer_current_write_position(&record->buffer) - 1);

      if( padding_length + 1 > buffer_length(&record->buffer) )
      {
        ERR("Wrong padding length");
        tls_alert_send( record, TLS_ALERT_FATAL, BAD_RECORD_MAC, &record->buffer );
        return MINITLS_ERR_CRYPTO;
      }

      int p;
      //Check each padding byte
      for(p = 0; p < padding_length; p++)
      {
        if( *(buffer_current_write_position(&record->buffer) - 1 - p) != padding_length )
        {
          ERR("Wrong padding");
          tls_alert_send( record, TLS_ALERT_FATAL, BAD_RECORD_MAC, &record->buffer );
          return MINITLS_ERR_CRYPTO;
        }
      }

      //Remove trailing padding + padding length
      buffer_set_length(&record->buffer, buffer_length(&record->buffer) - 1 - padding_length);
    }
    else
#endif
#if CRYPTO_ARC4
    if(record->security_type == TLS_SECURITY_TYPE_ARC4_SHA)
    {
      DBG("Ciphertext");
      DBG_BLOCK(buffer_dump(&record->buffer);)

      //Decrypt message
      crypto_arc4_process( &record->cipher_rx.arc4, &record->buffer );
    }
    else
#endif
    {}

    DBG("Plaintext + MAC");
    DBG_BLOCK(buffer_dump(&record->buffer);)

    //Check MAC
    ret = tls_mac_check( record->server_write_mac_key, header.type, header.version, record->sequence_number_rx, &record->buffer );
    if(ret)
    {
      ERR("MAC Check failed, ret %d", ret);
      tls_alert_send( record, TLS_ALERT_FATAL, BAD_RECORD_MAC, &record->buffer );
      return ret;
    }

    DBG("Plaintext");
    DBG_BLOCK(buffer_dump(&record->buffer);)

    //Increment seq number
    record->sequence_number_rx++;
  }
  else
  {
    //No security
  }

  //Now dispatch depending on content type
  switch( header.type )
  {
  case TLS_CHANGE_CIPHER_SPEC:
    ret = tls_record_change_cipher_spec(record, false);
    if(ret)
    {
      ERR("Invalid change cipher spec request, ret %d", ret);
      tls_alert_send( record, TLS_ALERT_FATAL, UNEXPECTED_MESSAGE, &record->buffer );
      return ret;
    }
    break;
  case TLS_ALERT:
    ret = tls_alert_process( record, &record->buffer );
    if(ret)
    {
      tls_record_close(record);
      //Close connection in any case
      if(ret == MINITLS_ERR_CONNECTION_CLOSED)
      {
        DBG("Connection closed by remote party");
        return MINITLS_OK;
      }
      //FIXME Do something
      ERR("Alert received, ret %d", ret);
      return ret;
    }
    break;
  case TLS_HANDSHAKE:
    if(/*(record->tls_socket->handshake != NULL) &&*/ !tls_handshake_is_done(&record->tls_socket->handshake))
    {
      ret = tls_handshake_process(&record->tls_socket->handshake, &record->buffer );
      if(ret)
      {
        ERR("Handshake process returned %d", ret);
        //TLS alert already sent by handshake function
        tls_handshake_clean(&record->tls_socket->handshake); //Cleanup handshake
        //record->tls_socket->handshake = NULL;
        return ret;
      }
      if(tls_handshake_is_done(&record->tls_socket->handshake))
      {
        tls_handshake_clean(&record->tls_socket->handshake); //Cleanup handshake
        //record->tls_socket->handshake = NULL;
        record->handshake_done = true; //Enable application data layer
      }
      return MINITLS_OK;
    }
    else
    {
      ERR("Unexpected handshake message, ret %d", ret);
      tls_alert_send( record, TLS_ALERT_FATAL, UNEXPECTED_MESSAGE, &record->buffer );
      return ret;
    }
  case TLS_APPLICATION_DATA:
    //Pass message to socket layer
    return tls_socket_readable_callback(record->tls_socket,  &record->buffer);
  default:
    //Has already been checked above
    return MINITLS_ERR_PROTOCOL_NON_CONFORMANT;
  }

  return MINITLS_OK;
}

minitls_err_t tls_record_send(tls_record_t* record, tls_content_type_t content_type, buffer_t* payload)
{
  minitls_err_t ret;
  int padding_item;
  /*
  struct {
      opaque IV[SecurityParameters.record_iv_length];
      block-ciphered struct {
          opaque content[TLSCompressed.length];
          opaque MAC[SecurityParameters.mac_length];
          uint8 padding[GenericBlockCipher.padding_length];
          uint8 padding_length;
      };
  } GenericBlockCipher;
  */

  //Check content type
  //Check that encryption level is OK for this content type
  switch( content_type )
  {
  //All of these are OK in plain mode
  case TLS_CHANGE_CIPHER_SPEC:
  case TLS_ALERT:
  case TLS_HANDSHAKE:
    break;
  //This is only acceptable in ciphered mode:
  case TLS_APPLICATION_DATA:
    if( (!tls_record_is_secure(record)) || (!record->handshake_done) )
    {
      tls_alert_send( record, TLS_ALERT_FATAL, INSUFFICIENT_SECURITY, &record->buffer);
      return MINITLS_ERR_PROTOCOL_NON_CONFORMANT;
    }
    break;
  default:
    tls_alert_send( record, TLS_ALERT_FATAL, ILLEGAL_PARAMETER, &record->buffer);
    return MINITLS_ERR_PROTOCOL_NON_CONFORMANT;
  }

  //Buffer must have enough space to add IV (head) + MAC (tail) + padding (tail)

  buffer_t header_iv;
  uint8_t header_iv_data[AES_128_BLOCK_SIZE];
  if( record->security_tx_state == TLS_SECURITY_ACTIVE )
  {
    //FIXME generate a random IV

    DBG("Plaintext");
    DBG_BLOCK(buffer_dump(payload);)

    //Compute & append MAC
    DBG("Sequence number: %d", record->sequence_number_tx);
    ret = tls_mac_append( record->client_write_mac_key, content_type, record->version, record->sequence_number_tx, payload );
    if(ret)
    {
      ERR("Could not append MAC, ret %d", ret);
      return ret;
    }

    //Increment sequence number
    record->sequence_number_tx++;

    DBG("Plaintext + MAC");
    DBG_BLOCK(buffer_dump(payload);)

#if CRYPTO_AES_128
    if(record->security_type == TLS_SECURITY_TYPE_AES_128_CBC_SHA)
    {

      //Add padding
      size_t padding_length = AES_128_BLOCK_SIZE - (buffer_length(payload) % AES_128_BLOCK_SIZE) - 1;
      if(buffer_space(payload) < padding_length)
      {
        return MINITLS_ERR_BUFFER_TOO_SMALL;
      }

      for(padding_item = 0; padding_item < padding_length; padding_item++)
      {
        buffer_nu8_write(payload, padding_length);
      }

      buffer_nu8_write(payload, padding_length);

      DBG("Plaintext + MAC + Padding + Padding Length");
      DBG_BLOCK(buffer_dump(payload);)

      buffer_init( &header_iv, header_iv_data, AES_128_BLOCK_SIZE );

      crypto_prng_get(record->tls_socket->minitls->prng, buffer_current_write_position(&header_iv), AES_128_BLOCK_SIZE);
      buffer_n_skip(&header_iv, AES_128_BLOCK_SIZE);

      //Encrypt message
      ret = crypto_aes_128_cbc_encrypt( &record->cipher_tx.aes_128, &header_iv, payload );
      if(ret)
      {
        ERR("Failed to encipher, ret %d", ret);
        return ret;
      }
    }
    else
#endif
#if CRYPTO_ARC4
    if(record->security_type == TLS_SECURITY_TYPE_ARC4_SHA)
    {
      //No IV
      buffer_init( &header_iv, NULL, 0 ); //0 Length

      //Encrypt message
      crypto_arc4_process( &record->cipher_tx.arc4, payload );
    }
    else
#endif
    {}

    DBG("Ciphertext");
    DBG_BLOCK(buffer_dump(payload);)
  }
  else
  {
    buffer_init( &header_iv, NULL, 0 ); //0 Length
  }

  //Now send message header
  tls_fragment_header_t header;
  header.type = content_type;
  header.version.major = record->version.major;
  header.version.minor = record->version.minor;
  header.length = buffer_length( &header_iv ) + buffer_length(payload);

  buffer_t header_fragment;
  uint8_t header_fragment_data[FRAGMENT_HEADER_SIZE];

  buffer_init( &header_fragment, header_fragment_data, FRAGMENT_HEADER_SIZE );

  buffer_nu8_write(&header_fragment, header.type);
  buffer_nu8_write(&header_fragment, header.version.major);
  buffer_nu8_write(&header_fragment, header.version.minor);
  buffer_nu16_write(&header_fragment, header.length);

  //Send fragment header
  ret = record_socket_write(record, &header_fragment);
  if(ret)
  {
    return ret;
  }

  //Send IV
  ret = record_socket_write(record, &header_iv);
  if(ret)
  {
    return ret;
  }

  //Send payload
  ret = record_socket_write(record, payload);
  if(ret)
  {
    return ret;
  }

  return MINITLS_OK;
}

minitls_err_t tls_record_set_keys(tls_record_t* record, tls_security_type_t security, const uint8_t* client_write_mac_key, const uint8_t* server_write_mac_key,
    const uint8_t* client_write_cipher_key, const uint8_t* server_write_cipher_key)
{
  if( (security != TLS_SECURITY_TYPE_AES_128_CBC_SHA) && (security != TLS_SECURITY_TYPE_ARC4_SHA) )
  {
    return MINITLS_ERR_NOT_IMPLEMENTED;
  }

  //Copy keys
  memcpy(&record->client_write_mac_key, client_write_mac_key, TLS_HMAC_SHA1_KEY_SIZE);
  memcpy(&record->server_write_mac_key, server_write_mac_key, TLS_HMAC_SHA1_KEY_SIZE);
  memcpy(&record->client_write_cipher_key, client_write_cipher_key, AES_128_KEY_SIZE); //TODO generic key size
  memcpy(&record->server_write_cipher_key, server_write_cipher_key, AES_128_KEY_SIZE);

  //Intialize cipher

  record->sequence_number_tx = 0;
  record->sequence_number_rx = 0;

  switch(security)
  {
#if CRYPTO_AES_128
  case TLS_SECURITY_TYPE_AES_128_CBC_SHA:
    crypto_aes_128_init(&record->cipher_tx.aes_128, record->client_write_cipher_key, expand_encryption_key);
    crypto_aes_128_init(&record->cipher_rx.aes_128, record->server_write_cipher_key, expand_decryption_key);
    break;
#endif
#if CRYPTO_ARC4
  case TLS_SECURITY_TYPE_ARC4_SHA:
    crypto_arc4_init(&record->cipher_tx.arc4, record->client_write_cipher_key, TLS_ARC4_KEY_SIZE );
    crypto_arc4_init(&record->cipher_rx.arc4, record->server_write_cipher_key, TLS_ARC4_KEY_SIZE);
    break;
#endif
  default:
    break;
  }

  record->security_tx_state = TLS_SECURITY_INTIALIZED;
  record->security_rx_state = TLS_SECURITY_INTIALIZED;
  record->security_type = security;

  return MINITLS_OK;
}

minitls_err_t tls_record_close(tls_record_t* record)
{
  if(record->socket_fd < 0) //Already closed
  {
    return MINITLS_OK;
  }

  //Don't really care about the return
  tls_alert_send(record, TLS_ALERT_WARNING, CLOSE_NOTIFY, &record->buffer);

  //Close socket
  socket_close(record->socket_fd);
  record->socket_fd = -1;

  return MINITLS_OK;
}

minitls_err_t tls_record_set_read_timeout(tls_record_t* record, int timeout)
{
  record->read_timeout = timeout;

  return MINITLS_OK;
}

minitls_err_t tls_record_set_write_timeout(tls_record_t* record, int timeout)
{
  record->write_timeout = timeout;

  return MINITLS_OK;
}

minitls_err_t record_wait_readable(tls_record_t* record)
{
  if(record->socket_fd < 0)
  {
    return MINITLS_ERR_SOCKET_CLOSED;
  }

  //Wait for record to be readable
  int ret = socket_wait_readable(record->socket_fd, record->read_timeout );
  if( ret < 0 )
  {
    //Timeout
    return MINITLS_ERR_TIMEOUT;
  }
  return MINITLS_OK;
}

minitls_err_t record_wait_writeable(tls_record_t* record)
{
  if(record->socket_fd < 0)
  {
    return MINITLS_ERR_SOCKET_CLOSED;
  }

  //Wait for record to be writeable
  int ret = socket_wait_writeable(record->socket_fd, record->write_timeout );
  if( ret < 0 )
  {
    //Timeout
    return MINITLS_ERR_TIMEOUT;
  }
  return MINITLS_OK;
}

minitls_err_t record_socket_read(tls_record_t* record, size_t size)
{
  minitls_err_t ret;
  if(record->socket_fd < 0)
  {
    return MINITLS_ERR_SOCKET_CLOSED;
  }

  DBG("Trying to read %d bytes", size);
  while(size > 0)
  {
    //Read Fragment length
    if( buffer_space(&record->buffer) < size )
    {
      ERR("Won't be able to read packet (%d bytes to read - %d bytes of space)", size, buffer_space(&record->buffer));
      return MINITLS_ERR_BUFFER_TOO_SMALL;
    }

    ret = record_wait_readable(record);
    if(ret)
    {
     ERR("Timeout");
     return ret;
    }

    int count = socket_recv(record->socket_fd, buffer_current_write_position(&record->buffer), size /*- buffer_length(&record->buffer)*/);
    if( count > 0 )
    {
      buffer_n_skip(&record->buffer, count);
      size -= count;
    }
    else if( count == 0 )
    {
      WARN("Socket closed");
      return MINITLS_ERR_SOCKET_CLOSED;
    }
    else
    {
      ERR("Error (returned %d)", count);
      return MINITLS_ERR_SOCKET_ERROR;
    }
  }

  DBG_BLOCK(buffer_dump(&record->buffer);)

  return MINITLS_OK;
}

minitls_err_t record_socket_write(tls_record_t* record, buffer_t* data)
{
  minitls_err_t ret;
  if(record->socket_fd < 0)
  {
    return MINITLS_ERR_SOCKET_CLOSED;
  }

  DBG("Trying to write %d bytes", buffer_length(data));
  DBG_BLOCK(buffer_dump(data);)
  while(buffer_length(data) > 0)
  {
    ret = record_wait_writeable(record);
    if(ret)
    {
     ERR("Timeout");
     return ret;
    }

    int count = socket_send(record->socket_fd, buffer_current_read_position(data), buffer_length(data));
    if( count > 0 )
    {
      buffer_n_discard(data, count);
    }
    else if( count == 0 )
    {
      WARN("Socket closed");
      return MINITLS_ERR_SOCKET_CLOSED;
    }
    else
    {
      ERR("Error (returned %d)", count);
      return MINITLS_ERR_SOCKET_ERROR;
    }
  }
  DBG("Done");
  return MINITLS_OK;
}


minitls_err_t tls_mac_append( const uint8_t* key, tls_content_type_t content_type, tls_protocol_version_t version,
    uint64_t sequence_number, buffer_t* buffer )
{
  crypto_hmac_sha1_t mac;
  crypto_hmac_sha1_init(&mac, key, TLS_HMAC_SHA1_KEY_SIZE);

  if( buffer_space(buffer) < HMAC_SHA1_SIZE )
  {
    return MINITLS_ERR_BUFFER_TOO_SMALL;
  }

  uint8_t header_buf[13];
  buffer_t header;

  buffer_init(&header, header_buf, 13);

  buffer_nu64_write(&header, sequence_number);

  buffer_nu8_write(&header, content_type);

  buffer_nu8_write(&header, version.major);
  buffer_nu8_write(&header, version.minor);

  buffer_nu16_write(&header, buffer_length(buffer));

  crypto_hmac_sha1_update(&mac, header_buf, 13);
  crypto_hmac_sha1_update(&mac, buffer_current_read_position(buffer), buffer_length(buffer));
  crypto_hmac_sha1_end(&mac, buffer_current_write_position(buffer));
  buffer_n_skip(buffer, HMAC_SHA1_SIZE);

  return MINITLS_OK;
}

minitls_err_t tls_mac_check( const uint8_t* key, tls_content_type_t content_type, tls_protocol_version_t version,
    uint64_t sequence_number, buffer_t* buffer )
{
  crypto_hmac_sha1_t mac;
  crypto_hmac_sha1_init(&mac, key, TLS_HMAC_SHA1_KEY_SIZE);

  if( buffer_length(buffer) < HMAC_SHA1_SIZE )
  {
    return MINITLS_ERR_PROTOCOL_NON_CONFORMANT;
  }

  size_t data_offset = buffer_get_read_offset(buffer);
  size_t data_length = buffer_length(buffer) - HMAC_SHA1_SIZE;

  uint8_t check[HMAC_SHA1_SIZE];

  uint8_t header_buf[13];
  buffer_t header;

  buffer_init(&header, header_buf, 13);

  buffer_nu64_write(&header, sequence_number);

  buffer_nu8_write(&header, content_type);

  buffer_nu8_write(&header, version.major);
  buffer_nu8_write(&header, version.minor);

  buffer_nu16_write(&header, data_length);

  crypto_hmac_sha1_update(&mac, header_buf, 13);
  crypto_hmac_sha1_update(&mac, buffer_current_read_position(buffer), data_length);
  buffer_n_discard(buffer, data_length);
  crypto_hmac_sha1_end(&mac, check);

  if( memcmp(buffer_current_read_position(buffer), check, HMAC_SHA1_SIZE) != 0 )
  {
    ERR("MAC differs; computed MAC was:");

    buffer_t computed_mac;
    buffer_byref(&computed_mac, check, HMAC_SHA1_SIZE);
    DBG_BLOCK(buffer_dump(&computed_mac);)

    return MINITLS_ERR_WRONG_MAC;
  }

  //Reset buffer position and discard MAC
  buffer_set_read_offset(buffer, data_offset);
  buffer_set_length(buffer, data_length);

  return MINITLS_OK;
}