NuMaker connection with AWS IoT thru MQTT/HTTPS

Dependencies:   MQTT

Revision:
41:b878d7cd7035
Parent:
36:1bec082ad582
Child:
44:2f9dc54e7f95
--- a/main.cpp	Thu Nov 12 16:24:00 2020 +0800
+++ b/main.cpp	Fri Mar 27 14:32:06 2020 +0800
@@ -14,8 +14,6 @@
 #define AWS_IOT_HTTPS_TEST      0
 
 #include "mbed.h"
-
-/* MyTLSSocket = Mbed TLS over TCPSocket */
 #include "MyTLSSocket.h"
 
 #if AWS_IOT_MQTT_TEST
@@ -238,13 +236,8 @@
      * @param[in] net_iface Network interface
      */
     AWS_IoT_MQTT_Test(const char * domain, const uint16_t port, NetworkInterface *net_iface) :
-        _domain(domain), _port(port) {
-        _tlssocket = new MyTLSSocket(net_iface, SSL_CA_CERT_PEM, SSL_USER_CERT_PEM, SSL_USER_PRIV_KEY_PEM);
-        /* Blocking mode */
-        _tlssocket->set_blocking(true);
-        /* Print Mbed TLS handshake log */
-        _tlssocket->set_debug(true);
-
+        _domain(domain), _port(port), _net_iface(net_iface) {
+        _tlssocket = new MyTLSSocket;
         _mqtt_client = new MQTT::Client<MyTLSSocket, Countdown, MAX_MQTT_PACKET_SIZE>(*_tlssocket);
     }
 
@@ -266,21 +259,61 @@
 
         int tls_rc;
         int mqtt_rc;
-        
+
         do {
+            /* Set host name of the remote host, used for certificate checking */
+            _tlssocket->set_hostname(_domain);
+
+            /* Set the certification of Root CA */
+            tls_rc = _tlssocket->set_root_ca_cert(SSL_CA_CERT_PEM);
+            if (tls_rc != NSAPI_ERROR_OK) {
+                printf("TLSSocket::set_root_ca_cert(...) returned %d\n", tls_rc);
+                break;
+            }
+
+            /* Set client certificate and client private key */
+            tls_rc = _tlssocket->set_client_cert_key(SSL_USER_CERT_PEM, SSL_USER_PRIV_KEY_PEM);
+            if (tls_rc != NSAPI_ERROR_OK) {
+                printf("TLSSocket::set_client_cert_key(...) returned %d\n", tls_rc);
+                break;
+            }
+
+            /* Blocking mode */
+            _tlssocket->set_blocking(true);
+
+            /* Open a network socket on the network stack of the given network interface */
+            printf("Opening network socket on network stack\n");
+            tls_rc = _tlssocket->open(_net_iface);
+            if (tls_rc != NSAPI_ERROR_OK) {
+                printf("Opens network socket on network stack failed: %d\n", tls_rc);
+                break;
+            }
+            printf("Opens network socket on network stack OK\n");
+
+            /* DNS resolution */
+            printf("DNS resolution for %s...\n", _domain);
+            SocketAddress sockaddr;
+            tls_rc = _net_iface->gethostbyname(_domain, &sockaddr);
+            if (tls_rc != NSAPI_ERROR_OK) {
+                printf("DNS resolution for %s failed with %d\n", _domain, tls_rc);
+                break;
+            }
+            sockaddr.set_port(_port);
+            printf("DNS resolution for %s: %s:%d\n", _domain, sockaddr.get_ip_address(), sockaddr.get_port());
+
             /* Connect to the server */
             /* Initialize TLS-related stuff */
             printf("Connecting with %s:%d\n", _domain, _port);
-            tls_rc = _tlssocket->connect(_domain, _port);
+            tls_rc = _tlssocket->connect(sockaddr);
             if (tls_rc != NSAPI_ERROR_OK) {
                 printf("Connects with %s:%d failed: %d\n", _domain, _port, tls_rc);
                 break;
             }
             printf("Connects with %s:%d OK\n", _domain, _port);
-            
+
             /* See the link below for AWS IoT support for MQTT:
              * http://docs.aws.amazon.com/iot/latest/developerguide/protocols.html */
-         
+
             /* MQTT connect */
             /* The message broker does not support persistent sessions (connections made with 
              * the cleanSession flag set to false. */
@@ -306,9 +339,9 @@
             conn_data.cleansession = 1;
             //conn_data.username.cstring = "USERNAME";
             //conn_data.password.cstring = "PASSWORD";
-        
+
             MQTT::connackData connack_data;
-        
+
             /* _tlssocket must connect to the network endpoint before calling this. */
             printf("MQTT connecting");
             if ((mqtt_rc = _mqtt_client->connect(conn_data, connack_data)) != 0) {
@@ -316,57 +349,56 @@
                 break;
             }
             printf("\rMQTT connects OK\n\n");
-            
+
             /* Subscribe/publish user topic */
             printf("Subscribing/publishing user topic\n");
             if (! sub_pub_topic(USER_MQTT_TOPIC, USER_MQTT_TOPIC_FILTERS, sizeof (USER_MQTT_TOPIC_FILTERS) / sizeof (USER_MQTT_TOPIC_FILTERS[0]), USER_MQTT_TOPIC_PUBLISH_MESSAGE)) {
                 break;
             }
             printf("Subscribes/publishes user topic OK\n\n");
-            
+
             /* Subscribe/publish UpdateThingShadow topic */
             printf("Subscribing/publishing UpdateThingShadow topic\n");
             if (! sub_pub_topic(UPDATETHINGSHADOW_MQTT_TOPIC, UPDATETHINGSHADOW_MQTT_TOPIC_FILTERS, sizeof (UPDATETHINGSHADOW_MQTT_TOPIC_FILTERS) / sizeof (UPDATETHINGSHADOW_MQTT_TOPIC_FILTERS[0]), UPDATETHINGSHADOW_MQTT_TOPIC_PUBLISH_MESSAGE)) {
                 break;
             }
             printf("Subscribes/publishes UpdateThingShadow topic OK\n\n");
-            
+
             /* Subscribe/publish GetThingShadow topic */
             printf("Subscribing/publishing GetThingShadow topic\n");
             if (! sub_pub_topic(GETTHINGSHADOW_MQTT_TOPIC, GETTHINGSHADOW_MQTT_TOPIC_FILTERS, sizeof (GETTHINGSHADOW_MQTT_TOPIC_FILTERS) / sizeof (GETTHINGSHADOW_MQTT_TOPIC_FILTERS[0]), GETTHINGSHADOW_MQTT_TOPIC_PUBLISH_MESSAGE)) {
                 break;
             }
             printf("Subscribes/publishes GetThingShadow topic OK\n\n");
-            
+
             /* Subscribe/publish DeleteThingShadow topic */
             printf("Subscribing/publishing DeleteThingShadow topic\n");
             if (! sub_pub_topic(DELETETHINGSHADOW_MQTT_TOPIC, DELETETHINGSHADOW_MQTT_TOPIC_FILTERS, sizeof (DELETETHINGSHADOW_MQTT_TOPIC_FILTERS) / sizeof (DELETETHINGSHADOW_MQTT_TOPIC_FILTERS[0]), DELETETHINGSHADOW_MQTT_TOPIC_PUBLISH_MESSAGE)) {
                 break;
             }
             printf("Subscribes/publishes DeleteThingShadow topic OK\n\n");
-            
+
         } while (0);
-        
+
         printf("MQTT disconnecting");
         if ((mqtt_rc = _mqtt_client->disconnect()) != 0) {
             printf("\rMQTT disconnects failed %d\n\n", mqtt_rc);
         }
         printf("\rMQTT disconnects OK\n\n");
-        
+
         _tlssocket->close();
     }
 
