Better timing

Dependencies:   FP MQTTPacket

Fork of MQTT by MQTT

Revision:
46:e335fcc1a663
Parent:
44:c299463ae853
Child:
53:15b5a280d22d
--- a/MQTTClient.h	Mon Aug 03 12:40:57 2015 +0000
+++ b/MQTTClient.h	Tue Aug 18 09:57:19 2015 +0000
@@ -1,5 +1,5 @@
 /*******************************************************************************
- * Copyright (c) 2014 IBM Corp.
+ * Copyright (c) 2014, 2015 IBM Corp.
  *
  * All rights reserved. This program and the accompanying materials
  * are made available under the terms of the Eclipse Public License v1.0
@@ -12,6 +12,10 @@
  *
  * Contributors:
  *    Ian Craggs - initial API and implementation and/or initial documentation
+ *    Ian Craggs - fix for bug 458512 - QoS 2 messages
+ *    Ian Craggs - fix for bug 460389 - send loop uses wrong length
+ *    Ian Craggs - fix for bug 464169 - clearing subscriptions
+ *    Ian Craggs - fix for bug 464551 - enums and ints can be different size
  *******************************************************************************/
 
 #if !defined(MQTTCLIENT_H)
@@ -118,7 +122,7 @@
      */
     int connect();
     
-    /** MQTT Connect - send an MQTT connect packet down the network and wait for a Connack
+        /** MQTT Connect - send an MQTT connect packet down the network and wait for a Connack
      *  The nework object must be connected to the network endpoint before calling this
      *  @param options - connect options
      *  @return success code -
@@ -190,6 +194,7 @@
 
 private:
 
+    void cleanSession();
     int cycle(Timer& timer);
     int waitfor(int packet_type, Timer& timer);
     int keepalive();
@@ -239,6 +244,7 @@
     unsigned short incomingQoS2messages[MAX_INCOMING_QOS2_MESSAGES];
     bool isQoS2msgidFree(unsigned short id);
     bool useQoS2msgid(unsigned short id);
+    void freeQoS2msgid(unsigned short id);
 #endif
 
 };
@@ -247,28 +253,35 @@
 
 
 template<class Network, class Timer, int a, int MAX_MESSAGE_HANDLERS>
+void MQTT::Client<Network, Timer, a, MAX_MESSAGE_HANDLERS>::cleanSession() 
+{
+    ping_outstanding = false;
+    for (int i = 0; i < MAX_MESSAGE_HANDLERS; ++i)
+        messageHandlers[i].topicFilter = 0;
+    isconnected = false;
+
+#if MQTTCLIENT_QOS1 || MQTTCLIENT_QOS2
+    inflightMsgid = 0;
+    inflightQoS = QOS0;
+#endif
+
+#if MQTTCLIENT_QOS2
+    pubrel = false;
+    for (int i = 0; i < MAX_INCOMING_QOS2_MESSAGES; ++i)
+        incomingQoS2messages[i] = 0;
+#endif
+}
+
+
+template<class Network, class Timer, int a, int MAX_MESSAGE_HANDLERS>
 MQTT::Client<Network, Timer, a, MAX_MESSAGE_HANDLERS>::Client(Network& network, unsigned int command_timeout_ms)  : ipstack(network), packetid()
 {
     last_sent = Timer();
     last_received = Timer();
-    ping_outstanding = false;
-    for (int i = 0; i < MAX_MESSAGE_HANDLERS; ++i)
-        messageHandlers[i].topicFilter = 0;
     this->command_timeout_ms = command_timeout_ms;
-    isconnected = false;
-    
-#if MQTTCLIENT_QOS1 || MQTTCLIENT_QOS2
-    inflightMsgid = 0;
-    inflightQoS = QOS0;
-#endif
+    cleanSession();
+}
 
-    
-#if MQTTCLIENT_QOS2
-    pubrel = false;
-    for (int i = 0; i < MAX_INCOMING_QOS2_MESSAGES; ++i)
-        incomingQoS2messages[i] = 0;
-#endif
-}
 
 #if MQTTCLIENT_QOS2
 template<class Network, class Timer, int a, int b>
@@ -296,6 +309,20 @@
     }
     return false;
 }
+
+
+template<class Network, class Timer, int a, int b>
+void MQTT::Client<Network, Timer, a, b>::freeQoS2msgid(unsigned short id)
+{
+    for (int i = 0; i < MAX_INCOMING_QOS2_MESSAGES; ++i)
+    {
+        if (incomingQoS2messages[i] == id)
+        {
+            incomingQoS2messages[i] = 0;
+            return;
+        }
+    }
+}
 #endif
 
 
@@ -307,7 +334,7 @@
 
     while (sent < length && !timer.expired())
     {
-        rc = ipstack.write(&sendbuf[sent], length, timer.left_ms());
+        rc = ipstack.write(&sendbuf[sent], length - sent, timer.left_ms());
         if (rc < 0)  // there was an error writing the data
             break;
         sent += rc;
@@ -322,8 +349,8 @@
         rc = FAILURE;
         
 #if defined(MQTT_DEBUG)
-    char printbuf[50];
-    DEBUG("Rc %d from sending packet %s\n", rc, MQTTPacket_toString(printbuf, sizeof(printbuf), sendbuf, length));
+    char printbuf[150];
+    DEBUG("Rc %d from sending packet %s\n", rc, MQTTFormat_toServerString(printbuf, sizeof(printbuf), sendbuf, length));
 #endif
     return rc;
 }
@@ -364,8 +391,8 @@
  * @param timeout the max time to wait for the packet read to complete, in milliseconds
  * @return the MQTT packet type, or -1 if none
  */
