A Threaded Secure MQTT Client example. Uses MBED TLS for SSL/TLS connection. QoS0 only for now. Example has been tested with K64F connected via Ethernet.

Dependencies:   FP MQTTPacket

Fork of HelloMQTT by MQTT

MQTTThreadedClient.cpp

Committer:
vpcola
Date:
2017-03-26
Revision:
23:06fac173529e
Child:
25:326f00faa092

File content as of revision 23:06fac173529e:

#include "mbed.h"
#include "rtos.h"
#include "MQTTThreadedClient.h"


static MemoryPool<PubMessage, 16> mpool;
static Queue<PubMessage, 16> mqueue;

int MQTTThreadedClient::readBytesToBuffer(char * buffer, size_t size, int timeout)
{
    int rc;
    
    if (tcpSocket == NULL)
        return -1;
        
    // non-blocking socket ... 
    tcpSocket->set_timeout(timeout);
    rc = tcpSocket->recv( (void *) buffer, size);
    
    // return 0 bytes if timeout ...
    if (NSAPI_ERROR_WOULD_BLOCK == rc)
        return TIMEOUT;
    else
        return rc; // return the number of bytes received or error
}

int MQTTThreadedClient::sendBytesFromBuffer(char * buffer, size_t size, int timeout)
{
    int rc;
    
    if (tcpSocket == NULL)
        return -1;
    
    // set the write timeout
    tcpSocket->set_timeout(timeout);
    rc = tcpSocket->send(buffer, size);
    
    if ( NSAPI_ERROR_WOULD_BLOCK == rc)
        return TIMEOUT;
    else
        return rc;
}

int MQTTThreadedClient::readPacketLength(int* value)
{
    int rc = MQTTPACKET_READ_ERROR;
    unsigned char c;
    int multiplier = 1;
    int len = 0;
    const int MAX_NO_OF_REMAINING_LENGTH_BYTES = 4;

    *value = 0;
    do
    {
        if (++len > MAX_NO_OF_REMAINING_LENGTH_BYTES)
        {
            rc = MQTTPACKET_READ_ERROR; /* bad data */
            goto exit;
        }
        
        rc = readBytesToBuffer((char *) &c, 1, DEFAULT_SOCKET_TIMEOUT);
        if (rc != 1)
        {
            rc = MQTTPACKET_READ_ERROR;
            goto exit;
        }
            
        *value += (c & 127) * multiplier;
        multiplier *= 128;
    } while ((c & 128) != 0);
    
    rc = MQTTPACKET_READ_COMPLETE;
        
exit:
    if (rc == MQTTPACKET_READ_ERROR )
        len = -1;
    
    return len;
}

int MQTTThreadedClient::sendPacket(size_t length)
{
    int rc = FAILURE;
    int sent = 0;

    while (sent < length)
    {
        rc = sendBytesFromBuffer((char *) &sendbuf[sent], length - sent, DEFAULT_SOCKET_TIMEOUT);
        if (rc < 0)  // there was an error writing the data
            break;
        sent += rc;
    }
    
    if (sent == length)
        rc = SUCCESS;
    else
        rc = FAILURE;
        
    return rc;
}
/**
 * Reads the entire packet to readbuf and returns
 * the type of packet when successful, otherwise
 * a negative error code is returned.
 **/
int MQTTThreadedClient::readPacket()
{
    int rc = FAILURE;
    MQTTHeader header = {0};
    int len = 0;
    int rem_len = 0;

    /* 1. read the header byte.  This has the packet type in it */
    if ( (rc = readBytesToBuffer((char *) &readbuf[0], 1, DEFAULT_SOCKET_TIMEOUT)) != 1)
        goto exit;

    len = 1;
    /* 2. read the remaining length.  This is variable in itself */
    if ( readPacketLength(&rem_len) < 0 )
        goto exit;
        
    len += MQTTPacket_encode(readbuf + 1, rem_len); /* put the original remaining length into the buffer */

    if (rem_len > (MAX_MQTT_PACKET_SIZE - len))
    {
        rc = BUFFER_OVERFLOW;
        goto exit;
    }

    /* 3. read the rest of the buffer using a callback to supply the rest of the data */
    if (rem_len > 0 && (readBytesToBuffer((char *) (readbuf + len), rem_len, DEFAULT_SOCKET_TIMEOUT) != rem_len))
        goto exit;

    // Convert the header to type
    // and update rc
    header.byte = readbuf[0];
    rc = header.bits.type;
    
exit:

    return rc;    
}