-    
 protected:
 
     /**
      * @brief   Subscribe/publish specific topic
      */
     bool sub_pub_topic(const char *topic, const char **topic_filters, size_t topic_filters_size, const char *publish_message_body) {
-        
+
         bool ret = false;
         int mqtt_rc;
-        
+
         do {
             const char **topic_filter;
             const char **topic_filter_end = topic_filters + topic_filters_size;
@@ -388,7 +420,7 @@
             MQTT::Message message;
 
             int _bpos;
-        
+
             _bpos = snprintf(_buffer, sizeof (_buffer) - 1, publish_message_body);
             if (_bpos < 0 || ((size_t) _bpos) > (sizeof (_buffer) - 1)) {
                 printf("snprintf failed: %d\n", _bpos);
@@ -411,7 +443,7 @@
                 break;
             }
             printf("\rMQTT publishes message to %s OK\n", topic);
-        
+
             /* Receive message with subscribed topic */
             printf("MQTT receives message with subscribed %s...\n", topic);
             Timer timer;
@@ -441,12 +473,12 @@
             }
 
             ret = true;
-        
+
         } while (0);
-        
+
         return ret;
     }
-    
+
 protected:
     MyTLSSocket *                                                           _tlssocket;
     MQTT::Client<MyTLSSocket, Countdown, MAX_MQTT_PACKET_SIZE> *            _mqtt_client;
