#include "mbed.h"
#include "EthernetInterface.h"
#include "rtos.h"

#define START_ADDRESS   0x80000
#define NUM_SECTORS     120
#define SECTOR_SIZE     4096
#define BUFFER_SIZE     1024


/*************************************************************/
//               Bootloader functions
/*************************************************************/
int *(*program_flash_boot)(int, char*, unsigned int) = (int *(*)(int, char*, unsigned int))0x795C9;
int *(*erase_sector_boot)(int) = (int *(*)(int))0x79465;
void *(*bootloader)(int) = (void *(*)(int))0x79121;


/*************************************************************/
//           Retrieving binary from server
/*************************************************************/
void write_flash(void)
{

    printf("Erasing flash!\r\n");
    for (int i = 0; i < NUM_SECTORS; i++)
        erase_sector_boot(START_ADDRESS + i * SECTOR_SIZE);

    printf("Connecting ethernet\r\n");
    EthernetInterface eth;
    eth.init(); //Use DHCP
    eth.connect();
    printf("IP Address is %s\r\n", eth.getIPAddress());
    
    
    TCPSocketConnection sock;
    
    //----------------YOUR SERVER DETAILS-----------------------//
    sock.connect("192.168.0.18", 80);
    char http_cmd[] = "GET /loader_test.bin HTTP/1.1\r\nHost: 192.168.0.18 80\r\n\r\n";
    //----------------YOUR SERVER DETAILS-----------------------//
    
    sock.send_all(http_cmd, sizeof(http_cmd)-1);
    
    char buffer[BUFFER_SIZE];
    int received;
    int binsize;
    int bufcount;
    int remaining;
    int count = 0;
    
    //Receive first packet
    received = sock.receive(buffer, sizeof(buffer));
    if (received <= 0) {
        printf("No data received from server\r\n");
        while(1);
    }
    
    //Search for "Content-Length", if not available, receive more until buffer is full    
    while(strstr(buffer, "Content-Length: ") == 0) {
        if (received == sizeof(buffer)) {
            printf("Could not determine size of bin file\r\n");
            while(1);
        } else {
            received += sock.receive(buffer + received, sizeof(buffer) - received);   
        }
    }
    //Determine size of the file
    char *temp = strstr(buffer, "Content-Length: ") + 16;   //'16' is size of "Content-Length: "
    sscanf(temp, "%d", &binsize); 
    printf("Size of the binary = %d bytes\r\n", binsize);   

    //Search for "\r\n\r\n" (beginning of bin file), if not available, receive more until buffer is full    
    while(strstr(buffer, "\r\n\r\n") == 0) {
        if (received == sizeof(buffer)) {
            printf("Could not find start of bin file\r\n");
            while(1);
        } else {
            received += sock.receive(buffer+received, sizeof(buffer) - received);   
        }
    }
    //Get pointer to begin of the file in the buffer
    temp = strstr(buffer, "\r\n\r\n") + 4;   //'16' is size of "\r\n\r\n"
    
    //See how much of the bin file we already received, and move this to the start of the buffer
    bufcount = received - ((uint32_t)temp - (uint32_t)buffer);
    memmove(buffer, temp, bufcount);
    printf("Received %d bytes\r\n", bufcount);
    
    //Start receiving the remaining bin file
    remaining = binsize - bufcount;
        
    while (remaining > 0) {
        //Completely fill the buffer each time so we can easily write it to flash
        while (bufcount < sizeof(buffer)) {
            //Try to receive remainder of the buffer
            received = sock.receive(&buffer[bufcount], sizeof(buffer)-bufcount);
            printf("Received %d\r\n", received);
            if (received <= 0) {
                printf("Error, should not happen\r\n");
                while(1);
            }
            
            //Track how much we received and how much is left
            bufcount += received;
            remaining -= received;
            if (remaining == 0) {
                if (program_flash_boot(count+START_ADDRESS, buffer, sizeof(buffer)) != 0) {
                    printf("Error @ 0x%X!\r\n", count);
                    while(1);
                }
                count += sizeof(buffer);
                break;
            }
        }
        //Buffer is full, program it and increase the counter (the counter is a bit redundant, we could get it from the other variables)
        if (program_flash_boot(count+START_ADDRESS, buffer, sizeof(buffer)) != 0) {
            printf("Error @ 0x%X!\r\n", count);
            while(1);
        }
        count += sizeof(buffer);
        bufcount = 0;
    }
    printf("Done\r\n");
    sock.close();
    
    eth.disconnect();

    bootloader(count);

}