20210305

Dependencies:   mbed FastPWM

Revision:
176:589ea3edcf3c
Parent:
175:2f7289dbd488
Child:
177:8e9cf31d63f4
--- a/main.cpp	Tue Nov 24 08:15:29 2020 +0000
+++ b/main.cpp	Tue Nov 24 10:16:10 2020 +0000
@@ -1,4 +1,4 @@
-//201124_3
+//201124_4
 #include "mbed.h"
 #include "FastPWM.h"
 #include "INIT_HW.h"
@@ -610,10 +610,10 @@
 
     float G_ha1[num_input_RL][num_hidden_unit1] = {0.0f};
     float G_ba1[num_hidden_unit1] = {0.0f};
-    float d_x_d_ha1[num_input_RL][num_hidden_unit1] = {0.0f};
-    float d_x_d_ba1[num_hidden_unit1] = {0.0f};
-    float d_y_d_ha1[num_input_RL][num_hidden_unit1] = {0.0f};
-    float d_y_d_ba1[num_hidden_unit1] = {0.0f};
+    float d_x_d_ha1[batch_size][num_input_RL][num_hidden_unit1] = {0.0f};
+    float d_x_d_ba1[batch_size][num_hidden_unit1] = {0.0f};
+    float d_y_d_ha1[batch_size][num_input_RL][num_hidden_unit1] = {0.0f};
+    float d_y_d_ba1[batch_size][num_hidden_unit1] = {0.0f};
 
     for (int index2 = 0; index2 < num_hidden_unit1; index2++) {
         for (int index1 = 0; index1 < num_input_RL; index1++) {
@@ -624,20 +624,20 @@
                     for(int k=0; k<num_hidden_unit2; k++) {
                         if (hxh_a_sum_array[n][k] >= 0) {
                             if (hx_a_sum_array[n][index2] > 0) {
-                                d_x_d_ha1[index1][index2] = d_x_d_ha1[index1][index2] + arr[n][index1]*ha2_temp[index2][k]*ha3_temp[k][0];
-                                d_y_d_ha1[index1][index2] = d_y_d_ha1[index1][index2] + arr[n][index1]*ha2_temp[index2][k]*ha3_temp[k][1];
+                                d_x_d_ha1[n][index1][index2] = d_x_d_ha1[n][index1][index2] + arr[n][index1]*ha2_temp[index2][k]*ha3_temp[k][0];
+                                d_y_d_ha1[n][index1][index2] = d_y_d_ha1[n][index1][index2] + arr[n][index1]*ha2_temp[index2][k]*ha3_temp[k][1];
                             }
                         }
                     }
                     float d_mean_d_ha1 = 0.0f;
                     float d_dev_d_ha1 = 0.0f;
-                    d_mean_d_ha1 = exp(hxhh_a_sum_array[n][0])/(1.0f+exp(hxhh_a_sum_array[n][0]))*d_x_d_ha1[index1][index2];
-                    d_dev_d_ha1 = exp(hxhh_a_sum_array[n][1])/(1.0f+exp(hxhh_a_sum_array[n][1]))*d_y_d_ha1[index1][index2];
+                    d_mean_d_ha1 = exp(hxhh_a_sum_array[n][0])/(1.0f+exp(hxhh_a_sum_array[n][0]))*d_x_d_ha1[n][index1][index2];
+                    d_dev_d_ha1 = exp(hxhh_a_sum_array[n][1])/(1.0f+exp(hxhh_a_sum_array[n][1]))*d_y_d_ha1[n][index1][index2];
 
                     G_ha1[index1][index2] = G_ha1[index1][index2] + advantage[n]/pi_old[n]*(d_mean_d_ha1*Grad_Normal_Dist_Mean(mean_array[n],deviation_array[n],action_array[n])+d_dev_d_ha1*Grad_Normal_Dist_Deviation(mean_array[n],deviation_array[n],action_array[n]));
                 }
             }
-            G_ha1[index1][index2] = G_ha1[index1][index2] / batch_size;
+            G_ha1[index1][index2] = -G_ha1[index1][index2] / batch_size;
             //ha1_temp[index1][index2] = ha1_temp[index1][index2] - gradient_rate * G_ha1[index1][index2];
         }
 
