#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#ifdef MBED_MAJOR_VERSION
#include "mbed.h"
Serial pc(P0_25, P0_8);
#endif

// definitions
#include "abac_them.h"

// policy constructors

attr_v2 new_attr_integer(char *name, int value)
{
    attr_v2 at;
    at.data_type = abac_integer;
    at.name = name;
    at.integer = value;
    return at;
}

attr_v2 new_attr_real(char *name, float value)
{
    attr_v2 at;
    at.data_type = abac_real;
    at.name = name;
    at.real = value;
    return at;
}

attr_v2 new_attr_integer_range(char *name, int min, int max)
{
    attr_v2 at;
    at.data_type = abac_integer_range;
    at.name = name;
    at.ran.integer_min = min;
    at.ran.integer_max = max;
    return at;
}

attr_v2 new_attr_real_range(char *name, float min, float max)
{
    attr_v2 at;
    at.data_type = abac_real_range;
    at.name = name;
    at.ran.real_min = min;
    at.ran.real_max = max;
    return at;
}

attr_v2 new_attr_string(char *name, char *value)
{
    attr_v2 at;
    at.data_type = abac_string;
    at.name = name;
    at.string = value;
    return at;
}

attr_v2 new_attr_string_list(char *name, size_t len)
{
    attr_v2 at;
    at.data_type = abac_string_list;
    at.name = name;
    at.inner_list_len = len;
    at.string_list = (char **) malloc(sizeof(char *) * len);
    return at;
}

attr_v2 new_attr_dictionary(char *name, attr_v2 **value, size_t len)
{
    attr_v2 at;
    at.data_type = abac_dictionary;
    at.name = name;
    at.inner_list_len = len;
    at.inner_attrs = value;
    return at;
}

attr_v2 **new_attr_list(size_t len)
{
    attr_v2 **list = (attr_v2**) malloc(sizeof(attr_v2 *) * len);
    return list;
}

char **new_operations_list(size_t len)
{
    char **list = (char**) malloc(sizeof(char *) * len);
    return list;
}

// graph constructors

node new_graph_node(char *value)
{
    node n;
    n.value = value;
    n.next = NULL;
    return n;
}

void create_directed_edge(node *a, node *b)
{
    a->next = b;
}

graph new_graph(size_t len)
{
    graph g;
    g.len = len;
    g.list = (node **) malloc(sizeof(node *) * len);
    return g;
}

// graph algorithms

int is_node_in(node k, node *list, size_t v_len)
{
    for (int i = 0; i < v_len; ++i)
        if (strcmp(k.value, list[i].value) == 0)
            return 1;
    return 0;
}

node *find_ancestors_dfs(graph g, node n, size_t *v_len)
{
    node *visited = (node *) malloc(sizeof(node) * g.len); // allocate max number of nodes for simplicity
    node **stack = (node **) malloc(sizeof(node *) * g.len);
    node k;
    size_t s_head = 0, the_len = 0;
    stack[s_head++] = &n;
    while (s_head > 0) {
        k = *stack[--s_head];
        if (!is_node_in(k, visited, the_len)) {
            visited[the_len++] = k;
            // walk over k.next and add to stack
            while (k.next) {
                stack[s_head++] = k.next;
                k = *k.next;
            }
        }
    }

    *v_len = the_len;
    return visited;
}

node *find_in_graph(attr_v2 *at, graph g)
{
    for (int j = 0; j < g.len; ++j)
        if (strcmp(at->string, g.list[j]->value) == 0)
            return g.list[j];
    return NULL;
}

void expand_attr(attr_v2 **at_orig, graph g)
{
    if ((*at_orig)->data_type != abac_string)
        return;
    size_t v_len = 0;
    node *n;
    n = find_in_graph(*at_orig, g);
    if (n == NULL)
        return;
    node *visited = find_ancestors_dfs(g, *n, &v_len);
    attr_v2 *at = (attr_v2 *) malloc(sizeof(attr_v2));
    *at = new_attr_string_list((*at_orig)->name, v_len);
    for (int j = 0; j < v_len; ++j)
        at->string_list[j] = visited[j].value;

    *at_orig = at;
}

