Skip to content

Commit

Permalink
Fix WebSocket when HTTP server is not running
Browse files Browse the repository at this point in the history
  • Loading branch information
Jarred-Sumner committed Jun 23, 2022
1 parent 20249b9 commit f05428e
Show file tree
Hide file tree
Showing 14 changed files with 103 additions and 38 deletions.
4 changes: 4 additions & 0 deletions src/deps/uws.zig
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,10 @@ pub const Loop = opaque {
};
}

pub fn run(this: *Loop) void {
us_loop_run(this);
}

extern fn uws_loop_defer(loop: *Loop, ctx: *anyopaque, cb: fn (ctx: *anyopaque) callconv(.C) void) void;

extern fn uws_get_loop() ?*Loop;
Expand Down
26 changes: 24 additions & 2 deletions src/http/websocket_http_client.zig
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
body_buf: ?*BodyBuf = null,
body_written: usize = 0,
websocket_protocol: u64 = 0,
event_loop_ref: bool = false,

pub const name = if (ssl) "WebSocketHTTPSClient" else "WebSocketHTTPClient";

Expand All @@ -140,10 +141,14 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
if (vm.uws_event_loop) |other| {
std.debug.assert(other == loop);
}
const is_new_loop = vm.uws_event_loop == null;

vm.uws_event_loop = loop;

Socket.configure(ctx, HTTPClient, handleOpen, handleClose, handleData, handleWritable, handleTimeout, handleConnectError, handleEnd);
if (is_new_loop) {
vm.prepareLoop();
}
}

pub fn connect(
Expand All @@ -167,11 +172,15 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
};
var host_ = host.toSlice(bun.default_allocator);
defer host_.deinit();
var vm = global.bunVM();
vm.us_loop_reference_count +|= 1;
client.event_loop_ref = true;

if (Socket.connect(host_.slice(), port, @ptrCast(*uws.us_socket_context_t, socket_ctx), HTTPClient, client, "tcp")) |out| {
out.tcp.timeout(120);
return out;
}
vm.us_loop_reference_count -|= 1;

client.clearData();

