[Ver 1.0] The code was given by Seunghoon shin, used for hydraulic quadrupedal robot. Buyoun Cho will revise the code for Post-LIGHT (the robot name is not determined yet).

Dependencies:   mbed FastPWM

Revision:
175:2f7289dbd488
Parent:
174:c828479f53f9
Child:
176:589ea3edcf3c
--- a/main.cpp	Tue Nov 24 06:09:50 2020 +0000
+++ b/main.cpp	Tue Nov 24 08:15:29 2020 +0000
@@ -1,4 +1,4 @@
-//201124_1
+//201124_3
 #include "mbed.h"
 #include "FastPWM.h"
 #include "INIT_HW.h"
@@ -490,22 +490,22 @@
 void update_Critic_Networks(float (*arr)[num_input_RL])
 {
     float gradient_rate = 0.001f;
-//    float hx_sum = 0.0f;
-
-
-    ///////////////////////////////////////////////////////////CRITIC
     float G_hc1[num_input_RL][num_hidden_unit1] = {0.0f};
-    float G_bc1[num_hidden_unit1] = {0.0f};
+    float d_V_d_hc1[batch_size][num_input_RL][num_hidden_unit1] = {0.0f};   ////////////////1
+    float G_bc1[num_hidden_unit1] = {0.0f}; 
+    float d_V_d_bc1[batch_size][num_hidden_unit1] = {0.0f};                 ////////////////2
     for (int index2 = 0; index2 < num_hidden_unit1; index2++) {
         for (int index1 = 0; index1 < num_input_RL; index1++) {
             for (int n=0; n<batch_size; n++) {
                 for(int k=0; k<num_hidden_unit2; k++) {
                     if (hxh_c_sum_array[n][k] >= 0) {
                         if (hx_c_sum_array[n][index2] > 0) {
-                            G_hc1[index1][index2] = G_hc1[index1][index2] + arr[n][index1]*hc2_temp[index2][k]*hc3_temp[k];
+                            //G_hc1[index1][index2] = G_hc1[index1][index2] + arr[n][index1]*hc2_temp[index2][k]*hc3_temp[k];               //////////////////////3
+                            d_V_d_hc1[n][index1][index2] = d_V_d_hc1[n][index1][index2] + arr[n][index1]*hc2_temp[index2][k]*hc3_temp[k];   //////////////////////4
                         }
                     }
                 }
+                G_hc1[index1][index2] = G_hc1[index1][index2] + 2.0f*(return_G[n]-V[n])*(-d_V_d_hc1[n][index1][index2]);                    /////////////////////5
             }
             G_hc1[index1][index2] = G_hc1[index1][index2] / batch_size;
             //hc1_temp[index1][index2] = hc1_temp[index1][index2] - gradient_rate * G_hc1[index1][index2];
@@ -514,53 +514,68 @@
             for(int k=0; k<num_hidden_unit2; k++) {
                 if (hxh_c_sum_array[n][k] >= 0) {
                     if (hx_c_sum_array[n][index2] > 0) {
-                        G_bc1[index2] = G_bc1[index2] + hc2_temp[index2][k]*hc3_temp[k];
+                        //G_bc1[index2] = G_bc1[index2] + hc2_temp[index2][k]*hc3_temp[k];                                                  //////////////////6
+                        d_V_d_bc1[n][index2] = d_V_d_bc1[n][index2] + hc2_temp[index2][k]*hc3_temp[k];                                      //////////////////7
                     }
 
                 }
             }
+            G_bc1[index2] = G_bc1[index2] + 2.0f*(return_G[n]-V[n])*(-d_V_d_bc1[n][index2]);                                                /////////////////////8
         }
         G_bc1[index2] = G_bc1[index2] / batch_size;
         //bc1_temp[index2] = bc1_temp[index2] - gradient_rate * G_bc1[index2];
     }
 
+
     float G_hc2[num_hidden_unit1][num_hidden_unit2] = {0.0f};
+    float d_V_d_hc2[batch_size][num_hidden_unit1][num_hidden_unit2] = {0.0f};   
     float G_bc2[num_hidden_unit2] = {0.0f};
+    float d_V_d_bc2[batch_size][num_hidden_unit2] = {0.0f};                    
     for (int index2 = 0; index2 < num_hidden_unit2; index2++) {
         for (int index1 = 0; index1 < num_hidden_unit1; index1++) {
             for (int n=0; n<batch_size; n++) {
                 if (hxh_c_sum_array[n][index2] >= 0) {
                     if (hx_c_sum_array[n][index1] > 0) {
-                        G_hc2[index1][index2] = G_hc2[index1][index2] + hx_c_sum_array[n][index1]*hc3_temp[index2];
+                        //G_hc2[index1][index2] = G_hc2[index1][index2] + hx_c_sum_array[n][index1]*hc3_temp[index2];
+                        d_V_d_hc2[n][index1][index2] = hx_c_sum_array[n][index1]*hc3_temp[index2];
                     }
                 }
+                G_hc2[index1][index2] = G_hc2[index1][index2] + 2.0f*(return_G[n]-V[n])*(-d_V_d_hc2[n][index1][index2]);                    
             }
             G_hc2[index1][index2] = G_hc2[index1][index2] / batch_size;
             //hc2_temp[index1][index2] = hc2_temp[index1][index2] - gradient_rate * G_hc2[index1][index2];
         }
         for (int n=0; n<batch_size; n++) {
             if (hxh_c_sum_array[n][index2] >= 0) {
-                G_bc2[index2] = G_bc2[index2] + hc3_temp[index2];
+                //G_bc2[index2] = G_bc2[index2] + hc3_temp[index2];
+                d_V_d_bc2[n][index2] = hc3_temp[index2];
             }
+            G_bc2[index2] = G_bc2[index2] + 2.0f*(return_G[n]-V[n])*(-d_V_d_bc2[n][index2]); 
         }
         G_bc2[index2] = G_bc2[index2] / batch_size;
         //bc2_temp[index2] = bc2_temp[index2] - gradient_rate * G_bc2[index2];
     }
 
     float G_hc3[num_hidden_unit2]= {0.0f};
+    float d_V_d_hc3[batch_size][num_hidden_unit2] = {0.0f};  
     float G_bc3 = 0.0f;
+    float d_V_d_bc3[batch_size] = {0.0f}; 
     for (int index2 = 0; index2 < 1; index2++) {
         for (int index1 = 0; index1 < num_hidden_unit2; index1++) {
             for (int n=0; n<batch_size; n++) {
                 if (hxh_c_sum_array[n][index1] >= 0) {
-                    G_hc3[index1] = G_hc3[index1] + hxh_c_sum_array[n][index1];
+                    //G_hc3[index1] = G_hc3[index1] + hxh_c_sum_array[n][index1];
+                    d_V_d_hc3[n][index1] = d_V_d_hc3[n][index1] + hxh_c_sum_array[n][index1];
                 }
+                G_hc3[index1] = G_hc3[index1] + 2.0f*(return_G[n]-V[n])*(-d_V_d_hc3[n][index1]);
             }
             G_hc3[index1] = G_hc3[index1] / batch_size;
             //hc3_temp[index1] = hc3_temp[index1] - gradient_rate * G_hc3[index1];
         }
         for (int n=0; n<batch_size; n++) {
-            G_bc2[index2] = G_bc2[index2] + 1.0f;
+            //G_bc2[index2] = G_bc2[index2] + 1.0f;
+            d_V_d_bc3[n] = 1.0f;
+            G_bc3 = G_bc3 + 2.0f*(return_G[n]-V[n])*(-d_V_d_bc3[n]); 
         }
         G_bc3 = G_bc3 / batch_size;
         //bc3_temp = bc3_temp - gradient_rate * G_bc3;
@@ -623,7 +638,7 @@
                 }
             }
             G_ha1[index1][index2] = G_ha1[index1][index2] / batch_size;