/**
 * Read until a specified packet type is received, or untill the specified
 * timeout dropping packets along the way.
 **/
int MQTTThreadedClient::readUntil(int packetType, int timeout)
{
    int pType = FAILURE;
    Timer timer;
    
    timer.start();
    do {
        pType = readPacket();
        if (pType < 0)
            break;
            
        if (timer.read_ms() > timeout)
        {
            pType = FAILURE;
            break;
        }
    }while(pType != packetType);
    
    return pType;    
}


int MQTTThreadedClient::connect(MQTTPacket_connectData& options)
{
    int rc = FAILURE;
    int len = 0;

    if (isConnected)
    {
        printf("Session already connected! \r\n");
        return rc;
    }
        
    // Copy the keepAliveInterval value to local
    // MQTT specifies in seconds, we have to multiply that
    // amount for our 32 bit timers which accepts ms.
    keepAliveInterval = (options.keepAliveInterval * 1000);
    
    printf("Connecting with: \r\n");
    printf("\tUsername: [%s]\r\n", options.username.cstring);
    printf("\tPassword: [%s]\r\n", options.password.cstring);
    
    if ((len = MQTTSerialize_connect(sendbuf, MAX_MQTT_PACKET_SIZE, &options)) <= 0)
    {
        printf("Error serializing connect packet ...\r\n");
        return rc;
    }
    if ((rc = sendPacket((size_t) len)) != SUCCESS)  // send the connect packet
    {
        printf("Error sending the connect request packet ...\r\n");
        return rc; 
    }
    
    // Wait for the CONNACK 
    if (readUntil(CONNACK, COMMAND_TIMEOUT) == CONNACK)
    {
        unsigned char connack_rc = 255;
        bool sessionPresent = false;
        printf("Connection acknowledgement received ... deserializing respones ...\r\n");
        if (MQTTDeserialize_connack((unsigned char*)&sessionPresent, &connack_rc, readbuf, MAX_MQTT_PACKET_SIZE) == 1)
            rc = connack_rc;
        else
            rc = FAILURE;
    }
    else
        rc = FAILURE;

    if (rc == SUCCESS)
    {
        printf("Connected!!! ... starting connection timers ...\r\n");
        isConnected = true;
        resetConnectionTimer();
    }else
    {
        // TODO: Call socket->disconnect()?
    }
    
    printf("Returning with rc = %d\r\n", rc);
    
    return rc;    
}

int MQTTThreadedClient::connect(const char * host, uint16_t port, MQTTPacket_connectData & options)
{
    int ret;
        
    tcpSocket->open(network);
    if (( ret = tcpSocket->connect(host, port)) < 0 )
    {
         
         printf("Error connecting to %s:%d with %d\r\n", host, port, ret);
         return ret;
    } 
        
    return connect(options);
}

int MQTTThreadedClient::publish(PubMessage& msg)
{
#if 0
    int id = queue.call(mbed::callback(this, &MQTTThreadedClient::sendPublish), topic, message);
    // TODO: handle id values when the function is called later
    if (id == 0)
        return FAILURE;
    else
        return SUCCESS;
#endif
    PubMessage *message = mpool.alloc();
    // Simple copy
    *message = msg;
    
    // Push the data to the thread
    printf("Pushing data to consumer thread ...\r\n");
    mqueue.put(message);
    
    return SUCCESS;
}

