#include "mbed.h"
#include "rtos.h"

//if you want to expose the deadlock, uncomment the DEADLOCK define
#define DEADLOCK 

//if you want to fix the deadlock by having the mutex lock timeout, uncomment the DEADLOCK_FIX define
#define DEADLOCK_FIX 

// I recommend running with the following order
// 1. do not define any macro to see normal execution
// 2. define DEADLOCK to see program hang
// 3. define DEADLOCK and DEADLOCK_FIX to see the hang fixed

typedef struct account{
        double balance;
        int id;
} Account;

// Globals
//two accounts
Account personA;
Account personB;
//pointers to those accounts
Account * personA_ptr;
Account * personB_ptr;
//mutexes for the two accounts
Mutex personA_mutex;
Mutex personB_mutex;

// function: withdraw
// inputs: Account pointer, double amount to withdraw
// description: subtracts input amount from the Account's balance
void withdraw (Account * person, double amount){
           person->balance -= amount;
           printf("Balance withdraw: %f from id %d\n", person->balance, person->id);
        }

// function: deposit
// inputs: Account pointer, double amount to deposit
// description: adds input amount to the Account's balance    
void deposit(Account * person, double amount){
           person->balance += amount;
           printf("Balance deposit: %f from id %d\n", person->balance, person->id);
        }
        
// function: transfer
// inputs: From Account mutex, To Account mutex, From Account pointer, To Account pointer, double amount to transfer
// description: protects access to each account using correspoinding mutex
//              withdraws amount from From Account and deposits amount to To Account
void transfer(Mutex f_m, Mutex t_m, Account * from, Account* to, double amount){ 
#ifdef DEADLOCK_FIX
    f_m.lock(500);
#else
    f_m.lock();
#endif

//putting a wait between grabing the locks allows the higher priority task to come in and grab the other lockk
#ifdef DEADLOCK 
    Thread::wait(5000);
#endif 

#ifdef DEADLOCK_FIX
    t_m.lock(500);
#else
    t_m.lock();
#endif
    
    withdraw(from, amount);
    deposit(to, amount);
    t_m.unlock();
    f_m.unlock();
}

// thread: AtoB
// description: transfers 500 from A to B, prints balance after transfer
void AtoB(void const *args) {
    transfer(personA_mutex, personB_mutex, personA_ptr, personB_ptr, 500);
    printf("Balance A (AtoB): %f\n", personA_ptr->balance);
    printf("Balance B (AtoB): %f\n", personB_ptr->balance);
}

// thread: BtoA
// description: transfers 200 from B to A, prints balance after transfer
void BtoA(void const *args) {
    transfer(personB_mutex, personA_mutex, personB_ptr, personA_ptr, 200);
    printf("Balance A (BtoA): %f\n", personA_ptr->balance);
    printf("Balance B (BtoA): %f\n", personB_ptr->balance);
}

int main () {
    printf("\n\nStart new transaction\n");

    //populate values for Account personA, set pointer
    personA_ptr = &personA;
    personA_ptr->id = 1;
    personA_ptr->balance = 1000;
    
    //populate values for Account personB, set pointer
    personB_ptr = &personB;
    personB_ptr->id = 2;
    personB_ptr->balance = 1000;

    //start threads
    Thread thread1(AtoB);
    Thread thread2(BtoA);   
    
    thread1.set_priority(osPriorityNormal);
    //second thread needs to have a higher priority than the first thread to crate the deadlock condition
    thread2.set_priority(osPriorityHigh);
    
    while (true){
    }
    
}