-            ha1_temp[index1][index2] = ha1_temp[index1][index2] - gradient_rate * G_ha1[index1][index2];
+            //ha1_temp[index1][index2] = ha1_temp[index1][index2] - gradient_rate * G_ha1[index1][index2];
         }
 
         for (int n=0; n<batch_size; n++) {
@@ -647,7 +662,7 @@
             }
         }
         G_ba1[index2] = G_ba1[index2] / batch_size;
-        ba1_temp[index2] = ba1_temp[index2] - gradient_rate * G_ba1[index2];
+        //ba1_temp[index2] = ba1_temp[index2] - gradient_rate * G_ba1[index2];
     }
 
     float G_ha2[num_hidden_unit1][num_hidden_unit2] = {0.0f};
@@ -680,7 +695,7 @@
                 }
             }
             G_ha2[index1][index2] = G_ha2[index1][index2] / batch_size;
-            ha2_temp[index1][index2] = ha2_temp[index1][index2] - gradient_rate * G_ha2[index1][index2];
+            //ha2_temp[index1][index2] = ha2_temp[index1][index2] - gradient_rate * G_ha2[index1][index2];
         }
 
         for (int n=0; n<batch_size; n++) {
@@ -701,7 +716,7 @@
             }
         }
         G_ba2[index2] = G_ba2[index2] / batch_size;
