#include "tree.h"
#include <stdlib.h>
#include <stdio.h>
#include <math.h>
#include <string.h>

/* Returns a string for each type of operator */
float eval_node(struct node* node, int level);
void set_number_node(struct node* node, float num);

/* local variables */
struct node* expr_root_node;
float Xvalue;

float eval_expr(struct node* node, float x_value) {
  Xvalue = x_value;
  return eval_node(node, 0);
}

float p(float l) {
  // printf(" op val: %g\n", l);
  return l;
}

float eval_op (struct node* op, int level) {
  float right = eval_node(op->right, level+1);
  float left  = eval_node(op->left,  level+1);
 
  switch(op->data.op) {
  case OP_ADD:  return p(left + right);
  case OP_SUB:  return p(left - right);
  case OP_MUL:  return p(left * right);
  case OP_DIV:  return p(left / right);
  case OP_EXP:  return p(pow(left,right));
  default:
    fprintf(stderr, "%d not defined as an operator in decode_op_type\n", op->data.op);
    exit(1);
  }

}

void set_root_node(struct node* node) {
  expr_root_node = node;
}

float eval_node(struct node* node, int level) {
  switch(node->type) {
  case NODE_NUMBER:
    return node->data.fval;

  case NODE_VAR:
//      printf("eval var x: %s\n",node->data.name);
//      printf("eval var x: %d\n",node->type);
//      printf("eval var x: %g\n",node->data.fval);
    if (strcmp(node->data.name,"x")==0) {
      return Xvalue;
    } else {
      fprintf(stderr,"Cannot parse variables: %s\n", node->data.name);
      exit(1);
    }

  case NODE_BINOP:
    return eval_op(node, level+1);

  case NODE_FUNC:
    if (strcmp(node->data.name,"sin")==0) {
      printf("Calling function: sin(...)\n");
      float y = eval_node(node->left, level+1);
      return sin(y);
    }
    if (strcmp(node->data.name,"cos")==0) {
      printf("Calling function: cos(...)\n");
      float y = eval_node(node->left, level+1);
      return cos(y);
    }
    if (strcmp(node->data.name,"tan")==0) {
      printf("Calling function: tan(...)\n");
      float y = eval_node(node->left, level+1);
      return tan(y);
    }
  }

  exit(1);
}


struct node* new_number_node(float num){
  struct node* node = calloc(sizeof(*node),1);
  node->type = NODE_NUMBER;
  node->data.fval = num;
  node->left = NULL;
  node->right = NULL;
  node->parent = NULL;
  return node;
}


struct node* new_var_node(char* str){
  struct node* node = calloc(sizeof(*node),1);
  node->type = NODE_VAR;
  node->data.name = str;
  node->left = NULL;
  node->right = NULL;
  node->parent = NULL;
  return node;
}


struct node* new_binop_node(enum op_type op, struct node* left, struct node *right){
  struct node* node = calloc(sizeof(*node),1);
  node->type = NODE_BINOP;
  node->data.op = op;
  node->left = left;
  node->left->parent = node;
  node->right = right;
  node->right->parent = node;
  node->parent = NULL;
  return node;
}

struct node* new_binop_node_no_right(enum op_type op, struct node* left){
  struct node* node = calloc(sizeof(*node),1);
  node->type = NODE_BINOP;
  node->data.op = op;
  node->left = left;
  node->left->parent = node;
  node->right = NULL;
  node->parent = NULL;
  return node;
}

struct node* new_func_node(char* name, struct node* left){
  struct node* node = calloc(sizeof(*node),1);
  node->type = NODE_FUNC;
  node->data.name = name;
  node->left = left;
  node->left->parent = node;
  node->right = NULL;
  node->parent = NULL;
  return node;
}

struct node* new_func_node_no_left(char* name) {
  struct node* node = calloc(sizeof(*node),1);
  node->type = NODE_FUNC;
  node->data.name = name;
  node->left = NULL;
  node->right = NULL;
  node->parent = NULL;
  return node;
}

