Commit 88e2c794 authored by Alex Hultman's avatar Alex Hultman

Merge WebSocketImpl with WebSocket, inline more of parser

parent fe9a7be3
......@@ -16,7 +16,6 @@ struct WIN32_EXPORT Group : uS::NodeData {
protected:
friend struct Hub;
friend struct WebSocket<isServer>;
friend class WebSocketProtocol<isServer>;
friend struct HttpSocket<false>;
friend struct HttpSocket<true>;
......
......@@ -77,7 +77,7 @@ uS::Socket *HttpSocket<isServer>::onData(uS::Socket *s, char *data, size_t lengt
return httpSocket;
}
httpSocket->httpBuffer.reserve(httpSocket->httpBuffer.length() + length + WebSocketProtocol<uWS::CLIENT>::CONSUME_POST_PADDING);
httpSocket->httpBuffer.reserve(httpSocket->httpBuffer.length() + length + WebSocketProtocol<uWS::CLIENT, WebSocket<uWS::CLIENT>>::CONSUME_POST_PADDING);
httpSocket->httpBuffer.append(data, length);
data = (char *) httpSocket->httpBuffer.data();
length = httpSocket->httpBuffer.length();
......@@ -169,7 +169,7 @@ uS::Socket *HttpSocket<isServer>::onData(uS::Socket *s, char *data, size_t lengt
webSocket->cork(true);
getGroup<isServer>(webSocket)->connectionHandler(webSocket, req);
if (!(webSocket->isClosed() || webSocket->isShuttingDown())) {
webSocket->consume(cursor, end - cursor, webSocket);
WebSocketProtocol<isServer, WebSocket<isServer>>::consume(cursor, end - cursor, webSocket);
}
webSocket->cork(false);
delete httpSocket;
......
......@@ -43,7 +43,7 @@ public:
void connect(std::string uri, void *user = nullptr, std::map<std::string, std::string> extraHeaders = {}, int timeoutMs = 5000, Group<CLIENT> *eh = nullptr);
void upgrade(uv_os_sock_t fd, const char *secKey, SSL *ssl, const char *extensions, size_t extensionsLength, const char *subprotocol, size_t subprotocolLength, Group<SERVER> *serverGroup = nullptr);
Hub(int extensionOptions = 0, bool useDefaultLoop = false) : uS::Node(LARGE_BUFFER_SIZE, WebSocketProtocol<SERVER>::CONSUME_PRE_PADDING, WebSocketProtocol<SERVER>::CONSUME_POST_PADDING, useDefaultLoop),
Hub(int extensionOptions = 0, bool useDefaultLoop = false) : uS::Node(LARGE_BUFFER_SIZE, WebSocketProtocol<SERVER, WebSocket<SERVER>>::CONSUME_PRE_PADDING, WebSocketProtocol<SERVER, WebSocket<SERVER>>::CONSUME_POST_PADDING, useDefaultLoop),
Group<SERVER>(extensionOptions, this, nodeData), Group<CLIENT>(0, this, nodeData) {
inflateInit2(&inflationStream, -15);
inflationBuffer = new char[LARGE_BUFFER_SIZE];
......@@ -88,8 +88,8 @@ public:
using Group<SERVER>::onHttpUpgrade;
using Group<SERVER>::onCancelledHttpRequest;
friend class WebSocketProtocol<true>;
friend class WebSocketProtocol<false>;
friend class WebSocket<true>;
friend class WebSocket<false>;
};
}
......
#include "WebSocket.h"
#include "Group.h"
#include "Hub.h"
namespace uWS {
......@@ -16,7 +17,7 @@ void WebSocket<isServer>::send(const char *message, size_t length, OpCode opCode
}
#endif
const int HEADER_LENGTH = WebSocketProtocol<!isServer>::LONG_MESSAGE_HEADER;
const int HEADER_LENGTH = WebSocketProtocol<!isServer, WebSocket<!isServer>>::LONG_MESSAGE_HEADER;
struct TransformData {
OpCode opCode;
......@@ -28,7 +29,7 @@ void WebSocket<isServer>::send(const char *message, size_t length, OpCode opCode
}
static size_t transform(const char *src, char *dst, size_t length, TransformData transformData) {
return WebSocketProtocol<isServer>::formatMessage(dst, src, length, transformData.opCode, length, false);
return WebSocketProtocol<isServer, WebSocket<isServer>>::formatMessage(dst, src, length, transformData.opCode, length, false);
}
};
......@@ -39,7 +40,7 @@ template <bool isServer>
typename WebSocket<isServer>::PreparedMessage *WebSocket<isServer>::prepareMessage(char *data, size_t length, OpCode opCode, bool compressed, void(*callback)(WebSocket<isServer> *webSocket, void *data, bool cancelled, void *reserved)) {
PreparedMessage *preparedMessage = new PreparedMessage;
preparedMessage->buffer = new char[length + 10];
preparedMessage->length = WebSocketProtocol<isServer>::formatMessage(preparedMessage->buffer, data, length, opCode, length, compressed);
preparedMessage->length = WebSocketProtocol<isServer, WebSocket<isServer>>::formatMessage(preparedMessage->buffer, data, length, opCode, length, compressed);
preparedMessage->references = 1;
preparedMessage->callback = (void(*)(void *, void *, bool, void *)) callback;
return preparedMessage;
......@@ -59,7 +60,7 @@ typename WebSocket<isServer>::PreparedMessage *WebSocket<isServer>::prepareMessa
int offset = 0;
for (size_t i = 0; i < messages.size(); i++) {
offset += WebSocketProtocol<isServer>::formatMessage(preparedMessage->buffer + offset, messages[i].data(), messages[i].length(), opCode, messages[i].length(), compressed);
offset += WebSocketProtocol<isServer, WebSocket<isServer>>::formatMessage(preparedMessage->buffer + offset, messages[i].data(), messages[i].length(), opCode, messages[i].length(), compressed);
}
preparedMessage->length = offset;
preparedMessage->references = 1;
......@@ -128,7 +129,7 @@ uS::Socket *WebSocket<isServer>::onData(uS::Socket *s, char *data, size_t length
webSocket->hasOutstandingPong = false;
if (!webSocket->isShuttingDown()) {
webSocket->cork(true);
webSocket->consume(data, length, s);
WebSocketProtocol<isServer, WebSocket<isServer>>::consume(data, length, webSocket);
if (!webSocket->isClosed()) {
webSocket->cork(false);
}
......@@ -166,7 +167,7 @@ void WebSocket<isServer>::close(int code, const char *message, size_t length) {
startTimeout<WebSocket<isServer>::onEnd>();
char closePayload[MAX_CLOSE_PAYLOAD + 2];
int closePayloadLength = WebSocketProtocol<isServer>::formatClosePayload(closePayload, code, message, length);
int closePayloadLength = WebSocketProtocol<isServer, WebSocket<isServer>>::formatClosePayload(closePayload, code, message, length);
send(closePayload, closePayloadLength, OpCode::CLOSE, [](WebSocket<isServer> *p, void *data, bool cancelled, void *reserved) {
if (!cancelled) {
p->shutdown();
......@@ -196,6 +197,116 @@ void WebSocket<isServer>::onEnd(uS::Socket *s) {
}
}
template <bool isServer>
bool WebSocket<isServer>::handleFragment(char *data, size_t length, unsigned int remainingBytes, int opCode, bool fin, void *user) {
WebSocket<isServer> *webSocket = (WebSocket<isServer> *) user;
if (opCode < 3) {
if (!remainingBytes && fin && !webSocket->fragmentBuffer.length()) {
if (webSocket->compressionStatus == WebSocket<isServer>::CompressionStatus::COMPRESSED_FRAME) {
webSocket->compressionStatus = WebSocket<isServer>::CompressionStatus::ENABLED;
Hub *hub = ((Group<isServer> *) webSocket->nodeData)->hub;
data = hub->inflate(data, length);
if (!data) {
forceClose(user);
return true;
}
}
if (opCode == 1 && !WebSocketProtocol<isServer, WebSocket<isServer>>::isValidUtf8((unsigned char *) data, length)) {
forceClose(user);
return true;
}
((Group<isServer> *) webSocket->nodeData)->messageHandler(webSocket, data, length, (OpCode) opCode);
if (webSocket->isClosed() || webSocket->isShuttingDown()) {
return true;
}
} else {
webSocket->fragmentBuffer.append(data, length);
if (!remainingBytes && fin) {
length = webSocket->fragmentBuffer.length();
if (webSocket->compressionStatus == WebSocket<isServer>::CompressionStatus::COMPRESSED_FRAME) {
webSocket->compressionStatus = WebSocket<isServer>::CompressionStatus::ENABLED;
Hub *hub = ((Group<isServer> *) webSocket->nodeData)->hub;
webSocket->fragmentBuffer.append("....");
data = hub->inflate((char *) webSocket->fragmentBuffer.data(), length);
if (!data) {
forceClose(user);
return true;
}
} else {
data = (char *) webSocket->fragmentBuffer.data();
}
if (opCode == 1 && !WebSocketProtocol<isServer, WebSocket<isServer>>::isValidUtf8((unsigned char *) data, length)) {
forceClose(user);
return true;
}
((Group<isServer> *) webSocket->nodeData)->messageHandler(webSocket, data, length, (OpCode) opCode);
if (webSocket->isClosed() || webSocket->isShuttingDown()) {
return true;
}
webSocket->fragmentBuffer.clear();
}
}
} else {
if (!remainingBytes && fin && !webSocket->controlTipLength) {
if (opCode == CLOSE) {
typename WebSocketProtocol<isServer, WebSocket<isServer>>::CloseFrame closeFrame = WebSocketProtocol<isServer, WebSocket<isServer>>::parseClosePayload(data, length);
webSocket->close(closeFrame.code, closeFrame.message, closeFrame.length);
return true;
} else {
if (opCode == PING) {
webSocket->send(data, length, (OpCode) OpCode::PONG);
((Group<isServer> *) webSocket->nodeData)->pingHandler(webSocket, data, length);
if (webSocket->isClosed() || webSocket->isShuttingDown()) {
return true;
}
} else if (opCode == PONG) {
((Group<isServer> *) webSocket->nodeData)->pongHandler(webSocket, data, length);
if (webSocket->isClosed() || webSocket->isShuttingDown()) {
return true;
}
}
}
} else {
webSocket->fragmentBuffer.append(data, length);
webSocket->controlTipLength += length;
if (!remainingBytes && fin) {
char *controlBuffer = (char *) webSocket->fragmentBuffer.data() + webSocket->fragmentBuffer.length() - webSocket->controlTipLength;
if (opCode == CLOSE) {
typename WebSocketProtocol<isServer, WebSocket<isServer>>::CloseFrame closeFrame = WebSocketProtocol<isServer, WebSocket<isServer>>::parseClosePayload(controlBuffer, webSocket->controlTipLength);
webSocket->close(closeFrame.code, closeFrame.message, closeFrame.length);
return true;
} else {
if (opCode == PING) {
webSocket->send(controlBuffer, webSocket->controlTipLength, (OpCode) OpCode::PONG);
((Group<isServer> *) webSocket->nodeData)->pingHandler(webSocket, controlBuffer, webSocket->controlTipLength);
if (webSocket->isClosed() || webSocket->isShuttingDown()) {
return true;
}
} else if (opCode == PONG) {
((Group<isServer> *) webSocket->nodeData)->pongHandler(webSocket, controlBuffer, webSocket->controlTipLength);
if (webSocket->isClosed() || webSocket->isShuttingDown()) {
return true;
}
}
}
webSocket->fragmentBuffer.resize(webSocket->fragmentBuffer.length() - webSocket->controlTipLength);
webSocket->controlTipLength = 0;
}
}
}
return false;
}
template struct WebSocket<SERVER>;
template struct WebSocket<CLIENT>;
......
......@@ -13,7 +13,7 @@ template <bool isServer>
struct HttpSocket;
template <const bool isServer>
struct WIN32_EXPORT WebSocket : uS::Socket, protected WebSocketProtocol<isServer> {
struct WIN32_EXPORT WebSocket : uS::Socket, WebSocketState<isServer> {
protected:
std::string fragmentBuffer;
enum CompressionStatus : char {
......@@ -31,6 +31,28 @@ protected:
static void onEnd(uS::Socket *s);
using uS::Socket::closeSocket;
static bool refusePayloadLength(void *user, uint64_t length) {
return length > 16777216;
}
static bool setCompressed(void *user) {
WebSocket<isServer> *webSocket = (WebSocket<isServer> *) user;
if (webSocket->compressionStatus == WebSocket<isServer>::CompressionStatus::ENABLED) {
webSocket->compressionStatus = WebSocket<isServer>::CompressionStatus::COMPRESSED_FRAME;
return true;
} else {
return false;
}
}
static void forceClose(void *user) {
WebSocket<isServer> *webSocket = (WebSocket<isServer> *) user;
webSocket->terminate();
}
static bool handleFragment(char *data, size_t length, unsigned int remainingBytes, int opCode, bool fin, void *user);
public:
struct PreparedMessage {
char *buffer;
......@@ -65,7 +87,7 @@ public:
friend struct Group<isServer>;
friend struct HttpSocket<isServer>;
friend struct uS::Socket;
friend class WebSocketProtocol<isServer>;
friend class WebSocketProtocol<isServer, WebSocket<isServer>>;
};
}
......
#include "Hub.h"
namespace uWS {
template <const bool isServer>
bool WebSocketProtocol<isServer>::setCompressed(void *user) {
WebSocket<isServer> *webSocket = (WebSocket<isServer> *) user;
if (webSocket->compressionStatus == WebSocket<isServer>::CompressionStatus::ENABLED) {
webSocket->compressionStatus = WebSocket<isServer>::CompressionStatus::COMPRESSED_FRAME;
return true;
} else {
return false;
}
}
template <const bool isServer>
bool WebSocketProtocol<isServer>::refusePayloadLength(void *user, uint64_t length) {
return length > 16777216;
}
template <const bool isServer>
void WebSocketProtocol<isServer>::forceClose(void *user) {
WebSocket<isServer> *webSocket = (WebSocket<isServer> *) user;
webSocket->terminate();
}
template <const bool isServer>
bool WebSocketProtocol<isServer>::handleFragment(char *data, size_t length, unsigned int remainingBytes, int opCode, bool fin, void *user) {
WebSocket<isServer> *webSocket = (WebSocket<isServer> *) user;
if (opCode < 3) {
if (!remainingBytes && fin && !webSocket->fragmentBuffer.length()) {
if (webSocket->compressionStatus == WebSocket<isServer>::CompressionStatus::COMPRESSED_FRAME) {
webSocket->compressionStatus = WebSocket<isServer>::CompressionStatus::ENABLED;
Hub *hub = ((Group<isServer> *) webSocket->nodeData)->hub;
data = hub->inflate(data, length);
if (!data) {
forceClose(user);
return true;
}
}
if (opCode == 1 && !isValidUtf8((unsigned char *) data, length)) {
forceClose(user);
return true;
}
((Group<isServer> *) webSocket->nodeData)->messageHandler(webSocket, data, length, (OpCode) opCode);
if (webSocket->isClosed() || webSocket->isShuttingDown()) {
return true;
}
} else {
webSocket->fragmentBuffer.append(data, length);
if (!remainingBytes && fin) {
length = webSocket->fragmentBuffer.length();
if (webSocket->compressionStatus == WebSocket<isServer>::CompressionStatus::COMPRESSED_FRAME) {
webSocket->compressionStatus = WebSocket<isServer>::CompressionStatus::ENABLED;
Hub *hub = ((Group<isServer> *) webSocket->nodeData)->hub;
webSocket->fragmentBuffer.append("....");
data = hub->inflate((char *) webSocket->fragmentBuffer.data(), length);
if (!data) {
forceClose(user);
return true;
}
} else {
data = (char *) webSocket->fragmentBuffer.data();
}
if (opCode == 1 && !isValidUtf8((unsigned char *) data, length)) {
forceClose(user);
return true;
}
((Group<isServer> *) webSocket->nodeData)->messageHandler(webSocket, data, length, (OpCode) opCode);
if (webSocket->isClosed() || webSocket->isShuttingDown()) {
return true;
}
webSocket->fragmentBuffer.clear();
}
}
} else {
if (!remainingBytes && fin && !webSocket->controlTipLength) {
if (opCode == CLOSE) {
CloseFrame closeFrame = parseClosePayload(data, length);
webSocket->close(closeFrame.code, closeFrame.message, closeFrame.length);
return true;
} else {
if (opCode == PING) {
webSocket->send(data, length, (OpCode) OpCode::PONG);
((Group<isServer> *) webSocket->nodeData)->pingHandler(webSocket, data, length);
if (webSocket->isClosed() || webSocket->isShuttingDown()) {
return true;
}
} else if (opCode == PONG) {
((Group<isServer> *) webSocket->nodeData)->pongHandler(webSocket, data, length);
if (webSocket->isClosed() || webSocket->isShuttingDown()) {
return true;
}
}
}
} else {
webSocket->fragmentBuffer.append(data, length);
webSocket->controlTipLength += length;
if (!remainingBytes && fin) {
char *controlBuffer = (char *) webSocket->fragmentBuffer.data() + webSocket->fragmentBuffer.length() - webSocket->controlTipLength;
if (opCode == CLOSE) {
CloseFrame closeFrame = parseClosePayload(controlBuffer, webSocket->controlTipLength);
webSocket->close(closeFrame.code, closeFrame.message, closeFrame.length);
return true;
} else {
if (opCode == PING) {
webSocket->send(controlBuffer, webSocket->controlTipLength, (OpCode) OpCode::PONG);
((Group<isServer> *) webSocket->nodeData)->pingHandler(webSocket, controlBuffer, webSocket->controlTipLength);
if (webSocket->isClosed() || webSocket->isShuttingDown()) {
return true;
}
} else if (opCode == PONG) {
((Group<isServer> *) webSocket->nodeData)->pongHandler(webSocket, controlBuffer, webSocket->controlTipLength);
if (webSocket->isClosed() || webSocket->isShuttingDown()) {
return true;
}
}
}
webSocket->fragmentBuffer.resize(webSocket->fragmentBuffer.length() - webSocket->controlTipLength);
webSocket->controlTipLength = 0;
}
}
}
return false;
}
template class WebSocketProtocol<SERVER>;
template class WebSocketProtocol<CLIENT>;
}
This diff is collapsed.
......@@ -1032,18 +1032,12 @@ void testReceivePerformance() {
exit(-1);
});
struct TestWebSocket : uWS::WebSocket<uWS::SERVER> {
void onData(char *data, size_t length) {
consume(data, length, this);
}
};
h.onConnection([originalBuffer, buffer, bufferLength, messages, &h](uWS::WebSocket<uWS::SERVER> *ws, uWS::HttpRequest req) {
for (int i = 0; i < 100; i++) {
memcpy(buffer, originalBuffer, bufferLength);
auto now = std::chrono::high_resolution_clock::now();
((TestWebSocket *) ws)->onData(buffer, bufferLength);
uWS::WebSocketProtocol<uWS::SERVER, uWS::WebSocket<uWS::SERVER>>::consume(buffer, bufferLength, ws);
int us = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now() - now).count();
std::cout << "Messages per microsecond: " << (double(messages) / double(us)) << std::endl;
......
......@@ -4,7 +4,6 @@ CONFIG -= app_bundle
CONFIG -= qt
SOURCES += main.cpp \
../src/WebSocketImpl.cpp \
../src/Networking.cpp \
../src/Hub.cpp \
../src/Node.cpp \
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment