Simple WebSocket server library.
Dependents: WebSocketServerTest
Revision 0:a816c25e83ed, committed 2015-03-16
- Comitter:
- flatbird
- Date:
- Mon Mar 16 10:13:30 2015 +0000
- Child:
- 1:db4114d55f83
- Commit message:
- WebSocketServer library
Changed in this revision
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/WebSocketConnection.cpp Mon Mar 16 10:13:30 2015 +0000 @@ -0,0 +1,239 @@ +#include "WebSocketConnection.h" +#include "WebSocketServer.h" +#include "sha1.h" + +#define UPGRADE_WEBSOCKET "Upgrade: websocket" +#define SEC_WEBSOCKET_KEY "Sec-WebSocket-Key:" +#define MAGIC_NUMBER "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" +#define OP_CONT 0x0 +#define OP_TEXT 0x1 +#define OP_BINARY 0x2 +#define OP_CLOSE 0x8 +#define OP_PING 0x9 +#define OP_PONG 0xA + +WebSocketConnection::WebSocketConnection(WebSocketServer* server) +{ + mServer = server; +} + +WebSocketConnection::~WebSocketConnection() +{ +} + +void WebSocketConnection::run() +{ + char buf[1024]; + bool isWebSocket = false; + + // it doesn't work... + // mConnection.set_blocking(true); + + while (mConnection.is_connected()) { + int ret = mConnection.receive(buf, sizeof(buf) - 1); + if (ret == 0) { + // printf("No data to receive\r\n"); + continue; + } + if (ret < 0) { + printf("ERROR: Failed to receive %d\r\n", ret); + break; + } + if (!isWebSocket) { + if (this->handleHTTP(buf, ret)) { + isWebSocket = true; + } else { + printf("ERROR: Non websocket\r\n"); + break; + } + } else { + if (!this->handleWebSocket(buf, ret)) { + break; + } + } + } + // printf("Closed\r\n"); + mConnection.close(); +} + +bool WebSocketConnection::handleHTTP(char* buf, int size) +{ + char* line = &buf[0]; + char key[128]; + bool isUpgradeWebSocket = false; + bool isSecWebSocketKeyFound = false; + + for (int i = 0; i < size; i++) { + if (buf[i] == '\r' && i+1 < size && buf[i+1] == '\n') { + buf[i] = '\0'; + if (strlen(buf) <= 0) { + break; + } + printf("[%s]\r\n", line); + if (line == &buf[0]) { + char* method = strtok(buf, " "); + char* path = strtok(NULL, " "); + char* version = strtok(NULL, " "); + // printf("[%s] [%s] [%s]\r\n", method, path, version); + mHandler = mServer->getHandler(path); + if (!mHandler) { + printf("ERROR: Handler not found for %s\r\n", path); + return false; + } + } else if (strncmp(line, UPGRADE_WEBSOCKET, strlen(UPGRADE_WEBSOCKET)) == 0) { + isUpgradeWebSocket = true; + } else if (strncmp(line, SEC_WEBSOCKET_KEY, strlen(SEC_WEBSOCKET_KEY)) == 0) { + isSecWebSocketKeyFound = true; + char* ptr = line + strlen(SEC_WEBSOCKET_KEY); + while (*ptr == ' ') ++ptr; + strcpy(key, ptr); + } + i += 2; + line = &buf[i]; + } + } + + if (isUpgradeWebSocket && isSecWebSocketKeyFound) { + this->sendUpgradeResponse(key); + if (mHandler) { + mHandler->onOpen(); + } + mPrevFin = true; + return true; + } + + return false; +} + +bool WebSocketConnection::handleWebSocket(char* buf, int size) +{ + uint8_t* ptr = (uint8_t*)buf; + + bool fin = (*ptr & 0x80) == 0x80; + uint8_t opcode = *ptr & 0xF; + + if (opcode == OP_PING) { + *ptr = ((*ptr & 0xF0) | OP_PONG); + mConnection.send_all(buf, size); + return true; + } + if (opcode == OP_CLOSE) { + if (mHandler) { + mHandler->onClose(); + } + return false; + } + ptr++; + + if (!fin || !mPrevFin) { + printf("WARN: Data consists of multiple frame not supported\r\n"); + mPrevFin = fin; + return true; // not an error, just discard it + } + mPrevFin = fin; + + bool mask = (*ptr & 0x80) == 0x80; + uint8_t len = *ptr & 0x7F; + ptr++; + + if (len > 125) { + printf("WARN: Extended payload length not supported\r\n"); + return true; // not an error, just discard it + } + + char* data; + if (mask) { + char* maskingKey = (char*)ptr; + data = (char*)(ptr + 4); + for (int i = 0; i < len; i++) { + data[i] = data[i] ^ maskingKey[(i % 4)]; + } + } else { + data = (char*)ptr; + } + if (mHandler) { + if (opcode == OP_TEXT) { + data[len] = '\0'; + mHandler->onMessage(data); + } else if (opcode == OP_BINARY) { + mHandler->onMessage(data, len); + } + } + return true; +} + +char* base64Encode(const uint8_t* data, size_t size, + char* outputBuffer, size_t outputBufferSize) +{ + static char encodingTable[] = {'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', + 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', + 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', + 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', + 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', + 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', + 'w', 'x', 'y', 'z', '0', '1', '2', '3', + '4', '5', '6', '7', '8', '9', '+', '/'}; + size_t outputLength = 4 * ((size + 2) / 3); + if (outputBufferSize - 1 < outputLength) { // -1 for NUL + return NULL; + } + + for (size_t i = 0, j = 0; i < size; /* nothing */) { + uint32_t octet1 = i < size ? (unsigned char)data[i++] : 0; + uint32_t octet2 = i < size ? (unsigned char)data[i++] : 0; + uint32_t octet3 = i < size ? (unsigned char)data[i++] : 0; + + uint32_t triple = (octet1 << 0x10) + (octet2 << 0x08) + octet3; + + outputBuffer[j++] = encodingTable[(triple >> 3 * 6) & 0x3F]; + outputBuffer[j++] = encodingTable[(triple >> 2 * 6) & 0x3F]; + outputBuffer[j++] = encodingTable[(triple >> 1 * 6) & 0x3F]; + outputBuffer[j++] = encodingTable[(triple >> 0 * 6) & 0x3F]; + } + + static int padTable[] = { 0, 2, 1 }; + int paddingCount = padTable[size % 3]; + + for (int i = 0; i < paddingCount; i++) { + outputBuffer[outputLength - 1 - i] = '='; + } + outputBuffer[outputLength] = '\0'; // NUL + + return outputBuffer; +} + +bool WebSocketConnection::sendUpgradeResponse(char* key) +{ + char buf[128]; + + if (strlen(key) + sizeof(MAGIC_NUMBER) > sizeof(buf)) { + return false; + } + strcpy(buf, key); + strcat(buf, MAGIC_NUMBER); + + uint8_t hash[20]; + SHA1Context sha; + SHA1Reset(&sha); + SHA1Input(&sha, (unsigned char*)buf, strlen(buf)); + SHA1Result(&sha, (uint8_t*)hash); + + char encoded[30]; + base64Encode(hash, 20, encoded, sizeof(encoded)); + + char resp[] = "HTTP/1.1 101 Switching Protocols\r\n" \ + "Upgrade: websocket\r\n" \ + "Connection: Upgrade\r\n" \ + "Sec-WebSocket-Accept: XXXXXXXXXXXXXXXXXXXXXXXXXXXXX\r\n\r\n"; + char* ptr = strstr(resp, "XXXXX"); + strcpy(ptr, encoded); + strcpy(ptr+strlen(encoded), "\r\n\r\n"); + + int ret = mConnection.send_all(resp, strlen(resp)); + if (ret < 0) { + printf("ERROR: Failed to send response\r\n"); + return false; + } + + return true; +}
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/WebSocketConnection.h Mon Mar 16 10:13:30 2015 +0000 @@ -0,0 +1,31 @@ +#ifndef _WEB_SOCKET_CONNECTION_H_ +#define _WEB_SOCKET_CONNECTION_H_ + +#include "TCPSocketServer.h" +#include "WebSocketHandler.h" +#include <string> +#include <map> + +class WebSocketServer; + +class WebSocketConnection +{ +public: + WebSocketConnection(WebSocketServer* server); + virtual ~WebSocketConnection(); + + void run(); + TCPSocketConnection& getTCPSocketConnection() { return mConnection; } + +private: + bool handleHTTP(char* buf, int size); + bool handleWebSocket(char* buf, int size); + bool sendUpgradeResponse(char* key); + + WebSocketServer* mServer; + TCPSocketConnection mConnection; + WebSocketHandler* mHandler; + bool mPrevFin; +}; + +#endif
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/WebSocketHandler.h Mon Mar 16 10:13:30 2015 +0000 @@ -0,0 +1,16 @@ +#ifndef _WEB_SOCKET_HANDLER_H_ +#define _WEB_SOCKET_HANDLER_H_ + +class WebSocketHandler +{ +public: + virtual void onOpen() {}; + virtual void onClose() {}; + // to receive text message + virtual void onMessage(char* text) {}; + // to receive binary message + virtual void onMessage(char* data, size_t size) {}; + virtual void onError() {}; +}; + +#endif
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/WebSocketServer.cpp Mon Mar 16 10:13:30 2015 +0000 @@ -0,0 +1,59 @@ +#include "WebSocketServer.h" +#include "WebSocketConnection.h" + +WebSocketServer::WebSocketServer() +{ +} + +WebSocketServer::~WebSocketServer() +{ +} + +bool WebSocketServer::init(int port) +{ + mTCPSocketServer.set_blocking(true); + + int ret = mTCPSocketServer.bind(port); + if (ret != 0) { + printf("ERROR: Failed to bind %d\r\n", ret); + return false; + } + ret = mTCPSocketServer.listen(); + if (ret != 0) { + printf("ERROR: Failed to listen %d\r\n", ret); + return false; + } + + return true; +} + +void WebSocketServer::run() +{ + WebSocketConnection connection(this); + + while (true) { + // printf("accepting\r\n"); + int ret = mTCPSocketServer.accept(connection.getTCPSocketConnection()); + if (ret != 0) { + continue; + } + connection.run(); + } +} + +void WebSocketServer::setHandler(const char* path, WebSocketHandler* handler) +{ + mHandlers[path] = handler; +} + +WebSocketHandler* WebSocketServer::getHandler(const char* path) +{ + WebSocketHandlerContainer::iterator it; + + it = mHandlers.find(path); + if (it != mHandlers.end()) { + return it->second; + } + return NULL; +} +
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/WebSocketServer.h Mon Mar 16 10:13:30 2015 +0000 @@ -0,0 +1,27 @@ +#ifndef _WEB_SOCKET_SERVER_H_ +#define _WEB_SOCKET_SERVER_H_ + +#include "TCPSocketServer.h" +#include "WebSocketHandler.h" +#include <string> +#include <map> + +class WebSocketServer +{ +public: + WebSocketServer(); + virtual ~WebSocketServer(); + + bool init(int port); + void run(); + void setHandler(const char* path, WebSocketHandler* handler); + WebSocketHandler* getHandler(const char* path); + +private: + typedef std::map<std::string, WebSocketHandler*> WebSocketHandlerContainer; + + TCPSocketServer mTCPSocketServer; + WebSocketHandlerContainer mHandlers; +}; + +#endif
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/sha1.cpp Mon Mar 16 10:13:30 2015 +0000 @@ -0,0 +1,386 @@ +/* + * From RFC3174 + */ +/* + * sha1.c + * + * Description: + * This file implements the Secure Hashing Algorithm 1 as + * defined in FIPS PUB 180-1 published April 17, 1995. + * + * The SHA-1, produces a 160-bit message digest for a given + * data stream. It should take about 2**n steps to find a + * message with the same digest as a given message and + * 2**(n/2) to find any two messages with the same digest, + * when n is the digest size in bits. Therefore, this + * algorithm can serve as a means of providing a + * "fingerprint" for a message. + * + * Portability Issues: + * SHA-1 is defined in terms of 32-bit "words". This code + * uses <stdint.h> (included via "sha1.h" to define 32 and 8 + * bit unsigned integer types. If your C compiler does not + * support 32 bit unsigned integers, this code is not + * appropriate. + * + * Caveats: + * SHA-1 is designed to work with messages less than 2^64 bits + * long. Although SHA-1 allows a message digest to be generated + * for messages of any number of bits less than 2^64, this + * implementation only works with messages with a length that is + * a multiple of the size of an 8-bit character. + * + */ + #include "sha1.h" + +/* + * Define the SHA1 circular left shift macro + */ +#define SHA1CircularShift(bits,word) \ + (((word) << (bits)) | ((word) >> (32-(bits)))) + +/* Local Function Prototyptes */ +void SHA1PadMessage(SHA1Context *); +void SHA1ProcessMessageBlock(SHA1Context *); + +/* + * SHA1Reset + * + * Description: + * This function will initialize the SHA1Context in preparation + * for computing a new SHA1 message digest. + * + * Parameters: + * context: [in/out] + * The context to reset. + * + * Returns: + * sha Error Code. + * + */ +int SHA1Reset(SHA1Context *context) +{ + if (!context) + { + return shaNull; + } + + context->Length_Low = 0; + context->Length_High = 0; + context->Message_Block_Index = 0; + + context->Intermediate_Hash[0] = 0x67452301; + context->Intermediate_Hash[1] = 0xEFCDAB89; + context->Intermediate_Hash[2] = 0x98BADCFE; + context->Intermediate_Hash[3] = 0x10325476; + context->Intermediate_Hash[4] = 0xC3D2E1F0; + + context->Computed = 0; + context->Corrupted = 0; + + return shaSuccess; +} + +/* + * SHA1Result + * + * Description: + * This function will return the 160-bit message digest into the + * Message_Digest array provided by the caller. + * NOTE: The first octet of hash is stored in the 0th element, + * the last octet of hash in the 19th element. + * + * Parameters: + * context: [in/out] + * The context to use to calculate the SHA-1 hash. + * Message_Digest: [out] + * Where the digest is returned. + * + * Returns: + * sha Error Code. + * + */ +int SHA1Result( SHA1Context *context, + uint8_t Message_Digest[SHA1HashSize]) +{ + int i; + + if (!context || !Message_Digest) + { + return shaNull; + } + + if (context->Corrupted) + { + return context->Corrupted; + } + + if (!context->Computed) + { + SHA1PadMessage(context); + for(i=0; i<64; ++i) + { + /* message may be sensitive, clear it out */ + context->Message_Block[i] = 0; + } + context->Length_Low = 0; /* and clear length */ + context->Length_High = 0; + context->Computed = 1; + } + + for(i = 0; i < SHA1HashSize; ++i) + { + Message_Digest[i] = context->Intermediate_Hash[i>>2] + >> 8 * ( 3 - ( i & 0x03 ) ); + } + + return shaSuccess; +} + +/* + * SHA1Input + * + * Description: + * This function accepts an array of octets as the next portion + * of the message. + * + * Parameters: + * context: [in/out] + * The SHA context to update + * message_array: [in] + * An array of characters representing the next portion of + * the message. + * length: [in] + * The length of the message in message_array + * + * Returns: + * sha Error Code. + * + */ +int SHA1Input( SHA1Context *context, + const uint8_t *message_array, + unsigned length) +{ + if (!length) + { + return shaSuccess; + } + + if (!context || !message_array) + { + return shaNull; + } + + if (context->Computed) + { + context->Corrupted = shaStateError; + return shaStateError; + } + + if (context->Corrupted) + { + return context->Corrupted; + } + while(length-- && !context->Corrupted) + { + context->Message_Block[context->Message_Block_Index++] = + (*message_array & 0xFF); + + context->Length_Low += 8; + if (context->Length_Low == 0) + { + context->Length_High++; + if (context->Length_High == 0) + { + /* Message is too long */ + context->Corrupted = 1; + } + } + + if (context->Message_Block_Index == 64) + { + SHA1ProcessMessageBlock(context); + } + + message_array++; + } + + return shaSuccess; +} + +/* + * SHA1ProcessMessageBlock + * + * Description: + * This function will process the next 512 bits of the message + * stored in the Message_Block array. + * + * Parameters: + * None. + * + * Returns: + * Nothing. + * + * Comments: + * Many of the variable names in this code, especially the + * single character names, were used because those were the + * names used in the publication. + * + * + */ +void SHA1ProcessMessageBlock(SHA1Context *context) +{ + const uint32_t K[] = { /* Constants defined in SHA-1 */ + 0x5A827999, + 0x6ED9EBA1, + 0x8F1BBCDC, + 0xCA62C1D6 + }; + int t; /* Loop counter */ + uint32_t temp; /* Temporary word value */ + uint32_t W[80]; /* Word sequence */ + uint32_t A, B, C, D, E; /* Word buffers */ + + /* + * Initialize the first 16 words in the array W + */ + for(t = 0; t < 16; t++) + { + W[t] = context->Message_Block[t * 4] << 24; + W[t] |= context->Message_Block[t * 4 + 1] << 16; + W[t] |= context->Message_Block[t * 4 + 2] << 8; + W[t] |= context->Message_Block[t * 4 + 3]; + } + + for(t = 16; t < 80; t++) + { + W[t] = SHA1CircularShift(1,W[t-3] ^ W[t-8] ^ W[t-14] ^ W[t-16]); + } + + A = context->Intermediate_Hash[0]; + B = context->Intermediate_Hash[1]; + C = context->Intermediate_Hash[2]; + D = context->Intermediate_Hash[3]; + E = context->Intermediate_Hash[4]; + + for(t = 0; t < 20; t++) + { + temp = SHA1CircularShift(5,A) + + ((B & C) | ((~B) & D)) + E + W[t] + K[0]; + E = D; + D = C; + C = SHA1CircularShift(30,B); + B = A; + A = temp; + } + + for(t = 20; t < 40; t++) + { + temp = SHA1CircularShift(5,A) + (B ^ C ^ D) + E + W[t] + K[1]; + E = D; + D = C; + C = SHA1CircularShift(30,B); + B = A; + A = temp; + } + + for(t = 40; t < 60; t++) + { + temp = SHA1CircularShift(5,A) + + ((B & C) | (B & D) | (C & D)) + E + W[t] + K[2]; + E = D; + D = C; + C = SHA1CircularShift(30,B); + B = A; + A = temp; + } + + for(t = 60; t < 80; t++) + { + temp = SHA1CircularShift(5,A) + (B ^ C ^ D) + E + W[t] + K[3]; + E = D; + D = C; + C = SHA1CircularShift(30,B); + B = A; + A = temp; + } + + context->Intermediate_Hash[0] += A; + context->Intermediate_Hash[1] += B; + context->Intermediate_Hash[2] += C; + context->Intermediate_Hash[3] += D; + context->Intermediate_Hash[4] += E; + + context->Message_Block_Index = 0; +} + + +/* + * SHA1PadMessage + * + * Description: + * According to the standard, the message must be padded to an even + * 512 bits. The first padding bit must be a '1'. The last 64 + * bits represent the length of the original message. All bits in + * between should be 0. This function will pad the message + * according to those rules by filling the Message_Block array + * accordingly. It will also call the ProcessMessageBlock function + * provided appropriately. When it returns, it can be assumed that + * the message digest has been computed. + * + * Parameters: + * context: [in/out] + * The context to pad + * ProcessMessageBlock: [in] + * The appropriate SHA*ProcessMessageBlock function + * Returns: + * Nothing. + * + */ + +void SHA1PadMessage(SHA1Context *context) +{ + /* + * Check to see if the current message block is too small to hold + * the initial padding bits and length. If so, we will pad the + * block, process it, and then continue padding into a second + * block. + */ + if (context->Message_Block_Index > 55) + { + context->Message_Block[context->Message_Block_Index++] = 0x80; + while(context->Message_Block_Index < 64) + { + context->Message_Block[context->Message_Block_Index++] = 0; + } + + SHA1ProcessMessageBlock(context); + + while(context->Message_Block_Index < 56) + { + context->Message_Block[context->Message_Block_Index++] = 0; + } + } + else + { + context->Message_Block[context->Message_Block_Index++] = 0x80; + while(context->Message_Block_Index < 56) + { + context->Message_Block[context->Message_Block_Index++] = 0; + } + } + + /* + * Store the message length as the last 8 octets + */ + context->Message_Block[56] = context->Length_High >> 24; + context->Message_Block[57] = context->Length_High >> 16; + context->Message_Block[58] = context->Length_High >> 8; + context->Message_Block[59] = context->Length_High; + context->Message_Block[60] = context->Length_Low >> 24; + context->Message_Block[61] = context->Length_Low >> 16; + context->Message_Block[62] = context->Length_Low >> 8; + context->Message_Block[63] = context->Length_Low; + + SHA1ProcessMessageBlock(context); +}
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/sha1.h Mon Mar 16 10:13:30 2015 +0000 @@ -0,0 +1,76 @@ +/* + * From RFC3174 + */ + +/* + * sha1.h + * + * Description: + * This is the header file for code which implements the Secure + * Hashing Algorithm 1 as defined in FIPS PUB 180-1 published + * April 17, 1995. + * + * Many of the variable names in this code, especially the + * single character names, were used because those were the names + * used in the publication. + * + * Please read the file sha1.c for more information. + * + */ +#ifndef _SHA1_H_ +#define _SHA1_H_ + +#include <stdint.h> +/* + * If you do not have the ISO standard stdint.h header file, then you + * must typdef the following: + * name meaning + * uint32_t unsigned 32 bit integer + * uint8_t unsigned 8 bit integer (i.e., unsigned char) + * int_least16_t integer of >= 16 bits + * + */ + +#ifndef _SHA_enum_ +#define _SHA_enum_ +enum +{ + shaSuccess = 0, + shaNull, /* Null pointer parameter */ + shaInputTooLong, /* input data too long */ + shaStateError /* called Input after Result */ +}; +#endif +#define SHA1HashSize 20 + +/* + * This structure will hold context information for the SHA-1 + * hashing operation + */ +typedef struct SHA1Context +{ + uint32_t Intermediate_Hash[SHA1HashSize/4]; /* Message Digest */ + + uint32_t Length_Low; /* Message length in bits */ + uint32_t Length_High; /* Message length in bits */ + + /* Index into message block array */ + int_least16_t Message_Block_Index; + uint8_t Message_Block[64]; /* 512-bit message blocks */ + + int Computed; /* Is the digest computed? */ + int Corrupted; /* Is the message digest corrupted? */ +} SHA1Context; + +/* + * Function Prototypes + */ + +int SHA1Reset( SHA1Context *); +int SHA1Input( SHA1Context *, + const uint8_t *, + unsigned int); +int SHA1Result( SHA1Context *, + uint8_t Message_Digest[SHA1HashSize]); + +#endif