// Try to match a number token
// >0 success
//  0 no match
// -1 syntax error
int isNumber(const char* expr) {
    int len = strlen(expr);
    int num_of_dots = 0;
    int i =0;

    for (int i=0; i<len; i++) {
        switch(expr[i]) {

        // find digit, '.'
        case '0' ... '9':
            continue;
        case '.':
            num_of_dots++;
            if (num_of_dots>1)
                return -num_of_dots; // an error
            continue;

        // legal end of match
        case ' ' :
        case '\t':
        case '\n':
        case '(' :
        case ')' :
        case '+' :
        case '-' :
        case '*' :
        case '/' :
        case '^' :
            return i;

        // error end of match
        default:
            return -i;
        }
     }
    return i;
}

int isFunc(const char* expr) {
    for (int i=0; i<strlen(expr); i++) {
        switch(expr[i]) {
        case 'a' ... 'z':
        case 'A' ... 'Z':
        case '_':
            continue;
        case '(':
            return i;
        default:
            return i==0 ? -1 : -i;
        }
    }        
    return -1;
}

int isVar(const char* expr) {
    if (strlen(expr)==0)
      return 0;

    for (int i=0; i<strlen(expr); i++) {
        // match var name
        switch(expr[i]) {
        case 'a' ... 'z':
        case 'A' ... 'Z':
        case '_':
            continue;

        // match terminating character
        case ' ' :
        case '\t':
        case '\n':
        case ')' :
        case '+' :
        case '-' :
        case '*' :
        case '/' :
        case '^' :
            return i;
        default:
            return i==0 ? -1 : -i;
        }
    }        
    return -1;
}

// Add right side of OP
struct node* add_right(struct node* node, struct node* right){
  node->right = right;
  right->parent = node;
  return right;
}

// Add right side of OP
struct node* add_left(struct node* node, struct node* left){
  node->left = left;
  left->parent = node;
  return left;
}

// Add a Node
struct node* add_op(struct node* left, enum op_type op) {
    struct node* binop = NULL;
    struct node* parent = NULL;

    // no precedence ordering required
    if (left->parent==NULL) {
        return new_binop_node_no_right(op, left);
    }

    // new operation is same or higher precedence, so append as child of branch
    if (op >= left->parent->data.op) {
        parent = left->parent;
        binop = new_binop_node_no_right(op, left);
    add_right(parent, binop);
    return binop;
    }

    // traverse up the tree until no more parents or parent operation is lower precedence
    return add_op(left->parent, op);
}

// scanner and parser
struct node* scanner(const char* es) {
    char buff[200];
    int len = strlen(es);
    struct node* left = NULL;
    int lookahead = -1; // 0: fails, >0: success 

    for (int i=0; i<len; i++) {

        switch(es[i]) {
        // skip space
        case ' ' :  continue; // printf("[space]");
        case '\t':  continue; // printf("[tab]");
        case '\n':  continue; // printf("[newline]");

        // find number
        case '0' ... '9':
        case '.':
            lookahead = isNumber(&es[i+1]);
            if (lookahead>=0) {
                memcpy(buff, &es[i], lookahead+1);
                buff[lookahead+1] = (char) 0;
                i = i + lookahead;
//                printf("token-number: '%s' (atof:%f)\n", buff,atof(buff));
                
                if (left==NULL) {
                    left = new_number_node(atof(buff));
                } else if (left->type==NODE_BINOP) {
                    left = add_right(left, new_number_node(atof(buff)));
                } else if (left->type==NODE_FUNC) {
                    left = add_left(left, new_number_node(atof(buff)));
            left = left->parent;
                }
            }
            continue;

        // operators
        case '+' :
//            printf("token-add: '%c'\n", es[i]);
            left = add_op(left, OP_ADD);
            continue;
        case '-' :
//            printf("token-sub: '%c'\n", es[i]);
            left = add_op(left, OP_SUB);
            continue;
        case '*' :
//            printf("token-mul: '%c'\n", es[i]);
            left = add_op(left, OP_MUL);
            continue;
        case '/' :
//            printf("token-div: '%c'\n", es[i]);
            left = add_op(left, OP_DIV);
            continue;
        case '^' :       
//            printf("token-exp: '%c'\n", es[i]);
            left = add_op(left, OP_EXP);
            continue;

        case 'a' ... 'z':
        case 'A' ... 'Z':
            // find function
            lookahead = isFunc(&es[i+1]);
            if (lookahead>=0) {
                memcpy(buff, &es[i], lookahead+1);
                buff[lookahead+1] = (char) 0;
                printf("token-func: '%s'\n", buff);
                i = i + lookahead;
/*
                if (left==NULL) {
                    left = new_func_node(buff);
                } else if (left->type==NODE_BINOP) {
                    left = add_right(left, new_func_node(buff));
                }
 */               continue;
            }

            // find variable
            lookahead = isVar(&es[i+1]);
            if (lookahead>=0) {
                memcpy(buff, &es[i], lookahead+1);
                buff[lookahead+1] = (char) 0;
//                printf("token-var: '%s'\n", buff);
                i = i + lookahead;

                if (left==NULL) {
                    left = new_var_node(buff);
                } else if (left->type==NODE_BINOP) {
                    left = add_right(left, new_var_node(buff));
                }
                continue;
            }

        // grouping
        case '(' :
            printf("token-l-paren: '%c'\n", es[i]);
            continue;
        case ')' :
            printf("token-r-paren: '%c'\n", es[i]);
            continue;


        default:    ;
        }
    }   

