#include "niMQTT.h"

niMQTT::niMQTT(char *server, void (*callback)(const char*, const char*), char *id, int port, char *username, char *password, bool debug):
    server(server), port(port), id(id), callback(callback), username(username), password(password),
    debug(debug), connected(true), message_id(0), thread(&niMQTT::thread_starter, this),
    waiting_new_packet(true), packet_sent(false), waiting_connack(0), waiting_suback(0), waiting_pingresp(0) {
    init();
}

int niMQTT::init() {
    if (debug) printf("*init\r\n");
    socket = new TCPSocketConnection;
    do printf("Socket connection...\r\n"); while (socket->connect(server, port) < 0);
    socket->set_blocking(true, KEEP_ALIVE*500); // KEEP_ALIVE / 2 in seconds

    printf("Socket connected.\r\n");

    thread.signal_set(START_THREAD);

    return connect();
}

int niMQTT::send(char *packet, int size) {
    //if (debug) {
        printf("*send: ");
        for(int i=0; i<size; i++) printf("0x%x ", packet[i]);
        printf("\r\n");
    //}

    int j = -1;
    do j = socket->send_all(packet, size); while (j < 0);

    if (j != size) printf ("%d bytes sent (%d expected)...\r\n", j, size);
    else if (debug) printf("packet sent\r\n");
    packet_sent = true;

    return (j == size);
}

int niMQTT::recv() {
    if (debug) printf("*recv\r\n");

    int timeout = 0;
    while (!waiting_new_packet && timeout++ != TIMEOUT/100) wait(0.1);
    if (timeout >= TIMEOUT/100) {
        printf("RECV TIMEOUT\r\n");
        if (waiting_connack > 0) printf("CONNACK not received !\r\n");
        if (waiting_suback > 0) printf("SUBACK not received !\r\n");
        if (waiting_pingresp > 0) printf("PINGRESP not received !\r\n");
        reconnect();
        return -1;
    }

    if (debug) printf("Receiving new packet...\r\n");

    char header_received;
    socket->receive(&header_received, 1);
    if (debug) printf("Received 0x%x\r\n", header_received);

    waiting_new_packet = false;
    //bool DUP = ((header_received & 4) == 4);
    //int QoS = (header_received & 6);
    //bool RETAIN = ((header_received & 1) == 1);

    switch (header_received & 0xf0) {
        case CONNACK: connack(); break;
        case PUBLISH: publish_received(); break;
        case PUBACK: puback_received(); break;
        case SUBACK: suback(); break;
        case UNSUBACK: suback(true); break;
        case PINGRESP: pingresp(); break;
        default: waiting_new_packet = true; reconnect(); printf("BAD HEADER: 0x%x\r\n", header_received); return -1;
    }

    return 0;
}

int niMQTT::connect() {
    if (debug) printf("*connect\r\n");
    int username_length = strlen(username);
    int password_length = strlen(password);
    int id_length = strlen(id);

    int use_username = (username_length != 0);
    int use_password = (password_length != 0);

    char variable_header[] = {0,6,77,81,73,115,100,112,3,
        use_username << 7 | use_password << 6,
        KEEP_ALIVE / 256, KEEP_ALIVE % 256 };

    int remaining_length = 14 + id_length + username_length + password_length + 2*(use_username + use_password);
    int packet_length = 2 + remaining_length;

    char fixed_header[] = { CONNECT, remaining_length };

    char packet[packet_length];
    memcpy(packet, fixed_header, 2);
    memcpy(packet + 2, variable_header, 12);

    // Adds the payload: id
    char id_size[2] = { id_length / 256, id_length % 256 };
    memcpy(packet + 14, id_size, 2);
    memcpy(packet + 16, id, id_length);

    // Adds username & Password to the payload
    if (use_username) {
        char username_size[2] = { username_length / 256, username_length % 256 };
        memcpy(packet + 16 + id_length, username_size, 2);
        memcpy(packet + 18 + id_length, username, username_length);
    }
    if (use_password) {
        char password_size[2] = { password_length / 256, password_length % 256 };
        memcpy(packet + 18 + id_length + username_length, password_size, 2);
        memcpy(packet + 20 + id_length + username_length, password, password_length);
    }

    waiting_connack++;

    return send(packet, packet_length);
}