@@ -454,7 +486,8 @@
     const char *_domain;                    /**< Domain name of the MQTT server */
     const uint16_t _port;                   /**< Port number of the MQTT server */
     char _buffer[MQTT_USER_BUFFER_SIZE];    /**< User buffer */
-    
+    NetworkInterface *_net_iface;
+
 private:
     static volatile uint16_t   _message_arrive_count;
 
@@ -465,7 +498,7 @@
         printf("%.*s\n", message.payloadlen, (char*)message.payload);
         ++ _message_arrive_count;
     }
-    
+
     static void clear_message_arrive_count() {
         _message_arrive_count = 0;
     }
@@ -493,13 +526,8 @@
      * @param[in] net_iface Network interface
      */
     AWS_IoT_HTTPS_Test(const char * domain, const uint16_t port, NetworkInterface *net_iface) :
-            _domain(domain), _port(port) {
-
-        _tlssocket = new MyTLSSocket(net_iface, SSL_CA_CERT_PEM, SSL_USER_CERT_PEM, SSL_USER_PRIV_KEY_PEM);
-        /* Non-blocking mode */
-        _tlssocket->set_blocking(false);
-        /* Print Mbed TLS handshake log */
-        _tlssocket->set_debug(true);
+        _domain(domain), _port(port), _net_iface(net_iface) {
+        _tlssocket = new MyTLSSocket;
     }
     /**
      * @brief AWS_IoT_HTTPS_Test Destructor
@@ -515,19 +543,62 @@
      * @param[in] path  The path of the file to fetch from the HTTPS server
      */
     void start_test() {
-        
+
         int tls_rc;
-         
+
         do {
+            /* Set host name of the remote host, used for certificate checking */
+            _tlssocket->set_hostname(_domain);
+
+            /* Set the certification of Root CA */
+            tls_rc = _tlssocket->set_root_ca_cert(SSL_CA_CERT_PEM);
+            if (tls_rc != NSAPI_ERROR_OK) {
+                printf("TLSSocket::set_root_ca_cert(...) returned %d\n", tls_rc);
+                break;
+            }
+
+            /* Set client certificate and client private key */
+            tls_rc = _tlssocket->set_client_cert_key(SSL_USER_CERT_PEM, SSL_USER_PRIV_KEY_PEM);
+            if (tls_rc != NSAPI_ERROR_OK) {
+                printf("TLSSocket::set_client_cert_key(...) returned %d\n", tls_rc);
+                break;
+            }
+
+            /* Open a network socket on the network stack of the given network interface */
+            printf("Opening network socket on network stack\n");
+            tls_rc = _tlssocket->open(_net_iface);
+            if (tls_rc != NSAPI_ERROR_OK) {
+                printf("Opens network socket on network stack failed: %d\n", tls_rc);
+                break;
+            }
+            printf("Opens network socket on network stack OK\n");
+
+            /* DNS resolution */
+            printf("DNS resolution for %s...\n", _domain);
+            SocketAddress sockaddr;
+            tls_rc = _net_iface->gethostbyname(_domain, &sockaddr);
+            if (tls_rc != NSAPI_ERROR_OK) {
+                printf("DNS resolution for %s failed with %d\n", _domain, tls_rc);
+                break;
+            }
+            sockaddr.set_port(_port);
+            printf("DNS resolution for %s: %s:%d\n", _domain, sockaddr.get_ip_address(), sockaddr.get_port());
+
             /* Connect to the server */
             /* Initialize TLS-related stuff */
             printf("Connecting with %s:%d\n", _domain, _port);
-            tls_rc = _tlssocket->connect(_domain, _port);
+            tls_rc = _tlssocket->connect(sockaddr);
             if (tls_rc != NSAPI_ERROR_OK) {
                 printf("Connects with %s:%d failed: %d\n", _domain, _port, tls_rc);
                 break;
             }
-            printf("Connects with %s:%d OK\n\n", _domain, _port);
+            printf("Connects with %s:%d OK\n", _domain, _port);
+
+            /* Non-blocking mode
+             *
+             * Don't change to non-blocking mode before connect; otherwise, we may meet NSAPI_ERROR_IN_PROGRESS.
+             */
+            _tlssocket->set_blocking(false);
 
             /* Publish to user topic through HTTPS/POST */
             printf("Publishing to user topic through HTTPS/POST\n");
@@ -535,42 +606,42 @@
                 break;
             }
             printf("Publishes to user topic through HTTPS/POST OK\n\n");
-        
+
             /* Update thing shadow by publishing to UpdateThingShadow topic through HTTPS/POST */
             printf("Updating thing shadow by publishing to Update Thing Shadow topic through HTTPS/POST\n");
             if (! run_req_resp(UPDATETHINGSHADOW_TOPIC_HTTPS_PATH, UPDATETHINGSHADOW_TOPIC_HTTPS_REQUEST_METHOD, UPDATETHINGSHADOW_TOPIC_HTTPS_REQUEST_MESSAGE_BODY)) {
                 break;
             }
             printf("Update thing shadow by publishing to Update Thing Shadow topic through HTTPS/POST OK\n\n");
-            
+
             /* Get thing shadow by publishing to GetThingShadow topic through HTTPS/POST */
             printf("Getting thing shadow by publishing to GetThingShadow topic through HTTPS/POST\n");
             if (! run_req_resp(GETTHINGSHADOW_TOPIC_HTTPS_PATH, GETTHINGSHADOW_TOPIC_HTTPS_REQUEST_METHOD, GETTHINGSHADOW_TOPIC_HTTPS_REQUEST_MESSAGE_BODY)) {
                 break;
             }
             printf("Get thing shadow by publishing to GetThingShadow topic through HTTPS/POST OK\n\n");
-            
+
             /* Delete thing shadow by publishing to DeleteThingShadow topic through HTTPS/POST */
             printf("Deleting thing shadow by publishing to DeleteThingShadow topic through HTTPS/POST\n");
             if (! run_req_resp(DELETETHINGSHADOW_TOPIC_HTTPS_PATH, DELETETHINGSHADOW_TOPIC_HTTPS_REQUEST_METHOD, DELETETHINGSHADOW_TOPIC_HTTPS_REQUEST_MESSAGE_BODY)) {
                 break;
             }
             printf("Delete thing shadow by publishing to DeleteThingShadow topic through HTTPS/POST OK\n\n");
-            
+
             /* Update thing shadow RESTfully through HTTPS/POST */
             printf("Updating thing shadow RESTfully through HTTPS/POST\n");
             if (! run_req_resp(UPDATETHINGSHADOW_THING_HTTPS_PATH, UPDATETHINGSHADOW_THING_HTTPS_REQUEST_METHOD, UPDATETHINGSHADOW_THING_HTTPS_REQUEST_MESSAGE_BODY)) {
                 break;
             }
             printf("Update thing shadow RESTfully through HTTPS/POST OK\n\n");
-            
+
             /* Get thing shadow RESTfully through HTTPS/GET */
             printf("Getting thing shadow RESTfully through HTTPS/GET\n");
             if (! run_req_resp(GETTHINGSHADOW_THING_HTTPS_PATH, GETTHINGSHADOW_THING_HTTPS_REQUEST_METHOD, GETTHINGSHADOW_THING_HTTPS_REQUEST_MESSAGE_BODY)) {
                 break;
             }
             printf("Get thing shadow RESTfully through HTTPS/GET OK\n\n");
-            
+
             /* Delete thing shadow RESTfully through HTTPS/DELETE */
             printf("Deleting thing shadow RESTfully through HTTPS/DELETE\n");
             if (! run_req_resp(DELETETHINGSHADOW_THING_HTTPS_PATH, DELETETHINGSHADOW_THING_HTTPS_REQUEST_METHOD, DELETETHINGSHADOW_THING_HTTPS_REQUEST_MESSAGE_BODY)) {
@@ -590,9 +661,9 @@
      * @brief   Run request/response through HTTPS
      */
     bool run_req_resp(const char *https_path, const char *https_request_method, const char *https_request_message_body) {
-        
+
         bool ret = false;
-        
+
         do {
             int tls_rc;
             bool _got200 = false;
@@ -611,7 +682,7 @@
             /* Print request message */
             printf("HTTPS: Request message:\n");
             printf("%s\n", _buffer);
-        
+
             int offset = 0;
             do {
                 tls_rc = _tlssocket->send((const unsigned char *) _buffer + offset, _bpos - offset);
@@ -619,8 +690,9 @@
                     offset += tls_rc;
                 }
             } while (offset < _bpos && 
-                    (tls_rc > 0 || tls_rc == MBEDTLS_ERR_SSL_WANT_READ || tls_rc == MBEDTLS_ERR_SSL_WANT_WRITE));
-            if (tls_rc < 0) {
+                    (tls_rc > 0 || tls_rc == NSAPI_ERROR_WOULD_BLOCK));
+            if (tls_rc < 0 &&
+                tls_rc != NSAPI_ERROR_WOULD_BLOCK) {
                 print_mbedtls_error("_tlssocket->send", tls_rc);
                 break;
             }
@@ -675,10 +747,9 @@
                     }
                 }
             } while ((offset_end == 0 || offset < offset_end) &&
-                    (tls_rc > 0 || tls_rc == MBEDTLS_ERR_SSL_WANT_READ || tls_rc == MBEDTLS_ERR_SSL_WANT_WRITE));
+                    (tls_rc > 0 || tls_rc == NSAPI_ERROR_WOULD_BLOCK));
             if (tls_rc < 0 && 
-                tls_rc != MBEDTLS_ERR_SSL_WANT_READ && 
-                tls_rc != MBEDTLS_ERR_SSL_WANT_WRITE) {
+                tls_rc != NSAPI_ERROR_WOULD_BLOCK) {
                 print_mbedtls_error("_tlssocket->read", tls_rc);
                 break;
             }
