Modified MQTT for Mbed OS.

Dependencies:   FP MQTTPacket

Dependents:   mbed-os-mqtt door_lock co657_IoT nucleo-f429zi-mbed-os-mqtt

Fork of MQTT by MQTT

Revision:
30:a4e3a97dabe3
Parent:
29:833386b16f3e
Parent:
28:8b2abe9bd814
Child:
31:a51dd239b78e
diff -r 833386b16f3e -r a4e3a97dabe3 MQTTClient.h
--- a/MQTTClient.h	Tue May 20 15:03:29 2014 +0000
+++ b/MQTTClient.h	Tue May 20 15:07:11 2014 +0000
@@ -18,13 +18,7 @@
  
  TODO: 
  
- log messages - use macros
- 
- define return code constants
- 
- call connectionLost at appropriate points - in sendPacket and readPacket
- 
- match wildcard topics
+ ensure publish packets are retried on reconnect
  
  updating usage of FP. Try to remove inclusion of FP.cpp in main. sg-
  
@@ -96,14 +90,7 @@
 {
     
 public:
-
-    typedef struct
-    {
-        Client* client;
-        Network* network;
-    } connectionLostInfo;
-    
-    typedef int (*connectionLostHandlers)(connectionLostInfo*);
+   
     typedef void (*messageHandler)(Message*);
 
     /** Construct the client
@@ -113,14 +100,6 @@
      */
     Client(Network& network, unsigned int command_timeout_ms = 30000); 
     
-    /** Set the connection lost callback - called whenever the connection is lost and we should be connected
-     *  @param clh - pointer to the callback function
-     */
-    void setConnectionLostHandler(connectionLostHandlers clh)
-    {
-        connectionLostHandler.attach(clh);
-    }
-    
     /** Set the default message handling callback - used for any message which does not match a subscription message handler
      *  @param mh - pointer to the callback function
      */
@@ -166,8 +145,9 @@
      *  yield can be called if no other MQTT operation is needed.  This will also allow messages to be 
      *  received.
      *  @param timeout_ms the time to wait, in milliseconds
+     *  @return success code - on failure, this means the client has disconnected
      */
-    void yield(int timeout_ms = 1000);
+    int yield(int timeout_ms = 1000);
     
 private:
 
@@ -178,7 +158,8 @@
     int decodePacket(int* value, int timeout);
     int readPacket(Timer& timer);
     int sendPacket(int length, Timer& timer);
-    int deliverMessage(MQTTString* topic, Message* message);
+    int deliverMessage(MQTTString& topicName, Message& message);
+    bool isTopicMatched(char* topicFilter, MQTTString& topicName);
     
     Network& ipstack;
     unsigned int command_timeout_ms;
@@ -192,19 +173,18 @@
     
     PacketId packetid;
     
-//    typedef FP<void, Message*> messageHandlerFP;
-//    FP<void, Message*> messageHandlerFP;
+    // typedef FP<void, Message*> messageHandlerFP;
     struct MessageHandlers
     {
-        const char* topic;
+        const char* topicFilter;
+        //messageHandlerFP fp; typedefs not liked?
         FP<void, Message*> fp;
     } messageHandlers[MAX_MESSAGE_HANDLERS];      // Message handlers are indexed by subscription topic
     
     FP<void, Message*> defaultMessageHandler;
-    
-    FP<int, connectionLostInfo*> connectionLostHandler;
-//    connectionLostFP connectionLostHandler;
-    
+     
+    bool isconnected;
+
 };
 
 }
@@ -216,8 +196,9 @@
     ping_timer = Timer();
     ping_outstanding = 0;
     for (int i = 0; i < MAX_MESSAGE_HANDLERS; ++i)
-        messageHandlers[i].topic = 0;
-    this->command_timeout_ms = command_timeout_ms;    
+        messageHandlers[i].topicFilter = 0;
+    this->command_timeout_ms = command_timeout_ms; 
+    isconnected = false;
 }
 
 
