A Threaded Secure MQTT Client example. Uses MBED TLS for SSL/TLS connection. QoS0 only for now. Example has been tested with K64F connected via Ethernet.
Fork of HelloMQTT by
Diff: MQTTThreadedClient.cpp
- Revision:
- 25:326f00faa092
- Parent:
- 23:06fac173529e
- Child:
- 26:4b21de8043a5
diff -r 9d5f0300d7ed -r 326f00faa092 MQTTThreadedClient.cpp --- a/MQTTThreadedClient.cpp Sun Mar 26 04:38:36 2017 +0000 +++ b/MQTTThreadedClient.cpp Mon Mar 27 03:53:18 2017 +0000 @@ -1,27 +1,280 @@ #include "mbed.h" #include "rtos.h" #include "MQTTThreadedClient.h" - +#include "mbedtls/platform.h" +#include "mbedtls/ssl.h" +#include "mbedtls/entropy.h" +#include "mbedtls/ctr_drbg.h" +#include "mbedtls/error.h" static MemoryPool<PubMessage, 16> mpool; static Queue<PubMessage, 16> mqueue; +// SSL/TLS variables +mbedtls_entropy_context _entropy; +mbedtls_ctr_drbg_context _ctr_drbg; +mbedtls_x509_crt _cacert; +mbedtls_ssl_context _ssl; +mbedtls_ssl_config _ssl_conf; +mbedtls_ssl_session saved_session; + +/** + * Receive callback for mbed TLS + */ +static int ssl_recv(void *ctx, unsigned char *buf, size_t len) +{ + int recv = -1; + TCPSocket *socket = static_cast<TCPSocket *>(ctx); + socket->set_timeout(DEFAULT_SOCKET_TIMEOUT); + recv = socket->recv(buf, len); + + if (NSAPI_ERROR_WOULD_BLOCK == recv) { + return MBEDTLS_ERR_SSL_WANT_READ; + } else if (recv < 0) { + return -1; + } else { + return recv; + } +} + +/** + * Send callback for mbed TLS + */ +static int ssl_send(void *ctx, const unsigned char *buf, size_t len) +{ + int sent = -1; + TCPSocket *socket = static_cast<TCPSocket *>(ctx); + socket->set_timeout(DEFAULT_SOCKET_TIMEOUT); + sent = socket->send(buf, len); + + if(NSAPI_ERROR_WOULD_BLOCK == sent) { + return MBEDTLS_ERR_SSL_WANT_WRITE; + } else if (sent < 0) { + return -1; + } else { + return sent; + } +} + +#if DEBUG_LEVEL > 0 +/** + * Debug callback for mbed TLS + * Just prints on the USB serial port + */ +static void my_debug(void *ctx, int level, const char *file, int line, + const char *str) +{ + const char *p, *basename; + (void) ctx; + + /* Extract basename from file */ + for(p = basename = file; *p != '\0'; p++) { + if(*p == '/' || *p == '\\') { + basename = p + 1; + } + } + + if (_debug) { + mbedtls_printf("%s:%04d: |%d| %s", basename, line, level, str); + } +} + +/** + * Certificate verification callback for mbed TLS + * Here we only use it to display information on each cert in the chain + */ +static int my_verify(void *data, mbedtls_x509_crt *crt, int depth, uint32_t *flags) +{ + const uint32_t buf_size = 1024; + char *buf = new char[buf_size]; + (void) data; + + if (_debug) mbedtls_printf("\nVerifying certificate at depth %d:\n", depth); + mbedtls_x509_crt_info(buf, buf_size - 1, " ", crt); + if (_debug) mbedtls_printf("%s", buf); + + if (*flags == 0) + if (_debug) mbedtls_printf("No verification issue for this certificate\n"); + else { + mbedtls_x509_crt_verify_info(buf, buf_size, " ! ", *flags); + if (_debug) mbedtls_printf("%s\n", buf); + } + + delete[] buf; + return 0; +} +#endif + + +void MQTTThreadedClient::setupTLS() +{ + if (ssl_ca_pem != NULL) + { + mbedtls_entropy_init(&_entropy); + mbedtls_ctr_drbg_init(&_ctr_drbg); + mbedtls_x509_crt_init(&_cacert); + mbedtls_ssl_init(&_ssl); + mbedtls_ssl_config_init(&_ssl_conf); + memset( &saved_session, 0, sizeof( mbedtls_ssl_session ) ); + } +} + +void MQTTThreadedClient::freeTLS() +{ + if (ssl_ca_pem != NULL) + { + mbedtls_entropy_free(&_entropy); + mbedtls_ctr_drbg_free(&_ctr_drbg); + mbedtls_x509_crt_free(&_cacert); + mbedtls_ssl_free(&_ssl); + mbedtls_ssl_config_free(&_ssl_conf); + } +} + +int MQTTThreadedClient::initTLS() +{ + int ret; + + printf("Initializing TLS ...\r\n"); + printf("mbedtls_ctr_drdbg_seed ...\r\n"); + if ((ret = mbedtls_ctr_drbg_seed(&_ctr_drbg, mbedtls_entropy_func, &_entropy, + (const unsigned char *) DRBG_PERS, + sizeof (DRBG_PERS))) != 0) { + printf("error [%d] mbedtls_crt_drbg_init", ret); + _error = ret; + return -1; + } + printf("mbedtls_x509_crt_parse ...\r\n"); + if ((ret = mbedtls_x509_crt_parse(&_cacert, (const unsigned char *) ssl_ca_pem, + strlen(ssl_ca_pem) + 1)) != 0) { + printf("error [%d] mbedtls_x509_crt_parse", ret); + _error = ret; + return -1; + } + + printf("mbedtls_ssl_config_defaults ...\r\n"); + if ((ret = mbedtls_ssl_config_defaults(&_ssl_conf, + MBEDTLS_SSL_IS_CLIENT, + MBEDTLS_SSL_TRANSPORT_STREAM, + MBEDTLS_SSL_PRESET_DEFAULT)) != 0) { + printf("error [%d] mbedtls_ssl_config_defaults", ret); + _error = ret; + return -1; + } + + printf("mbedtls_ssl_config_ca_chain ...\r\n"); + mbedtls_ssl_conf_ca_chain(&_ssl_conf, &_cacert, NULL); + printf("mbedtls_ssl_conf_rng ...\r\n"); + mbedtls_ssl_conf_rng(&_ssl_conf, mbedtls_ctr_drbg_random, &_ctr_drbg); + + /* It is possible to disable authentication by passing + * MBEDTLS_SSL_VERIFY_NONE in the call to mbedtls_ssl_conf_authmode() + */ + printf("mbedtls_ssl_conf_authmode ...\r\n"); + mbedtls_ssl_conf_authmode(&_ssl_conf, MBEDTLS_SSL_VERIFY_REQUIRED); + +#if DEBUG_LEVEL > 0 + mbedtls_ssl_conf_verify(&_ssl_conf, my_verify, NULL); + mbedtls_ssl_conf_dbg(&_ssl_conf, my_debug, NULL); + mbedtls_debug_set_threshold(DEBUG_LEVEL); +#endif + + printf("mbedtls_ssl_setup ...\r\n"); + if ((ret = mbedtls_ssl_setup(&_ssl, &_ssl_conf)) != 0) { + printf("error [%d] mbedtls_ssl_setup", ret); + _error = ret; + return -1; + } + + return 0; +} + +int MQTTThreadedClient::doTLSHandshake() +{ + int ret; + + /* Start the handshake, the rest will be done in onReceive() */ + printf("Starting the TLS handshake...\r\n"); + ret = mbedtls_ssl_handshake(&_ssl); + if (ret < 0) + { + if (ret != MBEDTLS_ERR_SSL_WANT_READ && + ret != MBEDTLS_ERR_SSL_WANT_WRITE) + { + printf("error [%d] mbedtls_ssl_handshake", ret); + tcpSocket->close(); + _error = -1; + } + else + { + // do not close the socket if timed out + _error = ret; + } + return -1; + } + + /* Handshake done, time to print info */ + printf("TLS connection to %s:%d established\r\n", + host.c_str(), port); + + const uint32_t buf_size = 1024; + char *buf = new char[buf_size]; + mbedtls_x509_crt_info(buf, buf_size, "\r ", + mbedtls_ssl_get_peer_cert(&_ssl)); + + printf("Server certificate:\r\n%s\r", buf); + // Verify server cert ... + uint32_t flags = mbedtls_ssl_get_verify_result(&_ssl); + if( flags != 0 ) + { + mbedtls_x509_crt_verify_info(buf, buf_size, "\r ! ", flags); + printf("Certificate verification failed:\r\n%s\r\r\n", buf); + // free server cert ... before error return + delete [] buf; + return -1; + } + + printf("Certificate verification passed\r\n\r\n"); + // delete server cert after verification + delete [] buf; + + // TODO: Save the session here for reconnect. + if( ( ret = mbedtls_ssl_get_session( &_ssl, &saved_session ) ) != 0 ) + { + printf( "mbedtls_ssl_get_session returned -0x%x\n\n", -ret ); + return -1; + } + + printf("Session saved for reconnect ...\r\n"); + return 0; +} + int MQTTThreadedClient::readBytesToBuffer(char * buffer, size_t size, int timeout) { int rc; - + if (tcpSocket == NULL) return -1; - - // non-blocking socket ... - tcpSocket->set_timeout(timeout); - rc = tcpSocket->recv( (void *) buffer, size); - - // return 0 bytes if timeout ... - if (NSAPI_ERROR_WOULD_BLOCK == rc) - return TIMEOUT; - else - return rc; // return the number of bytes received or error + + if (ssl_ca_pem != NULL) + { + // Do SSL/TLS read + rc = mbedtls_ssl_read(&_ssl, (unsigned char *) buffer, size); + if (MBEDTLS_ERR_SSL_WANT_READ == rc) + return TIMEOUT; + else + return rc; + } else { + // non-blocking socket ... + tcpSocket->set_timeout(timeout); + rc = tcpSocket->recv( (void *) buffer, size); + + // return 0 bytes if timeout ... + if (NSAPI_ERROR_WOULD_BLOCK == rc) + return TIMEOUT; + else + return rc; // return the number of bytes received or error + } } int MQTTThreadedClient::sendBytesFromBuffer(char * buffer, size_t size, int timeout) @@ -31,14 +284,24 @@ if (tcpSocket == NULL) return -1; - // set the write timeout - tcpSocket->set_timeout(timeout); - rc = tcpSocket->send(buffer, size); - - if ( NSAPI_ERROR_WOULD_BLOCK == rc) - return TIMEOUT; - else - return rc; + if (ssl_ca_pem != NULL) { + // Do SSL/TLS write + rc = mbedtls_ssl_write(&_ssl, (const unsigned char *) buffer, size); + if (MBEDTLS_ERR_SSL_WANT_WRITE == rc) + return TIMEOUT; + else + return rc; + } else { + + // set the write timeout + tcpSocket->set_timeout(timeout); + rc = tcpSocket->send(buffer, size); + + if ( NSAPI_ERROR_WOULD_BLOCK == rc) + return TIMEOUT; + else + return rc; + } } int MQTTThreadedClient::readPacketLength(int* value) @@ -227,19 +490,51 @@ return rc; } -int MQTTThreadedClient::connect(const char * host, uint16_t port, MQTTPacket_connectData & options) + +void MQTTThreadedClient::disconnect() +{ + if (isConnected) + { + // TODO: Send unsubscribe message ... + + isConnected = false; + tcpSocket->close(); + } +} + +int MQTTThreadedClient::connect(const char * chost, uint16_t cport, MQTTPacket_connectData & options) { - int ret; - + int ret = FAILURE; + + // Copy the settings for reconnection + host = chost; + port = cport; + connect_options = options; + tcpSocket->open(network); - if (( ret = tcpSocket->connect(host, port)) < 0 ) + if (ssl_ca_pem != NULL) + { + printf("mbedtls_ssl_set_hostname ...\r\n"); + mbedtls_ssl_set_hostname(&_ssl, host.c_str()); + printf("mbedtls_ssl_set_bio ...\r\n"); + mbedtls_ssl_set_bio(&_ssl, static_cast<void *>(tcpSocket), + ssl_send, ssl_recv, NULL ); + } + + if (( ret = tcpSocket->connect(host.c_str(), port)) < 0 ) { - printf("Error connecting to %s:%d with %d\r\n", host, port, ret); + printf("Error connecting to %s:%d with %d\r\n", host.c_str(), port, ret); return ret; } - - return connect(options); + + if ((ssl_ca_pem != NULL) && (doTLSHandshake() < 0)) + { + printf("TLS Handshake failed! \r\n"); + return FAILURE; + } + + return connect(connect_options); } int MQTTThreadedClient::publish(PubMessage& msg) @@ -257,7 +552,7 @@ *message = msg; // Push the data to the thread - printf("Pushing data to consumer thread ...\r\n"); + printf("[Thread:%d] Pushing data to consumer thread ...\r\n", Thread::gettid()); mqueue.put(message); return SUCCESS; @@ -269,7 +564,7 @@ if (!isConnected) { - printf("Not connected!!! ...\r\n"); + printf("[Thread:%d] Not connected!!! ...\r\n", Thread::gettid()); return FAILURE; } @@ -278,20 +573,86 @@ topicString, (unsigned char*) &message.payload[0], (int) message.payloadlen); if (len <= 0) { - printf("Failed serializing message ...\r\n"); + printf("[Thread:%d]Failed serializing message ...\r\n", Thread::gettid()); return FAILURE; } if (sendPacket(len) == SUCCESS) { - printf("Successfully sent publish packet to server ...\r\n"); + printf("[Thread:%d]Successfully sent publish packet to server ...\r\n", Thread::gettid()); return SUCCESS; } - printf("Failed to send publish packet to server ...\r\n"); + printf("[Thread:%d]Failed to send publish packet to server ...\r\n", Thread::gettid()); return FAILURE; } + +void MQTTThreadedClient::addTopicHandler(const char * topicstr, void (*function)(MessageData &)) +{ + // Push the subscription into the map ... + FP<void,MessageData &> fp; + fp.attach(function); + topicCBMap_t.insert(std::pair<std::string, FP<void,MessageData &> >(std::string(topicstr),fp)); +} + +int MQTTThreadedClient::processSubscriptions() +{ + int numsubscribed = 0; + + if (!isConnected) + { + printf("Session not connected!!\r\n"); + return 0; + } + + std::map<std::string, FP<void, MessageData &> >::iterator it; + for(it = topicCBMap_t.begin(); it != topicCBMap_t.end(); it++) + { + int rc = FAILURE; + int len = 0; + //TODO: We only subscribe to QoS = 0 for now + QoS qos = QOS0; + + MQTTString topic = {(char*)it->first.c_str(), {0, 0}}; + printf("Subscribing to topic [%s]\r\n", topic.cstring); + + + len = MQTTSerialize_subscribe(sendbuf, MAX_MQTT_PACKET_SIZE, 0, packetid.getNext(), 1, &topic, (int*)&qos); + if (len <= 0) { + printf("Error serializing subscribe packet ...\r\n"); + continue; + } + + if ((rc = sendPacket(len)) != SUCCESS) { + printf("Error sending subscribe packet [%d]\r\n", rc); + continue; + } + + printf("Waiting for subscription ack ...\r\n"); + // Wait for SUBACK, dropping packets read along the way ... + if (readUntil(SUBACK, COMMAND_TIMEOUT) == SUBACK) { // wait for suback + int count = 0, grantedQoS = -1; + unsigned short mypacketid; + if (MQTTDeserialize_suback(&mypacketid, 1, &count, &grantedQoS, readbuf, MAX_MQTT_PACKET_SIZE) == 1) + rc = grantedQoS; // 0, 1, 2 or 0x80 + // For as long as we do not get 0x80 .. + if (rc != 0x80) + { + // Reset connection timers here ... + resetConnectionTimer(); + printf("Successfully subscribed to %s ...\r\n", it->first.c_str()); + numsubscribed++; + } else { + printf("Failed to subscribe to topic %s ... (not authorized?)\r\n", it->first.c_str()); + } + } else + printf("Failed to subscribe to topic %s (ack not received) ...\r\n", it->first.c_str()); + } // end for loop + + return numsubscribed; +} + int MQTTThreadedClient::subscribe(const char * topicstr, QoS qos, void (*function)(MessageData &)) { int rc = FAILURE; @@ -389,7 +750,7 @@ MQTTString topicName = MQTTString_initializer; Message msg; int intQoS; - printf("Deserializing publish message ...\r\n"); + printf("[Thread:%d]Deserializing publish message ...\r\n", Thread::gettid()); if (MQTTDeserialize_publish((unsigned char*)&msg.dup, &intQoS, (unsigned char*)&msg.retained, @@ -398,7 +759,7 @@ (unsigned char**)&msg.payload, (int*)&msg.payloadlen, readbuf, MAX_MQTT_PACKET_SIZE) != 1) { - printf("Error deserializing published message ...\r\n"); + printf("[Thread:%d]Error deserializing published message ...\r\n", Thread::gettid()); return -1; } @@ -409,7 +770,7 @@ }else topic = (const char *) topicName.cstring; - printf("Got message for topic [%s], QoS [%d] ...\r\n", topic.c_str(), intQoS); + printf("[Thread:%d]Got message for topic [%s], QoS [%d] ...\r\n", Thread::gettid(), topic.c_str(), intQoS); msg.qos = (QoS) intQoS; @@ -420,7 +781,7 @@ // Call the callback function if (topicCBMap[topic].attached()) { - printf("Invoking function handler for topic ...\r\n"); + printf("[Thread:%d]Invoking function handler for topic ...\r\n", Thread::gettid()); MessageData md(topicName, msg); topicCBMap[topic](md); @@ -475,7 +836,7 @@ int len = MQTTSerialize_pingreq(sendbuf, MAX_MQTT_PACKET_SIZE); if (len > 0 && (sendPacket(len) == SUCCESS)) // send the ping packet { - printf("Ping request sent successfully ...\r\n"); + printf("[Thread:%d]Ping request sent successfully ...\r\n", Thread::gettid()); } } @@ -484,81 +845,93 @@ int pType; // Continuesly listens for packets and dispatch // message handlers ... - while(true) - { - pType = readPacket(); - switch(pType) + do { + // Connect to server ... + if (!isConnected) { - case TIMEOUT: - // No data available from the network ... - break; - case FAILURE: - case BUFFER_OVERFLOW: - { - // TODO: Network error, do we disconnect and reconnect? - printf("Failure or buffer overflow problem ... \r\n"); - MBED_ASSERT(false); - } - break; - /** - * The rest of the return codes below (all positive) is about MQTT - * response codes - **/ - case CONNACK: - case PUBACK: - case SUBACK: - break; - case PUBLISH: - { - printf("Publish received!....\r\n"); - // We receive data from the MQTT server .. - if (handlePublishMsg() < 0) - { - printf("Error handling PUBLISH message ... \r\n"); - break; - } - } - break; - case PINGRESP: - { - printf("Got ping response ...\r\n"); - resetConnectionTimer(); - } - break; - default: - printf("Unknown/Not handled message from server pType[%d]\r\n", pType); + // Attempt to reconnect ... } - // Check if its time to send a keepAlive packet - if (hasConnectionTimedOut()) - { - // Queue the ping request so that other - // pending operations queued above will go first - queue.call(this, &MQTTThreadedClient::sendPingRequest); - } - - // Check if we have messages on the message queue - osEvent evt = mqueue.get(10); - if (evt.status == osEventMessage) { - - printf("Got message to publish! ... \r\n"); - - // Unpack the message - PubMessage * message = (PubMessage *)evt.value.p; - - // Send the packet, do not queue the call - // like the ping above .. - if ( sendPublish(*message) == SUCCESS) - // Reset timers if we have been able to send successfully - resetConnectionTimer(); - - // Free the message from mempool after using - mpool.free(message); - } - - // Dispatch any queued events ... - queue.dispatch(100); - } - + while(true) { + pType = readPacket(); + switch(pType) { + case TIMEOUT: + // No data available from the network ... + break; + case FAILURE: + goto reconnect; + case BUFFER_OVERFLOW: + { + // TODO: Network error, do we disconnect and reconnect? + printf("[Thread:%d]Failure or buffer overflow problem ... \r\n", Thread::gettid()); + MBED_ASSERT(false); + } + break; + /** + * The rest of the return codes below (all positive) is about MQTT + * response codes + **/ + case CONNACK: + case PUBACK: + case SUBACK: + break; + case PUBLISH: + { + printf("[Thread:%d]Publish received!....\r\n", Thread::gettid()); + // We receive data from the MQTT server .. + if (handlePublishMsg() < 0) { + printf("[Thread:%d]Error handling PUBLISH message ... \r\n", Thread::gettid()); + break; + } + } + break; + case PINGRESP: + { + printf("[Thread:%d]Got ping response ...\r\n", Thread::gettid()); + resetConnectionTimer(); + } + break; + default: + printf("[Thread:%d]Unknown/Not handled message from server pType[%d]\r\n", Thread::gettid(), pType); + } + + // Check if its time to send a keepAlive packet + if (hasConnectionTimedOut()) { + // Queue the ping request so that other + // pending operations queued above will go first + queue.call(this, &MQTTThreadedClient::sendPingRequest); + } + + // Check if we have messages on the message queue + osEvent evt = mqueue.get(10); + if (evt.status == osEventMessage) { + + printf("[Thread:%d]Got message to publish! ... \r\n", Thread::gettid()); + + // Unpack the message + PubMessage * message = (PubMessage *)evt.value.p; + + // Send the packet, do not queue the call + // like the ping above .. + if ( sendPublish(*message) == SUCCESS) { + // Reset timers if we have been able to send successfully + resetConnectionTimer(); + } else { + // Disconnected? + goto reconnect; + } + + // Free the message from mempool after using + mpool.free(message); + } + + // Dispatch any queued events ... + queue.dispatch(100); + } // end while loop + +reconnect: + disconnect(); + // reconnect? + } while(true); } \ No newline at end of file