#include "mbed.h"
#include <vector>

Serial pc(USBTX, USBRX);
Serial uart(p13, p14);
DigitalInOut nREQ(p15);
DigitalInOut nRDY(p16);

uint16_t ccitt(uint8_t len, uint8_t* buf);

uint16_t ccitt(uint8_t len, uint8_t* buf) {
    // length is masked to 8 bits (256=0) for checksum calculation, as it is in
    // the packet
    uint16_t crc = 0xffff ^ (len & 0xff);

    for (int i = 0; i < len; i++) {
        crc = (crc >> 8) | (crc << 8);
        crc ^= buf[i];
        crc ^= (crc & 0xFF) >> 4;
        crc ^= crc << 12;
        crc ^= (crc & 0xFF) << 5;
    }

    return crc;
}

int main() {
    pc.baud(38400);
    uart.baud(4800);

    pc.printf("\n\r--------\n\rUART Test\n\r");

    if (uart.readable()) {
        pc.printf("Leftover characters: ");
        while(uart.readable()) {
            pc.printf("%c", uart.getc());
        }
        pc.printf("\n\r");
    }    

    nREQ.input();
    nREQ.mode(PullUp);
    nRDY.input();
    nRDY.mode(PullUp);

    pc.printf("Waiting\n\r");

    while(nREQ != 0) {
//        wait(0.01);
//        pc.printf("nREQ: %d\n\r", nREQ.read());
    }

    pc.printf("nREQ has gone low\n\r");

    nRDY = 0;
    nRDY.output();

    uint8_t len = uart.getc();
    pc.printf("recieved len: %u\n\r", len);
    std::vector<uint8_t> payload;
    for (int i = 0; i < len; i++) {
        payload.push_back(uart.getc());
//        pc.printf("%d have recieved: %d\n\r", payload.size(), payload[i]);
    }
    
    pc.printf("No more stuff to load...\n\r");

    uint8_t checksum = 0;
    for (int i = 0; i < 2; i++) {
        checksum |= ((uint16_t) (uart.getc())) << (i * 8);
    }

    pc.printf("len: %d\n\r", len);
    pc.printf("payload: ");
    for (int i = 0; i < len; i++) {
        pc.printf("%c", payload[i]);
    }
    pc.printf("\n\r");
    pc.printf("checksum: 0x%X\n\r", checksum);

    uint16_t crc = ccitt(len, &payload[0]);
    pc.printf("ccitt: 0x%X\n\r", crc);

    nRDY.input();
    nRDY.mode(PullUp);

    pc.printf("Sending\n\r");

    uint8_t* str = (uint8_t*) "HelloHelloHelloHelloHello";
    uint8_t str_len = 5;
    // uint16_t str_checksum = ccitt(str_len, str);

    nREQ = 0;
    nREQ.output();

    while(nRDY != 0) {
        wait(0.1);
    }

    pc.printf("got nRDY\n\r");

    uart.putc(str_len);
    for (int i = 0; i < (str_len + 1); i++) {
        uart.putc(str[i]);
    }
    uart.putc(0);
    uart.putc(0);

    nREQ.input();
    nREQ.mode(PullUp);

    pc.printf("All done\n\r");

    while (1) {
        // if (uart.readable()) {
        //     pc.printf("%c", uart.getc());
        // }
    }
}