#include "mbed.h"

#include <vector>
 
Serial       pc(USBTX, USBRX); // tx, rx
Serial       c2c(p13, p14);    // tx, rx
DigitalInOut   nREQ(p15);
DigitalInOut   nRDY(p16);


// Fletcher checksum calculation, not optimised
static uint16_t checkSum(uint8_t length_byte, uint8_t* buf){
    // length is treated as 16-bit for checksum calculation
    uint16_t sum1 = length_byte % 255;
    uint16_t sum2 = sum1;
 
    for (int i = 0; i < length_byte; i++) {
        sum1 = (sum1 + buf[i]) % 255;
        sum2 = (sum2 + sum1) % 255;
    }
 
    return ((sum2 << 8) | sum1);
}

void sendTestPacket(){
    pc.printf("\n\rTX: wait for nRDY... ");
    
    nREQ = 0;
    nREQ.mode(PinMode(OpenDrain));
    nREQ.output();
    
    nRDY = 1;
    nRDY.input();
    
    // wait for nRDY
    while(nRDY == 1)
        wait(0.1);

    pc.printf("\n\rgot nRDY\n\r");
    const uint8_t len = 0;
    
    c2c.putc(len);
    while(!c2c.writeable())
        ;
    pc.printf("wrote len=%d\n\r", len);
    uint8_t data[len==0? 256 : len];
    for(int i = 0; i < (len==0? 256 : len); i++){
        data[i] = uint8_t((i + 31) & 0xff);
        c2c.putc(data[i]);
        pc.printf(".");
        while(!c2c.writeable())
            ;
    }
    pc.printf("\n\rwrote data\n\r", len);
    uint16_t csum = checkSum(len, data);
    c2c.putc(uint8_t(csum & 0xff));
    while(!c2c.writeable())
        ;
    c2c.putc(uint8_t((csum >> 8) & 0xff));
    pc.printf("\n\rwrote checksum\n\r", len);
}

void receiveTestPacket(){
    pc.printf("RX: wait for nREQ...");
    
    nREQ = 0;
    nRDY = 0;
    nREQ.mode(PinMode(OpenDrain | PullUp));
    nRDY.mode(PinMode(OpenDrain | PullUp));
    nREQ.output();
    nRDY.output();    
    
//    nREQ.input();
//    nRDY = 1;
//    nRDY.mode(PinMode(OpenDrain));
    
    while(nREQ == 0)
        wait(0.1);
        
    pc.printf("\n\rgot REQ\n\r");
    nRDY = 1;
    int len = c2c.getc();
    std::vector<uint8_t> msg;
    for(int i = 0; i < len; i++)
        msg.push_back(c2c.getc());
    
    uint16_t csum = 0;
    for(int i =0; i < 2; i++)
        csum |= (uint16_t(c2c.getc()) << (i*8));
    
    
    pc.printf("len=%d\n\r", len);
    for(int i = 0; i < len; i++)
        pc.printf("0x%x ", msg[i]);
    pc.printf("\n\rcsum : 0x%x\n\r", csum);
    
    uint16_t check_checksum = checkSum(len, &msg[0]);
    pc.printf("check: 0x%x\n\r", check_checksum);
    
    nRDY = 0;
}

int main() {
    pc.printf("\n\r--------\n\rc2c UART Test\n\r");
    
    c2c.baud(38400);
    
//    while(1) {
//        if(c2c.readable()) {
//            pc.putc(c2c.getc());
//        }
//    }
    
    if(c2c.readable())
        pc.printf("Leftover characters:");
    while(c2c.readable())
        pc.printf("0x%x ", int(c2c.getc()));
    pc.printf("\n\r");
    
    
    // receive rx_count things
    int rx_count = 1;
    while(rx_count > 0){
        receiveTestPacket();
    }
    
//    sendTestPacket();
//    
    pc.printf("\n\r--- complete ---\n\r");
}