/*
MiniTLS - A super trimmed down TLS/SSL Library for embedded devices
Author: Donatien Garnier
Copyright (C) 2013-2014 AppNearMe Ltd

This program is free software; you can redistribute it and/or
modify it under the terms of the GNU General Public License
as published by the Free Software Foundation; either version 2
of the License, or (at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
*//**
 * \file crypto_ecc.c
 * \copyright Copyright (c) AppNearMe Ltd 2013
 * \author Donatien Garnier
 */

#define __DEBUG__ 0//4
#ifndef __MODULE__
#define __MODULE__ "crypto_ecc.c"
#endif

#include "core/fwk.h"
#include "crypto_ecc.h"
#include "inc/minitls_errors.h"
#include "inc/minitls_config.h"

#include "crypto_math.h"
#include "ltc/ltc.h"

static minitls_err_t crypto_ecc_dsa_check_get_asn1_rs(void* r, void* s, const uint8_t* signature, size_t signature_size);

static const crypto_ecc_curve_t crypto_ecc_curves[];

minitls_err_t crypto_ecc_curve_get(const crypto_ecc_curve_t** curve, crypto_ecc_curve_type_t type)
{
 /*
GNUTLS supports:
SECP192R1 = ECC-192 in libtomcrypt
SECP224R1 = ECC-224 in libtomcrypt
SECP256R1 = ECC-256 in libtomcrypt
SECP384R1 = ECC-384 in libtomcrypt
SECP521R1 = ECC-521 in libtomcrypt

So let's support the same ones!
 */

  int i = 0;
  while(true)
  {
    const crypto_ecc_curve_t* current_curve = &crypto_ecc_curves[i];
    if(current_curve->size == 0)
    {
      return MINITLS_ERR_NOT_IMPLEMENTED;
    }
    if(current_curve->type == type)
    {
      *curve = current_curve;
      break;
    }
  }
  return MINITLS_OK;
}

minitls_err_t crypto_ecc_ansi_x963_import(crypto_ecc_public_key_t* key, const crypto_ecc_curve_t* curve, const uint8_t* x963, size_t size)
{
  int err;

  /* must be odd */
  if ((size & 1) == 0) {
     return MINITLS_ERR_PARAMETERS;
  }

  /* init key */
  if (mp_init_multi(&key->pubkey.x, &key->pubkey.y, &key->pubkey.z, NULL) != MINITLS_OK) {
     return MINITLS_ERR_MEMORY;
  }

  /* check for 4, 6 or 7 */
  if (x963[0] != 4 && x963[0] != 6 && x963[0] != 7) {
     err = MINITLS_ERR_PARAMETERS;
     goto error;
  }

  /* read data */
  mp_read_unsigned_bin(&key->pubkey.x, (unsigned char *)x963+1, (size-1)>>1);

  mp_read_unsigned_bin(&key->pubkey.y, (unsigned char *)x963+1+((size-1)>>1), (size-1)>>1);

  mp_set(&key->pubkey.z, 1);

  if (((size-1)>>1) != (unsigned long) curve->size) {
    err = MINITLS_ERR_PARAMETERS;
    goto error;
  }

  key->curve = curve;

  /* we're done */
  return MINITLS_OK;
error:
  mp_clear_multi(&key->pubkey.x, &key->pubkey.y, &key->pubkey.z, NULL);
  return err;
}