int MQTTThreadedClient::sendPublish(PubMessage& message)
{
     MQTTString topicString = MQTTString_initializer;
     
     if (!isConnected) 
     {
        printf("Not connected!!! ...\r\n");
        return FAILURE;
     }
        
     topicString.cstring = (char*) &message.topic[0];
     int len = MQTTSerialize_publish(sendbuf, MAX_MQTT_PACKET_SIZE, 0, message.qos, false, message.id,
              topicString, (unsigned char*) &message.payload[0], (int) message.payloadlen);
     if (len <= 0)
     {
         printf("Failed serializing message ...\r\n");
         return FAILURE;
     }
     
     if (sendPacket(len) == SUCCESS)
     {
         printf("Successfully sent publish packet to server ...\r\n");
         return SUCCESS;
     }
    
    printf("Failed to send publish packet to server ...\r\n");
    return FAILURE;
}
    
int MQTTThreadedClient::subscribe(const char * topicstr, QoS qos, void (*function)(MessageData &))
{
    int rc = FAILURE;
    int len = 0;

    MQTTString topic = {(char*)topicstr, {0, 0}};
    printf("Subscribing to topic [%s]\r\n", topicstr);
    
    if (!isConnected)
    {
        printf("Session already connected!!\r\n");
        return rc;
    }

    len = MQTTSerialize_subscribe(sendbuf, MAX_MQTT_PACKET_SIZE, 0, packetid.getNext(), 1, &topic, (int*)&qos);
    if (len <= 0)
    {
        printf("Error serializing subscribe packet ...\r\n");
        return rc;
    }
    
    if ((rc = sendPacket(len)) != SUCCESS) 
    {
        printf("Error sending subscribe packet [%d]\r\n", rc);
        return rc;
    }
    
    printf("Waiting for subscription ack ...\r\n");
    // Wait for SUBACK, dropping packets read along the way ...
    if (readUntil(SUBACK, COMMAND_TIMEOUT) == SUBACK)  // wait for suback
    {
        int count = 0, grantedQoS = -1;
        unsigned short mypacketid;
        if (MQTTDeserialize_suback(&mypacketid, 1, &count, &grantedQoS, readbuf, MAX_MQTT_PACKET_SIZE) == 1)
            rc = grantedQoS; // 0, 1, 2 or 0x80
        // For as long as we do not get 0x80 .. 
        if (rc != 0x80)
        {
            // Add message handlers to the map
            FP<void,MessageData &> fp;
            fp.attach(function);
            
            topicCBMap.insert(std::pair<std::string, FP<void,MessageData &> >(std::string(topicstr),fp));
            
            // Reset connection timers here ...
            resetConnectionTimer();
            
            printf("Successfully subscribed to %s ...\r\n", topicstr);
            rc = SUCCESS;
        }else
        {
            printf("Failed to subscribe to topic %s ... (not authorized?)\r\n", topicstr);
        }
    }
    else
    {   
        printf("Failed to subscribe to topic %s (ack not received) ...\r\n", topicstr);
        rc = FAILURE;
    }
        
    return rc;
        
}


bool MQTTThreadedClient::isTopicMatched(char* topicFilter, MQTTString& topicName)
{
    char* curf = topicFilter;
    char* curn = topicName.lenstring.data;
    char* curn_end = curn + topicName.lenstring.len;

    while (*curf && curn < curn_end)
    {
        if (*curn == '/' && *curf != '/')
            break;
        if (*curf != '+' && *curf != '#' && *curf != *curn)
            break;
        if (*curf == '+')
        {   // skip until we meet the next separator, or end of string
            char* nextpos = curn + 1;
            while (nextpos < curn_end && *nextpos != '/')
                nextpos = ++curn + 1;
        }
        else if (*curf == '#')
            curn = curn_end - 1;    // skip until end of string
        curf++;
        curn++;
    };

    return (curn == curn_end) && (*curf == '\0');
}