@@ -648,29 +648,29 @@
                 for(int k=0; k<num_hidden_unit2; k++) {
                     if (hxh_a_sum_array[n][k] >= 0) {
                         if (hx_a_sum_array[n][index2] > 0) {
-                            d_x_d_ba1[index2] = d_x_d_ba1[index2] + ha2_temp[index2][k]*ha3_temp[k][0];
-                            d_y_d_ba1[index2] = d_y_d_ba1[index2] + ha2_temp[index2][k]*ha3_temp[k][1];
+                            d_x_d_ba1[n][index2] = d_x_d_ba1[n][index2] + ha2_temp[index2][k]*ha3_temp[k][0];
+                            d_y_d_ba1[n][index2] = d_y_d_ba1[n][index2] + ha2_temp[index2][k]*ha3_temp[k][1];
                         }
                     }
                 }
                 float d_mean_d_ba1 = 0.0f;
                 float d_dev_d_ba1 = 0.0f;
-                d_mean_d_ba1 = exp(hxhh_a_sum_array[n][0])/(1.0f+exp(hxhh_a_sum_array[n][0]))*d_x_d_ba1[index2];
-                d_dev_d_ba1 = exp(hxhh_a_sum_array[n][1])/(1.0f+exp(hxhh_a_sum_array[n][1]))*d_y_d_ba1[index2];
+                d_mean_d_ba1 = exp(hxhh_a_sum_array[n][0])/(1.0f+exp(hxhh_a_sum_array[n][0]))*d_x_d_ba1[n][index2];
+                d_dev_d_ba1 = exp(hxhh_a_sum_array[n][1])/(1.0f+exp(hxhh_a_sum_array[n][1]))*d_y_d_ba1[n][index2];
 
                 G_ba1[index2] = G_ba1[index2] + advantage[n]/pi_old[n]*(d_mean_d_ba1*Grad_Normal_Dist_Mean(mean_array[n],deviation_array[n],action_array[n])+d_dev_d_ba1*Grad_Normal_Dist_Deviation(mean_array[n],deviation_array[n],action_array[n]));
             }
         }
-        G_ba1[index2] = G_ba1[index2] / batch_size;
+        G_ba1[index2] = -G_ba1[index2] / batch_size;
         //ba1_temp[index2] = ba1_temp[index2] - gradient_rate * G_ba1[index2];
     }
 
     float G_ha2[num_hidden_unit1][num_hidden_unit2] = {0.0f};
     float G_ba2[num_hidden_unit2] = {0.0f};
-    float d_x_d_ha2[num_hidden_unit1][num_hidden_unit2] = {0.0f};
-    float d_x_d_ba2[num_hidden_unit2] = {0.0f};
-    float d_y_d_ha2[num_hidden_unit1][num_hidden_unit2] = {0.0f};
-    float d_y_d_ba2[num_hidden_unit2] = {0.0f};
+    float d_x_d_ha2[batch_size][num_hidden_unit1][num_hidden_unit2] = {0.0f};
+    float d_x_d_ba2[batch_size][num_hidden_unit2] = {0.0f};
+    float d_y_d_ha2[batch_size][num_hidden_unit1][num_hidden_unit2] = {0.0f};
+    float d_y_d_ba2[batch_size][num_hidden_unit2] = {0.0f};
 
     for (int index2 = 0; index2 < num_hidden_unit2; index2++) {
         for (int index1 = 0; index1 < num_hidden_unit1; index1++) {
@@ -678,23 +678,22 @@
                 if((advantage[n] >= 0.0f && ratio[n] >= 1.0f + epsilon) || (advantage[n] < 0.0f && ratio[n] < 1.0f - epsilon)) {
                     G_ha2[index1][index2] = G_ha2[index1][index2];
                 } else {
-
                     if (hxh_a_sum_array[n][index2] >= 0) {
                         if (hx_a_sum_array[n][index1] > 0) {
-                            d_x_d_ha2[index1][index2] = d_x_d_ha2[index1][index2] + hx_a_sum_array[n][index1]*ha3_temp[index2][0];
-                            d_y_d_ha2[index1][index2] = d_y_d_ha2[index1][index2] + hx_a_sum_array[n][index1]*ha3_temp[index2][1];
+                            d_x_d_ha2[n][index1][index2] = hx_a_sum_array[n][index1]*ha3_temp[index2][0];
+                            d_y_d_ha2[n][index1][index2] = hx_a_sum_array[n][index1]*ha3_temp[index2][1];
                         }
                     }
 
                     float d_mean_d_ha2 = 0.0f;
                     float d_dev_d_ha2 = 0.0f;
-                    d_mean_d_ha2 = exp(hxhh_a_sum_array[n][0])/(1.0f+exp(hxhh_a_sum_array[n][0]))*d_x_d_ha2[index1][index2];
-                    d_dev_d_ha2 = exp(hxhh_a_sum_array[n][1])/(1.0f+exp(hxhh_a_sum_array[n][1]))*d_y_d_ha2[index1][index2];
+                    d_mean_d_ha2 = exp(hxhh_a_sum_array[n][0])/(1.0f+exp(hxhh_a_sum_array[n][0]))*d_x_d_ha2[n][index1][index2];
+                    d_dev_d_ha2 = exp(hxhh_a_sum_array[n][1])/(1.0f+exp(hxhh_a_sum_array[n][1]))*d_y_d_ha2[n][index1][index2];
 
                     G_ha2[index1][index2] = G_ha2[index1][index2] + advantage[n]/pi_old[n]*(d_mean_d_ha2*Grad_Normal_Dist_Mean(mean_array[n],deviation_array[n],action_array[n])+d_dev_d_ha2*Grad_Normal_Dist_Deviation(mean_array[n],deviation_array[n],action_array[n]));
                 }
             }
-            G_ha2[index1][index2] = G_ha2[index1][index2] / batch_size;
+            G_ha2[index1][index2] = -G_ha2[index1][index2] / batch_size;
             //ha2_temp[index1][index2] = ha2_temp[index1][index2] - gradient_rate * G_ha2[index1][index2];
         }
 
