/*
 * Copyright 2010-2015 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License").
 * You may not use this file except in compliance with the License.
 * A copy of the License is located at
 *
 *  http://aws.amazon.com/apache2.0
 *
 * or in the "license" file accompanying this file. This file is distributed
 * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
 * express or implied. See the License for the specific language governing
 * permissions and limitations under the License.
 */

#include <stdbool.h>
#include <string.h>

#include "aws_iot_config.h"
#include "aws_iot_error.h"
#include "aws_iot_log.h"
#include "network_interface.h"
#include "mbedtls/config.h"

#include "mbedtls/net.h"
#include "mbedtls/ssl.h"
#include "mbedtls/entropy.h"
#include "mbedtls/ctr_drbg.h"
#include "mbedtls/certs.h"
#include "mbedtls/x509.h"
#include "mbedtls/error.h"
#include "mbedtls/debug.h"
#include "mbedtls/timing.h"
#include "mbedtls/net_sockets.h"
#include "pem.h"

#include "platform.h"
#include "WNCTCPSocketConnection.h"

#ifdef USING_AVNET_SHIELD
// Used for BIO connections
extern WNCTCPSocketConnection* _tcpsocket;
#endif

// SD File System
#include "SDFileSystem.h"

// SD defines
#define CERT_MAX_SIZE 4096

// SD file pointer/buffer
FILE *fp;
char fp_buffer[CERT_MAX_SIZE];

// From main.cpp
extern char HostAddress[255];
extern char MqttClientID[32];
extern char ThingName[32];
extern char PortString[4];

/*
 * This is a function to do further verification if needed on the cert received
 */
static int myCertVerify(void *data, mbedtls_x509_crt *crt, int depth, uint32_t *flags) {
	char buf[1024];
	((void) data);

	DEBUG("\nVerify requested for (Depth %d):\n", depth);
	mbedtls_x509_crt_info(buf, sizeof(buf) - 1, "", crt);
	DEBUG("%s", buf);

	if ((*flags) == 0) {
		DEBUG("  This certificate has no flags\n");
	} else {
		DEBUG(buf, sizeof(buf), "  ! ", *flags); DEBUG("%s\n", buf);
	}

	return (0);
}

static int ret = 0, i;
static mbedtls_entropy_context entropy;
static mbedtls_ctr_drbg_context ctr_drbg;
static mbedtls_ssl_context ssl;
static mbedtls_ssl_config conf;
static uint32_t flags;
static mbedtls_x509_crt cacert;
static mbedtls_x509_crt clicert;
static mbedtls_pk_context pkey;
static mbedtls_net_context server_fd;

// Used to zero the given buffer
static void mbedtls_zeroize( char *v, size_t n ) {
	volatile char *p = v; while( n-- ) *p++ = 0;
}

// Parser sub function
int mqtt_parse_sub(std::string *search_str, char *param, const char *str_to_find)
{
	int index_start, index_end;
	mbedtls_zeroize(param, strlen(param));
	
	index_start = search_str->find(str_to_find);
	if (index_start < 0)
	    return -1;
	
    index_end = search_str->find("\n", index_start); 
    if (index_end < 0)
	    index_end = search_str->find("\0", index_start);
	    
	if (index_end < 0)
	    return -1;
    
    index_start += strlen(str_to_find);
    strcpy(param, search_str->substr(index_start, index_end-index_start-1).c_str());
    
    return 0;
} 

// Read MQTT config info
int mbedtls_mqtt_config_parse_file(ShadowParameters_t *sp, const char *path )
{
	int ret, size;
    mbedtls_zeroize(fp_buffer, CERT_MAX_SIZE);
    
    INFO("...Reading MQTT data from SD");
    fp = fopen(path, "r");   
    if (fp != NULL) {
        size = fread(fp_buffer, sizeof(char), CERT_MAX_SIZE, fp);
        DEBUG("...Number of data read: %d, text from file: %s", size, fp_buffer);
        fclose(fp);
    }
    else {
        ERROR("Could not open file: %s", path);
        return -1;
    }
    
    std::string filestr(fp_buffer);
        
    ret = mqtt_parse_sub(&filestr, HostAddress, "AWS_IOT_MQTT_HOST=");
    sp->pHost = HostAddress;
    INFO("...Host=%s", sp->pHost);
    if (ret < 0) {
        ERROR("Could not parse AWS_IOT_MQTT_HOST string.");
        return ret;
    }
    
    ret = mqtt_parse_sub(&filestr, PortString, "AWS_IOT_MQTT_PORT=");
    sp->port = atoi(PortString);
    INFO("...Port=%d", sp->port);
    if (ret < 0) {
        ERROR("Could not parse AWS_IOT_MQTT_PORT string.");
        return ret;
    }
    
    ret = mqtt_parse_sub(&filestr, MqttClientID, "AWS_IOT_MQTT_CLIENT_ID=");
    sp->pMqttClientId = MqttClientID;
    INFO("...pMqttClientId=%s", sp->pMqttClientId);
    if (ret < 0) {
        ERROR("Could not parse AWS_IOT_MQTT_CLIENT_ID string.");
        return ret;
    }
    
    ret = mqtt_parse_sub(&filestr, ThingName, "AWS_IOT_MY_THING_NAME=");
    sp->pMyThingName = ThingName;
    INFO("...pMyThingName=%s", sp->pMyThingName);
    if (ret < 0) {
        ERROR("Could not parse AWS_IOT_MY_THING_NAME string.");
        return ret;
    }

    return( ret );
}