@@ -691,26 +762,27 @@
             printf("HTTPS: Received 200 OK status ... %s\n", _got200 ? "[OK]" : "[FAIL]");
             printf("HTTPS: Received message:\n");
             printf("%s\n", _buffer);
-        
+
             ret = true;
-            
+
         } while (0);
-        
+
         return ret;
     }
-     
+
 protected:
     MyTLSSocket *     _tlssocket;
 
     const char *_domain;                    /**< Domain name of the HTTPS server */
     const uint16_t _port;                   /**< Port number of the HTTPS server */
     char _buffer[HTTPS_USER_BUFFER_SIZE];   /**< User buffer */
+    NetworkInterface *_net_iface;
 };
 
 #endif  // End of AWS_IOT_HTTPS_TEST
 
 int main() {
-    
+
     /* The default 9600 bps is too slow to print full TLS debug info and could
      * cause the other party to time out. */
 
@@ -732,14 +804,20 @@
         printf("Connecting to the network failed %d!\n", status);
         return -1;
     }
-    printf("Connected to the network successfully. IP address: %s\n", net->get_ip_address());
-    
+    SocketAddress sockaddr;
+    status = net->get_ip_address(&sockaddr);
+    if (status != NSAPI_ERROR_OK) {
+        printf("Network interface get_ip_address(...) failed with %d", status);
+        return -1;
+    }
+    printf("Connected to the network successfully. IP address: %s\n", sockaddr.get_ip_address());
+
 #if AWS_IOT_MQTT_TEST
     AWS_IoT_MQTT_Test *mqtt_test = new AWS_IoT_MQTT_Test(AWS_IOT_MQTT_SERVER_NAME, AWS_IOT_MQTT_SERVER_PORT, net);
     mqtt_test->start_test();
     delete mqtt_test;
 #endif  // End of AWS_IOT_MQTT_TEST
-    
+
 #if AWS_IOT_HTTPS_TEST
     AWS_IoT_HTTPS_Test *https_test = new AWS_IoT_HTTPS_Test(AWS_IOT_HTTPS_SERVER_NAME, AWS_IOT_HTTPS_SERVER_PORT, net);
     https_test->start_test();