Commit 0b95f6b0 authored by Alex Hultman's avatar Alex Hultman

Initial send compression support

parent d487dbcf
......@@ -235,7 +235,7 @@ void HttpSocket<isServer>::upgrade(const char *secKey, const char *extensions, s
upgradeResponseLength += 26 + extensionsResponse.length() + 2;
}
// select first protocol
for (int i = 0; i < subprotocolLength; i++) {
for (unsigned int i = 0; i < subprotocolLength; i++) {
if (subprotocol[i] == ',') {
subprotocolLength = i;
break;
......
......@@ -5,40 +5,78 @@
namespace uWS {
char *Hub::deflate(char *data, size_t &length) {
dynamicZlibBuffer.clear();
deflationStream.next_in = (Bytef *) data;
deflationStream.avail_in = (unsigned int) length;
// note: zlib requires more than 6 bytes with Z_SYNC_FLUSH
const int DEFLATE_OUTPUT_CHUNK = LARGE_BUFFER_SIZE;
int err;
do {
deflationStream.next_out = (Bytef *) zlibBuffer;
deflationStream.avail_out = DEFLATE_OUTPUT_CHUNK;
err = ::deflate(&deflationStream, Z_SYNC_FLUSH);
if (Z_OK == err && deflationStream.avail_out == 0) {
dynamicZlibBuffer.append(zlibBuffer, DEFLATE_OUTPUT_CHUNK - deflationStream.avail_out);
continue;
} else {
break;
}
} while (true);
// note: should not change avail_out
deflateReset(&deflationStream);
if (dynamicZlibBuffer.length()) {
dynamicZlibBuffer.append(zlibBuffer, DEFLATE_OUTPUT_CHUNK - deflationStream.avail_out);
length = dynamicZlibBuffer.length();
return (char *) dynamicZlibBuffer.data();
}
length = DEFLATE_OUTPUT_CHUNK - deflationStream.avail_out;
return zlibBuffer;
}
// todo: let's go through this code once more some time!
char *Hub::inflate(char *data, size_t &length, size_t maxPayload) {
dynamicInflationBuffer.clear();
dynamicZlibBuffer.clear();
inflationStream.next_in = (Bytef *) data;
inflationStream.avail_in = (unsigned int) length;
int err;
do {
inflationStream.next_out = (Bytef *) inflationBuffer;
inflationStream.next_out = (Bytef *) zlibBuffer;
inflationStream.avail_out = LARGE_BUFFER_SIZE;
err = ::inflate(&inflationStream, Z_FINISH);
if (!inflationStream.avail_in) {
break;
}
dynamicInflationBuffer.append(inflationBuffer, LARGE_BUFFER_SIZE - inflationStream.avail_out);
} while (err == Z_BUF_ERROR && dynamicInflationBuffer.length() <= maxPayload);
dynamicZlibBuffer.append(zlibBuffer, LARGE_BUFFER_SIZE - inflationStream.avail_out);
} while (err == Z_BUF_ERROR && dynamicZlibBuffer.length() <= maxPayload);
inflateReset(&inflationStream);
if ((err != Z_BUF_ERROR && err != Z_OK) || dynamicInflationBuffer.length() > maxPayload) {
if ((err != Z_BUF_ERROR && err != Z_OK) || dynamicZlibBuffer.length() > maxPayload) {
length = 0;
return nullptr;
}
if (dynamicInflationBuffer.length()) {
dynamicInflationBuffer.append(inflationBuffer, LARGE_BUFFER_SIZE - inflationStream.avail_out);
if (dynamicZlibBuffer.length()) {
dynamicZlibBuffer.append(zlibBuffer, LARGE_BUFFER_SIZE - inflationStream.avail_out);
length = dynamicInflationBuffer.length();
return (char *) dynamicInflationBuffer.data();
length = dynamicZlibBuffer.length();
return (char *) dynamicZlibBuffer.data();
}
length = LARGE_BUFFER_SIZE - inflationStream.avail_out;
return inflationBuffer;
return zlibBuffer;
}
void Hub::onServerAccept(uS::Socket *s) {
......
......@@ -18,10 +18,11 @@ protected:
Group<CLIENT> *group;
};
z_stream inflationStream = {};
char *inflationBuffer;
z_stream inflationStream = {}, deflationStream = {};
char *deflate(char *data, size_t &length);
char *inflate(char *data, size_t &length, size_t maxPayload);
std::string dynamicInflationBuffer;
char *zlibBuffer;
std::string dynamicZlibBuffer;
static const int LARGE_BUFFER_SIZE = 300 * 1024;
static void onServerAccept(uS::Socket *s);
......@@ -46,7 +47,9 @@ public:
Hub(int extensionOptions = 0, bool useDefaultLoop = false, unsigned int maxPayload = 16777216) : uS::Node(LARGE_BUFFER_SIZE, WebSocketProtocol<SERVER, WebSocket<SERVER>>::CONSUME_PRE_PADDING, WebSocketProtocol<SERVER, WebSocket<SERVER>>::CONSUME_POST_PADDING, useDefaultLoop),
Group<SERVER>(extensionOptions, maxPayload, this, nodeData), Group<CLIENT>(0, maxPayload, this, nodeData) {
inflateInit2(&inflationStream, -15);
inflationBuffer = new char[LARGE_BUFFER_SIZE];
zlibBuffer = new char[LARGE_BUFFER_SIZE];
deflateInit2(&deflationStream, 1, Z_DEFLATED, -15, 8, Z_DEFAULT_STRATEGY);
#ifdef UWS_THREADSAFE
getLoop()->preCbData = nodeData;
......@@ -63,7 +66,7 @@ public:
~Hub() {
inflateEnd(&inflationStream);
delete [] inflationBuffer;
delete [] zlibBuffer;
}
using uS::Node::run;
......
......@@ -14,7 +14,7 @@ namespace uWS {
*
*/
template <bool isServer>
void WebSocket<isServer>::send(const char *message, size_t length, OpCode opCode, void(*callback)(WebSocket<isServer> *webSocket, void *data, bool cancelled, void *reserved), void *callbackData) {
void WebSocket<isServer>::send(const char *message, size_t length, OpCode opCode, void(*callback)(WebSocket<isServer> *webSocket, void *data, bool cancelled, void *reserved), void *callbackData, bool compress) {
#ifdef UWS_THREADSAFE
std::lock_guard<std::recursive_mutex> lockGuard(*nodeData->asyncMutex);
......@@ -30,7 +30,9 @@ void WebSocket<isServer>::send(const char *message, size_t length, OpCode opCode
struct TransformData {
OpCode opCode;
} transformData = {opCode};
bool compress;
Socket *s;
} transformData = {opCode, compress && compressionStatus == WebSocket<isServer>::CompressionStatus::ENABLED, this};
struct WebSocketTransformer {
static size_t estimate(const char *data, size_t length) {
......@@ -38,6 +40,11 @@ void WebSocket<isServer>::send(const char *message, size_t length, OpCode opCode
}
static size_t transform(const char *src, char *dst, size_t length, TransformData transformData) {
if (transformData.compress) {
char *deflated = Group<isServer>::from(transformData.s)->hub->deflate((char *) src, length);
return WebSocketProtocol<isServer, WebSocket<isServer>>::formatMessage(dst, deflated, length, transformData.opCode, length, true);
}
return WebSocketProtocol<isServer, WebSocket<isServer>>::formatMessage(dst, src, length, transformData.opCode, length, false);
}
};
......
......@@ -72,7 +72,7 @@ public:
void terminate();
void ping(const char *message) {send(message, OpCode::PING);}
void send(const char *message, OpCode opCode = OpCode::TEXT) {send(message, strlen(message), opCode);}
void send(const char *message, size_t length, OpCode opCode, void(*callback)(WebSocket<isServer> *webSocket, void *data, bool cancelled, void *reserved) = nullptr, void *callbackData = nullptr);
void send(const char *message, size_t length, OpCode opCode, void(*callback)(WebSocket<isServer> *webSocket, void *data, bool cancelled, void *reserved) = nullptr, void *callbackData = nullptr, bool compress = false);
static PreparedMessage *prepareMessage(char *data, size_t length, OpCode opCode, bool compressed, void(*callback)(WebSocket<isServer> *webSocket, void *data, bool cancelled, void *reserved) = nullptr);
static PreparedMessage *prepareMessageBatch(std::vector<std::string> &messages, std::vector<int> &excludedMessages,
OpCode opCode, bool compressed, void(*callback)(WebSocket<isServer> *webSocket, void *data, bool cancelled, void *reserved) = nullptr);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment