Skip to content

Commit

Permalink
Add boundaries to wire format
Browse files Browse the repository at this point in the history
  • Loading branch information
malmaud authored and amitmurthy committed Jun 30, 2016
1 parent 844f284 commit 561db3b
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 55 deletions.
175 changes: 120 additions & 55 deletions base/multi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,30 +37,47 @@ end
hash(r::RRID, h::UInt) = hash(r.whence, hash(r.id, h))
==(r::RRID, s::RRID) = (r.whence==s.whence && r.id==s.id)

## Wire format description
#
# Each message has three parts, which are written in order to the worker's stream.
# 1) A header of type MsgHeader is serialized to the stream (via `serialize`).
# 2) A message of type AbstractMsg is then serialized.
# 3) Finally, a fixed bounday of 10 bytes is written.

# Message header stored separately from body to be able to send back errors if
# a deserialization error occurs when reading the message body.
type MsgHeader
response_oid::RRID
notify_oid::RRID
end

# Special oid (0,0) uses to indicate a null ID.
# Used instead of Nullable to decrease wire size of header.
null_id(id) = id == RRID(0, 0)

MsgHeader(;response_oid::RRID=RRID(0,0), notify_oid::RRID=RRID(0,0)) =
MsgHeader(response_oid, notify_oid)

type CallMsg{Mode} <: AbstractMsg
f::Function
args::Tuple
kwargs::Array
response_oid::RRID
end
type CallWaitMsg <: AbstractMsg
f::Function
args::Tuple
kwargs::Array
response_oid::RRID
notify_oid::RRID
end
type RemoteDoMsg <: AbstractMsg
f::Function
args::Tuple
kwargs::Array
end
type ResultMsg <: AbstractMsg
response_oid::RRID
value::Any
end


# Worker initialization messages
type IdentifySocketMsg <: AbstractMsg
from_pid::Int
Expand All @@ -70,34 +87,32 @@ end
type JoinPGRPMsg <: AbstractMsg
self_pid::Int
other_workers::Array
notify_oid::RRID
topology::Symbol
worker_pool
end
type JoinCompleteMsg <: AbstractMsg
notify_oid::RRID
cpu_cores::Int
ospid::Int
end

function send_msg_unknown(s::IO, msg)
function send_msg_unknown(s::IO, header, msg)
error("attempt to send to unknown socket")
end

function send_msg(s::IO, msg)
function send_msg(s::IO, header, msg)
id = worker_id_from_socket(s)
if id > -1
return send_msg(worker_from_id(id), msg)
return send_msg(worker_from_id(id), header, msg)
end
send_msg_unknown(s, msg)
send_msg_unknown(s, header, msg)
end

function send_msg_now(s::IO, msg::AbstractMsg)
function send_msg_now(s::IO, msghdr, msg::AbstractMsg)
id = worker_id_from_socket(s)
if id > -1
return send_msg_now(worker_from_id(id), msg)
return send_msg_now(worker_from_id(id), msghdr, msg)
end
send_msg_unknown(s, msg)
send_msg_unknown(s, msghdr, msg)
end

abstract ClusterManager
Expand Down Expand Up @@ -197,12 +212,12 @@ function set_worker_state(w, state)
notify(w.c_state; all=true)
end

function send_msg_now(w::Worker, msg)
send_msg_(w, msg, true)
function send_msg_now(w::Worker, msghdr, msg)
send_msg_(w, msghdr, msg, true)
end

function send_msg(w::Worker, msg)
send_msg_(w, msg, false)
function send_msg(w::Worker, msghdr, msg)
send_msg_(w, msghdr, msg, false)
end

function flush_gc_msgs(w::Worker)
Expand Down Expand Up @@ -241,14 +256,20 @@ function check_worker_state(w::Worker)
end
end

# Boundary inserted between messages on the wire, used for recovering
# from deserialization errors. Picked arbitrarily.
# A size of 10 bytes indicates ~ ~1e24 possible boundaries, so chance of collision with message contents is trivial.
const MSG_BOUNDARY = UInt8[0x79, 0x8e, 0x8e, 0xf5, 0x6e, 0x9b, 0x2e, 0x97, 0xd5, 0x7d]

