/*
MiniTLS - A super trimmed down TLS/SSL Library for embedded devices
Author: Donatien Garnier
Copyright (C) 2013-2014 AppNearMe Ltd

This program is free software; you can redistribute it and/or
modify it under the terms of the GNU General Public License
as published by the Free Software Foundation; either version 2
of the License, or (at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
*//**
 * \file crypto_rsa.c
 * \copyright Copyright (c) AppNearMe Ltd 2014
 * \author Donatien Garnier
 */

#define __DEBUG__ 0
#ifndef __MODULE__
#define __MODULE__ "crypto_rsa.c"
#endif

#include "core/fwk.h"
#include "crypto_rsa.h"
#include "inc/minitls_errors.h"
#include "inc/minitls_config.h"

#include "crypto_math.h"
#include "ltc/ltc.h"

static minitls_err_t crypto_pkcs_1_v1_5_encode(const uint8_t* msg,
    size_t  msglen,
    size_t  modulus_bitlen,
    crypto_prng_t* prng,
    uint8_t* out,
    size_t* outlen);
static minitls_err_t crypto_rsa_exptmod(const uint8_t* in, size_t inlen,
    uint8_t *out, size_t* outlen,
    crypto_rsa_public_key_t* key);
static minitls_err_t crypto_ecc_dsa_check_get_asn1_Ne(void* N, void* e, const uint8_t* key, size_t key_size);

minitls_err_t crypto_rsa_pkcs1_import(crypto_rsa_public_key_t* key, const uint8_t* pkcs1, size_t size)
{
  int           err;

  /* init key */
  if ((err = mp_init_multi(&key->e, &key->N, NULL)) != MINITLS_OK) {
     return err;
  }

  if( (err = crypto_ecc_dsa_check_get_asn1_Ne(&key->N, &key->e, pkcs1, size)) != MINITLS_OK )
  {
    goto LBL_ERR;
  }

  return MINITLS_OK;
LBL_ERR:
  mp_clear_multi(&key->e, &key->N, NULL);
  return err;
}

minitls_err_t crypto_rsa_encrypt(const crypto_rsa_public_key_t* public_key,
    uint8_t* plaintext, size_t plaintext_size,
    uint8_t* secret, size_t max_secret_size, size_t* secret_size, crypto_prng_t* prng)
{
  minitls_err_t ret;

  /* get modulus len in bits */
  size_t modulus_bitlen = mp_count_bits( (&public_key->N));

  /* outlen must be at least the size of the modulus */
  size_t modulus_bytelen = mp_unsigned_bin_size( (&public_key->N));
  if (modulus_bytelen > max_secret_size) {
    WARN("modulus_bytelen = %d but max_secret_size = %d", modulus_bytelen, max_secret_size);
     *secret_size = modulus_bytelen;
     return MINITLS_ERR_BUFFER_TOO_SMALL;
  }

  //Apply padding
  *secret_size = max_secret_size;
  ret = crypto_pkcs_1_v1_5_encode(plaintext, plaintext_size, modulus_bitlen, prng, secret, secret_size);
  if(ret)
  {
    return ret;
  }

  //Do the exponentiation
  ret = crypto_rsa_exptmod(secret, *secret_size, secret, secret_size, public_key);
  if(ret)
  {
    return ret;
  }

  return MINITLS_OK;
}

minitls_err_t crypto_rsa_exptmod(const uint8_t* in, size_t inlen,
                      uint8_t *out, size_t* outlen,
                      crypto_rsa_public_key_t* key)
{
   fp_int        tmp;
   unsigned long x;
   int           err;

   /* init and copy into tmp */
   if ((err = mp_init_multi(&tmp, NULL)) != MINITLS_OK)                                    { return err; }
   mp_read_unsigned_bin(&tmp, (unsigned char *)in, (int)inlen);

   /* sanity check on the input */
   if (mp_cmp(&key->N, &tmp) == MP_LT) {
      err = MINITLS_ERR_WRONG_LENGTH;
      goto error;
   }

   /* exptmod it */
   if ((err = mp_exptmod(&tmp, &key->e, &key->N, &tmp)) != MINITLS_OK)                                { goto error; }

   /* read it back */
   x = (unsigned long)mp_unsigned_bin_size(&key->N);
   if (x > *outlen) {
      *outlen = x;
      err = MINITLS_ERR_BUFFER_TOO_SMALL;
      goto error;
   }

   /* this should never happen ... */
   if (mp_unsigned_bin_size(&tmp) > mp_unsigned_bin_size(&key->N)) {
      err = MINITLS_ERR_CRYPTO;
      goto error;
   }
   *outlen = x;

   /* convert it */
   zeromem(out, x);
   mp_to_unsigned_bin(&tmp, out+(x-mp_unsigned_bin_size(&tmp)));

   /* clean up and return */
   err = MINITLS_OK;
error:
   mp_clear_multi(&tmp, NULL);
   return err;
}