// Override function: Parses CRT from SD
int mbedtls_x509_crt_parse_file( mbedtls_x509_crt *chain, const char *path )
{
    int ret, size;
    mbedtls_zeroize(fp_buffer, CERT_MAX_SIZE);
    
    INFO("...Reading CERT data from SD");
    fp = fopen(path, "r");
    if (fp != NULL) {
        size = fread(fp_buffer, sizeof(char), CERT_MAX_SIZE, fp);
        DEBUG("...Number of data read: %d, text from file: %s", size, fp_buffer);
        fclose(fp);
    }
    else {
        ERROR("Could not open file: %s", path);
        return -1;
    }

    DEBUG("...CRT Parse");
    ret = mbedtls_x509_crt_parse( chain, (unsigned char *)fp_buffer, size+2);
  
    return( ret );
}

// Override function: Parses KEY from SD
int mbedtls_pk_parse_keyfile( mbedtls_pk_context *ctx,
                              const char *path, const char *pwd )
{
    int ret, size;  
    mbedtls_zeroize(fp_buffer, CERT_MAX_SIZE);

    INFO("...Reading KEY data from SD");
    fp = fopen(path, "r");
    if (fp != NULL) {
        size = fread(fp_buffer, sizeof(char), CERT_MAX_SIZE, fp);
        DEBUG("...Number of data read: %d, text from file: %s", size, fp_buffer);
        fclose(fp);
    }
    else {
        ERROR("Could not open file: %s", path);
        return -1;
    }

    DEBUG("...Key Parse");
    if( pwd == NULL ) {
        DEBUG("...Using PWD");
        ret = mbedtls_pk_parse_key( ctx, (unsigned char *)fp_buffer, size+1, NULL, 0 );
    }
    else {
        DEBUG("...No PWD");
        ret = mbedtls_pk_parse_key( ctx, (unsigned char *)fp_buffer, size+1, (const unsigned char *) pwd, strlen( pwd ) );
    }
  
    return( ret );
}


/* personalization string for the drbg */
const char *DRBG_PERS = "mbed TLS helloword client";

int iot_tls_init(Network *pNetwork) {
	IoT_Error_t ret_val = NONE_ERROR;
	const char *pers = "aws_iot_tls_wrapper";
	unsigned char buf[MBEDTLS_SSL_MAX_CONTENT_LEN + 1];
	
	mbedtls_net_init(&server_fd);
	mbedtls_ssl_init(&ssl);
	mbedtls_ssl_config_init(&conf);
	mbedtls_ctr_drbg_init(&ctr_drbg);
	mbedtls_x509_crt_init(&cacert);
	mbedtls_x509_crt_init(&clicert);
	mbedtls_pk_init(&pkey);
	
	DEBUG("...Seeding the random number generator");
	mbedtls_entropy_init(&entropy);
	if ((ret = mbedtls_ctr_drbg_seed(&ctr_drbg, mbedtls_entropy_func, &entropy, (const unsigned char *) DRBG_PERS, sizeof (DRBG_PERS))) != 0) {			
		ERROR(" failed\n  ! mbedtls_ctr_drbg_seed returned -0x%x\n", -ret);
		return ret_val;
	} 	
	DEBUG(" ok\n");

	pNetwork->my_socket = 0;
	pNetwork->connect = iot_tls_connect;
	pNetwork->mqttread = iot_tls_read;
	pNetwork->mqttwrite = iot_tls_write;
	pNetwork->disconnect = iot_tls_disconnect;
	pNetwork->isConnected = iot_tls_is_connected;
	pNetwork->destroy = iot_tls_destroy;

	return ret_val;
}

int iot_tls_is_connected(Network *pNetwork) {
	/* Use this to add implementation which can check for physical layer disconnect */
	return 1;
}