function send_msg_(w::Worker, msg, now::Bool)
function send_msg_(w::Worker, header, msg, now::Bool)
check_worker_state(w)
io = w.w_stream
lock(io.lock)
try
reset_state(w.w_serializer)
serialize(w.w_serializer, header)
serialize(w.w_serializer, msg) # io is wrapped in w_serializer
write(io, MSG_BOUNDARY)

if !now && w.gcflag
flush_gc_msgs(w)
Expand Down Expand Up @@ -768,7 +789,6 @@ function showerror(io::IO, re::RemoteException)
showerror(io, re.captured)
end


function run_work_thunk(thunk, print_error)
local result
try
Expand Down Expand Up @@ -811,7 +831,7 @@ end
function remotecall(f, w::Worker, args...; kwargs...)
rr = Future(w)
#println("$(myid()) asking for $rr")
send_msg(w, CallMsg{:call}(f, args, kwargs, remoteref_id(rr)))
send_msg(w, MsgHeader(response_oid=remoteref_id(rr)), CallMsg{:call}(f, args, kwargs))
rr
end

Expand All @@ -829,7 +849,7 @@ function remotecall_fetch(f, w::Worker, args...; kwargs...)
oid = RRID()
rv = lookup_ref(oid)
rv.waitingfor = w.id
send_msg(w, CallMsg{:call_fetch}(f, args, kwargs, oid))
send_msg(w, MsgHeader(response_oid=oid), CallMsg{:call_fetch}(f, args, kwargs))
v = take!(rv)
delete!(PGRP.refs, oid)
isa(v, RemoteException) ? throw(v) : v
Expand All @@ -846,7 +866,7 @@ function remotecall_wait(f, w::Worker, args...; kwargs...)
rv = lookup_ref(prid)
rv.waitingfor = w.id
rr = Future(w)
send_msg(w, CallWaitMsg(f, args, kwargs, remoteref_id(rr), prid))
send_msg(w, MsgHeader(response_oid=remoteref_id(rr), notify_oid=prid), CallWaitMsg(f, args, kwargs))
v = fetch(rv.c)
delete!(PGRP.refs, prid)
isa(v, RemoteException) && throw(v)
Expand All @@ -866,7 +886,7 @@ function remote_do(f, w::LocalProcess, args...; kwargs...)
end

function remote_do(f, w::Worker, args...; kwargs...)
send_msg(w, RemoteDoMsg(f, args, kwargs))
send_msg(w, MsgHeader(), RemoteDoMsg(f, args, kwargs))
nothing
end

Expand Down Expand Up @@ -952,13 +972,13 @@ close(rr::RemoteChannel) = call_on_owner(close_ref, rr)

function deliver_result(sock::IO, msg, oid, value)
#print("$(myid()) sending result $oid\n")
if is(msg,:call_fetch) || isa(value, RemoteException)
if is(msg, :call_fetch) || isa(value, RemoteException)
val = value
else
val = :OK
end
try
send_msg_now(sock, ResultMsg(oid, val))
send_msg_now(sock, MsgHeader(response_oid=oid), ResultMsg(val))
catch e
# terminate connection in case of serialization error
# otherwise the reading end would hang
Expand Down Expand Up @@ -996,28 +1016,73 @@ function process_messages(r_stream::IO, w_stream::IO, incoming=true)
end

function message_handler_loop(r_stream::IO, w_stream::IO, incoming::Bool)
wpid=0 # the worker r_stream is connected to.
boundary = similar(MSG_BOUNDARY)
try
version = process_hdr(r_stream, incoming)
serializer = ClusterSerializer(r_stream)

# The first message will associate wpid with r_stream
msghdr = deserialize(serializer)
msg = deserialize(serializer)
readbytes!(r_stream, boundary, length(MSG_BOUNDARY))

handle_msg(msg, msghdr, r_stream, w_stream, version)
wpid = worker_id_from_socket(r_stream)

@assert wpid > 0

while true
reset_state(serializer)
msg = deserialize(serializer)
# println("got msg: ", msg)
handle_msg(msg, r_stream, w_stream, version)
msghdr = deserialize(serializer)
# println("msghdr: ", msghdr)

