#include "mbed.h"
#include "comms.h"

void le2be_u16(uint8_t* buffer, uint16_t in_word) {  
        *buffer++ = (uint8_t)((in_word & 0xFF00) >> 8);
        *buffer = (uint8_t)(in_word & 0x00FF);
}

/**
 * This function is used to calculate a one's complement based checksum
 * as is commonly used in IP based protocols
 */
uint16_t chksum_u16(uint32_t in_word) {
    uint32_t temp = in_word & 0x0000FFFF;
    temp += (in_word & 0xFFFF0000)>>16;
    return (uint16_t)( (temp & 0x0000FFFF) + ((temp & 0xFFFF0000)>>16) );
}

/**
 * Transport constructor.
 * Set up the pseudo header and packet to be used in communications
 */
Transport::Transport (uint32_t src_addr, uint32_t dst_addr, uint32_t ID) {
    // Set up IPv4 pseudo header for checksum calculation
    pseudo_header_.src_addr = src_addr;
    pseudo_header_.dst_addr = dst_addr;
    pseudo_header_.protocol = PROTOCOL;
    pseudo_header_.length = sizeof(header_);
    // Set up default ports
    header_.src_port = PORT_BASE + ID;
    header_.dst_port = PORT_BASE;
    header_.length = sizeof(header_);
    // Calculate checksum of the pseudo-header: this should save some future recalculation
    uint32_t w_checksum = chksum_u16(src_addr);
    w_checksum = chksum_u16(dst_addr+w_checksum);
    w_checksum = chksum_u16(pseudo_header_.protocol+w_checksum);
    pseudo_header_.checksum = chksum_u16(pseudo_header_.length+w_checksum);
    // Now calculate checksum of the UDP packet header + pseudo-header
    w_checksum = chksum_u16(header_.src_port+pseudo_header_.checksum);
    w_checksum = chksum_u16(header_.dst_port+w_checksum);
    header_.checksum = chksum_u16(header_.length+w_checksum);
    // Load up the packet
    le2be_u16((packet_+0),header_.src_port);
    le2be_u16((packet_+2),header_.dst_port);
    le2be_u16((packet_+4),header_.length);
    le2be_u16((packet_+6),header_.checksum);
}

/**
 * Provides a pointer to a buffer containing a formatted packet and
 * the total length of said packet.
 */
void Transport::get_packet (uint8_t** buffer, uint16_t* length) {
    *buffer = packet_;
    *length = header_.length;
}

/**
 * Loads the transport packet with an array of 16bit samples.
 * This will require translating the words into big-endian.
 */
void Transport::load_data (uint16_t* buffer, uint16_t length) {
    uint8_t* pkt_ptr = packet_ + sizeof(header_);
    for (uint16_t i=0; i<length; i++) {
        le2be_u16(pkt_ptr,*buffer++);
        pkt_ptr += 2;
    }
    // Set the header length (in octets/bytes)
    header_.length = sizeof(header_) + (length<<1);
    le2be_u16((packet_+4),header_.length);
    // For now set the checksum to IGNORE
    header_.checksum = 0x0000;
    le2be_u16((packet_+6),header_.checksum);
}

/**
 * Set the destination port.
 */
void Transport::set_dst_port (uint16_t port) {
    // Store the port both in the local struct...
    header_.dst_port = port;
    // ...and in the transmission packet
    le2be_u16((packet_+2),port);
    //packet_[2] = (uint8_t)((port & 0xFF00) >> 8);
    //packet_[3] = (uint8_t)(port & 0x00FF);
}