-        ba2_temp[index2] = ba2_temp[index2] - gradient_rate * G_ba2[index2];
+        //ba2_temp[index2] = ba2_temp[index2] - gradient_rate * G_ba2[index2];
     }
 
     float G_ha3[num_hidden_unit2][2] = {0.0f};
@@ -722,7 +737,6 @@
                         if (hx_a_sum_array[n][index1] > 0) {
                             d_x_d_ha3[index1][index2] = d_x_d_ha3[index1][index2] + hxh_a_sum_array[n][index1];
                             d_y_d_ha3[index1][index2] = d_y_d_ha3[index1][index2] + hxh_a_sum_array[n][index1];
-
                         }
                     }
                     float d_mean_d_ha3 = 0.0f;
@@ -734,7 +748,7 @@
                 }
             }
             G_ha3[index1][index2] = G_ha3[index1][index2] / batch_size;
-            ha3_temp[index1][index2] = ha3_temp[index1][index2] - gradient_rate * G_ha3[index1][index2];
+            //ha3_temp[index1][index2] = ha3_temp[index1][index2] - gradient_rate * G_ha3[index1][index2];
         }
 
         for (int n=0; n<batch_size; n++) {
@@ -754,8 +768,29 @@
             }
         }
         G_ba3[index2] = G_ba3[index2] / batch_size;
+        //ba3_temp[index2] = ba3_temp[index2] - gradient_rate * G_ba3[index2];
+    }
+    
+    // Simultaneous Update
+    for (int index2 = 0; index2 < num_hidden_unit1; index2++) {
+        for (int index1 = 0; index1 < num_input_RL; index1++) {
+            ha1_temp[index1][index2] = ha1_temp[index1][index2] - gradient_rate * G_ha1[index1][index2];
+        }
+        ba1_temp[index2] = ba1_temp[index2] - gradient_rate * G_ba1[index2];
+    }
+    for (int index2 = 0; index2 < num_hidden_unit2; index2++) {
+        for (int index1 = 0; index1 < num_hidden_unit1; index1++) {
+            ha2_temp[index1][index2] = ha2_temp[index1][index2] - gradient_rate * G_ha2[index1][index2];
+        }
+        ba2_temp[index2] = ba2_temp[index2] - gradient_rate * G_ba2[index2];
+    }
+    for (int index2 = 0; index2 < 2; index2++) {
+        for (int index1 = 0; index1 < num_hidden_unit2; index1++) {
+            ha3_temp[index1][index2] = ha3_temp[index1][index2] - gradient_rate * G_ha3[index1][index2];
+        }
         ba3_temp[index2] = ba3_temp[index2] - gradient_rate * G_ba3[index2];
     }
+    
 }
 
 ///////////////////////////ReLU - Bad performance//////////////////////////////////