try
msg = deserialize(serializer)
catch e
# Deserialization error; discard bytes in stream until boundary found
boundary_idx = 1
while true
# This may throw an EOF error if the terminal boundary was not written
# correctly, triggering the higher-scoped catch block below
byte = read(r_stream, UInt8)
if byte == MSG_BOUNDARY[boundary_idx]
boundary_idx += 1
if boundary_idx > length(MSG_BOUNDARY)
break
end
else
boundary_idx = 1
end
end
# println("Deserialization error.")
remote_err = RemoteException(myid(), CapturedException(e, catch_backtrace()))
if !null_id(msghdr.response_oid)
ref = lookup_ref(msghdr.response_oid)
put!(ref, remote_err)
end
if !null_id(msghdr.notify_oid)
deliver_result(w_stream, :call_fetch, msghdr.notify_oid, remote_err)
end
continue
end
readbytes!(r_stream, boundary, length(MSG_BOUNDARY))

# println("got msg: ", typeof(msg))
handle_msg(msg, msghdr, r_stream, w_stream, version)
end
catch e
# println(STDERR, "Process($(myid())) - Exception ", e)
iderr = worker_id_from_socket(r_stream)
if (iderr < 1)
if (wpid < 1)
println(STDERR, e)
println(STDERR, "Process($(myid())) - Unknown remote, closing connection.")
else
werr = worker_from_id(iderr)
werr = worker_from_id(wpid)
oldstate = werr.state
set_worker_state(werr, W_TERMINATED)

# If error occured talking to pid 1, commit harakiri
if iderr == 1
# If unhandleable error occured talking to pid 1, exit
if wpid == 1
if isopen(w_stream)
print(STDERR, "fatal error on ", myid(), ": ")
display_error(e, catch_backtrace())
Expand All @@ -1028,15 +1093,15 @@ function message_handler_loop(r_stream::IO, w_stream::IO, incoming::Bool)
# Will treat any exception as death of node and cleanup
# since currently we do not have a mechanism for workers to reconnect
# to each other on unhandled errors
deregister_worker(iderr)
deregister_worker(wpid)
end

isopen(r_stream) && close(r_stream)
isopen(w_stream) && close(w_stream)

if (myid() == 1) && (iderr > 1)
if (myid() == 1) && (wpid > 1)
if oldstate != W_TERMINATING
println(STDERR, "Worker $iderr terminated.")
println(STDERR, "Worker $wpid terminated.")
rethrow(e)
end
end
Expand Down Expand Up @@ -1071,44 +1136,44 @@ function process_hdr(s, validate_cookie)
return VersionNumber(strip(String(version)))
end

function handle_msg(msg::CallMsg{:call}, r_stream, w_stream, version)
schedule_call(msg.response_oid, ()->msg.f(msg.args...; msg.kwargs...))
function handle_msg(msg::CallMsg{:call}, msghdr, r_stream, w_stream, version)
schedule_call(msghdr.response_oid, ()->msg.f(msg.args...; msg.kwargs...))
end
function handle_msg(msg::CallMsg{:call_fetch}, r_stream, w_stream, version)
function handle_msg(msg::CallMsg{:call_fetch}, msghdr, r_stream, w_stream, version)
@schedule begin
v = run_work_thunk(()->msg.f(msg.args...; msg.kwargs...), false)
deliver_result(w_stream, :call_fetch, msg.response_oid, v)
deliver_result(w_stream, :call_fetch, msghdr.response_oid, v)
end
end

function handle_msg(msg::CallWaitMsg, r_stream, w_stream, version)
function handle_msg(msg::CallWaitMsg, msghdr, r_stream, w_stream, version)
@schedule begin
rv = schedule_call(msg.response_oid, ()->msg.f(msg.args...; msg.kwargs...))
deliver_result(w_stream, :call_wait, msg.notify_oid, fetch(rv.c))
rv = schedule_call(msghdr.response_oid, ()->msg.f(msg.args...; msg.kwargs...))
deliver_result(w_stream, :call_wait, msghdr.notify_oid, fetch(rv.c))
end
end