minitls_err_t crypto_ecc_ansi_x963_export(const crypto_ecc_public_key_t* key, uint8_t* x963, size_t max_size, size_t* size)
{

  unsigned char buf[ECC_BUF_SIZE];
  unsigned long numlen;

  numlen = crypto_ecc_get_key_size_for_curve(key->curve);

  if (max_size < (1 + 2*numlen)) {
     *size = 1 + 2*numlen;
     return MINITLS_ERR_BUFFER_TOO_SMALL;
  }

  /* store byte 0x04 */
  x963[0] = 0x04;

  /* pad and store x */
  zeromem(buf, sizeof(buf));
  mp_to_unsigned_bin(&key->pubkey.x, buf + (numlen - mp_unsigned_bin_size(&key->pubkey.x)));
  memcpy(x963+1, buf, numlen);

  /* pad and store y */
  zeromem(buf, sizeof(buf));
  mp_to_unsigned_bin(&key->pubkey.y, buf + (numlen - mp_unsigned_bin_size(&key->pubkey.y)));
  memcpy(x963+1+numlen, buf, numlen);

  *size = 1 + 2*numlen;
  return MINITLS_OK;
}

minitls_err_t crypto_ecc_generate_key(crypto_ecc_private_key_t* key, const crypto_ecc_curve_t* curve, crypto_prng_t* prng)
{
  ecc_point     base;
  fp_int        prime, order;
  int ret;

  size_t keysize = crypto_ecc_get_key_size_for_curve(curve);

  DBG("Generating key of size %d", keysize);

  uint8_t buf[keysize];

  /* make up random string */
  DBG("Getting data from PRNG");
  crypto_prng_get(prng, buf, keysize);
#if 0
  if( crypto_prng_get(prng, buf, keysize) != MINITLS_OK ) {
     ret = MINITLS_ERR_PRNG;
     goto errbuf;
  }
#endif

  DBG("Initializing numbers");
  /* setup the key variables */
  if ((ret = mp_init_multi(&key->pub.pubkey.x, &key->pub.pubkey.y, &key->pub.pubkey.z, &key->privkey, &prime, &order, NULL)) != MINITLS_OK) {
     goto errbuf;
  }
  if ((ret = mp_init_multi(&base.x, &base.y, &base.z, NULL)) != MINITLS_OK) {
     ret = MINITLS_ERR_MEMORY;
     goto errkey;
  }

  DBG("Reading the key specs");
   /* read in the specs for this key */
  if ((ret = mp_read_radix(&prime,   (char *)curve->prime, 16)) != MINITLS_OK)                  { goto errkey; }
  if ((ret = mp_read_radix(&order,   (char *)curve->order, 16)) != MINITLS_OK)                  { goto errkey; }
  if ((ret = mp_read_radix(&base.x, (char *)curve->Gx, 16)) != MINITLS_OK)                     { goto errkey; }
  if ((ret = mp_read_radix(&base.y, (char *)curve->Gy, 16)) != MINITLS_OK)                     { goto errkey; }
  /*if ((ret = */mp_set(&base.z, 1); /*) != MINITLS_OK)                                                  { goto errkey; }*/
  /*if ((ret =*/ mp_read_unsigned_bin(&key->privkey, (unsigned char *)buf, keysize);/*) != MINITLS_OK)         { goto errkey; }*/

  /* the key should be smaller than the order of base point */
  if (mp_cmp(&key->privkey, &order) != MP_LT) {
      if((ret = mp_mod(&key->privkey, &order, &key->privkey)) != MINITLS_OK)                                    { goto errkey; }
  }

  DBG("Compute public key");
  /* make the public key */
  if ((ret = ltc_ecc_mulmod(&key->privkey, &base, &key->pub.pubkey, &prime, 1)) != MINITLS_OK)              { goto errkey; }

  //Save curve
  key->pub.curve = curve;

  DBG("Done");
  /* free up ram */
  ret = MINITLS_OK;
  goto cleanup;
errkey:
   mp_clear_multi(&key->pub.pubkey.x, &key->pub.pubkey.y, &key->pub.pubkey.z, &key->privkey, NULL);
cleanup:
   mp_clear_multi(&base.x, &base.y, &base.z, &prime, &order, NULL);
errbuf:
#ifdef MINITLS_CLEAN_STACK
   zeromem(buf, &key->pub.curve->size);
#endif

   return ret;
}