@@ -230,13 +211,9 @@
     while (sent < length && !timer.expired())
     {
         rc = ipstack.write(&buf[sent], length, timer.left_ms());
-        if (rc == -1)
-        {
-            connectionLostInfo info = {this, &ipstack};
-            connectionLostHandler(&info);
-        }
-        else
-            sent += rc;
+        if (rc < 0)  // there was an error writing the data
+            break;
+        sent += rc;
     }
     if (sent == length)
     {
@@ -249,7 +226,8 @@
 }
 
 
-template<class Network, class Timer, int a, int b> int MQTT::Client<Network, Timer, a, b>::decodePacket(int* value, int timeout)
+template<class Network, class Timer, int a, int b> 
+int MQTT::Client<Network, Timer, a, b>::decodePacket(int* value, int timeout)
 {
     char c;
     int multiplier = 1;
@@ -286,7 +264,7 @@
 template<class Network, class Timer, int a, int b> 
 int MQTT::Client<Network, Timer, a, b>::readPacket(Timer& timer) 
 {
-    int rc = -1;
+    int rc = FAILURE;
     MQTTHeader header = {0};
     int len = 0;
     int rem_len = 0;
@@ -311,23 +289,63 @@
 }
 
 
+// assume topic filter and name is in correct format
+// # can only be at end
+// + and # can only be next to separator
+template<class Network, class Timer, int a, int b> 
+bool MQTT::Client<Network, Timer, a, b>::isTopicMatched(char* topicFilter, MQTTString& topicName)
+{
+    char* curf = topicFilter;
+    char* curn = topicName.lenstring.data;
+    char* curn_end = curn + topicName.lenstring.len;
+    
+    while (*curf && curn < curn_end)
+    {
+        if (*curn == '/' && *curf != '/')
+            break;
+        if (*curf != '+' && *curf != '#' && *curf != *curn)
+            break;
+        if (*curf == '+')
+        {   // skip until we meet the next separator, or end of string
+            char* nextpos = curn + 1;
+            while (nextpos < curn_end && *nextpos != '/')
+                nextpos = ++curn + 1;
+        }
+        else if (*curf == '#')
+            curn = curn_end - 1;    // skip until end of string
+        curf++;
+        curn++;
+    };
+    
+    return (curn == curn_end) && (*curf == '\0');
+}
+
+
+
 template<class Network, class Timer, int a, int MAX_MESSAGE_HANDLERS> 
-int MQTT::Client<Network, Timer, a, MAX_MESSAGE_HANDLERS>::deliverMessage(MQTTString* topic, Message* message)
+int MQTT::Client<Network, Timer, a, MAX_MESSAGE_HANDLERS>::deliverMessage(MQTTString& topicName, Message& message)
 {
-    int rc = -1;
+    int rc = FAILURE;
 
     // we have to find the right message handler - indexed by topic
     for (int i = 0; i < MAX_MESSAGE_HANDLERS; ++i)
     {
-        if (messageHandlers[i].topic != 0 && MQTTPacket_equals(topic, (char*)messageHandlers[i].topic))
+        if (messageHandlers[i].topicFilter != 0 && (MQTTPacket_equals(&topicName, (char*)messageHandlers[i].topicFilter) ||
+                isTopicMatched((char*)messageHandlers[i].topicFilter, topicName)))
         {
-            messageHandlers[i].fp(message);
-            rc = 0;
-            break;
+            if (messageHandlers[i].fp.attached())
+            {
+                messageHandlers[i].fp(&message);
+                rc = SUCCESS;
+            }
         }
     }
-    if (rc == -1)
-        defaultMessageHandler(message);
+    
+    if (rc == FAILURE && defaultMessageHandler.attached()) 
+    {
+        defaultMessageHandler(&message);
+        rc = SUCCESS;
+    }   
     
     return rc;
 }
@@ -335,13 +353,22 @@
 
 
 template<class Network, class Timer, int a, int b> 
-void MQTT::Client<Network, Timer, a, b>::yield(int timeout_ms)
+int MQTT::Client<Network, Timer, a, b>::yield(int timeout_ms)
 {
+    int rc = SUCCESS;
     Timer timer = Timer();
     
     timer.countdown_ms(timeout_ms);
     while (!timer.expired())
-        cycle(timer);
+    {
+        if (cycle(timer) == FAILURE)
+        {
+            rc = FAILURE;
+            break;
+        }
+    }
+        
+    return rc;
 }
 
 
@@ -353,7 +380,9 @@
     // read the socket, see what work is due
     int packet_type = readPacket(timer);
     
-    int len = 0, rc;
+    int len = 0,
+        rc = SUCCESS;
+
     switch (packet_type)
     {
         case CONNACK:
@@ -363,28 +392,34 @@
         case PUBLISH:
             MQTTString topicName;
             Message msg;
-            rc = MQTTDeserialize_publish((int*)&msg.dup, (int*)&msg.qos, (int*)&msg.retained, (int*)&msg.id, &topicName,
-                                 (char**)&msg.payload, (int*)&msg.payloadlen, readbuf, MAX_MQTT_PACKET_SIZE);;
-            rc = rc;    // make sure optimizer doesnt omit this
-            deliverMessage(&topicName, &msg);
+            if (MQTTDeserialize_publish((int*)&msg.dup, (int*)&msg.qos, (int*)&msg.retained, (int*)&msg.id, &topicName,
+                                 (char**)&msg.payload, (int*)&msg.payloadlen, readbuf, MAX_MQTT_PACKET_SIZE) != 1)
+                goto exit;
+            deliverMessage(topicName, msg);
             if (msg.qos != QOS0)
             {
                 if (msg.qos == QOS1)
                     len = MQTTSerialize_ack(buf, MAX_MQTT_PACKET_SIZE, PUBACK, 0, msg.id);
                 else if (msg.qos == QOS2)
                     len = MQTTSerialize_ack(buf, MAX_MQTT_PACKET_SIZE, PUBREC, 0, msg.id);
-                if ((rc = sendPacket(len, timer)) != SUCCESS)
+                if (len <= 0)
+                    rc = FAILURE;
+                else
+                    rc = sendPacket(len, timer);
+                if (rc == FAILURE)
                     goto exit; // there was a problem
             }
             break;
         case PUBREC:
             int type, dup, mypacketid;
-            if (MQTTDeserialize_ack(&type, &dup, &mypacketid, readbuf, MAX_MQTT_PACKET_SIZE) == 1)
-                ; 
-            len = MQTTSerialize_ack(buf, MAX_MQTT_PACKET_SIZE, PUBREL, 0, mypacketid);
-            if ((rc = sendPacket(len, timer)) != SUCCESS) // send the PUBREL packet
+            if (MQTTDeserialize_ack(&type, &dup, &mypacketid, readbuf, MAX_MQTT_PACKET_SIZE) != 1)
+                rc = FAILURE;
+            else if ((len = MQTTSerialize_ack(buf, MAX_MQTT_PACKET_SIZE, PUBREL, 0, mypacketid)) <= 0)
+                rc = FAILURE;
+            else if ((rc = sendPacket(len, timer)) != SUCCESS) // send the PUBREL packet
+                rc = FAILURE; // there was a problem
+            if (rc == FAILURE)
                 goto exit; // there was a problem
-
             break;
         case PUBCOMP:
             break;
@@ -394,30 +429,30 @@
     }
     keepalive();
 exit:
-    return packet_type;
+    if (rc == SUCCESS)
+        rc = packet_type;
+    return rc;
 }
 
 
 template<class Network, class Timer, int MAX_MQTT_PACKET_SIZE, int b>
 int MQTT::Client<Network, Timer, MAX_MQTT_PACKET_SIZE, b>::keepalive()
 {
-    int rc = 0;
+    int rc = FAILURE;
 
     if (keepAliveInterval == 0)
+    {
+        rc = SUCCESS;
         goto exit;
+    }
 
     if (ping_timer.expired())
     {
-        if (ping_outstanding)
-            rc = -1;
-        else
+        if (!ping_outstanding)
         {
             Timer timer = Timer(1000);
             int len = MQTTSerialize_pingreq(buf, MAX_MQTT_PACKET_SIZE);
-            rc = sendPacket(len, timer); // send the ping packet
-            if (rc != SUCCESS) 
-                rc = -1; // indicate there's a problem
-            else
+            if (len > 0 && (rc = sendPacket(len, timer)) == SUCCESS) // send the ping packet
                 ping_outstanding = true;
         }
     }
@@ -431,7 +466,7 @@
 template<class Network, class Timer, int a, int b> 
 int MQTT::Client<Network, Timer, a, b>::waitfor(int packet_type, Timer& timer)
 {
-    int rc = -1;
+    int rc = FAILURE;
     
     do
     {
@@ -448,6 +483,7 @@
 int MQTT::Client<Network, Timer, MAX_MQTT_PACKET_SIZE, b>::connect(MQTTPacket_connectData* options)
 {
     Timer connect_timer = Timer(command_timeout_ms);
+    int rc = FAILURE;
 
     MQTTPacket_connectData default_options = MQTTPacket_connectData_initializer;
     if (options == 0)
@@ -456,8 +492,9 @@
     this->keepAliveInterval = options->keepAliveInterval;
     ping_timer.countdown(this->keepAliveInterval);
     int len = MQTTSerialize_connect(buf, MAX_MQTT_PACKET_SIZE, options);
-    int rc = sendPacket(len, connect_timer); // send the connect packet
-    if (rc != SUCCESS) 
+    if (len <= 0)
+        goto exit;
+    if ((rc = sendPacket(len, connect_timer)) != SUCCESS)  // send the connect packet
         goto exit; // there was a problem
     
     // this will be a blocking call, wait for the connack
@@ -466,9 +503,15 @@
         int connack_rc = -1;
         if (MQTTDeserialize_connack(&connack_rc, readbuf, MAX_MQTT_PACKET_SIZE) == 1)
             rc = connack_rc;
+        else
+            rc = FAILURE;
     }
+    else
+        rc = FAILURE;
     
 exit:
+    if (rc == SUCCESS)
+        isconnected = true;
     return rc;
 }
 
@@ -476,17 +519,19 @@
 template<class Network, class Timer, int MAX_MQTT_PACKET_SIZE, int MAX_MESSAGE_HANDLERS> 
 int MQTT::Client<Network, Timer, MAX_MQTT_PACKET_SIZE, MAX_MESSAGE_HANDLERS>::subscribe(const char* topicFilter, enum QoS qos, messageHandler messageHandler)
 { 
-    int len = -1;
+    int rc = FAILURE;
     Timer timer = Timer(command_timeout_ms);
+    int len = 0;
     
     MQTTString topic = {(char*)topicFilter, 0, 0};
-    
-    int rc = MQTTSerialize_subscribe(buf, MAX_MQTT_PACKET_SIZE, 0, packetid.getNext(), 1, &topic, (int*)&qos);
-    if (rc <= 0)
+    if (!isconnected)
         goto exit;
-    len = rc;
+    
+    len = MQTTSerialize_subscribe(buf, MAX_MQTT_PACKET_SIZE, 0, packetid.getNext(), 1, &topic, (int*)&qos);
+    if (len <= 0)
+        goto exit;
     if ((rc = sendPacket(len, timer)) != SUCCESS) // send the subscribe packet
-        goto exit; // there was a problem
+        goto exit;             // there was a problem
     
     if (waitfor(SUBACK, timer) == SUBACK)      // wait for suback 
     {
@@ -497,9 +542,9 @@
         {
             for (int i = 0; i < MAX_MESSAGE_HANDLERS; ++i)
             {
-                if (messageHandlers[i].topic == 0)
+                if (messageHandlers[i].topicFilter == 0)
                 {
-                    messageHandlers[i].topic = topicFilter;
+                    messageHandlers[i].topicFilter = topicFilter;
                     messageHandlers[i].fp.attach(messageHandler);
                     rc = 0;
                     break;
@@ -507,8 +552,12 @@
             }
         }
     }
-    
+    else 
+        rc = FAILURE;
+        
 exit:
+    //if (rc == FAILURE)
+    //   closesession();
     return rc;
 }
 
@@ -516,15 +565,14 @@
 template<class Network, class Timer, int MAX_MQTT_PACKET_SIZE, int MAX_MESSAGE_HANDLERS> 
 int MQTT::Client<Network, Timer, MAX_MQTT_PACKET_SIZE, MAX_MESSAGE_HANDLERS>::unsubscribe(const char* topicFilter)
 {   
-    int len = -1;
+    int rc = FAILURE;
     Timer timer = Timer(command_timeout_ms);
     
     MQTTString topic = {(char*)topicFilter, 0, 0};
     
-    int rc = MQTTSerialize_unsubscribe(buf, MAX_MQTT_PACKET_SIZE, 0, packetid.getNext(), 1, &topic);
-    if (rc <= 0)
+    int len = MQTTSerialize_unsubscribe(buf, MAX_MQTT_PACKET_SIZE, 0, packetid.getNext(), 1, &topic);
+    if (len <= 0)
         goto exit;
-    len = rc;
     if ((rc = sendPacket(len, timer)) != SUCCESS) // send the subscribe packet
         goto exit; // there was a problem
     
@@ -534,6 +582,8 @@
         if (MQTTDeserialize_unsuback(&mypacketid, readbuf, MAX_MQTT_PACKET_SIZE) == 1)
             rc = 0; 
     }
+    else
+        rc = FAILURE;
     
 exit:
     return rc;
@@ -544,6 +594,7 @@
 template<class Network, class Timer, int MAX_MQTT_PACKET_SIZE, int b> 
 int MQTT::Client<Network, Timer, MAX_MQTT_PACKET_SIZE, b>::publish(const char* topicName, Message* message)
 {
+    int rc = FAILURE;
     Timer timer = Timer(command_timeout_ms);
     
     MQTTString topicString = {(char*)topicName, 0, 0};
@@ -553,8 +604,9 @@
     
     int len = MQTTSerialize_publish(buf, MAX_MQTT_PACKET_SIZE, 0, message->qos, message->retained, message->id, 
               topicString, (char*)message->payload, message->payloadlen);
-    int rc = sendPacket(len, timer); // send the subscribe packet
-    if (rc != SUCCESS) 
+    if (len <= 0)
+        goto exit;
+    if ((rc = sendPacket(len, timer)) != SUCCESS) // send the subscribe packet
         goto exit; // there was a problem
     
     if (message->qos == QOS1)
@@ -562,18 +614,22 @@
         if (waitfor(PUBACK, timer) == PUBACK)
         {
             int type, dup, mypacketid;
-            if (MQTTDeserialize_ack(&type, &dup, &mypacketid, readbuf, MAX_MQTT_PACKET_SIZE) == 1)
-                rc = 0; 
+            if (MQTTDeserialize_ack(&type, &dup, &mypacketid, readbuf, MAX_MQTT_PACKET_SIZE) != 1)
+                rc = FAILURE;
         }
+        else
+            rc = FAILURE;
     }
     else if (message->qos == QOS2)
     {
         if (waitfor(PUBCOMP, timer) == PUBCOMP)
         {
             int type, dup, mypacketid;
-            if (MQTTDeserialize_ack(&type, &dup, &mypacketid, readbuf, MAX_MQTT_PACKET_SIZE) == 1)
-                rc = 0; 
+            if (MQTTDeserialize_ack(&type, &dup, &mypacketid, readbuf, MAX_MQTT_PACKET_SIZE) != 1)
+                rc = FAILURE;
         }
+        else
+            rc = FAILURE;
     }
     
 exit:
@@ -584,10 +640,12 @@
 template<class Network, class Timer, int MAX_MQTT_PACKET_SIZE, int b> 
 int MQTT::Client<Network, Timer, MAX_MQTT_PACKET_SIZE, b>::disconnect()
 {  
+    int rc = FAILURE;
     Timer timer = Timer(command_timeout_ms);     // we might wait for incomplete incoming publishes to complete
     int len = MQTTSerialize_disconnect(buf, MAX_MQTT_PACKET_SIZE);
-    int rc = sendPacket(len, timer);   // send the disconnect packet
-    
+    if (len > 0)
+        rc = sendPacket(len, timer);            // send the disconnect packet
+
     return rc;
 }