int iot_tls_connect(Network *pNetwork, TLSConnectParams params) {
	const char *pers = "aws_iot_tls_wrapper";

	DEBUG("...Loading the CA root certificate");	
#ifdef USING_SD_CARD
	ret = mbedtls_x509_crt_parse_file(&cacert, AWS_IOT_ROOT_CA_FILENAME);
#else
	ret = mbedtls_x509_crt_parse(&cacert, (const unsigned char *)AWS_IOT_ROOT_CA, strlen ((const char *)AWS_IOT_ROOT_CA)+1);
#endif
	if (ret < 0) {
		ERROR(" failed\n  !  mbedtls_x509_crt_parse returned -0x%x\n\n", -ret);
		return ret;
	} 
	DEBUG(" ok (%d skipped)", ret);
    
    
	DEBUG("...Loading the client cert");
#ifdef USING_SD_CARD
	ret = mbedtls_x509_crt_parse_file(&clicert, AWS_IOT_CERTIFICATE_FILENAME);
#else
	ret = mbedtls_x509_crt_parse(&clicert, (const unsigned char *)AWS_IOT_CERTIFICATE, strlen ((const char *)AWS_IOT_CERTIFICATE)+1);
#endif
	if (ret != 0) {
		ERROR(" failed\n  !  mbedtls_x509_crt_parse returned -0x%x\n\n", -ret);
		return ret;
	}
	DEBUG(" ok");
	
	DEBUG("...Loading the client key");
#ifdef USING_SD_CARD
	ret = mbedtls_pk_parse_keyfile(&pkey, AWS_IOT_PRIVATE_KEY_FILENAME, "");
#else
	ret = mbedtls_pk_parse_key(&pkey, (const unsigned char *)AWS_IOT_PRIVATE_KEY, strlen ((const char *)AWS_IOT_PRIVATE_KEY)+1, NULL, 0 );	
#endif	
	if (ret != 0) {
		ERROR(" failed\n  !  mbedtls_pk_parse_key returned -0x%x\n\n", -ret);
		return ret;
	} 
	DEBUG(" ok");

	char portBuffer[6];
	sprintf(portBuffer, "%d", params.DestinationPort); 
	DEBUG("...Connecting to %s/%s", params.pDestinationURL, portBuffer);
	if ((ret = mbedtls_net_connect(&server_fd, params.pDestinationURL, portBuffer, MBEDTLS_NET_PROTO_TCP)) != 0) {
		ERROR(" failed\n  ! mbedtls_net_connect returned -0x%x\n\n", -ret);
		return ret;
	}

    
	ret = mbedtls_net_set_block(&server_fd);
	if (ret != 0) {
		ERROR(" failed\n  ! net_set_(non)block() returned -0x%x\n\n", -ret);
		return ret;
	} 
	DEBUG(" ok");
	

	DEBUG("...Setting up the SSL/TLS structure");
	if ((ret = mbedtls_ssl_config_defaults(&conf, MBEDTLS_SSL_IS_CLIENT, MBEDTLS_SSL_TRANSPORT_STREAM,
			MBEDTLS_SSL_PRESET_DEFAULT)) != 0) {
		ERROR(" failed\n  ! mbedtls_ssl_config_defaults returned -0x%x\n\n", -ret);
		return ret;
	}


	mbedtls_ssl_conf_verify(&conf, myCertVerify, NULL);
	if (params.ServerVerificationFlag == true) {
		mbedtls_ssl_conf_authmode(&conf, MBEDTLS_SSL_VERIFY_REQUIRED);
	} else {
		mbedtls_ssl_conf_authmode(&conf, MBEDTLS_SSL_VERIFY_OPTIONAL);
	}
	mbedtls_ssl_conf_rng(&conf, mbedtls_ctr_drbg_random, &ctr_drbg);

	mbedtls_ssl_conf_ca_chain(&conf, &cacert, NULL);
	if ((ret = mbedtls_ssl_conf_own_cert(&conf, &clicert, &pkey)) != 0) {
		ERROR(" failed\n  ! mbedtls_ssl_conf_own_cert returned %d\n\n", ret);
		return ret;
	}

	mbedtls_ssl_conf_read_timeout(&conf, params.timeout_ms);

	if ((ret = mbedtls_ssl_setup(&ssl, &conf)) != 0) {
		ERROR(" failed\n  ! mbedtls_ssl_setup returned -0x%x\n\n", -ret);
		return ret;
	}
	if ((ret = mbedtls_ssl_set_hostname(&ssl, params.pDestinationURL)) != 0) {
		ERROR(" failed\n  ! mbedtls_ssl_set_hostname returned %d\n\n", ret);
		return ret;
	}
	
	DEBUG("...Set Socket I/O Functions");
	mbedtls_ssl_set_bio(&ssl, static_cast<void *>(_tcpsocket), mbedtls_net_send, NULL, mbedtls_net_recv_timeout );
	DEBUG(" ok");
	

	DEBUG("...Performing the SSL/TLS handshake");
	while ((ret = mbedtls_ssl_handshake(&ssl)) != 0) {
		if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) {
			ERROR(" failed\n  ! mbedtls_ssl_handshake returned -0x%x\n", -ret);
			if (ret == MBEDTLS_ERR_X509_CERT_VERIFY_FAILED) {
				ERROR("    Unable to verify the server's certificate. "
						"Either it is invalid,\n"
						"    or you didn't set ca_file or ca_path "
						"to an appropriate value.\n"
						"    Alternatively, you may want to use "
						"auth_mode=optional for testing purposes.\n");
			}
			return ret;
		}
	}
	

	DEBUG(" ok\n    [ Protocol is %s ]\n    [ Ciphersuite is %s ]\n", mbedtls_ssl_get_version(&ssl), mbedtls_ssl_get_ciphersuite(&ssl));
	if ((ret = mbedtls_ssl_get_record_expansion(&ssl)) >= 0) {
		DEBUG("    [ Record expansion is %d ]\n", ret);
	} else {
		DEBUG("    [ Record expansion is unknown (compression) ]\n");
	}


	DEBUG("...Verifying peer X.509 certificate");
	if (params.ServerVerificationFlag == true) {
		if ((flags = mbedtls_ssl_get_verify_result(&ssl)) != 0) {
			char vrfy_buf[512];
			ERROR(" failed\n");
			mbedtls_x509_crt_verify_info(vrfy_buf, sizeof(vrfy_buf), "  ! ", flags);
			ERROR("%s\n", vrfy_buf);
		} else {
			DEBUG(" ok\n");
			ret = NONE_ERROR;
		}
	} else {
		DEBUG(" Server Verification skipped\n");
		ret = NONE_ERROR;
	}


    DEBUG("...SSL get peer cert");
	if (mbedtls_ssl_get_peer_cert(&ssl) != NULL) {
		DEBUG("...Peer certificate information");		
		const uint32_t buf_size = 1024;
        char *buf = new char[buf_size];
        mbedtls_x509_crt_info(buf,          buf_size,        "      ", mbedtls_ssl_get_peer_cert(&ssl));
		DEBUG("...Server certificate:\r\n%s\r", buf);
	}

	mbedtls_ssl_conf_read_timeout(&conf, 10);

	return ret;
}