size_t crypto_ecc_get_key_size_for_curve(const crypto_ecc_curve_t* curve)
{
  switch(curve->type)
   {
   case secp192r1:
     return 192 / 8;
   case secp224r1:
     return 224 / 8;
   case secp256r1:
     return 256 / 8;
   case secp384r1:
     return 384 / 8;
   case secp521r1:
     return 512 / 8;
   default:
     break;
   }

  return 0;
}

const crypto_ecc_public_key_t* crypto_ecc_get_public_key(const crypto_ecc_private_key_t* private_key)
{
  return &private_key->pub;
}

minitls_err_t crypto_ecc_dsa_check(const crypto_ecc_public_key_t* key, const uint8_t* hash, size_t hash_size, const uint8_t* signature, size_t signature_size)
{
  crypto_ecc_point_t    mG, mQ;
  fp_int          r, s, v, w, u1, u2, e, p, m;
  fp_digit        mp;
  int           err;

  bool valid = false;

  /* default to invalid signature */
  valid = false;


  /* allocate ints */
  if ((err = mp_init_multi(&r, &s, &v, &w, &u1, &u2, &p, &e, &m, NULL)) != MINITLS_OK) {
     return MINITLS_ERR_MEMORY;
  }

  /* allocate points */
  err = mp_init_multi(&mG.x, &mG.y, &mG.z, &mQ.x, &mQ.y, &mQ.z, NULL);
  if (err) {
     mp_clear_multi(&r, &s, &v, &w, &u1, &u2, &p, &e, &m, NULL);
     return MINITLS_ERR_MEMORY;
  }

#if 0
  /* parse header */
  //TODO
  if ((err = der_decode_sequence_multi(signature, signature_size,
                                 LTC_ASN1_INTEGER, 1UL, &r,
                                 LTC_ASN1_INTEGER, 1UL, &s,
                                 LTC_ASN1_EOL, 0UL, NULL)) != MINITLS_OK) {
     goto error;
  }
#endif

  //Decode ASN.1 sequence: [INTEGER:&r, INTEGER:&s]
  if( (err = crypto_ecc_dsa_check_get_asn1_rs(&r, &s, signature, signature_size) ) != MINITLS_OK )
  {
    goto error;
  }

  /* get the order */
  if ((err = mp_read_radix(&p, (char *)key->curve->order, 16)) != MINITLS_OK)                                { goto error; }

  /* get the modulus */
  if ((err = mp_read_radix(&m, (char *)key->curve->prime, 16)) != MINITLS_OK)                                { goto error; }

  /* check for zero */
  if (mp_iszero(&r) || mp_iszero(&s) || mp_cmp(&r, &p) != MP_LT || mp_cmp(&s, &p) != MP_LT) {
     err = MINITLS_ERR_PARAMETERS;
     goto error;
  }

  /* read hash */
  /*if ((err =*/ mp_read_unsigned_bin(&e, (unsigned char *)hash, (int)hash_size);/*) != MINITLS_OK)                { goto error; }*/

  /*  &w  = &s^-1 mod n */
  if ((err = mp_invmod(&s, &p, &w)) != MINITLS_OK)                                                          { goto error; }

  /* &u1 = ew */
  if ((err = mp_mulmod(&e, &w, &p, &u1)) != MINITLS_OK)                                                      { goto error; }

  /* &u2 = rw */
  if ((err = mp_mulmod(&r, &w, &p, &u2)) != MINITLS_OK)                                                      { goto error; }

  /* find mG and mQ */
  if ((err = mp_read_radix(&mG.x, (char *)key->curve->Gx, 16)) != MINITLS_OK)                               { goto error; }
  if ((err = mp_read_radix(&mG.y, (char *)key->curve->Gy, 16)) != MINITLS_OK)                               { goto error; }
  /*if ((err = */mp_set(&mG.z, 1);/*) != MINITLS_OK)                                                            { goto error; }*/

  /*if ((err = */mp_copy(&key->pubkey.x, &mQ.x);/*) != MINITLS_OK)                                               { goto error; }*/
  /*if ((err = */mp_copy(&key->pubkey.y, &mQ.y);/*) != MINITLS_OK)                                               { goto error; }*/
  /*if ((err = */mp_copy(&key->pubkey.z, &mQ.z);/*) != MINITLS_OK)                                               { goto error; }*/

  /* compute &u1*mG + &u2*mQ = mG */
#ifndef LTC_ECC_SHAMIR
//  if (ltc_mp.ecc_mul2add == NULL) {
#endif
     if ((err = ltc_ecc_mulmod(&u1, &mG, &mG, &m, 0)) != MINITLS_OK)                                       { goto error; }
     if ((err = ltc_ecc_mulmod(&u2, &mQ, &mQ, &m, 0)) != MINITLS_OK)                                       { goto error; }

     /* find the montgomery mp */
     if ((err = mp_montgomery_setup(&m, &mp)) != MINITLS_OK)                                              { goto error; }

     /* add them */
     if ((err = ltc_ecc_projective_add_point(&mQ, &mG, &mG, &m, &mp)) != MINITLS_OK)                                      { goto error; }

     /* reduce */
     if ((err = ltc_ecc_map(&mG, &m, &mp)) != MINITLS_OK)                                                { goto error; }
#ifdef LTC_ECC_SHAMIR
     /* use Shamir'&s trick to compute &u1*mG + &u2*mQ using half of the doubles */
     if ((err = ltc_ecc_mul2add(&mG, &u1, &mQ, &u2, &mG, &m)) != MINITLS_OK)                                { goto error; }
#endif

  /* &v = X_x1 mod n */
  if ((err = mp_mod(&mG.x, &p, &v)) != MINITLS_OK)                                                         { goto error; }

  /* does &v == &r */
  if (mp_cmp(&v, &r) == MP_EQ) {
    valid = true;
  }

error:
  mp_clear_multi(&mG.x, &mG.y, &mG.z, &mQ.x, &mQ.y, &mQ.z, &r, &s, &v, &w, &u1, &u2, &p, &e, &m, NULL);
  mp_montgomery_free(&mp);
  if(err == MINITLS_OK)
  {
    if(valid)
    {
      return MINITLS_OK;
    }
    else
    {
      return MINITLS_ERR_WRONG_ECDSA;
    }
  }
  else
  {
    return err;
  }
}