    // find root
    while (left->parent != NULL) {
        left = left->parent;
    }
    return left;
}



/*   
   static const char expr_str[] = "5*4 * 3 / 2 - 1 + 37\n";

   // compile expression
   struct node* expr = comp_expr(expr_str);

   // evaluate expression
   eval_expr(expr, 1);

   // some output
   printf("Expression: %s\n", expr_str);
   x = 3.0; printf("Value[x=%2.3g]: %2.3g\n", x, eval_expr(expr, x));
   x = 3.5; printf("Value[x=%2.3g]: %2.3g\n", x, eval_expr(expr, x));
   x = 4.0; printf("Value[x=%2.3g]: %2.3g\n", x, eval_expr(expr, x));

   // example 2
   static const char str2[] = "5*4 * 3 / 2 - x + 37\n";
   expr = comp_expr(str2);
   eval_expr(expr, 1);

   printf("Expression: %s\n", str2);
   x = 1.0; printf("Value[x=%2.3g]: %2.3g\n", x, eval_expr(expr, x));
   x = 3.0; printf("Value[x=%2.3g]: %2.3g\n", x, eval_expr(expr, x));
   x = 3.5; printf("Value[x=%2.3g]: %2.3g\n", x, eval_expr(expr, x));
   x = 4.0; printf("Value[x=%2.3g]: %2.3g\n", x, eval_expr(expr, x));


   // example 3
   static const char str3[] = "5*4 * 3 / 2  - 1 + 37 + sin(x)\n";
   expr = comp_expr(str3);
   eval_expr(expr, 1);

   printf("Expression: %s\n", str3);
   x =-1.0; printf("Value[x=%2.3g]: %2.3g\n", x, eval_expr(expr, x));
   x = 1.0; printf("Value[x=%2.3g]: %2.3g\n", x, eval_expr(expr, x));
   x = 3.0; printf("Value[x=%2.3g]: %2.3g\n", x, eval_expr(expr, x));
   x = 3.5; printf("Value[x=%2.3g]: %2.3g\n", x, eval_expr(expr, x));
   x = 4.0; printf("Value[x=%2.3g]: %2.3g\n", x, eval_expr(expr, x));

   // example 4
   static const char str4[] = "5*4 * 3 / 2  - 1 + 37 + sin(x) + 2^3\n";
   expr = comp_expr(str4);
   eval_expr(expr, 1);

   printf("Expression: %s\n", str4);
   x = 1.0; printf("Value[x=%2.3g]: %2.3g\n", x, eval_expr(expr, x));
   x = 5.0; printf("Value[x=%2.3g]: %2.3g\n", x, eval_expr(expr, x));
   x = 7.0; printf("Value[x=%2.3g]: %2.3g\n", x, eval_expr(expr, x));
   x = 9.0; printf("Value[x=%2.3g]: %2.3g\n", x, eval_expr(expr, x));
   x =11.0; printf("Value[x=%2.3g]: %2.3g\n", x, eval_expr(expr, x));
*/