int niMQTT::connack() {
    if (debug) printf("CONNACK Received\r\n");
    if (waiting_connack > 0) waiting_connack--;
    else {
        printf("CONNACK UNEXPECTED !\r\n");
        reconnect();
        return -2;
    }

    char resp[3];
    socket->receive(resp, 3);
    waiting_new_packet = true;

    if (resp[0] != 0x2) {
        printf("Wrong second byte of CONNACK, get 0x%x instead of 0x2\r\n", resp[1]);
        reconnect();
        return -2;
    }
    switch (resp[2]) {
        case 0: printf("Connection Accepted\r\n"); break;
        case 1: printf("Connection Refused: unacceptable protocol version\r\n"); break;
        case 2: printf("Connection Refused: identifier rejected\r\n"); break;
        case 3: printf("Connection Refused: server unavailable\r\n"); break;
        case 4: printf("Connection Refused: bad user name or password\r\n"); break;
        case 5: printf("Connection Refused: not authorized\r\n"); break;
        default: printf("I have no idea what I am doing\r\n");
    }

    return (resp[2] == 0);
}

int niMQTT::pub(char *topic, char *message) {
    if (debug) printf("*pub\r\n");
    int topic_length = strlen(topic);
    int message_length = strlen(message);

    int remaining_length = topic_length + message_length + 2;
    int remaining_length_2 = remaining_length_length(remaining_length);
    int packet_length = 1 + remaining_length + remaining_length_2;

    char header = PUBLISH;
    char packet[packet_length];
    // header
    memcpy(packet, &header, 1);
    get_remaining_length(remaining_length, packet);

    // variable header: topic name
    char topic_size[2] = { topic_length / 256, topic_length % 256 };
    memcpy(packet + 1 + remaining_length_2, topic_size, 2);
    memcpy(packet + 3 + remaining_length_2, topic, topic_length);

    // payload: message
    memcpy(packet + 3 + remaining_length_2 + topic_length, message, message_length);

    return send(packet, packet_length);
}

void niMQTT::publish_received() {
    //remaining length
    int remaining_length = decode_remaining_length();

    // topic
    char mqtt_utf8_length[2];
    socket->receive(mqtt_utf8_length, 2);
    int utf8_length = mqtt_utf8_length[0] * 256 + mqtt_utf8_length[1];

    if (debug) printf("PUBLISH Received: %i, %i\r\n", remaining_length, utf8_length);

    char topic[utf8_length + 1];
    socket->receive(topic, utf8_length);
    topic[utf8_length] = 0;

    // payload
    int message_length = remaining_length - utf8_length - 2;
    char message[message_length + 1];
    socket->receive(message, message_length);
    message[message_length] = 0;

    waiting_new_packet = true;

    call_callback(topic, message);
}

int niMQTT::puback() {
    if (debug) printf("*puback\r\n");
    char fixed_header[] = { PUBACK, 2, message_id / 256, message_id % 256 };
    return send(fixed_header, 4);
}

int niMQTT::puback_received() {
    waiting_new_packet = true;
    return 0; // TODO
}

int niMQTT::sub(char *topic, bool unsub) {
    if (debug) printf("*sub\r\n");
    char command = (unsub) ? UNSUBSCRIBE : SUBSCRIBE;
    int topic_length = strlen(topic);

    int remaining_length = topic_length + 5;
    int remaining_length_2 = remaining_length_length(remaining_length);
    int packet_length = 1 + remaining_length + remaining_length_2;

    char header = command | LEAST_ONCE;
    char packet[packet_length];
    // header
    memcpy(packet, &header, 1);
    get_remaining_length(remaining_length, packet);

    // variable header: message identifier
    message_id++;
    char variable_header[] = { message_id / 256, message_id % 256 };
    memcpy(packet + 1 + remaining_length_2, variable_header, 2);

    // payload: topic name & requested QoS
    char topic_size[2] = { topic_length / 256, topic_length % 256 };
    char requested_qos = MOST_ONCE;
    memcpy(packet + 3 + remaining_length_2, topic_size, 2);
    memcpy(packet + 5 + remaining_length_2, topic, topic_length);
    memcpy(packet + 5 + remaining_length_2 + topic_length, &requested_qos, 1);

    waiting_suback++;

    return send(packet, packet_length);
}

