#include "mbed.h"
#include "SerialInterfaceProtocol.h"
#include "CommandPacket2.h"

SerialInterfaceProtocol::SerialInterfaceProtocol(SerialBuffer_t *in, SerialBuffer_t *out)
{
    // assign input and output buffers
    SerialInputBuffer = in;
    SerialOutputBuffer = out;
    
    // init command vector table
    for (int i = 0; i < SIP_CMD_VECTOR_TABLE_SZ; i++)
    {
        CommandVectorTable[i] = NULL;
    }
    
    // init internal state machine
    state = NONE;
    
    // init internal state
    isChecksumEnabled = false;
}

SerialInterfaceProtocol::~SerialInterfaceProtocol()
{
    
}

void SerialInterfaceProtocol::registerCommand(uint8_t command, callback_func f)
{
    CommandVectorTable[command] = f;
}

void SerialInterfaceProtocol::deRegisterCommand(uint8_t command)
{
    CommandVectorTable[command] = NULL;
}

int SerialInterfaceProtocol::execute(uint8_t *response, uint8_t *response_length)
{
    // read command from packet buffer
    uint8_t command = PacketBuffer.getCommand();
    
    // execute the command if it's already been registered
    if (CommandVectorTable[command] != NULL)
    {
        // extract length and payload from packet buffer
        uint8_t length = PacketBuffer.getLength();
        uint8_t payload[length];
        memset(payload, 0x0, length);
        
        for (uint8_t i = 0; i < length; i++)
        {
            payload[i] = PacketBuffer.getPayload(i);
        }
        
        return CommandVectorTable[command](
            payload, 
            length,
            response, 
            response_length
        );
    }
    
    return -1;
}

int SerialInterfaceProtocol::assemble(uint8_t *response, uint8_t response_length)
{
    // if no response, return E, otherwise return F
    if (response_length != 0)
    {
        respond(
            0xD0 | ((uint8_t) errno & 0x0F), 
            response, 
            response_length
        );
    }
    else
    {
        respond(
            0xE0 | ((uint8_t) errno & 0x0F), 
            response, 
            response_length
        );
    }
    
    return 0;   
}

int SerialInterfaceProtocol::respond(uint8_t command, uint8_t *response, uint8_t response_length)
{
    // create a new packet buffer
    CommandPacket2 responsePacket;
    
    // set sflag
    responsePacket.setSFlag();  // use default sflag
    
    // return flag as specified by the user
    responsePacket.setCommand(command);
    
    // copy buffer length
    responsePacket.setLength(response_length);
    
    // copy buffer to payload
    for (uint8_t i = 0; i < response_length; i++)
    {
        responsePacket.setPayload(i, response[i]);
    }
    
    // generate checksum
    responsePacket.generateChecksum();
    
    // end flag
    responsePacket.setEFlag();
    
    // serialize the packet
    int total_len = response_length * 2 + 8;
    uint8_t buffer[total_len];
    memset(buffer, 0x0, total_len);
    
    responsePacket.serialize(buffer);
    
#ifdef DEBUG_SIP
    printf("SIP::respond:total: %d chars\r\n", total_len);
#endif

    // add to ring buffer
    for (int i = 0; i < total_len; i++)
    {
#ifdef DEBUG_SIP
        printf("SIP::respond:0x%x (%c)\r\n", buffer[i], buffer[i]);
#endif
        SerialOutputBuffer->enqueue(buffer[i]);
    }
    
    return 0;
};
    