int MQTTThreadedClient::handlePublishMsg()
{
    MQTTString topicName = MQTTString_initializer;
    Message msg;
    int intQoS;
    printf("Deserializing publish message ...\r\n");
    if (MQTTDeserialize_publish((unsigned char*)&msg.dup, 
            &intQoS, 
            (unsigned char*)&msg.retained, 
            (unsigned short*)&msg.id, 
            &topicName,
            (unsigned char**)&msg.payload, 
            (int*)&msg.payloadlen, readbuf, MAX_MQTT_PACKET_SIZE) != 1)
    {
        printf("Error deserializing published message ...\r\n");
        return -1;
    }

    std::string topic;
    if (topicName.lenstring.len > 0)
    {
        topic = std::string((const char *) topicName.lenstring.data, (size_t) topicName.lenstring.len);
    }else
        topic = (const char *) topicName.cstring;
    
    printf("Got message for topic [%s], QoS [%d] ...\r\n", topic.c_str(), intQoS);
    
    msg.qos = (QoS) intQoS;

    
    // Call the handlers for each topic 
    if (topicCBMap.find(topic) != topicCBMap.end())
    {
        // Call the callback function 
        if (topicCBMap[topic].attached())
        {
            printf("Invoking function handler for topic ...\r\n");
            MessageData md(topicName, msg);            
            topicCBMap[topic](md);
            
            return 1;
        }
    }
    
    // TODO: depending on the QoS
    // we send data to the server = PUBACK or PUBREC
    switch(intQoS)
    {
        case QOS0:
            // We send back nothing ...
            break;
        case QOS1:
            // TODO: implement
            break;
        case QOS2:
            // TODO: implement
            break;
        default:
            break;
    }
    
    return 0;
}

void MQTTThreadedClient::resetConnectionTimer()
{
    if (keepAliveInterval > 0)
    {
        comTimer.reset();
        comTimer.start();
    }
}

bool MQTTThreadedClient::hasConnectionTimedOut()
{
    if (keepAliveInterval > 0 ) {
        // Check connection timer
        if (comTimer.read_ms() > keepAliveInterval)
            return true;
        else
            return false;
    }

    return false;
}
        
void MQTTThreadedClient::sendPingRequest()
{
    int len = MQTTSerialize_pingreq(sendbuf, MAX_MQTT_PACKET_SIZE);
    if (len > 0 && (sendPacket(len) == SUCCESS)) // send the ping packet
    {
        printf("Ping request sent successfully ...\r\n");
    }
}

void MQTTThreadedClient::startListener()
{
    int pType;
    // Continuesly listens for packets and dispatch
    // message handlers ...
    while(true)
    {
        pType = readPacket();        
        switch(pType)
        {
            case TIMEOUT:
                // No data available from the network ... 
                break;
            case FAILURE:                
            case BUFFER_OVERFLOW:
                {
                    // TODO: Network error, do we disconnect and reconnect?
                    printf("Failure or buffer overflow problem ... \r\n");
                    MBED_ASSERT(false);
                }
                break;
            /**
            *  The rest of the return codes below (all positive) is about MQTT
             * response codes
             **/
            case CONNACK:
            case PUBACK:
            case SUBACK:
                break;
            case PUBLISH:
                {
                    printf("Publish received!....\r\n");
                    // We receive data from the MQTT server ..
                    if (handlePublishMsg() < 0)
                    {
                        printf("Error handling PUBLISH message ... \r\n");
                        break;
                    }
                }
                break;
            case PINGRESP: 
                {
                    printf("Got ping response ...\r\n");
                    resetConnectionTimer();
                }
                break;
            default:
                printf("Unknown/Not handled message from server pType[%d]\r\n", pType);
        }
        
        // Check if its time to send a keepAlive packet
        if (hasConnectionTimedOut())
        {
            // Queue the ping request so that other
            // pending operations queued above will go first
            queue.call(this, &MQTTThreadedClient::sendPingRequest);
        }
        
        // Check if we have messages on the message queue
        osEvent evt = mqueue.get(10);
        if (evt.status == osEventMessage) {
            
            printf("Got message to publish! ... \r\n");
            
            // Unpack the message
            PubMessage * message = (PubMessage *)evt.value.p;
            
            // Send the packet, do not queue the call
            // like the ping above ..
            if ( sendPublish(*message) == SUCCESS)
                // Reset timers if we have been able to send successfully
                resetConnectionTimer();
            
            // Free the message from mempool  after using
            mpool.free(message);
        }        
        
        // Dispatch any queued events ...
        queue.dispatch(100); 
    }
   
}