Sungwoo Kim
/
HydraulicControlBoard_Learning
for learning
Diff: main.cpp
- Revision:
- 175:2f7289dbd488
- Parent:
- 174:c828479f53f9
- Child:
- 176:589ea3edcf3c
--- a/main.cpp Tue Nov 24 06:09:50 2020 +0000 +++ b/main.cpp Tue Nov 24 08:15:29 2020 +0000 @@ -1,4 +1,4 @@ -//201124_1 +//201124_3 #include "mbed.h" #include "FastPWM.h" #include "INIT_HW.h" @@ -490,22 +490,22 @@ void update_Critic_Networks(float (*arr)[num_input_RL]) { float gradient_rate = 0.001f; -// float hx_sum = 0.0f; - - - ///////////////////////////////////////////////////////////CRITIC float G_hc1[num_input_RL][num_hidden_unit1] = {0.0f}; - float G_bc1[num_hidden_unit1] = {0.0f}; + float d_V_d_hc1[batch_size][num_input_RL][num_hidden_unit1] = {0.0f}; ////////////////1 + float G_bc1[num_hidden_unit1] = {0.0f}; + float d_V_d_bc1[batch_size][num_hidden_unit1] = {0.0f}; ////////////////2 for (int index2 = 0; index2 < num_hidden_unit1; index2++) { for (int index1 = 0; index1 < num_input_RL; index1++) { for (int n=0; n<batch_size; n++) { for(int k=0; k<num_hidden_unit2; k++) { if (hxh_c_sum_array[n][k] >= 0) { if (hx_c_sum_array[n][index2] > 0) { - G_hc1[index1][index2] = G_hc1[index1][index2] + arr[n][index1]*hc2_temp[index2][k]*hc3_temp[k]; + //G_hc1[index1][index2] = G_hc1[index1][index2] + arr[n][index1]*hc2_temp[index2][k]*hc3_temp[k]; //////////////////////3 + d_V_d_hc1[n][index1][index2] = d_V_d_hc1[n][index1][index2] + arr[n][index1]*hc2_temp[index2][k]*hc3_temp[k]; //////////////////////4 } } } + G_hc1[index1][index2] = G_hc1[index1][index2] + 2.0f*(return_G[n]-V[n])*(-d_V_d_hc1[n][index1][index2]); /////////////////////5 } G_hc1[index1][index2] = G_hc1[index1][index2] / batch_size; //hc1_temp[index1][index2] = hc1_temp[index1][index2] - gradient_rate * G_hc1[index1][index2]; @@ -514,53 +514,68 @@ for(int k=0; k<num_hidden_unit2; k++) { if (hxh_c_sum_array[n][k] >= 0) { if (hx_c_sum_array[n][index2] > 0) { - G_bc1[index2] = G_bc1[index2] + hc2_temp[index2][k]*hc3_temp[k]; + //G_bc1[index2] = G_bc1[index2] + hc2_temp[index2][k]*hc3_temp[k]; //////////////////6 + d_V_d_bc1[n][index2] = d_V_d_bc1[n][index2] + hc2_temp[index2][k]*hc3_temp[k]; //////////////////7 } } } + G_bc1[index2] = G_bc1[index2] + 2.0f*(return_G[n]-V[n])*(-d_V_d_bc1[n][index2]); /////////////////////8 } G_bc1[index2] = G_bc1[index2] / batch_size; //bc1_temp[index2] = bc1_temp[index2] - gradient_rate * G_bc1[index2]; } + float G_hc2[num_hidden_unit1][num_hidden_unit2] = {0.0f}; + float d_V_d_hc2[batch_size][num_hidden_unit1][num_hidden_unit2] = {0.0f}; float G_bc2[num_hidden_unit2] = {0.0f}; + float d_V_d_bc2[batch_size][num_hidden_unit2] = {0.0f}; for (int index2 = 0; index2 < num_hidden_unit2; index2++) { for (int index1 = 0; index1 < num_hidden_unit1; index1++) { for (int n=0; n<batch_size; n++) { if (hxh_c_sum_array[n][index2] >= 0) { if (hx_c_sum_array[n][index1] > 0) { - G_hc2[index1][index2] = G_hc2[index1][index2] + hx_c_sum_array[n][index1]*hc3_temp[index2]; + //G_hc2[index1][index2] = G_hc2[index1][index2] + hx_c_sum_array[n][index1]*hc3_temp[index2]; + d_V_d_hc2[n][index1][index2] = hx_c_sum_array[n][index1]*hc3_temp[index2]; } } + G_hc2[index1][index2] = G_hc2[index1][index2] + 2.0f*(return_G[n]-V[n])*(-d_V_d_hc2[n][index1][index2]); } G_hc2[index1][index2] = G_hc2[index1][index2] / batch_size; //hc2_temp[index1][index2] = hc2_temp[index1][index2] - gradient_rate * G_hc2[index1][index2]; } for (int n=0; n<batch_size; n++) { if (hxh_c_sum_array[n][index2] >= 0) { - G_bc2[index2] = G_bc2[index2] + hc3_temp[index2]; + //G_bc2[index2] = G_bc2[index2] + hc3_temp[index2]; + d_V_d_bc2[n][index2] = hc3_temp[index2]; } + G_bc2[index2] = G_bc2[index2] + 2.0f*(return_G[n]-V[n])*(-d_V_d_bc2[n][index2]); } G_bc2[index2] = G_bc2[index2] / batch_size; //bc2_temp[index2] = bc2_temp[index2] - gradient_rate * G_bc2[index2]; } float G_hc3[num_hidden_unit2]= {0.0f}; + float d_V_d_hc3[batch_size][num_hidden_unit2] = {0.0f}; float G_bc3 = 0.0f; + float d_V_d_bc3[batch_size] = {0.0f}; for (int index2 = 0; index2 < 1; index2++) { for (int index1 = 0; index1 < num_hidden_unit2; index1++) { for (int n=0; n<batch_size; n++) { if (hxh_c_sum_array[n][index1] >= 0) { - G_hc3[index1] = G_hc3[index1] + hxh_c_sum_array[n][index1]; + //G_hc3[index1] = G_hc3[index1] + hxh_c_sum_array[n][index1]; + d_V_d_hc3[n][index1] = d_V_d_hc3[n][index1] + hxh_c_sum_array[n][index1]; } + G_hc3[index1] = G_hc3[index1] + 2.0f*(return_G[n]-V[n])*(-d_V_d_hc3[n][index1]); } G_hc3[index1] = G_hc3[index1] / batch_size; //hc3_temp[index1] = hc3_temp[index1] - gradient_rate * G_hc3[index1]; } for (int n=0; n<batch_size; n++) { - G_bc2[index2] = G_bc2[index2] + 1.0f; + //G_bc2[index2] = G_bc2[index2] + 1.0f; + d_V_d_bc3[n] = 1.0f; + G_bc3 = G_bc3 + 2.0f*(return_G[n]-V[n])*(-d_V_d_bc3[n]); } G_bc3 = G_bc3 / batch_size; //bc3_temp = bc3_temp - gradient_rate * G_bc3; @@ -623,7 +638,7 @@ } } G_ha1[index1][index2] = G_ha1[index1][index2] / batch_size; - ha1_temp[index1][index2] = ha1_temp[index1][index2] - gradient_rate * G_ha1[index1][index2]; + //ha1_temp[index1][index2] = ha1_temp[index1][index2] - gradient_rate * G_ha1[index1][index2]; } for (int n=0; n<batch_size; n++) { @@ -647,7 +662,7 @@ } } G_ba1[index2] = G_ba1[index2] / batch_size; - ba1_temp[index2] = ba1_temp[index2] - gradient_rate * G_ba1[index2]; + //ba1_temp[index2] = ba1_temp[index2] - gradient_rate * G_ba1[index2]; } float G_ha2[num_hidden_unit1][num_hidden_unit2] = {0.0f}; @@ -680,7 +695,7 @@ } } G_ha2[index1][index2] = G_ha2[index1][index2] / batch_size; - ha2_temp[index1][index2] = ha2_temp[index1][index2] - gradient_rate * G_ha2[index1][index2]; + //ha2_temp[index1][index2] = ha2_temp[index1][index2] - gradient_rate * G_ha2[index1][index2]; } for (int n=0; n<batch_size; n++) { @@ -701,7 +716,7 @@ } } G_ba2[index2] = G_ba2[index2] / batch_size; - ba2_temp[index2] = ba2_temp[index2] - gradient_rate * G_ba2[index2]; + //ba2_temp[index2] = ba2_temp[index2] - gradient_rate * G_ba2[index2]; } float G_ha3[num_hidden_unit2][2] = {0.0f}; @@ -722,7 +737,6 @@ if (hx_a_sum_array[n][index1] > 0) { d_x_d_ha3[index1][index2] = d_x_d_ha3[index1][index2] + hxh_a_sum_array[n][index1]; d_y_d_ha3[index1][index2] = d_y_d_ha3[index1][index2] + hxh_a_sum_array[n][index1]; - } } float d_mean_d_ha3 = 0.0f; @@ -734,7 +748,7 @@ } } G_ha3[index1][index2] = G_ha3[index1][index2] / batch_size; - ha3_temp[index1][index2] = ha3_temp[index1][index2] - gradient_rate * G_ha3[index1][index2]; + //ha3_temp[index1][index2] = ha3_temp[index1][index2] - gradient_rate * G_ha3[index1][index2]; } for (int n=0; n<batch_size; n++) { @@ -754,8 +768,29 @@ } } G_ba3[index2] = G_ba3[index2] / batch_size; + //ba3_temp[index2] = ba3_temp[index2] - gradient_rate * G_ba3[index2]; + } + + // Simultaneous Update + for (int index2 = 0; index2 < num_hidden_unit1; index2++) { + for (int index1 = 0; index1 < num_input_RL; index1++) { + ha1_temp[index1][index2] = ha1_temp[index1][index2] - gradient_rate * G_ha1[index1][index2]; + } + ba1_temp[index2] = ba1_temp[index2] - gradient_rate * G_ba1[index2]; + } + for (int index2 = 0; index2 < num_hidden_unit2; index2++) { + for (int index1 = 0; index1 < num_hidden_unit1; index1++) { + ha2_temp[index1][index2] = ha2_temp[index1][index2] - gradient_rate * G_ha2[index1][index2]; + } + ba2_temp[index2] = ba2_temp[index2] - gradient_rate * G_ba2[index2]; + } + for (int index2 = 0; index2 < 2; index2++) { + for (int index1 = 0; index1 < num_hidden_unit2; index1++) { + ha3_temp[index1][index2] = ha3_temp[index1][index2] - gradient_rate * G_ha3[index1][index2]; + } ba3_temp[index2] = ba3_temp[index2] - gradient_rate * G_ba3[index2]; } + } ///////////////////////////ReLU - Bad performance////////////////////////////////// @@ -1223,40 +1258,41 @@ //Network Update(just update and hold network) for (int epoch = 0; epoch < num_epoch; epoch++) { float loss_sum = 0.0f; - for (int i=batch_size-1; i>=0; i--) { + for (int n=batch_size-1; n>=0; n--) { //Calculate Estimated V - //float temp_array[3] = {state_array[i][0], state_array[i][1], state_array[i][2]}; - float temp_array[2] = {state_array[i][0], state_array[i][1]}; - V[i] = Critic_Network_Temp(temp_array); + //float temp_array[3] = {state_array[n][0], state_array[n][1], state_array[n][2]}; + float temp_array[2] = {state_array[n][0], state_array[n][1]}; + V[n] = Critic_Network_Temp(temp_array); for (int i=0; i<num_hidden_unit1; i++) { - hx_c_sum_array[RL_timer][i] = hx_c_sum[i]; + hx_c_sum_array[n][i] = hx_c_sum[i]; } for (int i=0; i<num_hidden_unit2; i++) { - hxh_c_sum_array[RL_timer][i] = hxh_c_sum[i]; + hxh_c_sum_array[n][i] = hxh_c_sum[i]; } - hxhh_c_sum_array[RL_timer] = hxhh_c_sum; - pi[i] = exp(-(action_array[i]-mean_array[i])*(action_array[i]-mean_array[i])/(2.0f*deviation_array[i]*deviation_array[i]))/(sqrt(2.0f*PI)*deviation_array[i]); + hxhh_c_sum_array[n] = hxhh_c_sum; + + pi[n] = exp(-(action_array[n]-mean_array[n])*(action_array[n]-mean_array[n])/(2.0f*deviation_array[n]*deviation_array[n]))/(sqrt(2.0f*PI)*deviation_array[n]); Actor_Network_Old(temp_array); - pi_old[i] = exp(-(action_array[i]-mean_old)*(action_array[i]-mean_old)/(2.0f*deviation_old*deviation_old))/(sqrt(2.0f*PI)*deviation_old); - r[i] = exp(-0.25f * state_array[i][1] * state_array[i][1]); - if(i == batch_size-1) return_G[i] = 0.0f; - else return_G[i] = gamma * return_G[i+1] + r[i]; - if(i == batch_size-1) td_target[i] = r[i]; - else td_target[i] = r[i] + gamma * V[i+1]; - delta[i] = td_target[i] - V[i]; - if(i == batch_size-1) advantage[i] = 0.0f; - else advantage[i] = gamma * lmbda * advantage[i+1] + delta[i]; - ratio[i] = pi[i]/pi_old[i]; - surr1[i] = ratio[i] * advantage[i]; - if (ratio[i] > 1.0f + epsilon) { - surr2[i] = (1.0f + epsilon)*advantage[i]; - } else if( ratio[i] < 1.0f - epsilon) { - surr2[i] = (1.0f - epsilon)*advantage[i]; + pi_old[n] = exp(-(action_array[n]-mean_old)*(action_array[n]-mean_old)/(2.0f*deviation_old*deviation_old))/(sqrt(2.0f*PI)*deviation_old); + r[n] = exp(-0.25f * state_array[n][1] * state_array[n][1]); + if(n == batch_size-1) return_G[n] = 0.0f; + else return_G[n] = gamma * return_G[n+1] + r[n]; + if(n == batch_size-1) td_target[n] = r[n]; + else td_target[n] = r[n] + gamma * V[n+1]; + delta[n] = td_target[n] - V[n]; + if(n == batch_size-1) advantage[n] = 0.0f; + else advantage[n] = gamma * lmbda * advantage[n+1] + delta[n]; + ratio[n] = pi[n]/pi_old[n]; + surr1[n] = ratio[n] * advantage[n]; + if (ratio[n] > 1.0f + epsilon) { + surr2[n] = (1.0f + epsilon)*advantage[n]; + } else if( ratio[n] < 1.0f - epsilon) { + surr2[n] = (1.0f - epsilon)*advantage[n]; } else { - surr2[i] = ratio[i]*advantage[i]; + surr2[n] = ratio[n]*advantage[n]; } - loss[i] = -min(surr1[i], surr2[i]); - loss_sum = loss_sum + loss[i]; + loss[n] = -min(surr1[n], surr2[n]); + loss_sum = loss_sum + loss[n]; } reward_sum = 0.0f; for (int i=0; i<batch_size; i++) { @@ -2687,11 +2723,8 @@ hxhh_a_sum_array[RL_timer][1] = hxhh_a_sum[1]; mean_array[RL_timer] = mean; deviation_array[RL_timer] = deviation; - mean_before_SP_array[RL_timer] = mean_before_SP; - deviation_before_SP_array[RL_timer] = deviation_before_SP; action_array[RL_timer] = rand_normal(mean_array[RL_timer], deviation_array[RL_timer]); - virt_pos = virt_pos + (action_array[RL_timer] - 3.0f) * 1000.0f * 0.0002f; if (virt_pos > 70.0f ) { virt_pos = 70.0f;