void expand_attrs(rule *req, graph g)
{
    for (int i = 0; i < req->users_len; ++i)
        expand_attr(&req->users[i], g);
    for (int i = 0; i < req->objects_len; ++i)
        expand_attr(&req->objects[i], g);
    for (int i = 0; i < req->contexts_len; ++i)
        expand_attr(&req->contexts[i], g);
}

// authorization

int is_subset(char **ro, size_t ro_len, char **po, size_t po_len)
{
    for (int i = 0; i < ro_len; i++) {
        int ok = 0;
        for (int j = 0; j < po_len; j++)
            if (strcmp(ro[i], po[j]) == 0)
                ok = 1;
        if (!ok)
            return 0;
    }
    return 1;
}

int is_string_in(char *a, char **b, size_t b_len)
{
    for (int i = 0; i < b_len; ++i)
        if (strcmp(a, b[i]) == 0)
            return 1;
    return 0;
}

int match_attr_v2(attr_v2 ra, attr_v2 pa)
{
    if (strcmp(ra.name, pa.name) != 0)
        return 0;

    switch(pa.data_type) {
    case abac_integer:
        if (ra.integer == pa.integer)
            return 1;
        break;
    case abac_real:
        if (ra.real == pa.real)
            return 1;
        break;
    case abac_integer_range:
        if (ra.integer >= pa.ran.integer_min && ra.integer <= pa.ran.integer_max)
            return 1;
        break;
    case abac_real_range:
        if (ra.real >= pa.ran.real_min && ra.real <= pa.ran.real_max)
            return 1;
        break;
    case abac_string:
        if (ra.data_type == abac_string && strcmp(ra.string, pa.string) == 0)
            return 1;
        else if (ra.data_type == abac_string_list && is_string_in(pa.string, ra.string_list, ra.inner_list_len))
            return 1;
        break;
    case abac_dictionary:
        return match_attrs_v2(ra.inner_attrs, ra.inner_list_len, pa.inner_attrs, pa.inner_list_len);
    }
    return 0;
}

int match_attrs_v2(attr_v2 **ras, size_t ras_len, attr_v2 **pas, size_t pas_len)
{
    int any_r;
    for (int i = 0; i < pas_len; i++)
    {
        any_r = 0;
        for (int j = 0; j < ras_len; j++)
            if (match_attr_v2(*(ras[i]), *(pas[i])))
                any_r = 1;
        if (!any_r)
            return 0;
    }
    if (any_r)
        return 1;
    else
        return 0;
}

int match_permission(rule r, rule perm)
{
    return 
        is_subset(r.operations, r.operations_len, perm.operations, perm.operations_len) &&
        match_attrs_v2(r.users, r.users_len, perm.users, perm.users_len) && 
        match_attrs_v2(r.objects, r.objects_len, perm.objects, perm.objects_len) && 
        match_attrs_v2(r.contexts, r.contexts_len, perm.contexts, perm.contexts_len);
}

int authorize_permissions(rule req, rule *perms, size_t p_len)
{
    for (int i = 0; i < p_len; i++)
        if (match_permission(req, perms[i]))
            return 1;
    return 0;
}

int authorize_permissions_expand(rule req, rule *perms, size_t p_len, graph g)
{
    expand_attrs(&req, g);
    for (int i = 0; i < p_len; i++)
        if (match_permission(req, perms[i]))
            return 1;
    return 0;
}

// debug

void show_attr_v2(attr_v2 at)
{
    switch(at.data_type) {
    case abac_integer:
        printf("%s: %d\n", at.name, at.integer);
        break;
    case abac_real:
        printf("%s: %.2f\n", at.name, at.real);
        break;
    case abac_integer_range:
        printf("%s: %d..%d\n", at.name, at.ran.integer_min, at.ran.integer_max);
        break;
    case abac_real_range:
        printf("%s: %.2f..%.2f\n", at.name, at.ran.real_min, at.ran.real_max);
        break;
    case abac_string:
        printf("%s: %s\n", at.name, at.string);
        break;
    case abac_string_list:
        printf("%s: ", at.name);
        for (int i = 0; i < at.inner_list_len; ++i)
            printf("%s ", at.string_list[i]);
        printf("\n");
        break;
    case abac_dictionary:
        printf("[%s:\n", at.name);
        for (int i = 0; i < at.inner_list_len; i++)
            show_attr_v2(*(at.inner_attrs[i]));
        printf("]\n");
        break;
    }
}