int iot_tls_write(Network *pNetwork, unsigned char *pMsg, int len, int timeout_ms) {

	int written;
	int frags;
	
	for (written = 0, frags = 0; written < len; written += ret, frags++) {
		while ((ret = mbedtls_ssl_write(&ssl, pMsg + written, len - written)) <= 0) {
			if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) {
				ERROR(" failed\n  ! mbedtls_ssl_write returned -0x%x\n\n", -ret);
				return ret;
			}
		}
	}
	
	return written;
}

int iot_tls_read(Network *pNetwork, unsigned char *pMsg, int len, int timeout_ms) {
	int rxLen = 0;
	bool isErrorFlag = false;
	bool isCompleteFlag = false;

    // TODO check this against base
    //mbedtls_ssl_conf_read_timeout(&conf, timeout_ms);

	do {
		ret = mbedtls_ssl_read(&ssl, pMsg, len);
		if (ret > 0) {
			rxLen += ret;
		} else if (ret != MBEDTLS_ERR_SSL_WANT_READ) {
			isErrorFlag = true;
		}
		if (rxLen >= len) {
			isCompleteFlag = true;
		}
	} while (!isErrorFlag && !isCompleteFlag);

	return ret;
}

void iot_tls_disconnect(Network *pNetwork) {
	do {
		ret = mbedtls_ssl_close_notify(&ssl);
	} while (ret == MBEDTLS_ERR_SSL_WANT_WRITE);
}

int iot_tls_destroy(Network *pNetwork) {

	mbedtls_net_free(&server_fd);

	mbedtls_x509_crt_free(&clicert);
	mbedtls_x509_crt_free(&cacert);
	mbedtls_pk_free(&pkey);
	mbedtls_ssl_free(&ssl);
	mbedtls_ssl_config_free(&conf);
	mbedtls_ctr_drbg_free(&ctr_drbg);
	mbedtls_entropy_free(&entropy);

	return 0;
}