minitls_err_t crypto_ecc_dh_generate_shared_secret(const crypto_ecc_private_key_t* private_key, const crypto_ecc_public_key_t* public_key, uint8_t* secret, size_t max_secret_size, size_t* secret_size)
{
  unsigned long  x;
  ecc_point      result;
  fp_int         prime;
  int            err;

  //Check that keys match the same curve
  if (private_key->pub.curve->type != public_key->curve->type) {
     return MINITLS_ERR_WRONG_CURVE;
  }

  /* make new point */
  if (mp_init_multi(&result.x, &result.y, &result.z, NULL) != MINITLS_OK) {
     return MINITLS_ERR_MEMORY;
  }

  /*if ((err = */mp_init(&prime);/*) != MINITLS_OK) {
     mp_clear_multi(&result->x, &result->y, &result->z, NULL);
     return err;
  }*/

  if ((err = mp_read_radix(&prime, (char *)private_key->pub.curve->prime, 16)) != MINITLS_OK)                               { goto done; }
  if ((err = ltc_ecc_mulmod(&private_key->privkey, &public_key->pubkey, &result, &prime, 1)) != MINITLS_OK)                { goto done; }

  x = (unsigned long)mp_unsigned_bin_size(&prime);
  if (max_secret_size < x) {
     *secret_size = x;
     err = MINITLS_ERR_BUFFER_TOO_SMALL;
     goto done;
  }
  zeromem(secret, x);
  /*if ((err =*/ mp_to_unsigned_bin(&result.x, secret + (x - mp_unsigned_bin_size(&result.x)));/*)   != MINITLS_OK)           { goto done; }*/

  err     = MINITLS_OK;
  *secret_size = x;
done:
  mp_clear(&prime);
  mp_clear_multi(&result.x, &result.y, &result.z, NULL);
  return err;
}

