my customized library

Dependencies:   FP MQTTPacket

Fork of MQTT by MQTT

Revision:
53:15b5a280d22d
Parent:
46:e335fcc1a663
Child:
54:ff9e5c4b52d0
diff -r 3f9919941b86 -r 15b5a280d22d MQTTClient.h
--- a/MQTTClient.h	Mon Sep 25 11:12:23 2017 +0000
+++ b/MQTTClient.h	Mon Sep 25 12:06:28 2017 +0000
@@ -1,5 +1,5 @@
 /*******************************************************************************
- * Copyright (c) 2014, 2015 IBM Corp.
+ * Copyright (c) 2014, 2017 IBM Corp.
  *
  * All rights reserved. This program and the accompanying materials
  * are made available under the terms of the Eclipse Public License v1.0
@@ -16,6 +16,9 @@
  *    Ian Craggs - fix for bug 460389 - send loop uses wrong length
  *    Ian Craggs - fix for bug 464169 - clearing subscriptions
  *    Ian Craggs - fix for bug 464551 - enums and ints can be different size
+ *    Mark Sonnentag - fix for bug 475204 - inefficient instantiation of Timer
+ *    Ian Craggs - fix for bug 475749 - packetid modified twice
+ *    Ian Craggs - add ability to set message handler separately #6
  *******************************************************************************/
 
 #if !defined(MQTTCLIENT_H)
@@ -23,7 +26,7 @@
 
 #include "FP.h"
 #include "MQTTPacket.h"
-#include "stdio.h"
+#include <stdio.h>
 #include "MQTTLogging.h"
 
 #if !defined(MQTTCLIENT_QOS1)
@@ -64,6 +67,19 @@
 };
 
 