-template<class Network, class Timer, int a, int b>
-int MQTT::Client<Network, Timer, a, b>::readPacket(Timer& timer)
+template<class Network, class Timer, int MAX_MQTT_PACKET_SIZE, int b>
+int MQTT::Client<Network, Timer, MAX_MQTT_PACKET_SIZE, b>::readPacket(Timer& timer)
 {
     int rc = FAILURE;
     MQTTHeader header = {0};
@@ -381,6 +408,12 @@
     decodePacket(&rem_len, timer.left_ms());
     len += MQTTPacket_encode(readbuf + 1, rem_len); /* put the original remaining length into the buffer */
 
+    if (rem_len > (MAX_MQTT_PACKET_SIZE - len))
+    {
+        rc = BUFFER_OVERFLOW;
+        goto exit;
+    }
+
     /* 3. read the rest of the buffer using a callback to supply the rest of the data */
     if (rem_len > 0 && (ipstack.read(readbuf + len, rem_len, timer.left_ms()) != rem_len))
         goto exit;
@@ -392,8 +425,11 @@
 exit:
         
 #if defined(MQTT_DEBUG)
-    char printbuf[50];
-    DEBUG("Rc %d from receiving packet %s\n", rc, MQTTPacket_toString(printbuf, sizeof(printbuf), readbuf, len));
+    if (rc >= 0)
+    {
+        char printbuf[50];
+        DEBUG("Rc %d from receiving packet %s\n", rc, MQTTFormat_toClientString(printbuf, sizeof(printbuf), readbuf, len));
+    }
 #endif
     return rc;
 }
