A Threaded Secure MQTT Client example. Uses MBED TLS for SSL/TLS connection. QoS0 only for now. Example has been tested with K64F connected via Ethernet.
Fork of HelloMQTT by
Diff: MQTTThreadedClient.cpp
- Revision:
- 25:326f00faa092
- Parent:
- 23:06fac173529e
- Child:
- 26:4b21de8043a5
--- a/MQTTThreadedClient.cpp Sun Mar 26 04:38:36 2017 +0000
+++ b/MQTTThreadedClient.cpp Mon Mar 27 03:53:18 2017 +0000
@@ -1,27 +1,280 @@
#include "mbed.h"
#include "rtos.h"
#include "MQTTThreadedClient.h"
-
+#include "mbedtls/platform.h"
+#include "mbedtls/ssl.h"
+#include "mbedtls/entropy.h"
+#include "mbedtls/ctr_drbg.h"
+#include "mbedtls/error.h"
static MemoryPool<PubMessage, 16> mpool;
static Queue<PubMessage, 16> mqueue;
+// SSL/TLS variables
+mbedtls_entropy_context _entropy;
+mbedtls_ctr_drbg_context _ctr_drbg;
+mbedtls_x509_crt _cacert;
+mbedtls_ssl_context _ssl;
+mbedtls_ssl_config _ssl_conf;
+mbedtls_ssl_session saved_session;
+
+/**
+ * Receive callback for mbed TLS
+ */
+static int ssl_recv(void *ctx, unsigned char *buf, size_t len)
+{
+ int recv = -1;
+ TCPSocket *socket = static_cast<TCPSocket *>(ctx);
+ socket->set_timeout(DEFAULT_SOCKET_TIMEOUT);
+ recv = socket->recv(buf, len);
+
+ if (NSAPI_ERROR_WOULD_BLOCK == recv) {
+ return MBEDTLS_ERR_SSL_WANT_READ;
+ } else if (recv < 0) {
+ return -1;
+ } else {
+ return recv;
+ }
+}
+
+/**
+ * Send callback for mbed TLS
+ */
+static int ssl_send(void *ctx, const unsigned char *buf, size_t len)
+{
+ int sent = -1;
+ TCPSocket *socket = static_cast<TCPSocket *>(ctx);
+ socket->set_timeout(DEFAULT_SOCKET_TIMEOUT);
+ sent = socket->send(buf, len);
+
+ if(NSAPI_ERROR_WOULD_BLOCK == sent) {
+ return MBEDTLS_ERR_SSL_WANT_WRITE;
+ } else if (sent < 0) {
+ return -1;
+ } else {
+ return sent;
+ }
+}
+
+#if DEBUG_LEVEL > 0
+/**
+ * Debug callback for mbed TLS
+ * Just prints on the USB serial port
+ */
+static void my_debug(void *ctx, int level, const char *file, int line,
+ const char *str)
+{
+ const char *p, *basename;
+ (void) ctx;
+
+ /* Extract basename from file */
+ for(p = basename = file; *p != '\0'; p++) {
+ if(*p == '/' || *p == '\\') {
+ basename = p + 1;
+ }
+ }
+
+ if (_debug) {
+ mbedtls_printf("%s:%04d: |%d| %s", basename, line, level, str);
+ }
+}
+
+/**
+ * Certificate verification callback for mbed TLS
+ * Here we only use it to display information on each cert in the chain
+ */
+static int my_verify(void *data, mbedtls_x509_crt *crt, int depth, uint32_t *flags)
+{
+ const uint32_t buf_size = 1024;
+ char *buf = new char[buf_size];
+ (void) data;
+
+ if (_debug) mbedtls_printf("\nVerifying certificate at depth %d:\n", depth);
+ mbedtls_x509_crt_info(buf, buf_size - 1, " ", crt);
+ if (_debug) mbedtls_printf("%s", buf);
+
+ if (*flags == 0)
+ if (_debug) mbedtls_printf("No verification issue for this certificate\n");
+ else {
+ mbedtls_x509_crt_verify_info(buf, buf_size, " ! ", *flags);
+ if (_debug) mbedtls_printf("%s\n", buf);
+ }
+
+ delete[] buf;
+ return 0;
+}
+#endif
+
+
+void MQTTThreadedClient::setupTLS()
+{
+ if (ssl_ca_pem != NULL)
+ {
+ mbedtls_entropy_init(&_entropy);
+ mbedtls_ctr_drbg_init(&_ctr_drbg);
+ mbedtls_x509_crt_init(&_cacert);
+ mbedtls_ssl_init(&_ssl);
+ mbedtls_ssl_config_init(&_ssl_conf);
+ memset( &saved_session, 0, sizeof( mbedtls_ssl_session ) );
+ }
+}
+
+void MQTTThreadedClient::freeTLS()
+{
+ if (ssl_ca_pem != NULL)
+ {
+ mbedtls_entropy_free(&_entropy);
+ mbedtls_ctr_drbg_free(&_ctr_drbg);
+ mbedtls_x509_crt_free(&_cacert);
+ mbedtls_ssl_free(&_ssl);
+ mbedtls_ssl_config_free(&_ssl_conf);
+ }
+}
+
+int MQTTThreadedClient::initTLS()
+{
+ int ret;
+
+ printf("Initializing TLS ...\r\n");
+ printf("mbedtls_ctr_drdbg_seed ...\r\n");
+ if ((ret = mbedtls_ctr_drbg_seed(&_ctr_drbg, mbedtls_entropy_func, &_entropy,
+ (const unsigned char *) DRBG_PERS,
+ sizeof (DRBG_PERS))) != 0) {
+ printf("error [%d] mbedtls_crt_drbg_init", ret);
+ _error = ret;
+ return -1;
+ }
+ printf("mbedtls_x509_crt_parse ...\r\n");
+ if ((ret = mbedtls_x509_crt_parse(&_cacert, (const unsigned char *) ssl_ca_pem,
+ strlen(ssl_ca_pem) + 1)) != 0) {
+ printf("error [%d] mbedtls_x509_crt_parse", ret);
+ _error = ret;
+ return -1;
+ }
+
+ printf("mbedtls_ssl_config_defaults ...\r\n");
+ if ((ret = mbedtls_ssl_config_defaults(&_ssl_conf,
+ MBEDTLS_SSL_IS_CLIENT,
+ MBEDTLS_SSL_TRANSPORT_STREAM,
+ MBEDTLS_SSL_PRESET_DEFAULT)) != 0) {
+ printf("error [%d] mbedtls_ssl_config_defaults", ret);
+ _error = ret;
+ return -1;
+ }
+
+ printf("mbedtls_ssl_config_ca_chain ...\r\n");
+ mbedtls_ssl_conf_ca_chain(&_ssl_conf, &_cacert, NULL);
+ printf("mbedtls_ssl_conf_rng ...\r\n");
+ mbedtls_ssl_conf_rng(&_ssl_conf, mbedtls_ctr_drbg_random, &_ctr_drbg);
+
+ /* It is possible to disable authentication by passing
+ * MBEDTLS_SSL_VERIFY_NONE in the call to mbedtls_ssl_conf_authmode()
+ */
+ printf("mbedtls_ssl_conf_authmode ...\r\n");
+ mbedtls_ssl_conf_authmode(&_ssl_conf, MBEDTLS_SSL_VERIFY_REQUIRED);
+
+#if DEBUG_LEVEL > 0
+ mbedtls_ssl_conf_verify(&_ssl_conf, my_verify, NULL);
+ mbedtls_ssl_conf_dbg(&_ssl_conf, my_debug, NULL);
+ mbedtls_debug_set_threshold(DEBUG_LEVEL);
+#endif
+
+ printf("mbedtls_ssl_setup ...\r\n");
+ if ((ret = mbedtls_ssl_setup(&_ssl, &_ssl_conf)) != 0) {
+ printf("error [%d] mbedtls_ssl_setup", ret);
+ _error = ret;
+ return -1;
+ }
+
+ return 0;
+}
+
+int MQTTThreadedClient::doTLSHandshake()
+{
+ int ret;
+
+ /* Start the handshake, the rest will be done in onReceive() */
+ printf("Starting the TLS handshake...\r\n");
+ ret = mbedtls_ssl_handshake(&_ssl);
+ if (ret < 0)
+ {
+ if (ret != MBEDTLS_ERR_SSL_WANT_READ &&
+ ret != MBEDTLS_ERR_SSL_WANT_WRITE)
+ {
+ printf("error [%d] mbedtls_ssl_handshake", ret);
+ tcpSocket->close();
+ _error = -1;
+ }
+ else
+ {
+ // do not close the socket if timed out
+ _error = ret;
+ }
+ return -1;
+ }
+
+ /* Handshake done, time to print info */
+ printf("TLS connection to %s:%d established\r\n",
+ host.c_str(), port);
+
+ const uint32_t buf_size = 1024;
+ char *buf = new char[buf_size];
+ mbedtls_x509_crt_info(buf, buf_size, "\r ",
+ mbedtls_ssl_get_peer_cert(&_ssl));
+
+ printf("Server certificate:\r\n%s\r", buf);
+ // Verify server cert ...
+ uint32_t flags = mbedtls_ssl_get_verify_result(&_ssl);
+ if( flags != 0 )
+ {
+ mbedtls_x509_crt_verify_info(buf, buf_size, "\r ! ", flags);
+ printf("Certificate verification failed:\r\n%s\r\r\n", buf);
+ // free server cert ... before error return
+ delete [] buf;
+ return -1;
+ }
+
+ printf("Certificate verification passed\r\n\r\n");
+ // delete server cert after verification
+ delete [] buf;
+
+ // TODO: Save the session here for reconnect.
+ if( ( ret = mbedtls_ssl_get_session( &_ssl, &saved_session ) ) != 0 )
+ {
+ printf( "mbedtls_ssl_get_session returned -0x%x\n\n", -ret );
+ return -1;
+ }
+
+ printf("Session saved for reconnect ...\r\n");
+ return 0;
+}
+
int MQTTThreadedClient::readBytesToBuffer(char * buffer, size_t size, int timeout)
{
int rc;
-
+
if (tcpSocket == NULL)
return -1;
-
- // non-blocking socket ...
- tcpSocket->set_timeout(timeout);
- rc = tcpSocket->recv( (void *) buffer, size);
-
- // return 0 bytes if timeout ...
- if (NSAPI_ERROR_WOULD_BLOCK == rc)
- return TIMEOUT;
- else
- return rc; // return the number of bytes received or error
+
+ if (ssl_ca_pem != NULL)
+ {
+ // Do SSL/TLS read
+ rc = mbedtls_ssl_read(&_ssl, (unsigned char *) buffer, size);
+ if (MBEDTLS_ERR_SSL_WANT_READ == rc)
+ return TIMEOUT;
+ else
+ return rc;
+ } else {
+ // non-blocking socket ...
+ tcpSocket->set_timeout(timeout);
+ rc = tcpSocket->recv( (void *) buffer, size);
+
+ // return 0 bytes if timeout ...
+ if (NSAPI_ERROR_WOULD_BLOCK == rc)
+ return TIMEOUT;
+ else
+ return rc; // return the number of bytes received or error
+ }
}
int MQTTThreadedClient::sendBytesFromBuffer(char * buffer, size_t size, int timeout)
@@ -31,14 +284,24 @@
if (tcpSocket == NULL)
return -1;
- // set the write timeout
- tcpSocket->set_timeout(timeout);
- rc = tcpSocket->send(buffer, size);
-
- if ( NSAPI_ERROR_WOULD_BLOCK == rc)
- return TIMEOUT;
- else
- return rc;
+ if (ssl_ca_pem != NULL) {
+ // Do SSL/TLS write
+ rc = mbedtls_ssl_write(&_ssl, (const unsigned char *) buffer, size);
+ if (MBEDTLS_ERR_SSL_WANT_WRITE == rc)
+ return TIMEOUT;
+ else
+ return rc;
+ } else {
+
+ // set the write timeout
+ tcpSocket->set_timeout(timeout);
+ rc = tcpSocket->send(buffer, size);
+
+ if ( NSAPI_ERROR_WOULD_BLOCK == rc)
+ return TIMEOUT;
+ else
+ return rc;
+ }
}
int MQTTThreadedClient::readPacketLength(int* value)
@@ -227,19 +490,51 @@
return rc;
}
-int MQTTThreadedClient::connect(const char * host, uint16_t port, MQTTPacket_connectData & options)
+
+void MQTTThreadedClient::disconnect()
+{
+ if (isConnected)
+ {
+ // TODO: Send unsubscribe message ...
+
+ isConnected = false;
+ tcpSocket->close();
+ }
+}
+
+int MQTTThreadedClient::connect(const char * chost, uint16_t cport, MQTTPacket_connectData & options)
{
- int ret;
-
+ int ret = FAILURE;
+
+ // Copy the settings for reconnection
+ host = chost;
+ port = cport;
+ connect_options = options;
+
tcpSocket->open(network);
- if (( ret = tcpSocket->connect(host, port)) < 0 )
+ if (ssl_ca_pem != NULL)
+ {
+ printf("mbedtls_ssl_set_hostname ...\r\n");
+ mbedtls_ssl_set_hostname(&_ssl, host.c_str());
+ printf("mbedtls_ssl_set_bio ...\r\n");
+ mbedtls_ssl_set_bio(&_ssl, static_cast<void *>(tcpSocket),
+ ssl_send, ssl_recv, NULL );
+ }
+
+ if (( ret = tcpSocket->connect(host.c_str(), port)) < 0 )
{
- printf("Error connecting to %s:%d with %d\r\n", host, port, ret);
+ printf("Error connecting to %s:%d with %d\r\n", host.c_str(), port, ret);
return ret;
}
-
- return connect(options);
+
+ if ((ssl_ca_pem != NULL) && (doTLSHandshake() < 0))
+ {
+ printf("TLS Handshake failed! \r\n");
+ return FAILURE;
+ }
+
+ return connect(connect_options);
}
int MQTTThreadedClient::publish(PubMessage& msg)
@@ -257,7 +552,7 @@
*message = msg;
// Push the data to the thread
- printf("Pushing data to consumer thread ...\r\n");
+ printf("[Thread:%d] Pushing data to consumer thread ...\r\n", Thread::gettid());
mqueue.put(message);
return SUCCESS;
@@ -269,7 +564,7 @@
if (!isConnected)
{
- printf("Not connected!!! ...\r\n");
+ printf("[Thread:%d] Not connected!!! ...\r\n", Thread::gettid());
return FAILURE;
}
@@ -278,20 +573,86 @@
topicString, (unsigned char*) &message.payload[0], (int) message.payloadlen);
if (len <= 0)
{
- printf("Failed serializing message ...\r\n");
+ printf("[Thread:%d]Failed serializing message ...\r\n", Thread::gettid());
return FAILURE;
}
if (sendPacket(len) == SUCCESS)
{
- printf("Successfully sent publish packet to server ...\r\n");
+ printf("[Thread:%d]Successfully sent publish packet to server ...\r\n", Thread::gettid());
return SUCCESS;
}
- printf("Failed to send publish packet to server ...\r\n");
+ printf("[Thread:%d]Failed to send publish packet to server ...\r\n", Thread::gettid());
return FAILURE;
}
+
+void MQTTThreadedClient::addTopicHandler(const char * topicstr, void (*function)(MessageData &))
+{
+ // Push the subscription into the map ...
+ FP<void,MessageData &> fp;
+ fp.attach(function);
+ topicCBMap_t.insert(std::pair<std::string, FP<void,MessageData &> >(std::string(topicstr),fp));
+}
+
+int MQTTThreadedClient::processSubscriptions()
+{
+ int numsubscribed = 0;
+
+ if (!isConnected)
+ {
+ printf("Session not connected!!\r\n");
+ return 0;
+ }
+
+ std::map<std::string, FP<void, MessageData &> >::iterator it;
+ for(it = topicCBMap_t.begin(); it != topicCBMap_t.end(); it++)
+ {
+ int rc = FAILURE;
+ int len = 0;
+ //TODO: We only subscribe to QoS = 0 for now
+ QoS qos = QOS0;
+
+ MQTTString topic = {(char*)it->first.c_str(), {0, 0}};
+ printf("Subscribing to topic [%s]\r\n", topic.cstring);
+
+
+ len = MQTTSerialize_subscribe(sendbuf, MAX_MQTT_PACKET_SIZE, 0, packetid.getNext(), 1, &topic, (int*)&qos);
+ if (len <= 0) {
+ printf("Error serializing subscribe packet ...\r\n");
+ continue;
+ }
+
+ if ((rc = sendPacket(len)) != SUCCESS) {
+ printf("Error sending subscribe packet [%d]\r\n", rc);
+ continue;
+ }
+
+ printf("Waiting for subscription ack ...\r\n");
+ // Wait for SUBACK, dropping packets read along the way ...
+ if (readUntil(SUBACK, COMMAND_TIMEOUT) == SUBACK) { // wait for suback
+ int count = 0, grantedQoS = -1;
+ unsigned short mypacketid;
+ if (MQTTDeserialize_suback(&mypacketid, 1, &count, &grantedQoS, readbuf, MAX_MQTT_PACKET_SIZE) == 1)
+ rc = grantedQoS; // 0, 1, 2 or 0x80
+ // For as long as we do not get 0x80 ..
+ if (rc != 0x80)
+ {
+ // Reset connection timers here ...
+ resetConnectionTimer();
+ printf("Successfully subscribed to %s ...\r\n", it->first.c_str());
+ numsubscribed++;
+ } else {
+ printf("Failed to subscribe to topic %s ... (not authorized?)\r\n", it->first.c_str());
+ }
+ } else
+ printf("Failed to subscribe to topic %s (ack not received) ...\r\n", it->first.c_str());
+ } // end for loop
+
+ return numsubscribed;
+}
+
int MQTTThreadedClient::subscribe(const char * topicstr, QoS qos, void (*function)(MessageData &))
{
int rc = FAILURE;
@@ -389,7 +750,7 @@
MQTTString topicName = MQTTString_initializer;
Message msg;
int intQoS;
- printf("Deserializing publish message ...\r\n");
+ printf("[Thread:%d]Deserializing publish message ...\r\n", Thread::gettid());
if (MQTTDeserialize_publish((unsigned char*)&msg.dup,
&intQoS,
(unsigned char*)&msg.retained,
@@ -398,7 +759,7 @@
(unsigned char**)&msg.payload,
(int*)&msg.payloadlen, readbuf, MAX_MQTT_PACKET_SIZE) != 1)
{
- printf("Error deserializing published message ...\r\n");
+ printf("[Thread:%d]Error deserializing published message ...\r\n", Thread::gettid());
return -1;
}
@@ -409,7 +770,7 @@
}else
topic = (const char *) topicName.cstring;
- printf("Got message for topic [%s], QoS [%d] ...\r\n", topic.c_str(), intQoS);
+ printf("[Thread:%d]Got message for topic [%s], QoS [%d] ...\r\n", Thread::gettid(), topic.c_str(), intQoS);
msg.qos = (QoS) intQoS;
@@ -420,7 +781,7 @@
// Call the callback function
if (topicCBMap[topic].attached())
{
- printf("Invoking function handler for topic ...\r\n");
+ printf("[Thread:%d]Invoking function handler for topic ...\r\n", Thread::gettid());
MessageData md(topicName, msg);
topicCBMap[topic](md);
@@ -475,7 +836,7 @@
int len = MQTTSerialize_pingreq(sendbuf, MAX_MQTT_PACKET_SIZE);
if (len > 0 && (sendPacket(len) == SUCCESS)) // send the ping packet
{
- printf("Ping request sent successfully ...\r\n");
+ printf("[Thread:%d]Ping request sent successfully ...\r\n", Thread::gettid());
}
}
@@ -484,81 +845,93 @@
int pType;
// Continuesly listens for packets and dispatch
// message handlers ...
- while(true)
- {
- pType = readPacket();
- switch(pType)
+ do {
+ // Connect to server ...
+ if (!isConnected)
{
- case TIMEOUT:
- // No data available from the network ...
- break;
- case FAILURE:
- case BUFFER_OVERFLOW:
- {
- // TODO: Network error, do we disconnect and reconnect?
- printf("Failure or buffer overflow problem ... \r\n");
- MBED_ASSERT(false);
- }
- break;
- /**
- * The rest of the return codes below (all positive) is about MQTT
- * response codes
- **/
- case CONNACK:
- case PUBACK:
- case SUBACK:
- break;
- case PUBLISH:
- {
- printf("Publish received!....\r\n");
- // We receive data from the MQTT server ..
- if (handlePublishMsg() < 0)
- {
- printf("Error handling PUBLISH message ... \r\n");
- break;
- }
- }
- break;
- case PINGRESP:
- {
- printf("Got ping response ...\r\n");
- resetConnectionTimer();
- }
- break;
- default:
- printf("Unknown/Not handled message from server pType[%d]\r\n", pType);
+ // Attempt to reconnect ...
}
- // Check if its time to send a keepAlive packet
- if (hasConnectionTimedOut())
- {
- // Queue the ping request so that other
- // pending operations queued above will go first
- queue.call(this, &MQTTThreadedClient::sendPingRequest);
- }
-
- // Check if we have messages on the message queue
- osEvent evt = mqueue.get(10);
- if (evt.status == osEventMessage) {
-
- printf("Got message to publish! ... \r\n");
-
- // Unpack the message
- PubMessage * message = (PubMessage *)evt.value.p;
-
- // Send the packet, do not queue the call
- // like the ping above ..
- if ( sendPublish(*message) == SUCCESS)
- // Reset timers if we have been able to send successfully
- resetConnectionTimer();
-
- // Free the message from mempool after using
- mpool.free(message);
- }
-
- // Dispatch any queued events ...
- queue.dispatch(100);
- }
-
+ while(true) {
+ pType = readPacket();
+ switch(pType) {
+ case TIMEOUT:
+ // No data available from the network ...
+ break;
+ case FAILURE:
+ goto reconnect;
+ case BUFFER_OVERFLOW:
+ {
+ // TODO: Network error, do we disconnect and reconnect?
+ printf("[Thread:%d]Failure or buffer overflow problem ... \r\n", Thread::gettid());
+ MBED_ASSERT(false);
+ }
+ break;
+ /**
+ * The rest of the return codes below (all positive) is about MQTT
+ * response codes
+ **/
+ case CONNACK:
+ case PUBACK:
+ case SUBACK:
+ break;
+ case PUBLISH:
+ {
+ printf("[Thread:%d]Publish received!....\r\n", Thread::gettid());
+ // We receive data from the MQTT server ..
+ if (handlePublishMsg() < 0) {
+ printf("[Thread:%d]Error handling PUBLISH message ... \r\n", Thread::gettid());
+ break;
+ }
+ }
+ break;
+ case PINGRESP:
+ {
+ printf("[Thread:%d]Got ping response ...\r\n", Thread::gettid());
+ resetConnectionTimer();
+ }
+ break;
+ default:
+ printf("[Thread:%d]Unknown/Not handled message from server pType[%d]\r\n", Thread::gettid(), pType);
+ }
+
+ // Check if its time to send a keepAlive packet
+ if (hasConnectionTimedOut()) {
+ // Queue the ping request so that other
+ // pending operations queued above will go first
+ queue.call(this, &MQTTThreadedClient::sendPingRequest);
+ }
+
+ // Check if we have messages on the message queue
+ osEvent evt = mqueue.get(10);
+ if (evt.status == osEventMessage) {
+
+ printf("[Thread:%d]Got message to publish! ... \r\n", Thread::gettid());
+
+ // Unpack the message
+ PubMessage * message = (PubMessage *)evt.value.p;
+
+ // Send the packet, do not queue the call
+ // like the ping above ..
+ if ( sendPublish(*message) == SUCCESS) {
+ // Reset timers if we have been able to send successfully
+ resetConnectionTimer();
+ } else {
+ // Disconnected?
+ goto reconnect;
+ }
+
+ // Free the message from mempool after using
+ mpool.free(message);
+ }
+
+ // Dispatch any queued events ...
+ queue.dispatch(100);
+ } // end while loop
+
+reconnect:
+ disconnect();
+ // reconnect?
+ } while(true);
}
\ No newline at end of file