Expand All @@ -183,6 +192,10 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
this.input_body_buf.len = 0;
}
pub fn clearData(this: *HTTPClient) void {
if (this.event_loop_ref) {
this.event_loop_ref = false;
JSC.VirtualMachine.vm.us_loop_reference_count -|= 1;
}
this.clearInput();
if (this.body_buf) |buf| {
this.body_buf = null;
Expand Down Expand Up @@ -777,6 +790,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
send_buffer: bun.LinearFifo(u8, .Dynamic),

globalThis: *JSC.JSGlobalObject,
event_loop_ref: bool = false,

pub const name = if (ssl) "WebSocketClientTLS" else "WebSocketClient";

Expand Down Expand Up @@ -1436,6 +1450,8 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
) orelse return null;
adopted.send_buffer.ensureTotalCapacity(2048) catch return null;
adopted.receive_buffer.ensureTotalCapacity(2048) catch return null;
adopted.event_loop_ref = true;
adopted.globalThis.bunVM().us_loop_reference_count +|= 1;
_ = globalThis.bunVM().eventLoop().ready_tasks_count.fetchAdd(1, .Monotonic);
return @ptrCast(
*anyopaque,
Expand All @@ -1446,12 +1462,18 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
pub fn finalize(this: *WebSocket) callconv(.C) void {
this.clearData();

if (this.event_loop_ref) {
this.event_loop_ref = false;
this.globalThis.bunVM().us_loop_reference_count -|= 1;
_ = this.globalThis.bunVM().eventLoop().ready_tasks_count.fetchSub(1, .Monotonic);
}

this.outgoing_websocket = null;

if (this.tcp.isClosed())
return;

this.tcp.close(0, null);
this.outgoing_websocket = null;
_ = this.globalThis.bunVM().eventLoop().ready_tasks_count.fetchSub(1, .Monotonic);
}

pub const Export = shim.exportFunctions(.{
Expand Down
56 changes: 33 additions & 23 deletions src/javascript/jsc/api/bun.zig
Original file line number Diff line number Diff line change
Expand Up @@ -173,30 +173,40 @@ pub fn inspect(
false,
false,
);
buffered_writer.flush() catch {
return JSC.C.JSValueMakeUndefined(ctx);
};

// when it's a small thing, rely on GC to manage the memory
if (writer.context.pos < 2048 and array.list.items.len == 0) {
var slice = writer.context.buffer[0..writer.context.pos];
if (slice.len == 0) {
return ZigString.Empty.toValue(ctx.ptr()).asObjectRef();
}

var zig_str = ZigString.init(slice).withEncoding();
return zig_str.toValueGC(ctx.ptr()).asObjectRef();
}

// when it's a big thing, we will manage it
{
writer.context.flush() catch {};
var slice = writer.context.context.toOwnedSlice();

var zig_str = ZigString.init(slice).withEncoding();
if (!zig_str.isUTF8()) {
return zig_str.toExternalValue(ctx.ptr()).asObjectRef();
} else {
return zig_str.toValueGC(ctx.ptr()).asObjectRef();
}
}
// we are going to always clone to keep things simple for now
// the common case here will be stack-allocated, so it should be fine
var out = ZigString.init(array.toOwnedSliceLeaky()).withEncoding();
const ret = out.toValueGC(ctx);
array.deinit();
return ret.asObjectRef();

// // when it's a small thing, rely on GC to manage the memory
// if (writer.context.pos < 2048 and array.list.items.len == 0) {
// var slice = writer.context.buffer[0..writer.context.pos];
// if (slice.len == 0) {
// return ZigString.Empty.toValue(ctx.ptr()).asObjectRef();
// }

// var zig_str =
// return zig_str.toValueGC(ctx.ptr()).asObjectRef();
// }

// // when it's a big thing, we will manage it
// {
// writer.context.flush() catch {};
// var slice = writer.context.context.toOwnedSlice();

// var zig_str = ZigString.init(slice).withEncoding();
// if (!zig_str.isUTF8()) {
// return zig_str.toExternalValue(ctx.ptr()).asObjectRef();
// } else {
// return zig_str.toValueGC(ctx.ptr()).asObjectRef();
// }
// }
}

pub fn registerMacro(
Expand Down
12 changes: 7 additions & 5 deletions src/javascript/jsc/api/server.zig
Original file line number Diff line number Diff line change
Expand Up @@ -1523,6 +1523,7 @@ pub fn NewServer(comptime ssl_enabled_: bool, comptime debug_mode_: bool) type {
if (this.listener) |listener| {
listener.close();
this.listener = null;
this.vm.disable_run_us_loop = false;
}

this.deinitIfWeCan();
Expand All @@ -1535,11 +1536,6 @@ pub fn NewServer(comptime ssl_enabled_: bool, comptime debug_mode_: bool) type {
}
}

// if you run multiple servers simultaneously, this could break it
if (this.vm.uws_event_loop != null and uws.Loop.get().? == this.vm.uws_event_loop.?) {
this.vm.uws_event_loop = null;
}

this.app.destroy();
const allocator = this.allocator;
allocator.destroy(this);
Expand Down Expand Up @@ -1646,6 +1642,12 @@ pub fn NewServer(comptime ssl_enabled_: bool, comptime debug_mode_: bool) type {

pub fn run(this: *ThisServer) void {
// this.app.addServerName(hostname_pattern: [*:0]const u8)

// we do not increment the reference count here
// uWS manages running the loop, so it is unnecessary
// this.vm.us_loop_reference_count +|= 1;
this.vm.disable_run_us_loop = true;

this.app.run();
}

Expand Down
5 changes: 5 additions & 0 deletions src/javascript/jsc/bindings/ScriptExecutionContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@ class ScriptExecutionContext : public CanMakeWeakPtr<ScriptExecutionContext> {
reinterpret_cast<Zig::GlobalObject*>(m_globalObject)->queueTask(task);
} // Executes the task on context's thread asynchronously.

void postTask(EventLoopTask* task)
{
reinterpret_cast<Zig::GlobalObject*>(m_globalObject)->queueTask(task);
} // Executes the task on context's thread asynchronously.

template<typename... Arguments>
void postCrossThreadTask(Arguments&&... arguments)
{
Expand Down
2 changes: 1 addition & 1 deletion src/javascript/jsc/bindings/headers-cpp.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//-- AUTOGENERATED FILE -- 1655637924
//-- AUTOGENERATED FILE -- 1655942279
// clang-format off
#pragma once

Expand Down
2 changes: 1 addition & 1 deletion src/javascript/jsc/bindings/headers.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// clang-format off
//-- AUTOGENERATED FILE -- 1655893003
//-- AUTOGENERATED FILE -- 1655942279
#pragma once

#include <stddef.h>
Expand Down
3 changes: 2 additions & 1 deletion src/javascript/jsc/bindings/webcore/JSAbortAlgorithm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ JSAbortAlgorithm::~JSAbortAlgorithm()
if (!context || context->isContextThread())
delete m_data;
else
context->postTask(DeleteCallbackDataTask(m_data));

context->postTask(new DeleteCallbackDataTask(m_data));
#ifndef NDEBUG
m_data = nullptr;
#endif
Expand Down
3 changes: 2 additions & 1 deletion src/javascript/jsc/bindings/webcore/JSCallbackData.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ class DeleteCallbackDataTask : public EventLoopTask {
public:
template<typename CallbackDataType>
explicit DeleteCallbackDataTask(CallbackDataType* data)
: EventLoopTask(EventLoopTask::CleanupTask, [data = std::unique_ptr<CallbackDataType>(data)](ScriptExecutionContext&) {
: EventLoopTask(EventLoopTask::CleanupTask, [data](ScriptExecutionContext&) mutable {
delete data;
})
{
}
Expand Down
2 changes: 1 addition & 1 deletion src/javascript/jsc/bindings/webcore/JSErrorCallback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ JSErrorCallback::~JSErrorCallback()
if (!context || context->isContextThread())
delete m_data;
else
context->postTask(DeleteCallbackDataTask(m_data));
context->postTask(new DeleteCallbackDataTask(m_data));
#ifndef NDEBUG
m_data = nullptr;
#endif
Expand Down
4 changes: 2 additions & 2 deletions src/javascript/jsc/bindings/webcore/WebSocket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,9 +280,9 @@ ExceptionOr<void> WebSocket::connect(const String& url, const Vector<String>& pr
return Exception { SyntaxError, makeString("Invalid url for WebSocket "_s, m_url.stringCenterEllipsizedToLength()) };
}

bool is_secure = m_url.protocolIs("wss");
bool is_secure = m_url.protocolIs("wss"_s);

if (!m_url.protocolIs("ws") && !is_secure) {
if (!m_url.protocolIs("ws"_s) && !is_secure) {
// context.addConsoleMessage(MessageSource::JS, MessageLevel::Error, );
m_state = CLOSED;
return Exception { SyntaxError, makeString("Wrong url scheme for WebSocket "_s, m_url.stringCenterEllipsizedToLength()) };
Expand Down
7 changes: 7 additions & 0 deletions src/javascript/jsc/event_loop.zig
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ pub const EventLoop = struct {
// TODO: fix this technical debt
pub fn tick(this: *EventLoop) void {
var poller = &this.virtual_machine.poller;
var ctx = this.virtual_machine;
while (true) {
this.tickConcurrent();

Expand All @@ -421,6 +422,12 @@ pub const EventLoop = struct {
}

this.global.vm().releaseWeakRefs();

if (!ctx.disable_run_us_loop and ctx.us_loop_reference_count > 0 and !ctx.is_us_loop_entered) {
ctx.is_us_loop_entered = true;
ctx.enterUWSLoop();
ctx.is_us_loop_entered = false;
}
}

// TODO: fix this technical debt
Expand Down
13 changes: 13 additions & 0 deletions src/javascript/jsc/javascript.zig
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,9 @@ pub const VirtualMachine = struct {

rare_data: ?*JSC.RareData = null,
poller: JSC.Poller = JSC.Poller{},
us_loop_reference_count: usize = 0,
disable_run_us_loop: bool = false,
is_us_loop_entered: bool = false,

pub fn io(this: *VirtualMachine) *IO {
if (this.io_ == null) {
Expand Down Expand Up @@ -361,6 +364,16 @@ pub const VirtualMachine = struct {
return this.event_loop;
}

pub fn prepareLoop(this: *VirtualMachine) void {
var loop = this.uws_event_loop.?;
_ = loop.addPostHandler(*JSC.EventLoop, this.eventLoop(), JSC.EventLoop.tick);
}

pub fn enterUWSLoop(this: *VirtualMachine) void {
var loop = this.uws_event_loop.?;
loop.run();
}

pub fn onExit(this: *VirtualMachine) void {
var rare_data = this.rare_data orelse return;
var hook = rare_data.cleanup_hook orelse return;
Expand Down
2 changes: 1 addition & 1 deletion src/network_thread.zig
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ pub fn getAddressList(allocator: std.mem.Allocator, name: []const u8, port: u16)

pub var has_warmed = false;
pub fn warmup() !void {
if (has_warmed) return;
if (has_warmed or global_loaded.load(.Monotonic) > 0) return;
has_warmed = true;
try init();
global.pool.forceSpawn();
Expand Down

0 comments on commit f05428e

Please sign in to comment.