Simple WebSocket server library.

Dependents:   WebSocketServerTest

WebSocketConnection.cpp

Committer:
flatbird
Date:
2015-03-16
Revision:
0:a816c25e83ed
Child:
1:db4114d55f83

File content as of revision 0:a816c25e83ed:

#include "WebSocketConnection.h"
#include "WebSocketServer.h"
#include "sha1.h"

#define UPGRADE_WEBSOCKET	"Upgrade: websocket"
#define SEC_WEBSOCKET_KEY	"Sec-WebSocket-Key:"
#define MAGIC_NUMBER		"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
#define OP_CONT		0x0
#define OP_TEXT		0x1
#define OP_BINARY	0x2
#define OP_CLOSE	0x8
#define OP_PING		0x9
#define OP_PONG		0xA

WebSocketConnection::WebSocketConnection(WebSocketServer* server)
{
	mServer = server;
}

WebSocketConnection::~WebSocketConnection()
{
}

void WebSocketConnection::run()
{
	char buf[1024];
	bool isWebSocket = false;

	// it doesn't work...
	// mConnection.set_blocking(true);

	while (mConnection.is_connected()) {
		int ret = mConnection.receive(buf, sizeof(buf) - 1);
		if (ret == 0) {
			// printf("No data to receive\r\n");
			continue;
		}
		if (ret < 0) {
			printf("ERROR: Failed to receive %d\r\n", ret);
			break;
		}
		if (!isWebSocket) {
			if (this->handleHTTP(buf, ret)) {
				isWebSocket = true;
			} else {
				printf("ERROR: Non websocket\r\n");
				break;
			}
		} else {
			if (!this->handleWebSocket(buf, ret)) {
				break;
			}
		}
	}
	// printf("Closed\r\n");
	mConnection.close();
}

bool WebSocketConnection::handleHTTP(char* buf, int size)
{
	char* line = &buf[0];
	char key[128];
	bool isUpgradeWebSocket = false;
	bool isSecWebSocketKeyFound = false;

	for (int i = 0; i < size; i++) {
		if (buf[i] == '\r' && i+1 < size && buf[i+1] == '\n') {
			buf[i] = '\0';
			if (strlen(buf) <= 0) {
				break;
			}
			printf("[%s]\r\n", line);
			if (line == &buf[0]) {
				char* method = strtok(buf, " ");
				char* path = strtok(NULL, " ");
				char* version = strtok(NULL, " ");
				// printf("[%s] [%s] [%s]\r\n", method, path, version);
				mHandler = mServer->getHandler(path);
				if (!mHandler) {
					printf("ERROR: Handler not found for %s\r\n", path);
					return false;
				}
			} else if (strncmp(line, UPGRADE_WEBSOCKET, strlen(UPGRADE_WEBSOCKET)) == 0) {
				isUpgradeWebSocket = true;
			} else if (strncmp(line, SEC_WEBSOCKET_KEY, strlen(SEC_WEBSOCKET_KEY)) == 0) {
				isSecWebSocketKeyFound = true;
				char* ptr = line + strlen(SEC_WEBSOCKET_KEY);
				while (*ptr == ' ') ++ptr;
				strcpy(key, ptr);
			}
			i += 2;
			line = &buf[i];
		}
	}

	if (isUpgradeWebSocket && isSecWebSocketKeyFound) {
		this->sendUpgradeResponse(key);
		if (mHandler) {
			mHandler->onOpen();
		}
		mPrevFin = true;
		return true;
	}

	return false;
}