function handle_msg(msg::RemoteDoMsg, r_stream, w_stream, version)
function handle_msg(msg::RemoteDoMsg, msghdr, r_stream, w_stream, version)
@schedule run_work_thunk(()->msg.f(msg.args...; msg.kwargs...), true)
end

function handle_msg(msg::ResultMsg, r_stream, w_stream, version)
put!(lookup_ref(msg.response_oid), msg.value)
function handle_msg(msg::ResultMsg, msghdr, r_stream, w_stream, version)
put!(lookup_ref(msghdr.response_oid), msg.value)
end

function handle_msg(msg::IdentifySocketMsg, r_stream, w_stream, version)
function handle_msg(msg::IdentifySocketMsg, msghdr, r_stream, w_stream, version)
# register a new peer worker connection
w=Worker(msg.from_pid, r_stream, w_stream, cluster_manager; version=version)
send_connection_hdr(w, false)
send_msg_now(w, IdentifySocketAckMsg())
send_msg_now(w, MsgHeader(), IdentifySocketAckMsg())
end

function handle_msg(msg::IdentifySocketAckMsg, r_stream, w_stream, version)
function handle_msg(msg::IdentifySocketAckMsg, msghdr, r_stream, w_stream, version)
w = map_sock_wrkr[r_stream]
w.version = version
end

function handle_msg(msg::JoinPGRPMsg, r_stream, w_stream, version)
function handle_msg(msg::JoinPGRPMsg, msghdr, r_stream, w_stream, version)
LPROC.id = msg.self_pid
controller = Worker(1, r_stream, w_stream, cluster_manager; version=version)
register_worker(LPROC)
Expand All @@ -1129,7 +1194,7 @@ function handle_msg(msg::JoinPGRPMsg, r_stream, w_stream, version)

set_default_worker_pool(msg.worker_pool)
send_connection_hdr(controller, false)
send_msg_now(controller, JoinCompleteMsg(msg.notify_oid, Sys.CPU_CORES, getpid()))
send_msg_now(controller, MsgHeader(notify_oid=msghdr.notify_oid), JoinCompleteMsg(Sys.CPU_CORES, getpid()))
end

function connect_to_peer(manager::ClusterManager, rpid::Int, wconfig::WorkerConfig)
Expand All @@ -1138,23 +1203,23 @@ function connect_to_peer(manager::ClusterManager, rpid::Int, wconfig::WorkerConf
w = Worker(rpid, r_s, w_s, manager; config=wconfig)
process_messages(w.r_stream, w.w_stream, false)
send_connection_hdr(w, true)
send_msg_now(w, IdentifySocketMsg(myid()))
send_msg_now(w, MsgHeader(), IdentifySocketMsg(myid()))
catch e
display_error(e, catch_backtrace())
println(STDERR, "Error [$e] on $(myid()) while connecting to peer $rpid. Exiting.")
exit(1)
end
end

function handle_msg(msg::JoinCompleteMsg, r_stream, w_stream, version)
function handle_msg(msg::JoinCompleteMsg, msghdr, r_stream, w_stream, version)
w = map_sock_wrkr[r_stream]
environ = get(w.config.environ, Dict())
environ[:cpu_cores] = msg.cpu_cores
w.config.environ = environ
w.config.ospid = msg.ospid
w.version = version

ntfy_channel = lookup_ref(msg.notify_oid)
ntfy_channel = lookup_ref(msghdr.notify_oid)
put!(ntfy_channel, w.id)

push!(default_worker_pool(), w)
Expand Down Expand Up @@ -1478,7 +1543,7 @@ function create_worker(manager, wconfig)

all_locs = map(x -> isa(x, Worker) ? (get(x.config.connect_at, ()), x.id) : ((), x.id, true), join_list)
send_connection_hdr(w, true)
send_msg_now(w, JoinPGRPMsg(w.id, all_locs, ntfy_oid, PGRP.topology, default_worker_pool()))
send_msg_now(w, MsgHeader(notify_oid=ntfy_oid), JoinPGRPMsg(w.id, all_locs, PGRP.topology, default_worker_pool()))

@schedule manage(w.manager, w.id, w.config, :register)
wait(rr_ntfy_join)
Expand Down
Loading

0 comments on commit 561db3b

Please sign in to comment.