@@ -473,7 +509,7 @@
     timer.countdown_ms(timeout_ms);
     while (!timer.expired())
     {
-        if (cycle(timer) == FAILURE)
+        if (cycle(timer) < 0)
         {
             rc = FAILURE;
             break;
@@ -490,23 +526,30 @@
     /* get one piece of work off the wire and one pass through */
 
     // read the socket, see what work is due
-    unsigned short packet_type = readPacket(timer);
+    int packet_type = readPacket(timer);
 
     int len = 0,
         rc = SUCCESS;
 
     switch (packet_type)
     {
+        case FAILURE:
+        case BUFFER_OVERFLOW:
+            rc = packet_type;
+            break;
         case CONNACK:
         case PUBACK:
         case SUBACK:
             break;
         case PUBLISH:
-            MQTTString topicName;
+        {
+            MQTTString topicName = MQTTString_initializer;
             Message msg;
-            if (MQTTDeserialize_publish((unsigned char*)&msg.dup, (int*)&msg.qos, (unsigned char*)&msg.retained, (unsigned short*)&msg.id, &topicName,
+            int intQoS;
+            if (MQTTDeserialize_publish((unsigned char*)&msg.dup, &intQoS, (unsigned char*)&msg.retained, (unsigned short*)&msg.id, &topicName,
                                  (unsigned char**)&msg.payload, (int*)&msg.payloadlen, readbuf, MAX_MQTT_PACKET_SIZE) != 1)
                 goto exit;
+            msg.qos = (enum QoS)intQoS;
 #if MQTTCLIENT_QOS2
             if (msg.qos != QOS2)
 #endif
@@ -518,7 +561,7 @@
                     deliverMessage(topicName, msg);
                 else
                     WARN("Maximum number of incoming QoS2 messages exceeded");
-            }   
+            }
 #endif
 #if MQTTCLIENT_QOS1 || MQTTCLIENT_QOS2
             if (msg.qos != QOS0)
@@ -536,19 +579,25 @@
             }
             break;
 #endif
+        }
 #if MQTTCLIENT_QOS2
         case PUBREC:
+        case PUBREL:
             unsigned short mypacketid;
             unsigned char dup, type;
             if (MQTTDeserialize_ack(&type, &dup, &mypacketid, readbuf, MAX_MQTT_PACKET_SIZE) != 1)
                 rc = FAILURE;
-            else if ((len = MQTTSerialize_ack(sendbuf, MAX_MQTT_PACKET_SIZE, PUBREL, 0, mypacketid)) <= 0)
+            else if ((len = MQTTSerialize_ack(sendbuf, MAX_MQTT_PACKET_SIZE, 
+                        (packet_type == PUBREC) ? PUBREL : PUBCOMP, 0, mypacketid)) <= 0)
                 rc = FAILURE;
             else if ((rc = sendPacket(len, timer)) != SUCCESS) // send the PUBREL packet
                 rc = FAILURE; // there was a problem
             if (rc == FAILURE)
                 goto exit; // there was a problem
+            if (packet_type == PUBREL)
+                freeQoS2msgid(mypacketid);
             break;
+            
         case PUBCOMP:
             break;
 #endif
@@ -579,7 +628,7 @@
     {
         if (!ping_outstanding)
         {
-            Timer timer = Timer(1000);
+            Timer timer(1000);
             int len = MQTTSerialize_pingreq(sendbuf, MAX_MQTT_PACKET_SIZE);
             if (len > 0 && (rc = sendPacket(len, timer)) == SUCCESS) // send the ping packet
                 ping_outstanding = true;
@@ -611,7 +660,7 @@
 template<class Network, class Timer, int MAX_MQTT_PACKET_SIZE, int b>
 int MQTT::Client<Network, Timer, MAX_MQTT_PACKET_SIZE, b>::connect(MQTTPacket_connectData& options)
 {
-    Timer connect_timer = Timer(command_timeout_ms);
+    Timer connect_timer(command_timeout_ms);
     int rc = FAILURE;
     int len = 0;
 
@@ -639,10 +688,10 @@
     }
     else
         rc = FAILURE;
-        
+
 #if MQTTCLIENT_QOS2
-    // resend an inflight publish
-    if (inflightMsgid >0 && inflightQoS == QOS2 && pubrel)
+    // resend any inflight publish
+    if (inflightMsgid > 0 && inflightQoS == QOS2 && pubrel)
     {
         if ((len = MQTTSerialize_ack(sendbuf, MAX_MQTT_PACKET_SIZE, PUBREL, 0, inflightMsgid)) <= 0)
             rc = FAILURE;
@@ -678,9 +727,9 @@
 int MQTT::Client<Network, Timer, MAX_MQTT_PACKET_SIZE, MAX_MESSAGE_HANDLERS>::subscribe(const char* topicFilter, enum QoS qos, messageHandler messageHandler)
 {
     int rc = FAILURE;
-    Timer timer = Timer(command_timeout_ms);
+    Timer timer(command_timeout_ms);
     int len = 0;
-    MQTTString topic = {(char*)topicFilter, 0, 0};
+    MQTTString topic = {(char*)topicFilter, {0, 0}};
 
     if (!isconnected)
         goto exit;
@@ -716,7 +765,7 @@
 
 exit:
     if (rc != SUCCESS)
-        isconnected = false;
+        cleanSession();
     return rc;
 }
 
@@ -725,8 +774,8 @@
 int MQTT::Client<Network, Timer, MAX_MQTT_PACKET_SIZE, MAX_MESSAGE_HANDLERS>::unsubscribe(const char* topicFilter)
 {
     int rc = FAILURE;
-    Timer timer = Timer(command_timeout_ms);
-    MQTTString topic = {(char*)topicFilter, 0, 0};
+    Timer timer(command_timeout_ms);
+    MQTTString topic = {(char*)topicFilter, {0, 0}};
     int len = 0;
 
     if (!isconnected)
@@ -741,14 +790,26 @@
     {
         unsigned short mypacketid;  // should be the same as the packetid above
         if (MQTTDeserialize_unsuback(&mypacketid, readbuf, MAX_MQTT_PACKET_SIZE) == 1)
+        {
             rc = 0;
+
+            // remove the subscription message handler associated with this topic, if there is one
+            for (int i = 0; i < MAX_MESSAGE_HANDLERS; ++i)
+            {
+                if (messageHandlers[i].topicFilter && strcmp(messageHandlers[i].topicFilter, topicFilter) == 0)
+                {
+                    messageHandlers[i].topicFilter = 0;
+                    break;
+                }
+            }
+        }
     }
     else
         rc = FAILURE;
 
 exit:
     if (rc != SUCCESS)
-        isconnected = false;
+        cleanSession();
     return rc;
 }
 
@@ -757,11 +818,11 @@
 int MQTT::Client<Network, Timer, MAX_MQTT_PACKET_SIZE, b>::publish(int len, Timer& timer, enum QoS qos)
 {
     int rc;
-    
+
     if ((rc = sendPacket(len, timer)) != SUCCESS) // send the publish packet
         goto exit; // there was a problem
 
-#if MQTTCLIENT_QOS1 
+#if MQTTCLIENT_QOS1
     if (qos == QOS1)
     {
         if (waitfor(PUBACK, timer) == PUBACK)
@@ -795,7 +856,7 @@
 
 exit:
     if (rc != SUCCESS)
-        isconnected = false;
+        cleanSession();
     return rc;
 }
 
@@ -805,13 +866,13 @@
 int MQTT::Client<Network, Timer, MAX_MQTT_PACKET_SIZE, b>::publish(const char* topicName, void* payload, size_t payloadlen, unsigned short& id, enum QoS qos, bool retained)
 {
     int rc = FAILURE;
-    Timer timer = Timer(command_timeout_ms);
+    Timer timer(command_timeout_ms);
     MQTTString topicString = MQTTString_initializer;
     int len = 0;
 
     if (!isconnected)
         goto exit;
-        
+
     topicString.cstring = (char*)topicName;
 
 #if MQTTCLIENT_QOS1 || MQTTCLIENT_QOS2
@@ -823,7 +884,7 @@
               topicString, (unsigned char*)payload, payloadlen);
     if (len <= 0)
         goto exit;
-        
+
 #if MQTTCLIENT_QOS1 || MQTTCLIENT_QOS2
     if (!cleansession)
     {
@@ -836,7 +897,7 @@
 #endif
     }
 #endif
-        
+
     rc = publish(len, timer, qos);
 exit:
     return rc;
@@ -862,14 +923,17 @@
 int MQTT::Client<Network, Timer, MAX_MQTT_PACKET_SIZE, b>::disconnect()
 {
     int rc = FAILURE;
-    Timer timer = Timer(command_timeout_ms);     // we might wait for incomplete incoming publishes to complete
+    Timer timer(command_timeout_ms);            // we might wait for incomplete incoming publishes to complete
     int len = MQTTSerialize_disconnect(sendbuf, MAX_MQTT_PACKET_SIZE);
     if (len > 0)
         rc = sendPacket(len, timer);            // send the disconnect packet
 
-    isconnected = false;
+    if (cleansession)
+        cleanSession();
+    else
+        isconnected = false;
     return rc;
 }
 
 
-#endif
+#endif
\ No newline at end of file