Skip to content

Commit

Permalink
Switch to Boost generic sockets (ray-project#7656)
Browse files Browse the repository at this point in the history
* Use generic Boost sockets

* Un-templatize server/client connections

Co-authored-by: Mehrdad <[email protected]>
  • Loading branch information
mehrdadn and web-flow authored Apr 6, 2020
1 parent 82c2d9f commit 203c077
Show file tree
Hide file tree
Showing 19 changed files with 364 additions and 359 deletions.
114 changes: 45 additions & 69 deletions src/ray/common/client_connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,39 +15,42 @@
#include "client_connection.h"

#include <stdio.h>

#include <boost/asio/buffer.hpp>
#include <boost/asio/generic/stream_protocol.hpp>
#include <boost/asio/placeholders.hpp>
#include <boost/asio/read.hpp>
#include <boost/asio/write.hpp>
#include <boost/bind.hpp>
#include <sstream>

#include "ray/common/ray_config.h"
#include "ray/util/url.h"
#include "ray/util/util.h"

namespace ray {

template <class T>
std::shared_ptr<ServerConnection<T>> ServerConnection<T>::Create(
boost::asio::basic_stream_socket<T> &&socket) {
std::shared_ptr<ServerConnection<T>> self(new ServerConnection(std::move(socket)));
std::shared_ptr<ServerConnection> ServerConnection::Create(
boost::asio::generic::stream_protocol::socket &&socket) {
std::shared_ptr<ServerConnection> self(new ServerConnection(std::move(socket)));
return self;
}

template <class T>
ServerConnection<T>::ServerConnection(boost::asio::basic_stream_socket<T> &&socket)
ServerConnection::ServerConnection(boost::asio::generic::stream_protocol::socket &&socket)
: socket_(std::move(socket)),
async_write_max_messages_(1),
async_write_queue_(),
async_write_in_flight_(false),
async_write_broken_pipe_(false) {}

template <class T>
ServerConnection<T>::~ServerConnection() {
ServerConnection::~ServerConnection() {
// If there are any pending messages, invoke their callbacks with an IOError status.
for (const auto &write_buffer : async_write_queue_) {
write_buffer->handler(Status::IOError("Connection closed."));
}
}

template <class T>
Status ServerConnection<T>::WriteBuffer(
Status ServerConnection::WriteBuffer(
const std::vector<boost::asio::const_buffer> &buffer) {
boost::system::error_code error;
// Loop until all bytes are written while handling interrupts.
Expand All @@ -71,8 +74,7 @@ Status ServerConnection<T>::WriteBuffer(
return ray::Status::OK();
}

template <class T>
Status ServerConnection<T>::ReadBuffer(
Status ServerConnection::ReadBuffer(
const std::vector<boost::asio::mutable_buffer> &buffer) {
boost::system::error_code error;
// Loop until all bytes are read while handling interrupts.
Expand All @@ -94,9 +96,8 @@ Status ServerConnection<T>::ReadBuffer(
return Status::OK();
}

template <class T>
ray::Status ServerConnection<T>::WriteMessage(int64_t type, int64_t length,
const uint8_t *message) {
ray::Status ServerConnection::WriteMessage(int64_t type, int64_t length,
const uint8_t *message) {
sync_writes_ += 1;
bytes_written_ += length;

Expand All @@ -109,8 +110,7 @@ ray::Status ServerConnection<T>::WriteMessage(int64_t type, int64_t length,
return WriteBuffer(message_buffers);
}

template <class T>
void ServerConnection<T>::WriteMessageAsync(
void ServerConnection::WriteMessageAsync(
int64_t type, int64_t length, const uint8_t *message,
const std::function<void(const ray::Status &)> &handler) {
async_writes_ += 1;
Expand All @@ -137,8 +137,7 @@ void ServerConnection<T>::WriteMessageAsync(
}
}

template <class T>
void ServerConnection<T>::DoAsyncWrites() {
void ServerConnection::DoAsyncWrites() {
// Make sure we were not writing to the socket.
RAY_CHECK(!async_write_in_flight_);
async_write_in_flight_ = true;
Expand Down Expand Up @@ -183,73 +182,70 @@ void ServerConnection<T>::DoAsyncWrites() {
}
auto this_ptr = this->shared_from_this();
boost::asio::async_write(
ServerConnection<T>::socket_, message_buffers,
ServerConnection::socket_, message_buffers,
[this, this_ptr, num_messages, call_handlers](
const boost::system::error_code &error, size_t bytes_transferred) {
ray::Status status = boost_to_ray_status(error);
if (error.value() == boost::system::errc::errc_t::broken_pipe) {
RAY_LOG(ERROR) << "Broken Pipe happened during calling "
<< "ServerConnection<T>::DoAsyncWrites.";
<< "ServerConnection::DoAsyncWrites.";
// From now on, calling DoAsyncWrites will directly call the handler
// with this broken-pipe status.
async_write_broken_pipe_ = true;
} else if (!status.ok()) {
RAY_LOG(ERROR) << "Error encountered during calling "
<< "ServerConnection<T>::DoAsyncWrites, message: "
<< "ServerConnection::DoAsyncWrites, message: "
<< status.message()
<< ", error code: " << static_cast<int>(error.value());
}
call_handlers(status, num_messages);
});
}

template <class T>
std::shared_ptr<ClientConnection<T>> ClientConnection<T>::Create(
ClientHandler<T> &client_handler, MessageHandler<T> &message_handler,
boost::asio::basic_stream_socket<T> &&socket, const std::string &debug_label,
std::shared_ptr<ClientConnection> ClientConnection::Create(
ClientHandler &client_handler, MessageHandler &message_handler,
boost::asio::generic::stream_protocol::socket &&socket,
const std::string &debug_label,
const std::vector<std::string> &message_type_enum_names, int64_t error_message_type) {
std::shared_ptr<ClientConnection<T>> self(
std::shared_ptr<ClientConnection> self(
new ClientConnection(message_handler, std::move(socket), debug_label,
message_type_enum_names, error_message_type));
// Let our manager process our new connection.
client_handler(*self);
return self;
}

template <class T>
ClientConnection<T>::ClientConnection(
MessageHandler<T> &message_handler, boost::asio::basic_stream_socket<T> &&socket,
ClientConnection::ClientConnection(
MessageHandler &message_handler,
boost::asio::generic::stream_protocol::socket &&socket,
const std::string &debug_label,
const std::vector<std::string> &message_type_enum_names, int64_t error_message_type)
: ServerConnection<T>(std::move(socket)),
: ServerConnection(std::move(socket)),
registered_(false),
message_handler_(message_handler),
debug_label_(debug_label),
message_type_enum_names_(message_type_enum_names),
error_message_type_(error_message_type) {}

template <class T>
void ClientConnection<T>::Register() {
void ClientConnection::Register() {
RAY_CHECK(!registered_);
registered_ = true;
}

template <class T>
void ClientConnection<T>::ProcessMessages() {
void ClientConnection::ProcessMessages() {
// Wait for a message header from the client. The message header includes the
// protocol version, the message type, and the length of the message.
std::vector<boost::asio::mutable_buffer> header;
header.push_back(boost::asio::buffer(&read_cookie_, sizeof(read_cookie_)));
header.push_back(boost::asio::buffer(&read_type_, sizeof(read_type_)));
header.push_back(boost::asio::buffer(&read_length_, sizeof(read_length_)));
boost::asio::async_read(
ServerConnection<T>::socket_, header,
boost::bind(&ClientConnection<T>::ProcessMessageHeader,
ServerConnection::socket_, header,
boost::bind(&ClientConnection::ProcessMessageHeader,
shared_ClientConnection_from_this(), boost::asio::placeholders::error));
}

template <class T>
void ClientConnection<T>::ProcessMessageHeader(const boost::system::error_code &error) {
void ClientConnection::ProcessMessageHeader(const boost::system::error_code &error) {
if (error) {
// If there was an error, disconnect the client.
read_type_ = error_message_type_;
Expand All @@ -260,22 +256,21 @@ void ClientConnection<T>::ProcessMessageHeader(const boost::system::error_code &

// If there was no error, make sure the ray cookie matches.
if (!CheckRayCookie()) {
ServerConnection<T>::Close();
ServerConnection::Close();
return;
}

// Resize the message buffer to match the received length.
read_message_.resize(read_length_);
ServerConnection<T>::bytes_read_ += read_length_;
ServerConnection::bytes_read_ += read_length_;
// Wait for the message to be read.
boost::asio::async_read(
ServerConnection<T>::socket_, boost::asio::buffer(read_message_),
boost::bind(&ClientConnection<T>::ProcessMessage,
shared_ClientConnection_from_this(), boost::asio::placeholders::error));
ServerConnection::socket_, boost::asio::buffer(read_message_),
boost::bind(&ClientConnection::ProcessMessage, shared_ClientConnection_from_this(),
boost::asio::placeholders::error));
}

template <class T>
bool ClientConnection<T>::CheckRayCookie() {
bool ClientConnection::CheckRayCookie() {
if (read_cookie_ == RayConfig::instance().ray_cookie()) {
return true;
}
Expand Down Expand Up @@ -303,21 +298,11 @@ bool ClientConnection<T>::CheckRayCookie() {
return false;
}

template <class T>
std::string ClientConnection<T>::RemoteEndpointInfo() {
return std::string();
std::string ClientConnection::RemoteEndpointInfo() {
return endpoint_to_url(ServerConnection::socket_.remote_endpoint(), false);
}

template <>
std::string ClientConnection<remote_stream_protocol>::RemoteEndpointInfo() {
const auto &remote_endpoint =
ServerConnection<remote_stream_protocol>::socket_.remote_endpoint();
return remote_endpoint.address().to_string() + ":" +
std::to_string(remote_endpoint.port());
}

template <class T>
void ClientConnection<T>::ProcessMessage(const boost::system::error_code &error) {
void ClientConnection::ProcessMessage(const boost::system::error_code &error) {
if (error) {
read_type_ = error_message_type_;
}
Expand All @@ -337,8 +322,7 @@ void ClientConnection<T>::ProcessMessage(const boost::system::error_code &error)
}
}

template <class T>
std::string ServerConnection<T>::DebugString() const {
std::string ServerConnection::DebugString() const {
std::stringstream result;
result << "\n- bytes read: " << bytes_read_;
result << "\n- bytes written: " << bytes_written_;
Expand All @@ -353,12 +337,4 @@ std::string ServerConnection<T>::DebugString() const {
return result.str();
}

#if defined(BOOST_ASIO_HAS_LOCAL_SOCKETS)
// We compile conditionally to prevent duplicate explicit instantiation error
template class ServerConnection<local_stream_protocol>;
template class ClientConnection<local_stream_protocol>;
#endif
template class ServerConnection<remote_stream_protocol>;
template class ClientConnection<remote_stream_protocol>;

} // namespace ray
Loading

0 comments on commit 203c077

Please sign in to comment.