+struct connackData
+{
+    int rc;
+    bool sessionPresent;
+};
+
+
+struct subackData
+{
+    int grantedQoS;
+};
+
+
 class PacketId
 {
 public:
@@ -74,7 +90,7 @@
 
     int getNext()
     {
-        return next = (next == MAX_PACKET_ID) ? 1 : ++next;
+        return next = (next == MAX_PACKET_ID) ? 1 : next + 1;
     }
 
 private:
@@ -108,34 +124,51 @@
     Client(Network& network, unsigned int command_timeout_ms = 30000);
 
     /** Set the default message handling callback - used for any message which does not match a subscription message handler
-     *  @param mh - pointer to the callback function
+     *  @param mh - pointer to the callback function.  Set to 0 to remove.
      */
     void setDefaultMessageHandler(messageHandler mh)
     {
-        defaultMessageHandler.attach(mh);
+        if (mh != 0)
+            defaultMessageHandler.attach(mh);
+        else
+            defaultMessageHandler.detach();
     }
 
+    /** Set a message handling callback.  This can be used outside of the the subscribe method.
+     *  @param topicFilter - a topic pattern which can include wildcards
+     *  @param mh - pointer to the callback function. If 0, removes the callback if any
+     */
+    int setMessageHandler(const char* topicFilter, messageHandler mh);
+
     /** MQTT Connect - send an MQTT connect packet down the network and wait for a Connack
      *  The nework object must be connected to the network endpoint before calling this
      *  Default connect options are used
      *  @return success code -
      */
     int connect();
-    
-        /** MQTT Connect - send an MQTT connect packet down the network and wait for a Connack
+
+    /** MQTT Connect - send an MQTT connect packet down the network and wait for a Connack
      *  The nework object must be connected to the network endpoint before calling this
      *  @param options - connect options
      *  @return success code -
      */
     int connect(MQTTPacket_connectData& options);
 
+    /** MQTT Connect - send an MQTT connect packet down the network and wait for a Connack
+     *  The nework object must be connected to the network endpoint before calling this
+     *  @param options - connect options
+     *  @param connackData - connack data to be returned
+     *  @return success code -
+     */
+    int connect(MQTTPacket_connectData& options, connackData& data);
+
     /** MQTT Publish - send an MQTT publish packet and wait for all acks to complete for all QoSs
      *  @param topic - the topic to publish to
      *  @param message - the message to send
      *  @return success code -
      */
     int publish(const char* topicName, Message& message);
-    
+
     /** MQTT Publish - send an MQTT publish packet and wait for all acks to complete for all QoSs
      *  @param topic - the topic to publish to
      *  @param payload - the data to send
@@ -145,12 +178,12 @@
      *  @return success code -
      */
     int publish(const char* topicName, void* payload, size_t payloadlen, enum QoS qos = QOS0, bool retained = false);
-    
+
     /** MQTT Publish - send an MQTT publish packet and wait for all acks to complete for all QoSs
      *  @param topic - the topic to publish to
      *  @param payload - the data to send
      *  @param payloadlen - the length of the data
-     *  @param id - the packet id used - returned 
+     *  @param id - the packet id used - returned
      *  @param qos - the QoS to send the publish at
      *  @param retained - whether the message should be retained
      *  @return success code -
@@ -165,6 +198,15 @@
      */
     int subscribe(const char* topicFilter, enum QoS qos, messageHandler mh);
 
+    /** MQTT Subscribe - send an MQTT subscribe packet and wait for the suback
+     *  @param topicFilter - a topic pattern which can include wildcards
+     *  @param qos - the MQTT QoS to subscribe at©
+     *  @param mh - the callback function to be invoked when a message is received for this subscription
+     *  @param
+     *  @return success code -
+     */
+    int subscribe(const char* topicFilter, enum QoS qos, messageHandler mh, subackData &data);
+
     /** MQTT Unsubscribe - send an MQTT unsubscribe packet and wait for the unsuback
      *  @param topicFilter - a topic pattern which can include wildcards
      *  @return success code -
@@ -194,6 +236,7 @@
 
 private:
 
+    void closeSession();
     void cleanSession();
     int cycle(Timer& timer);
     int waitfor(int packet_type, Timer& timer);
@@ -253,12 +296,10 @@
 
 
 template<class Network, class Timer, int a, int MAX_MESSAGE_HANDLERS>
-void MQTT::Client<Network, Timer, a, MAX_MESSAGE_HANDLERS>::cleanSession() 
+void MQTT::Client<Network, Timer, a, MAX_MESSAGE_HANDLERS>::cleanSession()
 {
-    ping_outstanding = false;
     for (int i = 0; i < MAX_MESSAGE_HANDLERS; ++i)
         messageHandlers[i].topicFilter = 0;
-    isconnected = false;
 
 #if MQTTCLIENT_QOS1 || MQTTCLIENT_QOS2
     inflightMsgid = 0;
@@ -274,12 +315,21 @@
 
 
 template<class Network, class Timer, int a, int MAX_MESSAGE_HANDLERS>
+void MQTT::Client<Network, Timer, a, MAX_MESSAGE_HANDLERS>::closeSession()
+{
+    ping_outstanding = false;
+    isconnected = false;
+    if (cleansession)
+        cleanSession();
+}
+
+
+template<class Network, class Timer, int a, int MAX_MESSAGE_HANDLERS>
 MQTT::Client<Network, Timer, a, MAX_MESSAGE_HANDLERS>::Client(Network& network, unsigned int command_timeout_ms)  : ipstack(network), packetid()
 {
-    last_sent = Timer();
-    last_received = Timer();
     this->command_timeout_ms = command_timeout_ms;
-    cleanSession();
+    cleansession = true;
+      closeSession();
 }
 
 
@@ -347,7 +397,7 @@
     }
     else
         rc = FAILURE;
-        
+
 #if defined(MQTT_DEBUG)
     char printbuf[150];
     DEBUG("Rc %d from sending packet %s\n", rc, MQTTFormat_toServerString(printbuf, sizeof(printbuf), sendbuf, length));
@@ -389,7 +439,7 @@
  * If any read fails in this method, then we should disconnect from the network, as on reconnect
  * the packets can be retried.
  * @param timeout the max time to wait for the packet read to complete, in milliseconds
- * @return the MQTT packet type, or -1 if none
+ * @return the MQTT packet type, 0 if none, -1 if error
  */
 template<class Network, class Timer, int MAX_MQTT_PACKET_SIZE, int b>
 int MQTT::Client<Network, Timer, MAX_MQTT_PACKET_SIZE, b>::readPacket(Timer& timer)
@@ -400,7 +450,8 @@
     int rem_len = 0;
 
     /* 1. read the header byte.  This has the packet type in it */
-    if (ipstack.read(readbuf, 1, timer.left_ms()) != 1)
+    rc = ipstack.read(readbuf, 1, timer.left_ms());
+    if (rc != 1)
         goto exit;
 
     len = 1;
@@ -423,12 +474,13 @@
     if (this->keepAliveInterval > 0)
         last_received.countdown(this->keepAliveInterval); // record the fact that we have successfully received a packet
 exit:
-        
+
 #if defined(MQTT_DEBUG)
     if (rc >= 0)
     {
         char printbuf[50];
-        DEBUG("Rc %d from receiving packet %s\n", rc, MQTTFormat_toClientString(printbuf, sizeof(printbuf), readbuf, len));
+        DEBUG("Rc %d from receiving packet %s\n", rc,
+            MQTTFormat_toClientString(printbuf, sizeof(printbuf), readbuf, len));
     }
 #endif
     return rc;
@@ -504,7 +556,7 @@
 int MQTT::Client<Network, Timer, a, b>::yield(unsigned long timeout_ms)
 {
     int rc = SUCCESS;
-    Timer timer = Timer();
+    Timer timer;
 
     timer.countdown_ms(timeout_ms);
     while (!timer.expired())
@@ -523,19 +575,19 @@
 template<class Network, class Timer, int MAX_MQTT_PACKET_SIZE, int b>
 int MQTT::Client<Network, Timer, MAX_MQTT_PACKET_SIZE, b>::cycle(Timer& timer)
 {
-    /* get one piece of work off the wire and one pass through */
-
-    // read the socket, see what work is due
-    int packet_type = readPacket(timer);
-
+    // get one piece of work off the wire and one pass through
     int len = 0,
         rc = SUCCESS;
 
+    int packet_type = readPacket(timer);    // read the socket, see what work is due
+
     switch (packet_type)
     {
-        case FAILURE:
-        case BUFFER_OVERFLOW:
+        default:
+            // no more data to read, unrecoverable. Or read packet fails due to unexpected network error
             rc = packet_type;
+            goto exit;
+        case 0: // timed out reading packet
             break;
         case CONNACK:
         case PUBACK:
@@ -546,6 +598,7 @@
             MQTTString topicName = MQTTString_initializer;
             Message msg;
             int intQoS;
+            msg.payloadlen = 0; /* this is a size_t, but deserialize publish sets this as int */
             if (MQTTDeserialize_publish((unsigned char*)&msg.dup, &intQoS, (unsigned char*)&msg.retained, (unsigned short*)&msg.id, &topicName,
                                  (unsigned char**)&msg.payload, (int*)&msg.payloadlen, readbuf, MAX_MQTT_PACKET_SIZE) != 1)
                 goto exit;
@@ -587,8 +640,8 @@
             unsigned char dup, type;
             if (MQTTDeserialize_ack(&type, &dup, &mypacketid, readbuf, MAX_MQTT_PACKET_SIZE) != 1)
                 rc = FAILURE;
-            else if ((len = MQTTSerialize_ack(sendbuf, MAX_MQTT_PACKET_SIZE, 
-                        (packet_type == PUBREC) ? PUBREL : PUBCOMP, 0, mypacketid)) <= 0)
+            else if ((len = MQTTSerialize_ack(sendbuf, MAX_MQTT_PACKET_SIZE,
+                                 (packet_type == PUBREC) ? PUBREL : PUBCOMP, 0, mypacketid)) <= 0)
                 rc = FAILURE;
             else if ((rc = sendPacket(len, timer)) != SUCCESS) // send the PUBREL packet
                 rc = FAILURE; // there was a problem
@@ -597,7 +650,7 @@
             if (packet_type == PUBREL)
                 freeQoS2msgid(mypacketid);
             break;
-            
+
         case PUBCOMP:
             break;
 #endif
@@ -605,10 +658,16 @@
             ping_outstanding = false;
             break;
     }
-    keepalive();
+
+    if (keepalive() != SUCCESS)
+        //check only keepalive FAILURE status so that previous FAILURE status can be considered as FAULT
+        rc = FAILURE;
+
 exit:
     if (rc == SUCCESS)
         rc = packet_type;
+    else if (isconnected)
+        closeSession();
     return rc;
 }
 
@@ -616,17 +675,22 @@
 template<class Network, class Timer, int MAX_MQTT_PACKET_SIZE, int b>
 int MQTT::Client<Network, Timer, MAX_MQTT_PACKET_SIZE, b>::keepalive()
 {
-    int rc = FAILURE;
+    int rc = SUCCESS;
 
     if (keepAliveInterval == 0)
-    {
-        rc = SUCCESS;
         goto exit;
-    }
 
     if (last_sent.expired() || last_received.expired())
     {
-        if (!ping_outstanding)
+        if (ping_outstanding)
+        {
+            rc = FAILURE; // session failure
+            #if defined(MQTT_DEBUG)
+                char printbuf[150];
+                DEBUG("PINGRESP not received in keepalive interval\n");
+            #endif
+        }
+        else
         {
             Timer timer(1000);
             int len = MQTTSerialize_pingreq(sendbuf, MAX_MQTT_PACKET_SIZE);
@@ -650,15 +714,16 @@
     {
         if (timer.expired())
             break; // we timed out
+        rc = cycle(timer);
     }
-    while ((rc = cycle(timer)) != packet_type);
+    while (rc != packet_type && rc >= 0);
 
     return rc;
 }
 
 
 template<class Network, class Timer, int MAX_MQTT_PACKET_SIZE, int b>
-int MQTT::Client<Network, Timer, MAX_MQTT_PACKET_SIZE, b>::connect(MQTTPacket_connectData& options)
+int MQTT::Client<Network, Timer, MAX_MQTT_PACKET_SIZE, b>::connect(MQTTPacket_connectData& options, connackData& data)
 {
     Timer connect_timer(command_timeout_ms);
     int rc = FAILURE;
@@ -679,10 +744,11 @@
     // this will be a blocking call, wait for the connack
     if (waitfor(CONNACK, connect_timer) == CONNACK)
     {
-        unsigned char connack_rc = 255;
-        bool sessionPresent = false;
-        if (MQTTDeserialize_connack((unsigned char*)&sessionPresent, &connack_rc, readbuf, MAX_MQTT_PACKET_SIZE) == 1)
-            rc = connack_rc;
+        data.rc = 0;
+        data.sessionPresent = false;
+        if (MQTTDeserialize_connack((unsigned char*)&data.sessionPresent,
+                            (unsigned char*)&data.rc, readbuf, MAX_MQTT_PACKET_SIZE) == 1)
+            rc = data.rc;
         else
             rc = FAILURE;
     }
@@ -710,12 +776,23 @@
 
 exit:
     if (rc == SUCCESS)
+    {
         isconnected = true;
+        ping_outstanding = false;
+    }
     return rc;
 }
 
 
 template<class Network, class Timer, int MAX_MQTT_PACKET_SIZE, int b>
+int MQTT::Client<Network, Timer, MAX_MQTT_PACKET_SIZE, b>::connect(MQTTPacket_connectData& options)
+{
+    connackData data;
+    return connect(options, data);
+}
+
+
+template<class Network, class Timer, int MAX_MQTT_PACKET_SIZE, int b>
 int MQTT::Client<Network, Timer, MAX_MQTT_PACKET_SIZE, b>::connect()
 {
     MQTTPacket_connectData default_options = MQTTPacket_connectData_initializer;
@@ -724,7 +801,51 @@
 
 
 template<class Network, class Timer, int MAX_MQTT_PACKET_SIZE, int MAX_MESSAGE_HANDLERS>
-int MQTT::Client<Network, Timer, MAX_MQTT_PACKET_SIZE, MAX_MESSAGE_HANDLERS>::subscribe(const char* topicFilter, enum QoS qos, messageHandler messageHandler)
+int MQTT::Client<Network, Timer, MAX_MQTT_PACKET_SIZE, MAX_MESSAGE_HANDLERS>::setMessageHandler(const char* topicFilter, messageHandler messageHandler)
+{
+    int rc = FAILURE;
+    int i = -1;
+
+    // first check for an existing matching slot
+    for (i = 0; i < MAX_MESSAGE_HANDLERS; ++i)
+    {
+        if (messageHandlers[i].topicFilter != 0 && strcmp(messageHandlers[i].topicFilter, topicFilter) == 0)
+        {
+            if (messageHandler == 0) // remove existing
+            {
+                messageHandlers[i].topicFilter = 0;
+                messageHandlers[i].fp.detach();
+            }
+            rc = SUCCESS; // return i when adding new subscription
+            break;
+        }
+    }
+    // if no existing, look for empty slot (unless we are removing)
+    if (messageHandler != 0) {
+        if (rc == FAILURE)
+        {
+            for (i = 0; i < MAX_MESSAGE_HANDLERS; ++i)
+            {
+                if (messageHandlers[i].topicFilter == 0)
+                {
+                    rc = SUCCESS;
+                    break;
+                }
+            }
+        }
+        if (i < MAX_MESSAGE_HANDLERS)
+        {
+            messageHandlers[i].topicFilter = topicFilter;
+            messageHandlers[i].fp.attach(messageHandler);
+        }
+    }
+    return rc;
+}
+
+
+template<class Network, class Timer, int MAX_MQTT_PACKET_SIZE, int MAX_MESSAGE_HANDLERS>
+int MQTT::Client<Network, Timer, MAX_MQTT_PACKET_SIZE, MAX_MESSAGE_HANDLERS>::subscribe(const char* topicFilter,
+     enum QoS qos, messageHandler messageHandler, subackData& data)
 {
     int rc = FAILURE;
     Timer timer(command_timeout_ms);
@@ -742,35 +863,34 @@
 
     if (waitfor(SUBACK, timer) == SUBACK)      // wait for suback
     {
-        int count = 0, grantedQoS = -1;
+        int count = 0;
         unsigned short mypacketid;
-        if (MQTTDeserialize_suback(&mypacketid, 1, &count, &grantedQoS, readbuf, MAX_MQTT_PACKET_SIZE) == 1)
-            rc = grantedQoS; // 0, 1, 2 or 0x80
-        if (rc != 0x80)
+        data.grantedQoS = 0;
+        if (MQTTDeserialize_suback(&mypacketid, 1, &count, &data.grantedQoS, readbuf, MAX_MQTT_PACKET_SIZE) == 1)
         {
-            for (int i = 0; i < MAX_MESSAGE_HANDLERS; ++i)
-            {
-                if (messageHandlers[i].topicFilter == 0)
-                {
-                    messageHandlers[i].topicFilter = topicFilter;
-                    messageHandlers[i].fp.attach(messageHandler);
-                    rc = 0;
-                    break;
-                }
-            }
+            if (data.grantedQoS != 0x80)
+                rc = setMessageHandler(topicFilter, messageHandler);
         }
     }
     else
         rc = FAILURE;
 
 exit:
-    if (rc != SUCCESS)
-        cleanSession();
+    if (rc == FAILURE)
+        closeSession();
     return rc;
 }
 
 
 template<class Network, class Timer, int MAX_MQTT_PACKET_SIZE, int MAX_MESSAGE_HANDLERS>
+int MQTT::Client<Network, Timer, MAX_MQTT_PACKET_SIZE, MAX_MESSAGE_HANDLERS>::subscribe(const char* topicFilter, enum QoS qos, messageHandler messageHandler)
+{
+    subackData data;
+    return subscribe(topicFilter, qos, messageHandler, data);
+}
+
+
+template<class Network, class Timer, int MAX_MQTT_PACKET_SIZE, int MAX_MESSAGE_HANDLERS>
 int MQTT::Client<Network, Timer, MAX_MQTT_PACKET_SIZE, MAX_MESSAGE_HANDLERS>::unsubscribe(const char* topicFilter)
 {
     int rc = FAILURE;
@@ -791,17 +911,8 @@
         unsigned short mypacketid;  // should be the same as the packetid above
         if (MQTTDeserialize_unsuback(&mypacketid, readbuf, MAX_MQTT_PACKET_SIZE) == 1)
         {
-            rc = 0;
-
             // remove the subscription message handler associated with this topic, if there is one
-            for (int i = 0; i < MAX_MESSAGE_HANDLERS; ++i)
-            {
-                if (messageHandlers[i].topicFilter && strcmp(messageHandlers[i].topicFilter, topicFilter) == 0)
-                {
-                    messageHandlers[i].topicFilter = 0;
-                    break;
-                }
-            }
+            setMessageHandler(topicFilter, 0);
         }
     }
     else
@@ -809,7 +920,7 @@
 
 exit:
     if (rc != SUCCESS)
-        cleanSession();
+        closeSession();
     return rc;
 }
 
@@ -837,7 +948,8 @@
         else
             rc = FAILURE;
     }
-#elif MQTTCLIENT_QOS2
+#endif
+#if MQTTCLIENT_QOS2
     else if (qos == QOS2)
     {
         if (waitfor(PUBCOMP, timer) == PUBCOMP)
@@ -856,7 +968,7 @@
 
 exit:
     if (rc != SUCCESS)
-        cleanSession();
+        closeSession();
     return rc;
 }
 
@@ -923,17 +1035,12 @@
 int MQTT::Client<Network, Timer, MAX_MQTT_PACKET_SIZE, b>::disconnect()
 {
     int rc = FAILURE;
-    Timer timer(command_timeout_ms);            // we might wait for incomplete incoming publishes to complete
+    Timer timer(command_timeout_ms);     // we might wait for incomplete incoming publishes to complete
     int len = MQTTSerialize_disconnect(sendbuf, MAX_MQTT_PACKET_SIZE);
     if (len > 0)
         rc = sendPacket(len, timer);            // send the disconnect packet
-
-    if (cleansession)
-        cleanSession();
-    else
-        isconnected = false;
+    closeSession();
     return rc;
 }
 
-
-#endif
\ No newline at end of file
+#endif