minitls_err_t crypto_pkcs_1_v1_5_encode(const uint8_t* msg,
                             size_t  msglen,
                             size_t  modulus_bitlen,
                             crypto_prng_t* prng,
                             uint8_t* out,
                             size_t* outlen)
{
  unsigned long modulus_len, ps_len, i;
  unsigned char *ps;
  int result;

  modulus_len = (modulus_bitlen >> 3) + (modulus_bitlen & 7 ? 1 : 0);

  /* test message size */
  if ((msglen + 11) > modulus_len) {
    return MINITLS_ERR_WRONG_LENGTH;
  }

  if (*outlen < modulus_len) {
    *outlen = modulus_len;
    result = MINITLS_ERR_BUFFER_TOO_SMALL;
    goto bail;
  }

  /* generate an octets string PS */
  ps = &out[2];
  ps_len = modulus_len - msglen - 3;


  /* now choose a random ps */
  crypto_prng_get(prng, ps, ps_len);

  /* transform zero bytes (if any) to non-zero random bytes */
  for (i = 0; i < ps_len; i++) {
    while (ps[i] == 0) {
      crypto_prng_get(prng, &ps[i], 1);
    }
  }

  /* create string of length modulus_len */
  out[0]          = 0x00;
  out[1]          = 2;  /* block_type is Block type 2 (LTC_PKCS #1 v1.5 encryption padding) */
  out[2 + ps_len] = 0x00;
  memcpy(&out[2 + ps_len + 1], msg, msglen);
  *outlen = modulus_len;

  result  = MINITLS_OK;
bail:
  return result;
}


//Decode (&N,&e) integers from ASN.1-encoded public key
#define ENSURE_SIZE(actual_size, min_size) do{ if( (actual_size) < (min_size) ) { return MINITLS_ERR_PARAMETERS; } }while(0)
minitls_err_t crypto_ecc_dsa_check_get_asn1_Ne(void* N, void* e, const uint8_t* key, size_t key_size)
{
  const uint8_t* p = key;
  size_t sz = key_size;

  /* OpenSSL encoded keys have this format:
   *
   * SEQUENCE(2 elem)
   * * SEQUENCE(2 elem)
   * * * OBJECT IDENTIFIER1.2.840.113549.1.1.1
   * * * NULL
   * * BIT STRING(1 elem)
   * * * SEQUENCE(2 elem)
   * * * * INTEGER(1024 bit)
   * * * * INTEGER 65537
   *
   *
   */




  ENSURE_SIZE(sz, 1);

  if( (p[0] != 0x30) && (p[0] != 0x31) ) //Sequence, SET types
  {
    return MINITLS_ERR_PARAMETERS;
  }

  p++;
  sz--;

  ENSURE_SIZE(sz, 1);

  size_t seq_size;
  //Get sequence length
  if(*p < 0x80)
  {
    seq_size = p[0];
    p++;
    sz--;
  }
  else if(*p == 0x81)
  {
    ENSURE_SIZE(sz, 2);
    seq_size = p[1];
    p+=2;
    sz-=2;
  }
  else if(*p == 0x82)
  {
    ENSURE_SIZE(sz, 3);
    seq_size = (p[1] << 8) | p[2];
    p+=3;
    sz-=3;
  }
  else if(*p == 0x83)
  {
    ENSURE_SIZE(sz, 4);
    seq_size = (p[1] << 16) | (p[2] << 8) | p[3];
    p+=4;
    sz-=4;
  }
  else if(*p == 0x84)
  {
    ENSURE_SIZE(sz, 5);
    seq_size = (p[1] << 24) |(p[2] << 16) | (p[3] << 8) | p[4];
    p+=5;
    sz-=5;
  }
  else
  {
    return MINITLS_ERR_PARAMETERS;
  }

  //Check that sequence size == remaining bytes size
  if( seq_size != sz )
  {
    return MINITLS_ERR_PARAMETERS;
  }

  //Read integers
  for(int i = 0; i < 2; i++)
  {
    ENSURE_SIZE(sz, 1);

    if( p[0] != 2 ) //Integer type
    {
      return MINITLS_ERR_PARAMETERS;
    }

    p++;
    sz--;

    ENSURE_SIZE(sz, 1);

    size_t integer_size;
    //Get sequence length
    if(*p < 0x80)
    {
      integer_size = p[0];
      p++;
      sz--;
    }
    else if(*p == 0x81)
    {
      ENSURE_SIZE(sz, 2);
      integer_size = p[1];
      p+=2;
      sz-=2;
    }
    else if(*p == 0x82)
    {
      ENSURE_SIZE(sz, 3);
      integer_size = (p[1] << 8) | p[2];
      p+=3;
      sz-=3;
    }
    else if(*p == 0x83)
    {
      ENSURE_SIZE(sz, 4);
      integer_size = (p[1] << 16) | (p[2] << 8) | p[3];
      p+=4;
      sz-=4;
    }
    else if(*p == 0x84)
    {
      ENSURE_SIZE(sz, 5);
      integer_size = (p[1] << 24) |(p[2] << 16) | (p[3] << 8) | p[4];
      p+=5;
      sz-=5;
    }
    else
    {
      return MINITLS_ERR_PARAMETERS;
    }

    //Check that we have enough bytes remaining
    ENSURE_SIZE(sz, integer_size);

    DBG("Integer of size %d", integer_size);

    //Read integer
    void* integer = (i==0)?N:e;

    /*int err;*/
    /*if ((err = */mp_read_unsigned_bin(integer, (unsigned char *)p, integer_size);/*) != MINITLS_OK) {
       return err;
    }*/

    p+=integer_size;
    sz-=integer_size;
  }

  if(sz > 0)
  {
    //Unread parameters left in sequence
    return MINITLS_ERR_PARAMETERS;
  }

  return MINITLS_OK;
}