@@ -1223,40 +1258,41 @@
                 //Network Update(just update and hold network)
                 for (int epoch = 0; epoch < num_epoch; epoch++) {
                     float loss_sum = 0.0f;
-                    for (int i=batch_size-1; i>=0; i--) {
+                    for (int n=batch_size-1; n>=0; n--) {
                         //Calculate Estimated V
-                        //float temp_array[3] = {state_array[i][0], state_array[i][1], state_array[i][2]};
-                        float temp_array[2] = {state_array[i][0], state_array[i][1]};
-                        V[i] = Critic_Network_Temp(temp_array);
+                        //float temp_array[3] = {state_array[n][0], state_array[n][1], state_array[n][2]};
+                        float temp_array[2] = {state_array[n][0], state_array[n][1]};
+                        V[n] = Critic_Network_Temp(temp_array);
                         for (int i=0; i<num_hidden_unit1; i++) {
-                            hx_c_sum_array[RL_timer][i] = hx_c_sum[i];
+                            hx_c_sum_array[n][i] = hx_c_sum[i];
                         }
                         for (int i=0; i<num_hidden_unit2; i++) {
-                            hxh_c_sum_array[RL_timer][i] = hxh_c_sum[i];
+                            hxh_c_sum_array[n][i] = hxh_c_sum[i];
                         }
-                        hxhh_c_sum_array[RL_timer] = hxhh_c_sum;
-                        pi[i] = exp(-(action_array[i]-mean_array[i])*(action_array[i]-mean_array[i])/(2.0f*deviation_array[i]*deviation_array[i]))/(sqrt(2.0f*PI)*deviation_array[i]);
+                        hxhh_c_sum_array[n] = hxhh_c_sum;
+                        
+                        pi[n] = exp(-(action_array[n]-mean_array[n])*(action_array[n]-mean_array[n])/(2.0f*deviation_array[n]*deviation_array[n]))/(sqrt(2.0f*PI)*deviation_array[n]);
                         Actor_Network_Old(temp_array);
-                        pi_old[i] = exp(-(action_array[i]-mean_old)*(action_array[i]-mean_old)/(2.0f*deviation_old*deviation_old))/(sqrt(2.0f*PI)*deviation_old);
-                        r[i] = exp(-0.25f * state_array[i][1] * state_array[i][1]);
-                        if(i == batch_size-1) return_G[i] = 0.0f;
-                        else return_G[i] = gamma * return_G[i+1] + r[i];
-                        if(i == batch_size-1) td_target[i] = r[i];
-                        else td_target[i] = r[i] + gamma * V[i+1];
-                        delta[i] = td_target[i] - V[i];
-                        if(i == batch_size-1) advantage[i] = 0.0f;
-                        else advantage[i] = gamma * lmbda * advantage[i+1] + delta[i];
-                        ratio[i] = pi[i]/pi_old[i];
-                        surr1[i] = ratio[i] * advantage[i];
-                        if (ratio[i] > 1.0f + epsilon) {
-                            surr2[i] = (1.0f + epsilon)*advantage[i];
-                        } else if( ratio[i] < 1.0f - epsilon) {
-                            surr2[i] = (1.0f - epsilon)*advantage[i];
+                        pi_old[n] = exp(-(action_array[n]-mean_old)*(action_array[n]-mean_old)/(2.0f*deviation_old*deviation_old))/(sqrt(2.0f*PI)*deviation_old);
+                        r[n] = exp(-0.25f * state_array[n][1] * state_array[n][1]);
+                        if(n == batch_size-1) return_G[n] = 0.0f;
+                        else return_G[n] = gamma * return_G[n+1] + r[n];
+                        if(n == batch_size-1) td_target[n] = r[n];
+                        else td_target[n] = r[n] + gamma * V[n+1];
+                        delta[n] = td_target[n] - V[n];
+                        if(n == batch_size-1) advantage[n] = 0.0f;
+                        else advantage[n] = gamma * lmbda * advantage[n+1] + delta[n];
+                        ratio[n] = pi[n]/pi_old[n];
+                        surr1[n] = ratio[n] * advantage[n];
+                        if (ratio[n] > 1.0f + epsilon) {
+                            surr2[n] = (1.0f + epsilon)*advantage[n];
+                        } else if( ratio[n] < 1.0f - epsilon) {
+                            surr2[n] = (1.0f - epsilon)*advantage[n];
                         } else {
-                            surr2[i] = ratio[i]*advantage[i];
+                            surr2[n] = ratio[n]*advantage[n];
                         }
-                        loss[i] = -min(surr1[i], surr2[i]);
-                        loss_sum = loss_sum + loss[i];
+                        loss[n] = -min(surr1[n], surr2[n]);
+                        loss_sum = loss_sum + loss[n];
                     }
                     reward_sum = 0.0f;
                     for (int i=0; i<batch_size; i++) {
@@ -2687,11 +2723,8 @@
                     hxhh_a_sum_array[RL_timer][1] = hxhh_a_sum[1];
                     mean_array[RL_timer] = mean;
                     deviation_array[RL_timer] = deviation;
-                    mean_before_SP_array[RL_timer] = mean_before_SP;
-                    deviation_before_SP_array[RL_timer] = deviation_before_SP;
                     action_array[RL_timer] = rand_normal(mean_array[RL_timer], deviation_array[RL_timer]);
 
-
                     virt_pos = virt_pos + (action_array[RL_timer] - 3.0f) * 1000.0f * 0.0002f;
                     if (virt_pos > 70.0f ) {
                         virt_pos = 70.0f;