//Decode (&r,&s) integers from ASN.1 Signature
#define ENSURE_SIZE(actual_size, min_size) do{ if( (actual_size) < (min_size) ) { return MINITLS_ERR_PARAMETERS; } }while(0)
minitls_err_t crypto_ecc_dsa_check_get_asn1_rs(void* r, void* s, const uint8_t* signature, size_t signature_size)
{
  const uint8_t* p = signature;
  size_t sz = signature_size;

  ENSURE_SIZE(sz, 1);

  if( (p[0] != 0x30) && (p[0] != 0x31) ) //Sequence, SET types
  {
    return MINITLS_ERR_PARAMETERS;
  }

  p++;
  sz--;

  ENSURE_SIZE(sz, 1);

  size_t seq_size;
  //Get sequence length
  if(*p < 0x80)
  {
    seq_size = p[0];
    p++;
    sz--;
  }
  else if(*p == 0x81)
  {
    ENSURE_SIZE(sz, 2);
    seq_size = p[1];
    p+=2;
    sz-=2;
  }
  else if(*p == 0x82)
  {
    ENSURE_SIZE(sz, 3);
    seq_size = (p[1] << 8) | p[2];
    p+=3;
    sz-=3;
  }
  else if(*p == 0x83)
  {
    ENSURE_SIZE(sz, 4);
    seq_size = (p[1] << 16) | (p[2] << 8) | p[3];
    p+=4;
    sz-=4;
  }
  else if(*p == 0x84)
  {
    ENSURE_SIZE(sz, 5);
    seq_size = (p[1] << 24) |(p[2] << 16) | (p[3] << 8) | p[4];
    p+=5;
    sz-=5;
  }
  else
  {
    return MINITLS_ERR_PARAMETERS;
  }

  //Check that sequence size == remaining bytes size
  if( seq_size != sz )
  {
    return MINITLS_ERR_PARAMETERS;
  }

  //Read integers
  for(int i = 0; i < 2; i++)
  {
    ENSURE_SIZE(sz, 1);

    if( p[0] != 2 ) //Integer type
    {
      return MINITLS_ERR_PARAMETERS;
    }

    p++;
    sz--;

    ENSURE_SIZE(sz, 1);

    size_t integer_size;
    //Get sequence length
    if(*p < 0x80)
    {
      integer_size = p[0];
      p++;
      sz--;
    }
    else if(*p == 0x81)
    {
      ENSURE_SIZE(sz, 2);
      integer_size = p[1];
      p+=2;
      sz-=2;
    }
    else if(*p == 0x82)
    {
      ENSURE_SIZE(sz, 3);
      integer_size = (p[1] << 8) | p[2];
      p+=3;
      sz-=3;
    }
    else if(*p == 0x83)
    {
      ENSURE_SIZE(sz, 4);
      integer_size = (p[1] << 16) | (p[2] << 8) | p[3];
      p+=4;
      sz-=4;
    }
    else if(*p == 0x84)
    {
      ENSURE_SIZE(sz, 5);
      integer_size = (p[1] << 24) |(p[2] << 16) | (p[3] << 8) | p[4];
      p+=5;
      sz-=5;
    }
    else
    {
      return MINITLS_ERR_PARAMETERS;
    }

    //Check that we have enough bytes remaining
    ENSURE_SIZE(sz, integer_size);

    //Read integer
    void* integer = (i==0)?r:s;

    /*int err;*/
    /*if ((err = */mp_read_unsigned_bin(integer, (unsigned char *)p, integer_size);/*) != MINITLS_OK) {
       return err;
    }*/

    p+=integer_size;
    sz-=integer_size;
  }

  if(sz > 0)
  {
    //Unread parameters left in sequence
    return MINITLS_ERR_PARAMETERS;
  }

  return MINITLS_OK;
}