int niMQTT::suback(bool unsub) {
    if (debug) printf("SUBACK received\r\n");
    if (waiting_suback > 0) waiting_suback--;
    else {
        printf("SUBACK UNEXPECTED !\r\n");
        reconnect();
        return -2;
    }

    //char command = (unsub) ? UNSUBACK : SUBACK;

    int remaining_length = decode_remaining_length();

    // Variable Header
    char var_resp[remaining_length];
    socket->receive(var_resp, remaining_length);
    waiting_new_packet = true;
    if (debug) {
        printf("suback: ");
        for (int j=0; j<remaining_length; j++) printf("0x%x ", var_resp[j]);
        printf("\r\n");
    }

    if (var_resp[0] * 256 + var_resp[1] != message_id) {
        printf("wrong message identifer in (UN)SUBACK, get %i instead of %i...\r\n", var_resp[0] * 256 + var_resp[1], message_id);
    }

    // here we should do things about the QoS if /unsuback, but let's say it's 0.

    return (var_resp[0] * 256 + var_resp[1] == message_id);
}

int niMQTT::pingreq () {
    if (debug) printf("*pingreq\r\n");
    char fixed_header[] = { PINGREQ, 0 };
    waiting_pingresp++;
    return send(fixed_header, 2);
}

int niMQTT::pingresp() {
    if (debug) printf("PINGRESP Received\r\n");
    if (waiting_pingresp > 0) waiting_pingresp--;
    else {
        printf("PINGRESP Unexpected !\r\n");
        reconnect();
        return -2;
    }

    char resp;
    socket->receive(&resp, 1);
    waiting_new_packet = true;

    if (resp != 0) {
        printf("Wrong second byte of PINGRESP, get 0x%x instead of 0x0\r\n", resp);
        reconnect();
        return -2;
    }

    return (resp == 0);
}

int niMQTT::disconnect() {
    if (debug) printf("*disconnect\r\n");
    char fixed_header[] = { DISCONNECT, 0 };
    return send(fixed_header, 2);
}

niMQTT::~niMQTT() {
    if (debug) printf("*~niMQTT()\r\n");
    connected = false;
    disconnect();
    socket->close();
    delete socket;
}

void niMQTT::reconnect() {
    if (debug) printf("Reconnecting...\r\n");
    disconnect();
    socket->close();
    
    do printf("Socket connection...\r\n"); while (socket->connect(server, port) < 0);
    socket->set_blocking(true, KEEP_ALIVE*500); // KEEP_ALIVE / 2 in seconds
    
    connect();
}

void niMQTT::thread_starter(void const *p) {
    niMQTT *instance = (niMQTT*)p;
    instance->thread_worker();
}

void niMQTT::thread_worker() {
    if (debug) printf("*thread_worker\r\n");
    thread.signal_wait(START_THREAD);
    while (connected) {
        if (debug) printf("New loop in thread worker\r\n");
        recv();
        Thread::wait(KEEP_ALIVE*100); // KEEP_ALIVE / 10 in seconds
        if (!packet_sent) pingreq();
        packet_sent = false;
    }
}

void niMQTT::get_remaining_length(int remaining_length, char *packet) {
    int X = remaining_length;
    int n = 1;
    char digit;
    do {
        digit = X % 0x80;
        X /= 0x80;
        if (X > 0) digit |= 0x80;
        memcpy(packet + n, &digit, 1);
        n++;
    } while (X > 0);
}

int niMQTT::decode_remaining_length() {
    int multiplier = 1;
    int value = 0;
    char digit = 0;
    do {
        while (socket->receive(&digit, 1) < 0) wait(0.1);
        value += (digit & 127) * multiplier;
        multiplier *= 128;
    } while ((digit & 0x80) != 0);
    return value;
}

int niMQTT::remaining_length_length(int remaining_length) {
    int X = remaining_length;
    int rll = 0;
    do {
        rll++;
        X /= 0x80;
    } while (X > 0);
    return rll;
}

void niMQTT::call_callback(const char *topic, const char *message) {
    callback(topic, message);
}