bool WebSocketConnection::handleWebSocket(char* buf, int size)
{
	uint8_t* ptr = (uint8_t*)buf;

	bool fin = (*ptr & 0x80) == 0x80;
	uint8_t opcode = *ptr & 0xF;

	if (opcode == OP_PING) {
		*ptr = ((*ptr & 0xF0) | OP_PONG);
		mConnection.send_all(buf, size);
		return true;
	}
	if (opcode == OP_CLOSE) {
		if (mHandler) {
			mHandler->onClose();
		}
		return false;
	}
	ptr++;

	if (!fin || !mPrevFin) {	
		printf("WARN: Data consists of multiple frame not supported\r\n");
		mPrevFin = fin;
		return true; // not an error, just discard it
	}
	mPrevFin = fin;

	bool mask = (*ptr & 0x80) == 0x80;
	uint8_t len = *ptr & 0x7F;
	ptr++;
	
	if (len > 125) {
		printf("WARN: Extended payload length not supported\r\n");
		return true; // not an error, just discard it
	}

	char* data;
	if (mask) {
		char* maskingKey = (char*)ptr;
		data = (char*)(ptr + 4);
		for (int i = 0; i < len; i++) {
        	data[i] = data[i] ^ maskingKey[(i % 4)];
        }
	} else {
		data = (char*)ptr;
	}
	if (mHandler) {
		if (opcode == OP_TEXT) {
			data[len] = '\0';
			mHandler->onMessage(data);
		} else if (opcode == OP_BINARY) {
			mHandler->onMessage(data, len);
		}
	}
	return true;
}

char* base64Encode(const uint8_t* data, size_t size,
                   char* outputBuffer, size_t outputBufferSize)
{
	static char encodingTable[] = {'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H',
	                               'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P',
	                               'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X',
	                               'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f',
	                               'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n',
	                               'o', 'p', 'q', 'r', 's', 't', 'u', 'v',
	                               'w', 'x', 'y', 'z', '0', '1', '2', '3',
	                               '4', '5', '6', '7', '8', '9', '+', '/'};
    size_t outputLength = 4 * ((size + 2) / 3);
    if (outputBufferSize - 1 < outputLength) { // -1 for NUL
    	return NULL;
    }

    for (size_t i = 0, j = 0; i < size; /* nothing */) {
        uint32_t octet1 = i < size ? (unsigned char)data[i++] : 0;
        uint32_t octet2 = i < size ? (unsigned char)data[i++] : 0;
        uint32_t octet3 = i < size ? (unsigned char)data[i++] : 0;

        uint32_t triple = (octet1 << 0x10) + (octet2 << 0x08) + octet3;

        outputBuffer[j++] = encodingTable[(triple >> 3 * 6) & 0x3F];
        outputBuffer[j++] = encodingTable[(triple >> 2 * 6) & 0x3F];
        outputBuffer[j++] = encodingTable[(triple >> 1 * 6) & 0x3F];
        outputBuffer[j++] = encodingTable[(triple >> 0 * 6) & 0x3F];
    }

	static int padTable[] = { 0, 2, 1 };
	int paddingCount = padTable[size % 3];

    for (int i = 0; i < paddingCount; i++) {
        outputBuffer[outputLength - 1 - i] = '=';
    }
    outputBuffer[outputLength] = '\0'; // NUL

    return outputBuffer;
}

bool WebSocketConnection::sendUpgradeResponse(char* key)
{
	char buf[128];

	if (strlen(key) + sizeof(MAGIC_NUMBER) > sizeof(buf)) {
		return false;
	}
	strcpy(buf, key);
	strcat(buf, MAGIC_NUMBER);

    uint8_t hash[20];
	SHA1Context sha;
    SHA1Reset(&sha);
    SHA1Input(&sha, (unsigned char*)buf, strlen(buf));
    SHA1Result(&sha, (uint8_t*)hash);

	char encoded[30];
    base64Encode(hash, 20, encoded, sizeof(encoded));

    char resp[] = "HTTP/1.1 101 Switching Protocols\r\n" \
	    "Upgrade: websocket\r\n" \
    	"Connection: Upgrade\r\n" \
    	"Sec-WebSocket-Accept: XXXXXXXXXXXXXXXXXXXXXXXXXXXXX\r\n\r\n";
    char* ptr = strstr(resp, "XXXXX");
    strcpy(ptr, encoded);
    strcpy(ptr+strlen(encoded), "\r\n\r\n");

    int ret = mConnection.send_all(resp, strlen(resp));
    if (ret < 0) {
    	printf("ERROR: Failed to send response\r\n");
    	return false;
    }

    return true;
}