//List of curves -- storing in strings is not optimal, TODO will have to be addressed at some point
static const crypto_ecc_curve_t crypto_ecc_curves[] = {
#if CRYPTO_ECC160
{
        20,
        secp160r1,
        "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF7FFFFFFF",
        "1C97BEFC54BD7A8B65ACF89F81D4D4ADC565FA45",
        "0100000000000000000001F4C8F927AED3CA752257",
        "4A96B5688EF573284664698968C38BB913CBFC82",
        "23A628553168947D59DCC912042351377AC5FB32",
},
#endif
#if CRYPTO_ECC192
{
        24,
        secp192r1,
        "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFFFFFFFFFFFF",
        "64210519E59C80E70FA7E9AB72243049FEB8DEECC146B9B1",
        "FFFFFFFFFFFFFFFFFFFFFFFF99DEF836146BC9B1B4D22831",
        "188DA80EB03090F67CBF20EB43A18800F4FF0AFD82FF1012",
        "7192B95FFC8DA78631011ED6B24CDD573F977A11E794811",
},
#endif
#if CRYPTO_ECC224
{
        28,
        secp224r1,
        "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF000000000000000000000001",
        "B4050A850C04B3ABF54132565044B0B7D7BFD8BA270B39432355FFB4",
        "FFFFFFFFFFFFFFFFFFFFFFFFFFFF16A2E0B8F03E13DD29455C5C2A3D",
        "B70E0CBD6BB4BF7F321390B94A03C1D356C21122343280D6115C1D21",
        "BD376388B5F723FB4C22DFE6CD4375A05A07476444D5819985007E34",
},
#endif
#if CRYPTO_ECC256
{
        32,
        secp256r1,
        "FFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFF",
        "5AC635D8AA3A93E7B3EBBD55769886BC651D06B0CC53B0F63BCE3C3E27D2604B",
        "FFFFFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9CAC2FC632551",
        "6B17D1F2E12C4247F8BCE6E563A440F277037D812DEB33A0F4A13945D898C296",
        "4FE342E2FE1A7F9B8EE7EB4A7C0F9E162BCE33576B315ECECBB6406837BF51F5",
},
#endif
#if CRYPTO_ECC384
{
        48,
        secp384r1,
        "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFFFF0000000000000000FFFFFFFF",
        "B3312FA7E23EE7E4988E056BE3F82D19181D9C6EFE8141120314088F5013875AC656398D8A2ED19D2A85C8EDD3EC2AEF",
        "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFC7634D81F4372DDF581A0DB248B0A77AECEC196ACCC52973",
        "AA87CA22BE8B05378EB1C71EF320AD746E1D3B628BA79B9859F741E082542A385502F25DBF55296C3A545E3872760AB7",
        "3617DE4A96262C6F5D9E98BF9292DC29F8F41DBD289A147CE9DA3113B5F0B8C00A60B1CE1D7E819D7A431D7C90EA0E5F",
},
#endif
#if CRYPTO_ECC521
{
        66,
        secp521r1,
        "1FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF",
        "51953EB9618E1C9A1F929A21A0B68540EEA2DA725B99B315F3B8B489918EF109E156193951EC7E937B1652C0BD3BB1BF073573DF883D2C34F1EF451FD46B503F00",
        "1FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFA51868783BF2F966B7FCC0148F709A5D03BB5C9B8899C47AEBB6FB71E91386409",
        "C6858E06B70404E9CD9E3ECB662395B4429C648139053FB521F828AF606B4D3DBAA14B5E77EFE75928FE1DC127A2FFA8DE3348B3C1856A429BF97E7E31C2E5BD66",
        "11839296A789A3BC0045C8A5FB42C7D1BD998F54449579B446817AFBD17273E662C97EE72995EF42640C550B9013FAD0761353C7086A272C24088BE94769FD16650",
},
#endif
{0,},
};