void show_operations(char **ops, size_t len)
{
    for (int i = 0; i < len; i++)
        printf("%s ", ops[i]);
    printf("\n");
}

void show_rule(rule r, char *desc)
{
    printf("\n>%s\n", desc);
    printf("#users:\n");
    for (int i = 0; i < r.users_len; i++)
        show_attr_v2(*(r.users[i]));
    printf("#objects:\n");
    for (int i = 0; i < r.objects_len; i++)
        show_attr_v2(*(r.objects[i]));
    printf("#contexts:\n");
    for (int i = 0; i < r.contexts_len; i++)
        show_attr_v2(*(r.contexts[i]));
    printf("#operations:\n");
    show_operations(r.operations, r.operations_len);
}

void show_visited(node *visited, size_t v_len)
{
    printf("visited: ");
    for (int i = 0; i < v_len; ++i)
        printf("%s ", visited[i].value);
    printf("\n");
}

void show_node_list(node *list, size_t len, char *desc)
{
    printf("%s: ", desc);
    for (int i = 0; i < len; ++i)
        printf("%s ", list[i].value);
    printf("\n");
}



int main() {

    attr_v2 **at_list;

    attr_v2 id_alice = new_attr_string("id", "alice");
    attr_v2 id_camera1 = new_attr_string("id", "camera1");
    attr_v2 id_lamp1 = new_attr_string("id", "lamp1");
    attr_v2 id_some_device_x = new_attr_string("id", "some-device-x");
    attr_v2 owner_alice = new_attr_string("owner", "alice");
    attr_v2 year_2020 = new_attr_integer("year", 2020);
    attr_v2 month_6 = new_attr_integer("month", 6);
    attr_v2 day_30 = new_attr_integer("day", 30);
    attr_v2 hour_17 = new_attr_integer("hour", 17);
    attr_v2 luminosity_25 = new_attr_integer("outdoorLuminosity", 25);
    attr_v2 age_min18 = new_attr_integer_range("age", 18, 120);
    attr_v2 minute_20_25 = new_attr_integer_range("minute", 20, 25);
    attr_v2 luminosity_max33 = new_attr_integer_range("outdoorLuminosity", 0, 33);
    attr_v2 reputation_min4 = new_attr_real_range("reputation", 4, 5);
    attr_v2 type_security = new_attr_string("type", "securityAppliance");
    attr_v2 type_lighting = new_attr_string("type", "lightingAppliance");
    attr_v2 household_role_child = new_attr_string("role", "child");
    attr_v2 household_id_home1 = new_attr_string("id", "home-1");
    attr_v2 type_camera = new_attr_string("type", "securityCamera");
    attr_v2 location_outdoor = new_attr_string("location", "outdoor");

    at_list = new_attr_list(1);
    at_list[0] = &household_id_home1;
    attr_v2 household_with_id = new_attr_dictionary("household", at_list, 1);

    at_list = new_attr_list(2);
    at_list[0] = &household_id_home1;
    at_list[1] = &household_role_child;
    attr_v2 household_with_id_role = new_attr_dictionary("household", at_list, 2);

    // p1
    rule perm1;
    perm1.users = new_attr_list(1);
    perm1.users_len = 1;
    perm1.users[0] = &id_alice;

    perm1.objects = new_attr_list(1);
    perm1.objects_len = 1;
    perm1.objects[0] = &owner_alice;

    perm1.contexts_len = 0;

    perm1.operations = new_operations_list(4);
    perm1.operations_len = 4;
    perm1.operations[0] = "create";
    perm1.operations[1] = "read";
    perm1.operations[2] = "update";
    perm1.operations[3] = "delete";
    show_rule(perm1, "perm1\0");

    // p2, authorizes req_e
    rule perm2;
    perm2.users = new_attr_list(2);
    perm2.users_len = 2;
    perm2.users[0] = &age_min18;
    perm2.users[1] = &household_with_id;

    perm2.objects = new_attr_list(2);
    perm2.objects_len = 2;
    perm2.objects[0] = &type_security;
    perm2.objects[1] = &household_with_id;

    perm2.contexts_len = 0;

    perm2.operations = new_operations_list(2);
    perm2.operations_len = 2;
    perm2.operations[0] = "read";
    perm2.operations[1] = "update";
    show_rule(perm2, "perm2\0");

    // p3, authorizes req
    rule perm3;
    perm3.users = new_attr_list(1);
    perm3.users_len = 1;
    perm3.users[0] = &household_with_id_role;

    perm3.objects = new_attr_list(2);
    perm3.objects_len = 2;
    perm3.objects[0] = &type_lighting;
    perm3.objects[1] = &household_with_id;

    perm3.contexts = new_attr_list(1);
    perm3.contexts_len = 1;
    perm3.contexts[0] = &luminosity_max33;

    perm3.operations = new_operations_list(2);
    perm3.operations_len = 2;
    perm3.operations[0] = "read";
    perm3.operations[1] = "update";
    show_rule(perm3, "perm3\0");

    // p4
    rule perm4;
    perm4.users = new_attr_list(1);
    perm4.users_len = 1;
    perm4.users[0] = &id_camera1;

    perm4.objects = new_attr_list(1);
    perm4.objects_len = 1;
    perm4.objects[0] = &id_lamp1;

    perm4.contexts_len = 0;

    perm4.operations = new_operations_list(2);
    perm4.operations_len = 2;
    perm4.operations[0] = "read";
    perm4.operations[1] = "update";
    show_rule(perm4, "perm4\0");

    // p5
    rule perm5;
    perm5.users = new_attr_list(1);
    perm5.users_len = 1;
    perm5.users[0] = &reputation_min4;

    perm5.objects = new_attr_list(3);
    perm5.objects_len = 3;
    perm5.objects[0] = &type_camera;
    perm5.objects[1] = &household_with_id;
    perm5.objects[2] = &location_outdoor;

    perm5.contexts = new_attr_list(1);
    perm5.contexts_len = 1;
    perm5.contexts[0] = &luminosity_max33;

    perm5.operations = new_operations_list(1);
    perm5.operations_len = 1;
    perm5.operations[0] = "contract";
    show_rule(perm5, "perm5\0");

    // p6
    rule perm6;
    perm6.users = new_attr_list(1);
    perm6.users_len = 1;
    perm6.users[0] = &id_some_device_x;

    perm6.objects = new_attr_list(5);
    perm6.objects_len = 5;
    perm6.objects[0] = &year_2020;
    perm6.objects[1] = &month_6;
    perm6.objects[2] = &day_30;
    perm6.objects[3] = &hour_17;
    perm6.objects[4] = &minute_20_25;

    perm6.contexts = new_attr_list(1);
    perm6.contexts_len = 1;
    perm6.contexts[0] = &luminosity_max33;

    perm6.operations = new_operations_list(1);
    perm6.operations_len = 1;
    perm6.operations[0] = "contract";
    show_rule(perm6, "perm6\0");

    // list of perms
    rule *perms = (rule *) malloc(sizeof(rule) * 6);
    perms[0] = perm1;
    perms[1] = perm2;
    perms[2] = perm3;
    perms[3] = perm4;
    perms[4] = perm5;
    perms[5] = perm6;

    // a request
    rule req;
    req.users = new_attr_list(1);
    req.users_len = 1;
    req.users[0] = &household_with_id_role;

    req.objects = new_attr_list(2);
    req.objects_len = 2;
    req.objects[0] = &type_lighting;
    req.objects[1] = &household_with_id;

    req.contexts = new_attr_list(1);
    req.contexts_len = 1;
    req.contexts[0] = &luminosity_25;

    req.operations = new_operations_list(1);
    req.operations_len = 1;
    req.operations[0] = "read";
    show_rule(req, "request\0");

    if (authorize_permissions(req, perms, 6))
        printf("\nauthorized request for policy #3\n");

    // creating a graph

    node n_child = new_graph_node("child");
    node n_father = new_graph_node("father");
    node n_mother = new_graph_node("mother");
    node n_adultFamilyMember = new_graph_node("adultFamilyMember");
    node n_family_member = new_graph_node("familyMember");
    node n_person = new_graph_node("person");
    create_directed_edge(&n_child, &n_family_member);
    create_directed_edge(&n_father, &n_adultFamilyMember);
    create_directed_edge(&n_mother, &n_adultFamilyMember);
    create_directed_edge(&n_adultFamilyMember, &n_family_member);
    create_directed_edge(&n_family_member, &n_person);

    node n_securityCamera = new_graph_node("securityCamera");
    node n_intrusionAlarm = new_graph_node("intrusionAlarm");
    node n_securityAppliance = new_graph_node("securityAppliance");
    create_directed_edge(&n_securityCamera, &n_securityAppliance);
    create_directed_edge(&n_intrusionAlarm, &n_securityAppliance);

    graph g = new_graph(6+3);
    g.list[0] = &n_child;
    g.list[1] = &n_father;
    g.list[2] = &n_mother;
    g.list[3] = &n_adultFamilyMember;
    g.list[4] = &n_family_member;
    g.list[5] = &n_person;
    g.list[6] = &n_securityCamera;
    g.list[7] = &n_intrusionAlarm;
    g.list[8] = &n_securityAppliance;

    // a request to expand
    rule req_e;
    req_e.users = new_attr_list(2);
    req_e.users_len = 2;
    attr_v2 age_25 = new_attr_integer("age", 25);
    req_e.users[0] = &age_25;
    req_e.users[1] = &household_with_id;

    req_e.objects = new_attr_list(2);
    req_e.objects_len = 2;
    req_e.objects[0] = &type_camera;
    req_e.objects[1] = &household_with_id;

    req_e.contexts_len = 0;

    req_e.operations = new_operations_list(1);
    req_e.operations_len = 1;
    req_e.operations[0] = "read";
    show_rule(req_e, "request that will be expanded\0");

    if (!authorize_permissions(req_e, perms, 6))
        printf("\ndenied non-expanded request for policy #2\n");
    if (authorize_permissions_expand(req_e, perms, 6, g))
        printf("\nauthorized expanded request for policy #2\n\n");

    // many policies
    int n_perms = 3000, median;
    median = (int) (n_perms / 2);
    rule *many_perms = (rule *) malloc(sizeof(rule) * n_perms);
    for (int i = 0; i < n_perms; ++i)
        many_perms[i] = perm5;
    many_perms[median] = perm2;

    // benchmark

    int runs = 3000;
#ifdef MBED_MAJOR_VERSION
    Timer t;
    t.start();
    for (int i = 0; i < runs; i++)
        authorize_permissions_expand(req_e, perms, 6, g);
    t.stop();
    pc.printf("The time taken to authorize 1 request against 6 policies, %d times, was %f ms\n", runs, t.read() * 1000);
    printf("> The time taken to authorize 1 request against 6 policies, %d times, was %f ms\n", runs, t.read() * 1000);

    t.start();
    authorize_permissions_expand(req_e, many_perms, n_perms, g);
    t.stop();
    pc.printf("The time taken to authorize 1 request against %d policies was %f ms\n", n_perms, t.read() * 1000);
    printf("> The time taken to authorize 1 request against %d policies was %f ms\n", n_perms, t.read() * 1000);
#elif defined(ESP32)
    unsigned long startTime, endTime;
    startTime = millis();
    for (int i = 0; i < runs; i++)
        authorize_permissions_expand(req_e, perms, 6, g);
    endTime = millis();
    Serial.print("The time taken to authorize 1 request against 6 policies, ");
    Serial.print(runs);
    Serial.print(" times, was ");
    Serial.print(endTime - startTime);
    Serial.println(" ms");

    startTime = millis();
    authorize_permissions_expand(req_e, many_perms, n_perms, g);
    endTime = millis();
    Serial.print("The time taken to authorize 1 request against ");
    Serial.print(n_perms);
    Serial.print(" policies was ");
    Serial.print(endTime - startTime);
    Serial.println(" ms");
#elif defined(__unix__)
    #include <time.h>
    clock_t t;
    t = clock();
    double elapsed;
    for (int i = 0; i < runs; i++)
        authorize_permissions_expand(req_e, perms, 6, g);
    t = clock() - t;
    elapsed = ((double) t) / CLOCKS_PER_SEC;
    printf("The time taken to authorize 1 request against 6 policies, %d times, was %f ms\n", runs, elapsed * 1000);

    t = clock();
    authorize_permissions_expand(req_e, many_perms, n_perms, g);
    t = clock() - t;
    elapsed = ((double) t) / CLOCKS_PER_SEC;
    printf("The time taken to authorize 1 request against %d policies was %f ms\n", n_perms, elapsed * 1000);
#endif

    free(many_perms);
}