@@ -704,27 +703,27 @@
             } else {
 
                 if (hxh_a_sum_array[n][index2] >= 0) {
-                    d_x_d_ba2[index2] = d_x_d_ba2[index2] + ha3_temp[index2][0];
-                    d_y_d_ba2[index2] = d_y_d_ba2[index2] + ha3_temp[index2][1];
+                    d_x_d_ba2[n][index2] = ha3_temp[index2][0];
+                    d_y_d_ba2[n][index2] = ha3_temp[index2][1];
                 }
                 float d_mean_d_ba2= 0.0f;
                 float d_dev_d_ba2= 0.0f;
-                d_mean_d_ba2 = exp(hxhh_a_sum_array[n][0])/(1.0f+exp(hxhh_a_sum_array[n][0]))*d_x_d_ba2[index2];
-                d_dev_d_ba2 = exp(hxhh_a_sum_array[n][1])/(1.0f+exp(hxhh_a_sum_array[n][1]))*d_y_d_ba2[index2];
+                d_mean_d_ba2 = exp(hxhh_a_sum_array[n][0])/(1.0f+exp(hxhh_a_sum_array[n][0]))*d_x_d_ba2[n][index2];
+                d_dev_d_ba2 = exp(hxhh_a_sum_array[n][1])/(1.0f+exp(hxhh_a_sum_array[n][1]))*d_y_d_ba2[n][index2];
 
                 G_ba2[index2] = G_ba2[index2] + advantage[n]/pi_old[n]*(d_mean_d_ba2*Grad_Normal_Dist_Mean(mean_array[n],deviation_array[n],action_array[n])+d_dev_d_ba2*Grad_Normal_Dist_Deviation(mean_array[n],deviation_array[n],action_array[n]));
             }
         }
-        G_ba2[index2] = G_ba2[index2] / batch_size;
+        G_ba2[index2] = -G_ba2[index2] / batch_size;
         //ba2_temp[index2] = ba2_temp[index2] - gradient_rate * G_ba2[index2];
     }
 
     float G_ha3[num_hidden_unit2][2] = {0.0f};
     float G_ba3[2] = {0.0f};