void SerialInterfaceProtocol::poll()
{
    static uint8_t payload_counter = 0;
    
    // temp variable for building full byte from hex strings
    static uint8_t command = 0;
    static uint8_t length = 0;
    static uint8_t payload = 0;
    static uint8_t checksum = 0;
    
    uint8_t response[SIP_MAX_RESP_LEN];
    uint8_t response_length;
    
    // fetch data from ring buffer
    while (SerialInputBuffer->getCounter() > 0)
    {
        uint8_t ch;
        ch = SerialInputBuffer->dequeue();
        
        // reset state to keep sync
        if (ch == CommandPacket2::CP_SFLAG)
        {
            state = SFLAG;
            
            // reset variable
            payload_counter = 0;
            errno = NO_ERROR;
        }
        
        switch (state)
        {
            case SFLAG:
                PacketBuffer.setSFlag(ch);
                state = COMMAND_H;
#ifdef DEBUG_SIP
                printf("SIP::SFLAG: 0x%x\r\n", PacketBuffer.getSFlag());
#endif
                break;
                
            case COMMAND_H:
                command = hexchar_to_uint8(ch) << 4;
                state = COMMAND_L;
#ifdef DEBUG_SIP
                printf("SIP::COMMAND_H: 0x%x\r\n", command);
#endif
                break;
                
            case COMMAND_L:
                command |= (hexchar_to_uint8(ch) & 0x0f);
                
                // store command
                PacketBuffer.setCommand(command);
                
                state = LENGTH_H;
#ifdef DEBUG_SIP
                printf("SIP::COMMAND_L: 0x%x\r\n", command);
#endif
                break;
                
            case LENGTH_H:
                length = hexchar_to_uint8(ch) << 4;
                state = LENGTH_L;
#ifdef DEBUG_SIP
                printf("SIP::LENGTH_H: 0x%x\r\n", length);
#endif
                break;
                
            case LENGTH_L:
                length |= (hexchar_to_uint8(ch) & 0x0f);
                
                // store length
                PacketBuffer.setLength(length);
                
                if (length != 0) // if the length is not zero, then proceed to payload state
                {
                    state = PAYLOAD_H;
                }
                else // otherwise proceed to checksum state
                {
                    state = CHECKSUM_H;
                }
#ifdef DEBUG_SIP
                printf("SIP::LENGTH_L: 0x%x\r\n", length);
#endif
                break;
            
            case PAYLOAD_H:
                payload = hexchar_to_uint8(ch) << 4; // store higher 4 bits of payload
                state = PAYLOAD_L;
#ifdef DEBUG_SIP
                printf("SIP::PAYLOAD_H: 0x%x\r\n", payload);
#endif
                break;
                
            case PAYLOAD_L:
                payload |= (hexchar_to_uint8(ch) & 0x0f); // store lower 4 bits of payload
                
                // store payload
                PacketBuffer.setPayload(payload_counter++, payload);
                
                if (payload_counter < PacketBuffer.getLength()) // append ch to payload until reach the length
                {
                    state = PAYLOAD_H;
                }
                else
                {
                    state = CHECKSUM_H;
                }
#ifdef DEBUG_SIP
                printf("SIP::PAYLOAD_L: 0x%x\r\n", payload);
#endif
                break;
                
            case CHECKSUM_H:
                checksum = hexchar_to_uint8(ch) << 4;
                state = CHECKSUM_L;
#ifdef DEBUG_SIP
                printf("SIP::CHECKSUM_H: 0x%x\r\n", checksum);
#endif
                break;
                
            case CHECKSUM_L:
                checksum |= (hexchar_to_uint8(ch) & 0x0f);
                
                // store checksum
                PacketBuffer.setChecksum(checksum);
                
                // checksum can be turned off
                if (isChecksumEnabled)
                {
                    if (PacketBuffer.verify()) // checksum match
                    {
                        state = EFLAG;
                    }
                    else // checksum mismatch
                    {
                        // clear response and response length
                        response_length = 0;
                        memset(response, 0x0, sizeof(response));
                        
                        // prepare for checksum error response
                        errno = INVALID_CS_ERROR;
                        assemble(response, response_length);
                        
                        state = NONE;
                    }
                }
                else
                {
                    state = EFLAG;
                }
                
                
#ifdef DEBUG_SIP
                printf("SIP::CHECKSUM_L: 0x%x\r\n", checksum);
#endif
                break;
                
            case EFLAG:
                if (ch == CommandPacket2::CP_EFLAG)
                {
                    PacketBuffer.setEFlag(ch);
                    
                    // clear response and response length
                    response_length = 0;
                    memset(response, 0x0, sizeof(response));
                    
                    // execute command
                    int ret = execute(response, &response_length);
                    
                    if (ret < 0) // command not registered
                    {
                        errno = INVALID_CMD_ERROR;
                    }
                    else if (ret != 0) // error to execute
                    {
                        errno = INVALID_EXEC_ERROR;
                    }
                    else
                    {
                        errno = NO_ERROR;
                    }
                    
                    assemble(response, response_length);
                }
                state = NONE;
#ifdef DEBUG_SIP
                printf("SIP::EFLAG: 0x%x\r\n", PacketBuffer.getEFlag());
#endif
                break;
                
            case NONE:
                // clear response and response length
                response_length = 0;
                memset(response, 0x0, sizeof(response));
                
                // Execute error generator
                errno = INVALID_SFLAG_ERROR;
                assemble(response, response_length);
                
#ifdef DEBUG_SIP
                printf("SIP::NONE\r\n");
#endif
                break;
                
            default:
                break;
        }
    }
}

void SerialInterfaceProtocol::disableChecksum()
{
    isChecksumEnabled = false;
}

void SerialInterfaceProtocol::enableChecksum()
{
    isChecksumEnabled = true;
}

uint8_t hexchar_to_uint8(uint8_t ch)
{
    uint8_t val = 0;

    if (ch >= '0' && ch <= '9')
    {
        val = ch - '0';
    }
    else if (ch >= 'A' && ch <= 'F')
    {
        val  = ch - 'A';
        val += 10;
    }
    else if (ch >= 'a' && ch <= 'f')
    {
        val = ch - 'a';
        val += 10;
    }

    return val;
}