-    float d_x_d_ha3[num_hidden_unit2][2] = {0.0f};
-    float d_x_d_ba3[2] = {0.0f};
-    float d_y_d_ha3[num_hidden_unit2][2] = {0.0f};
-    float d_y_d_ba3[2] = {0.0f};
+    float d_x_d_ha3[batch_size][num_hidden_unit2][2] = {0.0f};
+    float d_x_d_ba3[batch_size][2] = {0.0f};
+    float d_y_d_ha3[batch_size][num_hidden_unit2][2] = {0.0f};
+    float d_y_d_ba3[batch_size][2] = {0.0f};
 
     for (int index2 = 0; index2 < 2; index2++) {
         for (int index1 = 0; index1 < num_hidden_unit2; index1++) {
@@ -732,22 +731,21 @@
                 if((advantage[n] >= 0.0f && ratio[n] >= 1.0f + epsilon) || (advantage[n] < 0.0f && ratio[n] < 1.0f - epsilon)) {
                     G_ha3[index1][index2] = G_ha3[index1][index2];
                 } else {
-
                     if (hxh_a_sum_array[n][index1] >= 0) {
                         if (hx_a_sum_array[n][index1] > 0) {
-                            d_x_d_ha3[index1][index2] = d_x_d_ha3[index1][index2] + hxh_a_sum_array[n][index1];
-                            d_y_d_ha3[index1][index2] = d_y_d_ha3[index1][index2] + hxh_a_sum_array[n][index1];
+                            d_x_d_ha3[n][index1][index2] = hxh_a_sum_array[n][index1];
+                            d_y_d_ha3[n][index1][index2] = hxh_a_sum_array[n][index1];
                         }
                     }
                     float d_mean_d_ha3 = 0.0f;
                     float d_dev_d_ha3 = 0.0f;
-                    d_mean_d_ha3 = exp(hxhh_a_sum_array[n][0])/(1.0f+exp(hxhh_a_sum_array[n][0]))*d_x_d_ha3[index1][index2];
-                    d_dev_d_ha3 = exp(hxhh_a_sum_array[n][1])/(1.0f+exp(hxhh_a_sum_array[n][1]))*d_y_d_ha3[index1][index2];
+                    d_mean_d_ha3 = exp(hxhh_a_sum_array[n][0])/(1.0f+exp(hxhh_a_sum_array[n][0]))*d_x_d_ha3[n][index1][index2];
+                    d_dev_d_ha3 = exp(hxhh_a_sum_array[n][1])/(1.0f+exp(hxhh_a_sum_array[n][1]))*d_y_d_ha3[n][index1][index2];
 
                     G_ha3[index1][index2] = G_ha3[index1][index2] + advantage[n]/pi_old[n]*(d_mean_d_ha3*Grad_Normal_Dist_Mean(mean_array[n],deviation_array[n],action_array[n])+d_dev_d_ha3*Grad_Normal_Dist_Deviation(mean_array[n],deviation_array[n],action_array[n]));
                 }
             }
-            G_ha3[index1][index2] = G_ha3[index1][index2] / batch_size;
+            G_ha3[index1][index2] = -G_ha3[index1][index2] / batch_size;
             //ha3_temp[index1][index2] = ha3_temp[index1][index2] - gradient_rate * G_ha3[index1][index2];
         }
 
@@ -756,18 +754,18 @@
                 G_ba3[index2] = G_ba3[index2];
             } else {
 
-                d_x_d_ba3[index2] = d_x_d_ba3[index2] + 1.0f;
-                d_y_d_ba3[index2] = d_y_d_ba3[index2] + 1.0f;
+                d_x_d_ba3[n][index2] = 1.0f;
+                d_y_d_ba3[n][index2] = 1.0f;
 
                 float d_mean_d_ba3= 0.0f;
                 float d_dev_d_ba3= 0.0f;
-                d_mean_d_ba3 = exp(hxhh_a_sum_array[n][0])/(1.0f+exp(hxhh_a_sum_array[n][0]))*d_x_d_ba3[index2];
-                d_dev_d_ba3 = exp(hxhh_a_sum_array[n][1])/(1.0f+exp(hxhh_a_sum_array[n][1]))*d_y_d_ba3[index2];
+                d_mean_d_ba3 = exp(hxhh_a_sum_array[n][0])/(1.0f+exp(hxhh_a_sum_array[n][0]))*d_x_d_ba3[n][index2];
+                d_dev_d_ba3 = exp(hxhh_a_sum_array[n][1])/(1.0f+exp(hxhh_a_sum_array[n][1]))*d_y_d_ba3[n][index2];
 
                 G_ba3[index2] = G_ba3[index2] + advantage[n]/pi_old[n]*(d_mean_d_ba3*Grad_Normal_Dist_Mean(mean_array[n],deviation_array[n],action_array[n])+d_dev_d_ba3*Grad_Normal_Dist_Deviation(mean_array[n],deviation_array[n],action_array[n]));
             }
         }
-        G_ba3[index2] = G_ba3[index2] / batch_size;
+        G_ba3[index2] = -G_ba3[index2] / batch_size;
         //ba3_temp[index2] = ba3_temp[index2] - gradient_rate * G_ba3[index2];
     }