From 203c077895ac422b80e31f062d33eadb89e66768 Mon Sep 17 00:00:00 2001 From: mehrdadn Date: Sun, 5 Apr 2020 22:26:46 -0700 Subject: [PATCH] Switch to Boost generic sockets (#7656) * Use generic Boost sockets * Un-templatize server/client connections Co-authored-by: Mehrdad --- src/ray/common/client_connection.cc | 114 +++++------- src/ray/common/client_connection.h | 59 +++---- .../object_store_notification_manager.cc | 6 +- .../object_store_notification_manager.h | 2 +- src/ray/raylet/client_connection_test.cc | 145 ++++++++-------- src/ray/raylet/node_manager.cc | 37 ++-- src/ray/raylet/node_manager.h | 31 ++-- src/ray/raylet/raylet.cc | 25 ++- src/ray/raylet/raylet.h | 2 +- src/ray/raylet/raylet_client.cc | 15 +- src/ray/raylet/raylet_client.h | 13 +- src/ray/raylet/worker.cc | 6 +- src/ray/raylet/worker.h | 6 +- src/ray/raylet/worker_pool.cc | 6 +- src/ray/raylet/worker_pool.h | 4 +- src/ray/raylet/worker_pool_test.cc | 18 +- src/ray/util/url.cc | 163 ++++++++++++------ src/ray/util/url.h | 23 ++- src/ray/util/url_test.cc | 48 +++--- 19 files changed, 364 insertions(+), 359 deletions(-) diff --git a/src/ray/common/client_connection.cc b/src/ray/common/client_connection.cc index 6d8baec1bb26..7fbc6b672da6 100644 --- a/src/ray/common/client_connection.cc +++ b/src/ray/common/client_connection.cc @@ -15,39 +15,42 @@ #include "client_connection.h" #include + +#include +#include +#include +#include +#include #include #include #include "ray/common/ray_config.h" +#include "ray/util/url.h" #include "ray/util/util.h" namespace ray { -template -std::shared_ptr> ServerConnection::Create( - boost::asio::basic_stream_socket &&socket) { - std::shared_ptr> self(new ServerConnection(std::move(socket))); +std::shared_ptr ServerConnection::Create( + boost::asio::generic::stream_protocol::socket &&socket) { + std::shared_ptr self(new ServerConnection(std::move(socket))); return self; } -template -ServerConnection::ServerConnection(boost::asio::basic_stream_socket &&socket) +ServerConnection::ServerConnection(boost::asio::generic::stream_protocol::socket &&socket) : socket_(std::move(socket)), async_write_max_messages_(1), async_write_queue_(), async_write_in_flight_(false), async_write_broken_pipe_(false) {} -template -ServerConnection::~ServerConnection() { +ServerConnection::~ServerConnection() { // If there are any pending messages, invoke their callbacks with an IOError status. for (const auto &write_buffer : async_write_queue_) { write_buffer->handler(Status::IOError("Connection closed.")); } } -template -Status ServerConnection::WriteBuffer( +Status ServerConnection::WriteBuffer( const std::vector &buffer) { boost::system::error_code error; // Loop until all bytes are written while handling interrupts. @@ -71,8 +74,7 @@ Status ServerConnection::WriteBuffer( return ray::Status::OK(); } -template -Status ServerConnection::ReadBuffer( +Status ServerConnection::ReadBuffer( const std::vector &buffer) { boost::system::error_code error; // Loop until all bytes are read while handling interrupts. @@ -94,9 +96,8 @@ Status ServerConnection::ReadBuffer( return Status::OK(); } -template -ray::Status ServerConnection::WriteMessage(int64_t type, int64_t length, - const uint8_t *message) { +ray::Status ServerConnection::WriteMessage(int64_t type, int64_t length, + const uint8_t *message) { sync_writes_ += 1; bytes_written_ += length; @@ -109,8 +110,7 @@ ray::Status ServerConnection::WriteMessage(int64_t type, int64_t length, return WriteBuffer(message_buffers); } -template -void ServerConnection::WriteMessageAsync( +void ServerConnection::WriteMessageAsync( int64_t type, int64_t length, const uint8_t *message, const std::function &handler) { async_writes_ += 1; @@ -137,8 +137,7 @@ void ServerConnection::WriteMessageAsync( } } -template -void ServerConnection::DoAsyncWrites() { +void ServerConnection::DoAsyncWrites() { // Make sure we were not writing to the socket. RAY_CHECK(!async_write_in_flight_); async_write_in_flight_ = true; @@ -183,19 +182,19 @@ void ServerConnection::DoAsyncWrites() { } auto this_ptr = this->shared_from_this(); boost::asio::async_write( - ServerConnection::socket_, message_buffers, + ServerConnection::socket_, message_buffers, [this, this_ptr, num_messages, call_handlers]( const boost::system::error_code &error, size_t bytes_transferred) { ray::Status status = boost_to_ray_status(error); if (error.value() == boost::system::errc::errc_t::broken_pipe) { RAY_LOG(ERROR) << "Broken Pipe happened during calling " - << "ServerConnection::DoAsyncWrites."; + << "ServerConnection::DoAsyncWrites."; // From now on, calling DoAsyncWrites will directly call the handler // with this broken-pipe status. async_write_broken_pipe_ = true; } else if (!status.ok()) { RAY_LOG(ERROR) << "Error encountered during calling " - << "ServerConnection::DoAsyncWrites, message: " + << "ServerConnection::DoAsyncWrites, message: " << status.message() << ", error code: " << static_cast(error.value()); } @@ -203,12 +202,12 @@ void ServerConnection::DoAsyncWrites() { }); } -template -std::shared_ptr> ClientConnection::Create( - ClientHandler &client_handler, MessageHandler &message_handler, - boost::asio::basic_stream_socket &&socket, const std::string &debug_label, +std::shared_ptr ClientConnection::Create( + ClientHandler &client_handler, MessageHandler &message_handler, + boost::asio::generic::stream_protocol::socket &&socket, + const std::string &debug_label, const std::vector &message_type_enum_names, int64_t error_message_type) { - std::shared_ptr> self( + std::shared_ptr self( new ClientConnection(message_handler, std::move(socket), debug_label, message_type_enum_names, error_message_type)); // Let our manager process our new connection. @@ -216,26 +215,24 @@ std::shared_ptr> ClientConnection::Create( return self; } -template -ClientConnection::ClientConnection( - MessageHandler &message_handler, boost::asio::basic_stream_socket &&socket, +ClientConnection::ClientConnection( + MessageHandler &message_handler, + boost::asio::generic::stream_protocol::socket &&socket, const std::string &debug_label, const std::vector &message_type_enum_names, int64_t error_message_type) - : ServerConnection(std::move(socket)), + : ServerConnection(std::move(socket)), registered_(false), message_handler_(message_handler), debug_label_(debug_label), message_type_enum_names_(message_type_enum_names), error_message_type_(error_message_type) {} -template -void ClientConnection::Register() { +void ClientConnection::Register() { RAY_CHECK(!registered_); registered_ = true; } -template -void ClientConnection::ProcessMessages() { +void ClientConnection::ProcessMessages() { // Wait for a message header from the client. The message header includes the // protocol version, the message type, and the length of the message. std::vector header; @@ -243,13 +240,12 @@ void ClientConnection::ProcessMessages() { header.push_back(boost::asio::buffer(&read_type_, sizeof(read_type_))); header.push_back(boost::asio::buffer(&read_length_, sizeof(read_length_))); boost::asio::async_read( - ServerConnection::socket_, header, - boost::bind(&ClientConnection::ProcessMessageHeader, + ServerConnection::socket_, header, + boost::bind(&ClientConnection::ProcessMessageHeader, shared_ClientConnection_from_this(), boost::asio::placeholders::error)); } -template -void ClientConnection::ProcessMessageHeader(const boost::system::error_code &error) { +void ClientConnection::ProcessMessageHeader(const boost::system::error_code &error) { if (error) { // If there was an error, disconnect the client. read_type_ = error_message_type_; @@ -260,22 +256,21 @@ void ClientConnection::ProcessMessageHeader(const boost::system::error_code & // If there was no error, make sure the ray cookie matches. if (!CheckRayCookie()) { - ServerConnection::Close(); + ServerConnection::Close(); return; } // Resize the message buffer to match the received length. read_message_.resize(read_length_); - ServerConnection::bytes_read_ += read_length_; + ServerConnection::bytes_read_ += read_length_; // Wait for the message to be read. boost::asio::async_read( - ServerConnection::socket_, boost::asio::buffer(read_message_), - boost::bind(&ClientConnection::ProcessMessage, - shared_ClientConnection_from_this(), boost::asio::placeholders::error)); + ServerConnection::socket_, boost::asio::buffer(read_message_), + boost::bind(&ClientConnection::ProcessMessage, shared_ClientConnection_from_this(), + boost::asio::placeholders::error)); } -template -bool ClientConnection::CheckRayCookie() { +bool ClientConnection::CheckRayCookie() { if (read_cookie_ == RayConfig::instance().ray_cookie()) { return true; } @@ -303,21 +298,11 @@ bool ClientConnection::CheckRayCookie() { return false; } -template -std::string ClientConnection::RemoteEndpointInfo() { - return std::string(); +std::string ClientConnection::RemoteEndpointInfo() { + return endpoint_to_url(ServerConnection::socket_.remote_endpoint(), false); } -template <> -std::string ClientConnection::RemoteEndpointInfo() { - const auto &remote_endpoint = - ServerConnection::socket_.remote_endpoint(); - return remote_endpoint.address().to_string() + ":" + - std::to_string(remote_endpoint.port()); -} - -template -void ClientConnection::ProcessMessage(const boost::system::error_code &error) { +void ClientConnection::ProcessMessage(const boost::system::error_code &error) { if (error) { read_type_ = error_message_type_; } @@ -337,8 +322,7 @@ void ClientConnection::ProcessMessage(const boost::system::error_code &error) } } -template -std::string ServerConnection::DebugString() const { +std::string ServerConnection::DebugString() const { std::stringstream result; result << "\n- bytes read: " << bytes_read_; result << "\n- bytes written: " << bytes_written_; @@ -353,12 +337,4 @@ std::string ServerConnection::DebugString() const { return result.str(); } -#if defined(BOOST_ASIO_HAS_LOCAL_SOCKETS) -// We compile conditionally to prevent duplicate explicit instantiation error -template class ServerConnection; -template class ClientConnection; -#endif -template class ServerConnection; -template class ClientConnection; - } // namespace ray diff --git a/src/ray/common/client_connection.h b/src/ray/common/client_connection.h index 9b9b2c3f4e9b..9383b57e5402 100644 --- a/src/ray/common/client_connection.h +++ b/src/ray/common/client_connection.h @@ -18,8 +18,10 @@ #include #include -#include +#include +#include #include +#include #include #include "ray/common/id.h" @@ -27,12 +29,14 @@ namespace ray { +typedef boost::asio::generic::stream_protocol local_stream_protocol; +typedef boost::asio::basic_stream_socket local_stream_socket; + /// \typename ServerConnection /// /// A generic type representing a client connection to a server. This typename /// can be used to write messages synchronously to the server. -template -class ServerConnection : public std::enable_shared_from_this> { +class ServerConnection : public std::enable_shared_from_this { public: /// ServerConnection destructor. virtual ~ServerConnection(); @@ -41,8 +45,7 @@ class ServerConnection : public std::enable_shared_from_this /// /// \param socket A reference to the server socket. /// \return std::shared_ptr. - static std::shared_ptr> Create( - boost::asio::basic_stream_socket &&socket); + static std::shared_ptr Create(local_stream_socket &&socket); /// Write a message to the client. /// @@ -83,7 +86,7 @@ class ServerConnection : public std::enable_shared_from_this protected: /// A private constructor for a server connection. - ServerConnection(boost::asio::basic_stream_socket &&socket); + ServerConnection(local_stream_socket &&socket); /// A message that is queued for writing asynchronously. struct AsyncWriteBuffer { @@ -95,7 +98,7 @@ class ServerConnection : public std::enable_shared_from_this }; /// The socket connection to the server. - boost::asio::basic_stream_socket socket_; + local_stream_socket socket_; /// Max number of messages to write out at once. const int async_write_max_messages_; @@ -128,24 +131,20 @@ class ServerConnection : public std::enable_shared_from_this void DoAsyncWrites(); }; -template class ClientConnection; -template -using ClientHandler = std::function &)>; -template +using ClientHandler = std::function; using MessageHandler = - std::function>, int64_t, const uint8_t *)>; + std::function, int64_t, const uint8_t *)>; /// \typename ClientConnection /// /// A generic type representing a client connection on a server. In addition to /// writing messages to the client, like in ServerConnection, this typename can /// also be used to process messages asynchronously from client. -template -class ClientConnection : public ServerConnection { +class ClientConnection : public ServerConnection { public: - using std::enable_shared_from_this>::shared_from_this; + using std::enable_shared_from_this::shared_from_this; /// Allocate a new node client connection. /// @@ -157,14 +156,14 @@ class ClientConnection : public ServerConnection { /// \param message_type_enum_names A table of printable enum names for the /// message types received from this client, used for debug messages. /// \return std::shared_ptr. - static std::shared_ptr> Create( - ClientHandler &new_client_handler, MessageHandler &message_handler, - boost::asio::basic_stream_socket &&socket, const std::string &debug_label, + static std::shared_ptr Create( + ClientHandler &new_client_handler, MessageHandler &message_handler, + local_stream_socket &&socket, const std::string &debug_label, const std::vector &message_type_enum_names, int64_t error_message_type); - std::shared_ptr> shared_ClientConnection_from_this() { - return std::static_pointer_cast>(shared_from_this()); + std::shared_ptr shared_ClientConnection_from_this() { + return std::static_pointer_cast(shared_from_this()); } /// Register the client. @@ -177,8 +176,7 @@ class ClientConnection : public ServerConnection { private: /// A private constructor for a node client connection. - ClientConnection(MessageHandler &message_handler, - boost::asio::basic_stream_socket &&socket, + ClientConnection(MessageHandler &message_handler, local_stream_socket &&socket, const std::string &debug_label, const std::vector &message_type_enum_names, int64_t error_message_type); @@ -203,7 +201,7 @@ class ClientConnection : public ServerConnection { /// Whether the client has sent us a registration message yet. bool registered_; /// The handler for a message from the client. - MessageHandler message_handler_; + MessageHandler message_handler_; /// A label used for debug messages. const std::string debug_label_; /// A table of printable enum names for the message types, used for debug @@ -218,21 +216,6 @@ class ClientConnection : public ServerConnection { std::vector read_message_; }; -typedef -#if defined(BOOST_ASIO_HAS_LOCAL_SOCKETS) - boost::asio::local::stream_protocol -#else - boost::asio::ip::tcp -#endif - local_stream_protocol; - -typedef boost::asio::ip::tcp remote_stream_protocol; - -using LocalServerConnection = ServerConnection; -using TcpServerConnection = ServerConnection; -using LocalClientConnection = ClientConnection; -using TcpClientConnection = ClientConnection; - } // namespace ray #endif // RAY_COMMON_CLIENT_CONNECTION_H diff --git a/src/ray/object_manager/object_store_notification_manager.cc b/src/ray/object_manager/object_store_notification_manager.cc index d36abfba64b1..79301e1087d7 100644 --- a/src/ray/object_manager/object_store_notification_manager.cc +++ b/src/ray/object_manager/object_store_notification_manager.cc @@ -56,10 +56,10 @@ ObjectStoreNotificationManager::ObjectStoreNotificationManager( boost::asio::detail::socket_error_retval) { switch (pi.iAddressFamily) { case AF_INET: - socket_.assign(local_stream_protocol::v4(), c_socket, ec); + socket_.assign(boost::asio::ip::tcp::v4(), c_socket, ec); break; case AF_INET6: - socket_.assign(local_stream_protocol::v6(), c_socket, ec); + socket_.assign(boost::asio::ip::tcp::v6(), c_socket, ec); break; default: ec = boost::system::errc::make_error_code( @@ -68,7 +68,7 @@ ObjectStoreNotificationManager::ObjectStoreNotificationManager( } } #else - socket_.assign(local_stream_protocol(), fd, ec); + socket_.assign(boost::asio::local::stream_protocol(), fd, ec); #endif RAY_CHECK(!ec); NotificationWait(); diff --git a/src/ray/object_manager/object_store_notification_manager.h b/src/ray/object_manager/object_store_notification_manager.h index eec317191a94..7565d63e9da3 100644 --- a/src/ray/object_manager/object_store_notification_manager.h +++ b/src/ray/object_manager/object_store_notification_manager.h @@ -90,7 +90,7 @@ class ObjectStoreNotificationManager { int64_t num_adds_processed_; int64_t num_removes_processed_; std::vector notification_; - local_stream_protocol::socket socket_; + local_stream_socket socket_; /// Flag to indicate whether or not to exit the process when received socket /// error. When it is false, socket error will be ignored. This flag is needed diff --git a/src/ray/raylet/client_connection_test.cc b/src/ray/raylet/client_connection_test.cc index a6d1d18d462f..083f024bbdfa 100644 --- a/src/ray/raylet/client_connection_test.cc +++ b/src/ray/raylet/client_connection_test.cc @@ -35,18 +35,21 @@ class ClientConnectionTest : public ::testing::Test { ClientConnectionTest() : io_service_(), in_(io_service_), out_(io_service_), error_message_type_(1) { #if defined(BOOST_ASIO_HAS_LOCAL_SOCKETS) - boost::asio::local::connect_pair(in_, out_); + boost::asio::local::stream_protocol::socket input(io_service_), output(io_service_); + boost::asio::local::connect_pair(input, output); + in_ = std::move(input); + out_ = std::move(output); #else boost::asio::detail::socket_type pair[2] = {boost::asio::detail::invalid_socket, boost::asio::detail::invalid_socket}; - RAY_CHECK(socketpair(AF_INET, SOCK_STREAM, 0, pair) == 0); - in_.assign(local_stream_protocol::v4(), pair[0]); - out_.assign(local_stream_protocol::v4(), pair[1]); + RAY_CHECK(socketpair(boost::asio::ip::tcp::v4().family(), SOCK_STREAM, 0, pair) == 0); + in_.assign(boost::asio::ip::tcp::v4(), pair[0]); + out_.assign(boost::asio::ip::tcp::v4(), pair[1]); #endif } - ray::Status WriteBadMessage(std::shared_ptr conn, - int64_t type, int64_t length, const uint8_t *message) { + ray::Status WriteBadMessage(std::shared_ptr conn, int64_t type, + int64_t length, const uint8_t *message) { std::vector message_buffers; auto write_cookie = 123456; // incorrect version. message_buffers.push_back(boost::asio::buffer(&write_cookie, sizeof(write_cookie))); @@ -58,8 +61,8 @@ class ClientConnectionTest : public ::testing::Test { protected: boost::asio::io_service io_service_; - local_stream_protocol::socket in_; - local_stream_protocol::socket out_; + local_stream_socket in_; + local_stream_socket out_; int64_t error_message_type_; }; @@ -67,21 +70,20 @@ TEST_F(ClientConnectionTest, SimpleSyncWrite) { const uint8_t arr[5] = {1, 2, 3, 4, 5}; int num_messages = 0; - ClientHandler client_handler = - [](LocalClientConnection &client) {}; + ClientHandler client_handler = [](ClientConnection &client) {}; - MessageHandler message_handler = - [&arr, &num_messages](std::shared_ptr client, - int64_t message_type, const uint8_t *message) { - ASSERT_TRUE(!std::memcmp(arr, message, 5)); - num_messages += 1; - }; + MessageHandler message_handler = [&arr, &num_messages]( + std::shared_ptr client, + int64_t message_type, const uint8_t *message) { + ASSERT_TRUE(!std::memcmp(arr, message, 5)); + num_messages += 1; + }; - auto conn1 = LocalClientConnection::Create( - client_handler, message_handler, std::move(in_), "conn1", {}, error_message_type_); + auto conn1 = ClientConnection::Create(client_handler, message_handler, std::move(in_), + "conn1", {}, error_message_type_); - auto conn2 = LocalClientConnection::Create( - client_handler, message_handler, std::move(out_), "conn2", {}, error_message_type_); + auto conn2 = ClientConnection::Create(client_handler, message_handler, std::move(out_), + "conn2", {}, error_message_type_); RAY_CHECK_OK(conn1->WriteMessage(0, 5, arr)); RAY_CHECK_OK(conn2->WriteMessage(0, 5, arr)); @@ -97,37 +99,34 @@ TEST_F(ClientConnectionTest, SimpleAsyncWrite) { const uint8_t msg3[5] = {8, 8, 8, 8, 8}; int num_messages = 0; - ClientHandler client_handler = - [](LocalClientConnection &client) {}; - - MessageHandler noop_handler = - [](std::shared_ptr client, int64_t message_type, - const uint8_t *message) {}; - - std::shared_ptr reader = NULL; - - MessageHandler message_handler = - [&msg1, &msg2, &msg3, &num_messages, &reader]( - std::shared_ptr client, int64_t message_type, - const uint8_t *message) { - if (num_messages == 0) { - ASSERT_TRUE(!std::memcmp(msg1, message, 5)); - } else if (num_messages == 1) { - ASSERT_TRUE(!std::memcmp(msg2, message, 5)); - } else { - ASSERT_TRUE(!std::memcmp(msg3, message, 5)); - } - num_messages += 1; - if (num_messages < 3) { - reader->ProcessMessages(); - } - }; + ClientHandler client_handler = [](ClientConnection &client) {}; + + MessageHandler noop_handler = [](std::shared_ptr client, + int64_t message_type, const uint8_t *message) {}; + + std::shared_ptr reader = NULL; + + MessageHandler message_handler = [&msg1, &msg2, &msg3, &num_messages, &reader]( + std::shared_ptr client, + int64_t message_type, const uint8_t *message) { + if (num_messages == 0) { + ASSERT_TRUE(!std::memcmp(msg1, message, 5)); + } else if (num_messages == 1) { + ASSERT_TRUE(!std::memcmp(msg2, message, 5)); + } else { + ASSERT_TRUE(!std::memcmp(msg3, message, 5)); + } + num_messages += 1; + if (num_messages < 3) { + reader->ProcessMessages(); + } + }; - auto writer = LocalClientConnection::Create( - client_handler, noop_handler, std::move(in_), "writer", {}, error_message_type_); + auto writer = ClientConnection::Create(client_handler, noop_handler, std::move(in_), + "writer", {}, error_message_type_); - reader = LocalClientConnection::Create(client_handler, message_handler, std::move(out_), - "reader", {}, error_message_type_); + reader = ClientConnection::Create(client_handler, message_handler, std::move(out_), + "reader", {}, error_message_type_); std::function callback = [](const ray::Status &status) { RAY_CHECK_OK(status); @@ -144,15 +143,13 @@ TEST_F(ClientConnectionTest, SimpleAsyncWrite) { TEST_F(ClientConnectionTest, SimpleAsyncError) { const uint8_t msg1[5] = {1, 2, 3, 4, 5}; - ClientHandler client_handler = - [](LocalClientConnection &client) {}; + ClientHandler client_handler = [](ClientConnection &client) {}; - MessageHandler noop_handler = - [](std::shared_ptr client, int64_t message_type, - const uint8_t *message) {}; + MessageHandler noop_handler = [](std::shared_ptr client, + int64_t message_type, const uint8_t *message) {}; - auto writer = LocalClientConnection::Create( - client_handler, noop_handler, std::move(in_), "writer", {}, error_message_type_); + auto writer = ClientConnection::Create(client_handler, noop_handler, std::move(in_), + "writer", {}, error_message_type_); std::function callback = [](const ray::Status &status) { ASSERT_TRUE(!status.ok()); @@ -166,15 +163,13 @@ TEST_F(ClientConnectionTest, SimpleAsyncError) { TEST_F(ClientConnectionTest, CallbackWithSharedRefDoesNotLeakConnection) { const uint8_t msg1[5] = {1, 2, 3, 4, 5}; - ClientHandler client_handler = - [](LocalClientConnection &client) {}; + ClientHandler client_handler = [](ClientConnection &client) {}; - MessageHandler noop_handler = - [](std::shared_ptr client, int64_t message_type, - const uint8_t *message) {}; + MessageHandler noop_handler = [](std::shared_ptr client, + int64_t message_type, const uint8_t *message) {}; - auto writer = LocalClientConnection::Create( - client_handler, noop_handler, std::move(in_), "writer", {}, error_message_type_); + auto writer = ClientConnection::Create(client_handler, noop_handler, std::move(in_), + "writer", {}, error_message_type_); std::function callback = [writer](const ray::Status &status) { @@ -189,22 +184,20 @@ TEST_F(ClientConnectionTest, ProcessBadMessage) { const uint8_t arr[5] = {1, 2, 3, 4, 5}; int num_messages = 0; - ClientHandler client_handler = - [](LocalClientConnection &client) {}; + ClientHandler client_handler = [](ClientConnection &client) {}; - MessageHandler message_handler = - [&arr, &num_messages](std::shared_ptr client, - int64_t message_type, const uint8_t *message) { - ASSERT_TRUE(!std::memcmp(arr, message, 5)); - num_messages += 1; - }; + MessageHandler message_handler = [&arr, &num_messages]( + std::shared_ptr client, + int64_t message_type, const uint8_t *message) { + ASSERT_TRUE(!std::memcmp(arr, message, 5)); + num_messages += 1; + }; - auto writer = LocalClientConnection::Create( - client_handler, message_handler, std::move(in_), "writer", {}, error_message_type_); + auto writer = ClientConnection::Create(client_handler, message_handler, std::move(in_), + "writer", {}, error_message_type_); - auto reader = - LocalClientConnection::Create(client_handler, message_handler, std::move(out_), - "reader", {}, error_message_type_); + auto reader = ClientConnection::Create(client_handler, message_handler, std::move(out_), + "reader", {}, error_message_type_); // If client ID is set, bad message would crash the test. // reader->SetClientID(UniqueID::FromRandom()); diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 997b7cd60fad..9a9e46cbd1e7 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -853,7 +853,7 @@ void NodeManager::HandleActorStateTransition(const ActorID &actor_id, } } -void NodeManager::ProcessNewClient(LocalClientConnection &client) { +void NodeManager::ProcessNewClient(ClientConnection &client) { // The new client is a worker, so begin listening for messages. client.ProcessMessages(); } @@ -921,9 +921,9 @@ void NodeManager::DispatchTasks( } } -void NodeManager::ProcessClientMessage( - const std::shared_ptr &client, int64_t message_type, - const uint8_t *message_data) { +void NodeManager::ProcessClientMessage(const std::shared_ptr &client, + int64_t message_type, + const uint8_t *message_data) { auto registered_worker = worker_pool_.GetRegisteredWorker(client); auto message_type_value = static_cast(message_type); RAY_LOG(DEBUG) << "[Worker] Message " @@ -1036,7 +1036,7 @@ void NodeManager::ProcessClientMessage( } void NodeManager::ProcessRegisterClientRequestMessage( - const std::shared_ptr &client, const uint8_t *message_data) { + const std::shared_ptr &client, const uint8_t *message_data) { client->Register(); flatbuffers::FlatBufferBuilder fbb; auto reply = @@ -1137,8 +1137,7 @@ void NodeManager::HandleDisconnectedActor(const ActorID &actor_id, bool was_loca } } -void NodeManager::HandleWorkerAvailable( - const std::shared_ptr &client) { +void NodeManager::HandleWorkerAvailable(const std::shared_ptr &client) { std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); HandleWorkerAvailable(worker); } @@ -1168,7 +1167,7 @@ void NodeManager::HandleWorkerAvailable(const std::shared_ptr &worker) { } void NodeManager::ProcessDisconnectClientMessage( - const std::shared_ptr &client, bool intentional_disconnect) { + const std::shared_ptr &client, bool intentional_disconnect) { std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); bool is_worker = false, is_driver = false; if (worker) { @@ -1301,7 +1300,7 @@ void NodeManager::ProcessDisconnectClientMessage( } void NodeManager::ProcessFetchOrReconstructMessage( - const std::shared_ptr &client, const uint8_t *message_data) { + const std::shared_ptr &client, const uint8_t *message_data) { auto message = flatbuffers::GetRoot(message_data); std::vector required_object_ids; for (int64_t i = 0; i < message->object_ids()->size(); ++i) { @@ -1330,7 +1329,7 @@ void NodeManager::ProcessFetchOrReconstructMessage( } void NodeManager::ProcessWaitRequestMessage( - const std::shared_ptr &client, const uint8_t *message_data) { + const std::shared_ptr &client, const uint8_t *message_data) { // Read the data. auto message = flatbuffers::GetRoot(message_data); std::vector object_ids = from_flatbuf(*message->object_ids()); @@ -1386,7 +1385,7 @@ void NodeManager::ProcessWaitRequestMessage( } void NodeManager::ProcessWaitForDirectActorCallArgsRequestMessage( - const std::shared_ptr &client, const uint8_t *message_data) { + const std::shared_ptr &client, const uint8_t *message_data) { // Read the data. auto message = flatbuffers::GetRoot(message_data); @@ -1428,7 +1427,7 @@ void NodeManager::ProcessPushErrorRequestMessage(const uint8_t *message_data) { } void NodeManager::ProcessPrepareActorCheckpointRequest( - const std::shared_ptr &client, const uint8_t *message_data) { + const std::shared_ptr &client, const uint8_t *message_data) { auto message = flatbuffers::GetRoot(message_data); ActorID actor_id = from_flatbuf(*message->actor_id()); @@ -1749,7 +1748,7 @@ void NodeManager::HandleForwardTask(const rpc::ForwardTaskRequest &request, } void NodeManager::ProcessSetResourceRequest( - const std::shared_ptr &client, const uint8_t *message_data) { + const std::shared_ptr &client, const uint8_t *message_data) { // Read the SetResource message auto message = flatbuffers::GetRoot(message_data); @@ -2186,10 +2185,10 @@ void NodeManager::HandleDirectCallTaskUnblocked(const std::shared_ptr &w worker->MarkUnblocked(); } -void NodeManager::AsyncResolveObjects( - const std::shared_ptr &client, - const std::vector &required_object_ids, const TaskID ¤t_task_id, - bool ray_get, bool mark_worker_blocked) { +void NodeManager::AsyncResolveObjects(const std::shared_ptr &client, + const std::vector &required_object_ids, + const TaskID ¤t_task_id, bool ray_get, + bool mark_worker_blocked) { std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); if (worker) { // The client is a worker. If the worker is not already blocked and the @@ -2242,7 +2241,7 @@ void NodeManager::AsyncResolveObjects( } void NodeManager::AsyncResolveObjectsFinish( - const std::shared_ptr &client, const TaskID ¤t_task_id, + const std::shared_ptr &client, const TaskID ¤t_task_id, bool was_blocked) { std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); @@ -3055,7 +3054,7 @@ void NodeManager::FinishAssignTask(const std::shared_ptr &worker, } void NodeManager::ProcessSubscribePlasmaReady( - const std::shared_ptr &client, const uint8_t *message_data) { + const std::shared_ptr &client, const uint8_t *message_data) { std::shared_ptr associated_worker = worker_pool_.GetRegisteredWorker(client); if (associated_worker == nullptr) { associated_worker = worker_pool_.GetRegisteredDriver(client); diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index 74a1fecd518f..13c71036231f 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -103,7 +103,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// /// \param client The client to process. /// \return Void. - void ProcessNewClient(LocalClientConnection &client); + void ProcessNewClient(ClientConnection &client); /// Process a message from a client. This method is responsible for /// explicitly listening for more messages from the client if the client is @@ -113,7 +113,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// \param message_type The message type (e.g., a flatbuffer enum). /// \param message_data A pointer to the message data. /// \return Void. - void ProcessClientMessage(const std::shared_ptr &client, + void ProcessClientMessage(const std::shared_ptr &client, int64_t message_type, const uint8_t *message_data); /// Subscribe to the relevant GCS tables and set up handlers. @@ -355,7 +355,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// \param mark_worker_blocked Whether to mark the worker as blocked. This /// should be False for direct calls. /// \return Void. - void AsyncResolveObjects(const std::shared_ptr &client, + void AsyncResolveObjects(const std::shared_ptr &client, const std::vector &required_object_ids, const TaskID ¤t_task_id, bool ray_get, bool mark_worker_blocked); @@ -371,7 +371,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// \param worker_was_blocked Whether we previously marked the worker as /// blocked in AsyncResolveObjects(). /// \return Void. - void AsyncResolveObjectsFinish(const std::shared_ptr &client, + void AsyncResolveObjectsFinish(const std::shared_ptr &client, const TaskID ¤t_task_id, bool was_blocked); /// Handle a direct call task that is blocked. Note that this callback may @@ -451,13 +451,13 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// \param message_data A pointer to the message data. /// \return Void. void ProcessRegisterClientRequestMessage( - const std::shared_ptr &client, const uint8_t *message_data); + const std::shared_ptr &client, const uint8_t *message_data); /// Handle the case that a worker is available. /// /// \param client The connection for the worker. /// \return Void. - void HandleWorkerAvailable(const std::shared_ptr &client); + void HandleWorkerAvailable(const std::shared_ptr &client); /// Handle the case that a worker is available. /// @@ -473,24 +473,23 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// \param client The client that sent the message. /// \param intentional_disconnect Whether the client was intentionally disconnected. /// \return Void. - void ProcessDisconnectClientMessage( - const std::shared_ptr &client, - bool intentional_disconnect = false); + void ProcessDisconnectClientMessage(const std::shared_ptr &client, + bool intentional_disconnect = false); /// Process client message of FetchOrReconstruct /// /// \param client The client that sent the message. /// \param message_data A pointer to the message data. /// \return Void. - void ProcessFetchOrReconstructMessage( - const std::shared_ptr &client, const uint8_t *message_data); + void ProcessFetchOrReconstructMessage(const std::shared_ptr &client, + const uint8_t *message_data); /// Process client message of WaitRequest /// /// \param client The client that sent the message. /// \param message_data A pointer to the message data. /// \return Void. - void ProcessWaitRequestMessage(const std::shared_ptr &client, + void ProcessWaitRequestMessage(const std::shared_ptr &client, const uint8_t *message_data); /// Process client message of WaitForDirectActorCallArgsRequest @@ -499,7 +498,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// \param message_data A pointer to the message data. /// \return Void. void ProcessWaitForDirectActorCallArgsRequestMessage( - const std::shared_ptr &client, const uint8_t *message_data); + const std::shared_ptr &client, const uint8_t *message_data); /// Process client message of PushErrorRequest /// @@ -512,7 +511,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// \param client The client that sent the message. /// \param message_data A pointer to the message data. void ProcessPrepareActorCheckpointRequest( - const std::shared_ptr &client, const uint8_t *message_data); + const std::shared_ptr &client, const uint8_t *message_data); /// Process client message of NotifyActorResumedFromCheckpoint. /// @@ -530,7 +529,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// \param client The client that sent the message. /// \param message_data A pointer to the message data. /// \return Void. - void ProcessSetResourceRequest(const std::shared_ptr &client, + void ProcessSetResourceRequest(const std::shared_ptr &client, const uint8_t *message_data); /// Handle the case where an actor is disconnected, determine whether this @@ -559,7 +558,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// \param client The client that sent the message. /// \param message_data A pointer to the message data. /// \return void. - void ProcessSubscribePlasmaReady(const std::shared_ptr &client, + void ProcessSubscribePlasmaReady(const std::shared_ptr &client, const uint8_t *message_data); /// Setup callback with Object Manager. diff --git a/src/ray/raylet/raylet.cc b/src/ray/raylet/raylet.cc index 2cd24c0e000d..9d33a10fba87 100644 --- a/src/ray/raylet/raylet.cc +++ b/src/ray/raylet/raylet.cc @@ -68,13 +68,7 @@ Raylet::Raylet(boost::asio::io_service &main_service, const std::string &socket_ node_manager_(main_service, self_node_id_, node_manager_config, object_manager_, gcs_client_, object_directory_), socket_name_(socket_name), - acceptor_(main_service, -#if defined(BOOST_ASIO_HAS_LOCAL_SOCKETS) - local_stream_protocol::endpoint(socket_name) -#else - parse_ip_tcp_endpoint(socket_name) -#endif - ), + acceptor_(main_service, parse_url_endpoint(socket_name)), socket_(main_service) { self_node_info_.set_node_id(self_node_id_.Binary()); self_node_info_.set_state(GcsNodeInfo::ALIVE); @@ -134,15 +128,16 @@ void Raylet::DoAccept() { void Raylet::HandleAccept(const boost::system::error_code &error) { if (!error) { // TODO: typedef these handlers. - ClientHandler client_handler = - [this](LocalClientConnection &client) { node_manager_.ProcessNewClient(client); }; - MessageHandler message_handler = - [this](std::shared_ptr client, int64_t message_type, - const uint8_t *message) { - node_manager_.ProcessClientMessage(client, message_type, message); - }; + ClientHandler client_handler = [this](ClientConnection &client) { + node_manager_.ProcessNewClient(client); + }; + MessageHandler message_handler = [this](std::shared_ptr client, + int64_t message_type, + const uint8_t *message) { + node_manager_.ProcessClientMessage(client, message_type, message); + }; // Accept a new local client and dispatch it to the node manager. - auto new_connection = LocalClientConnection::Create( + auto new_connection = ClientConnection::Create( client_handler, message_handler, std::move(socket_), "worker", node_manager_message_enum, static_cast(protocol::MessageType::DisconnectClient)); diff --git a/src/ray/raylet/raylet.h b/src/ray/raylet/raylet.h index da3f55e25f18..53bbff0532d2 100644 --- a/src/ray/raylet/raylet.h +++ b/src/ray/raylet/raylet.h @@ -97,7 +97,7 @@ class Raylet { /// An acceptor for new clients. boost::asio::basic_socket_acceptor acceptor_; /// The socket to listen on for new clients. - local_stream_protocol::socket socket_; + local_stream_socket socket_; }; } // namespace raylet diff --git a/src/ray/raylet/raylet_client.cc b/src/ray/raylet/raylet_client.cc index 1f670ea7a7f0..3a66cc852802 100644 --- a/src/ray/raylet/raylet_client.cc +++ b/src/ray/raylet/raylet_client.cc @@ -34,20 +34,20 @@ using MessageType = ray::protocol::MessageType; -static int read_bytes(local_stream_protocol::socket &conn, void *cursor, size_t length) { +namespace ray { + +static int read_bytes(local_stream_socket &conn, void *cursor, size_t length) { boost::system::error_code ec; size_t nread = boost::asio::read(conn, boost::asio::buffer(cursor, length), ec); return nread == length ? 0 : -1; } -static int write_bytes(local_stream_protocol::socket &conn, void *cursor, size_t length) { +static int write_bytes(local_stream_socket &conn, void *cursor, size_t length) { boost::system::error_code ec; size_t nread = boost::asio::write(conn, boost::asio::buffer(cursor, length), ec); return nread == length ? 0 : -1; } -namespace ray { - raylet::RayletConnection::RayletConnection(boost::asio::io_service &io_service, const std::string &raylet_socket, int num_retries, int64_t timeout) @@ -62,12 +62,7 @@ raylet::RayletConnection::RayletConnection(boost::asio::io_service &io_service, RAY_CHECK(!raylet_socket.empty()); boost::system::error_code ec; for (int num_attempts = 0; num_attempts < num_retries; ++num_attempts) { -#if defined(BOOST_ASIO_HAS_LOCAL_SOCKETS) - local_stream_protocol::endpoint endpoint(raylet_socket); -#else - local_stream_protocol::endpoint endpoint = parse_ip_tcp_endpoint(raylet_socket); -#endif - if (!conn_.connect(endpoint, ec)) { + if (!conn_.connect(parse_url_endpoint(raylet_socket), ec)) { break; } if (num_attempts > 0) { diff --git a/src/ray/raylet/raylet_client.h b/src/ray/raylet/raylet_client.h index 7f2a4660e498..822956f7d81f 100644 --- a/src/ray/raylet/raylet_client.h +++ b/src/ray/raylet/raylet_client.h @@ -43,16 +43,11 @@ using ResourceMappingType = std::unordered_map>>; using WaitResultPair = std::pair, std::vector>; -typedef -#if defined(BOOST_ASIO_HAS_LOCAL_SOCKETS) - boost::asio::local::stream_protocol -#else - boost::asio::ip::tcp -#endif - local_stream_protocol; - namespace ray { +typedef boost::asio::generic::stream_protocol local_stream_protocol; +typedef boost::asio::basic_stream_socket local_stream_socket; + /// Interface for leasing workers. Abstract for testing. class WorkerLeaseInterface { public: @@ -123,7 +118,7 @@ class RayletConnection { private: /// The Unix domain socket that connects to raylet. - local_stream_protocol::socket conn_; + local_stream_socket conn_; /// A mutex to protect stateful operations of the raylet client. std::mutex mutex_; /// A mutex to protect write operations of the raylet client. diff --git a/src/ray/raylet/worker.cc b/src/ray/raylet/worker.cc index d705b4f51caa..6d92c31c70bd 100644 --- a/src/ray/raylet/worker.cc +++ b/src/ray/raylet/worker.cc @@ -27,7 +27,7 @@ namespace raylet { /// A constructor responsible for initializing the state of a worker. Worker::Worker(const WorkerID &worker_id, const Language &language, int port, - std::shared_ptr connection, + std::shared_ptr connection, rpc::ClientCallManager &client_call_manager) : worker_id_(worker_id), language_(language), @@ -104,9 +104,7 @@ void Worker::MarkDetachedActor() { is_detached_actor_ = true; } bool Worker::IsDetachedActor() const { return is_detached_actor_; } -const std::shared_ptr Worker::Connection() const { - return connection_; -} +const std::shared_ptr Worker::Connection() const { return connection_; } void Worker::SetOwnerAddress(const rpc::Address &address) { owner_address_ = address; } const rpc::Address &Worker::GetOwnerAddress() const { return owner_address_; } diff --git a/src/ray/raylet/worker.h b/src/ray/raylet/worker.h index ee0a8a1eca93..9700cd9d9a90 100644 --- a/src/ray/raylet/worker.h +++ b/src/ray/raylet/worker.h @@ -37,7 +37,7 @@ class Worker { /// A constructor that initializes a worker object. /// NOTE: You MUST manually set the worker process. Worker(const WorkerID &worker_id, const Language &language, int port, - std::shared_ptr connection, + std::shared_ptr connection, rpc::ClientCallManager &client_call_manager); /// A destructor responsible for freeing all worker state. ~Worker() {} @@ -64,7 +64,7 @@ class Worker { const ActorID &GetActorId() const; void MarkDetachedActor(); bool IsDetachedActor() const; - const std::shared_ptr Connection() const; + const std::shared_ptr Connection() const; void SetOwnerAddress(const rpc::Address &address); const rpc::Address &GetOwnerAddress() const; @@ -104,7 +104,7 @@ class Worker { /// If port <= 0, this indicates that the worker will not listen to a port. int port_; /// Connection state of a worker. - std::shared_ptr connection_; + std::shared_ptr connection_; /// The worker's currently assigned task. TaskID assigned_task_id_; /// Job ID for the worker's current assigned task. diff --git a/src/ray/raylet/worker_pool.cc b/src/ray/raylet/worker_pool.cc index 8acf91c85697..88a52678b68e 100644 --- a/src/ray/raylet/worker_pool.cc +++ b/src/ray/raylet/worker_pool.cc @@ -30,7 +30,7 @@ namespace { // A helper function to get a worker from a list. std::shared_ptr GetWorker( const std::unordered_set> &worker_pool, - const std::shared_ptr &connection) { + const std::shared_ptr &connection) { for (auto it = worker_pool.begin(); it != worker_pool.end(); it++) { if ((*it)->Connection() == connection) { return (*it); @@ -317,7 +317,7 @@ Status WorkerPool::RegisterDriver(const std::shared_ptr &driver) { } std::shared_ptr WorkerPool::GetRegisteredWorker( - const std::shared_ptr &connection) const { + const std::shared_ptr &connection) const { for (const auto &entry : states_by_lang_) { auto worker = GetWorker(entry.second.registered_workers, connection); if (worker != nullptr) { @@ -328,7 +328,7 @@ std::shared_ptr WorkerPool::GetRegisteredWorker( } std::shared_ptr WorkerPool::GetRegisteredDriver( - const std::shared_ptr &connection) const { + const std::shared_ptr &connection) const { for (const auto &entry : states_by_lang_) { auto driver = GetWorker(entry.second.registered_drivers, connection); if (driver != nullptr) { diff --git a/src/ray/raylet/worker_pool.h b/src/ray/raylet/worker_pool.h index ec2a98b20b78..cf24256dc8fd 100644 --- a/src/ray/raylet/worker_pool.h +++ b/src/ray/raylet/worker_pool.h @@ -87,7 +87,7 @@ class WorkerPool { /// \return The Worker that owns the given client connection. Returns nullptr /// if the client has not registered a worker yet. std::shared_ptr GetRegisteredWorker( - const std::shared_ptr &connection) const; + const std::shared_ptr &connection) const; /// Get the client connection's registered driver. /// @@ -95,7 +95,7 @@ class WorkerPool { /// \return The Worker that owns the given client connection. Returns nullptr /// if the client has not registered a driver. std::shared_ptr GetRegisteredDriver( - const std::shared_ptr &connection) const; + const std::shared_ptr &connection) const; /// Disconnect a registered worker. /// diff --git a/src/ray/raylet/worker_pool_test.cc b/src/ray/raylet/worker_pool_test.cc index 7c5a44392595..e48987b61916 100644 --- a/src/ray/raylet/worker_pool_test.cc +++ b/src/ray/raylet/worker_pool_test.cc @@ -103,17 +103,17 @@ class WorkerPoolTest : public ::testing::Test { std::shared_ptr CreateWorker(Process proc, const Language &language = Language::PYTHON) { - std::function client_handler = - [this](LocalClientConnection &client) { HandleNewClient(client); }; - std::function, int64_t, const uint8_t *)> - message_handler = [this](std::shared_ptr client, + std::function client_handler = + [this](ClientConnection &client) { HandleNewClient(client); }; + std::function, int64_t, const uint8_t *)> + message_handler = [this](std::shared_ptr client, int64_t message_type, const uint8_t *message) { HandleMessage(client, message_type, message); }; - local_stream_protocol::socket socket(io_service_); + local_stream_socket socket(io_service_); auto client = - LocalClientConnection::Create(client_handler, message_handler, std::move(socket), - "worker", {}, error_message_type_); + ClientConnection::Create(client_handler, message_handler, std::move(socket), + "worker", {}, error_message_type_); std::shared_ptr worker = std::make_shared( WorkerID::FromRandom(), language, -1, client, client_call_manager_); if (!proc.IsNull()) { @@ -162,8 +162,8 @@ class WorkerPoolTest : public ::testing::Test { rpc::ClientCallManager client_call_manager_; private: - void HandleNewClient(LocalClientConnection &){}; - void HandleMessage(std::shared_ptr, int64_t, const uint8_t *){}; + void HandleNewClient(ClientConnection &){}; + void HandleMessage(std::shared_ptr, int64_t, const uint8_t *){}; }; static inline TaskSpecification ExampleTaskSpec( diff --git a/src/ray/util/url.cc b/src/ray/util/url.cc index 1f5e2e58a8c9..6c5f063fe2c5 100644 --- a/src/ray/util/url.cc +++ b/src/ray/util/url.cc @@ -16,64 +16,125 @@ #include +#ifdef _WIN32 +#include +#else +#include +#endif + #include -#include -#include +#include #include +#include +#ifndef _WIN32 +#include +#endif +#include + +#include "ray/util/filesystem.h" #include "ray/util/logging.h" +#include "ray/util/util.h" -boost::asio::ip::tcp::endpoint parse_ip_tcp_endpoint(const std::string &endpoint, - int default_port) { - const std::string scheme_sep = "://"; - size_t scheme_begin = 0, scheme_end = endpoint.find(scheme_sep, scheme_begin); - size_t host_begin; - if (scheme_end < endpoint.size()) { - host_begin = scheme_end + scheme_sep.size(); - } else { - scheme_end = scheme_begin; - host_begin = scheme_end; +namespace ray { + +/// Uses sscanf() to read a token matching from the string, advancing the iterator. +/// \param c_str A string iterator that is dereferenceable. (i.e.: c_str < string::end()) +/// \param format The pattern. It must not produce any output. (e.g., use %*d, not %d.) +/// \return The scanned prefix of the string, if any. +static std::string ScanToken(std::string::const_iterator &c_str, std::string format) { + int i = 0; + std::string result; + format += "%n"; + if (static_cast(sscanf(&*c_str, format.c_str(), &i)) <= 1) { + result.insert(result.end(), c_str, c_str + i); + c_str += i; + } + return result; +} + +std::string endpoint_to_url( + const boost::asio::generic::basic_endpoint &ep, + bool include_scheme) { + std::string result, scheme; + switch (ep.protocol().family()) { + case AF_INET: { + scheme = "tcp://"; + boost::asio::ip::tcp::endpoint e(boost::asio::ip::tcp::v4(), 0); + RAY_CHECK(e.size() == ep.size()); + const sockaddr *src = ep.data(); + sockaddr *dst = e.data(); + *reinterpret_cast(dst) = *reinterpret_cast(src); + std::ostringstream ss; + ss << e; + result = ss.str(); + break; + } + case AF_INET6: { + scheme = "tcp://"; + boost::asio::ip::tcp::endpoint e(boost::asio::ip::tcp::v6(), 0); + RAY_CHECK(e.size() == ep.size()); + const sockaddr *src = ep.data(); + sockaddr *dst = e.data(); + *reinterpret_cast(dst) = *reinterpret_cast(src); + std::ostringstream ss; + ss << e; + result = ss.str(); + break; + } + case AF_UNIX: + scheme = "unix://"; +#ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS + result.append(reinterpret_cast(ep.data())->sun_path, + ep.size() - offsetof(sockaddr_un, sun_path)); +#else + RAY_LOG(FATAL) << "UNIX-domain socket endpoints are not supported"; +#endif + break; + default: + RAY_LOG(FATAL) << "unsupported protocol family: " << ep.protocol().family(); + break; + } + if (include_scheme) { + result.insert(0, scheme); } - std::string scheme = endpoint.substr(scheme_begin, scheme_end - scheme_begin); - RAY_CHECK(scheme_end == host_begin || scheme == "tcp"); - size_t port_end = endpoint.find('/', host_begin); - if (port_end >= endpoint.size()) { - port_end = endpoint.size(); + return result; +} + +boost::asio::generic::basic_endpoint +parse_url_endpoint(const std::string &endpoint, int default_port) { + // Syntax reference: https://en.wikipedia.org/wiki/URL#Syntax + // Note that we're a bit more flexible, to allow parsing "127.0.0.1" as a URL. + boost::asio::generic::stream_protocol::endpoint result; + std::string address = endpoint, scheme; + if (address.find("unix://") == 0) { + scheme = "unix://"; + address.erase(0, scheme.size()); + } else if (address.size() > 0 && ray::IsDirSep(address[0])) { + scheme = "unix://"; + } else if (address.find("tcp://") == 0) { + scheme = "tcp://"; + address.erase(0, scheme.size()); + } else { + scheme = "tcp://"; } - size_t host_end, port_begin; - if (endpoint.find('[', host_begin) == host_begin) { - // IPv6 with brackets (optional ports) - ++host_begin; - host_end = endpoint.find(']', host_begin); - if (host_end < port_end) { - port_begin = endpoint.find(':', host_end + 1); - if (port_begin < port_end) { - ++port_begin; - } else { - port_begin = port_end; - } - } else { - host_end = port_end; - port_begin = host_end; - } - } else if (std::count(endpoint.begin() + static_cast(host_begin), - endpoint.begin() + static_cast(port_end), ':') > 1) { - // IPv6 without brackets (no ports) - port_begin = port_end; - host_end = port_begin; + if (scheme == "unix://") { +#ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS + result = boost::asio::local::stream_protocol::endpoint(address); +#else + RAY_LOG(FATAL) << "UNIX-domain socket endpoints are not supported: " << endpoint; +#endif + } else if (scheme == "tcp://") { + std::string::const_iterator i = address.begin(); + std::string host = ScanToken(i, "[%*[^][/]]"); + host = host.empty() ? ScanToken(i, "%*[^/:]") : host.substr(1, host.size() - 2); + std::string port_str = ScanToken(i, ":%*d"); + int port = port_str.empty() ? default_port : std::stoi(port_str.substr(1)); + result = boost::asio::ip::tcp::endpoint(boost::asio::ip::make_address(host), port); } else { - // IPv4 - host_end = endpoint.find(':', host_begin); - if (host_end < port_end) { - port_begin = host_end + 1; - } else { - host_end = port_end; - port_begin = host_end; - } + RAY_LOG(FATAL) << "Unable to parse socket endpoint: " << endpoint; } - std::string host = endpoint.substr(host_begin, host_end - host_begin); - std::string port_str = endpoint.substr(port_begin, port_end - port_begin); - boost::asio::ip::address address = boost::asio::ip::make_address(host); - int port = port_str.empty() ? default_port : atoi(port_str.c_str()); - return boost::asio::ip::tcp::endpoint(address, port); + return result; } + +} // namespace ray diff --git a/src/ray/util/url.h b/src/ray/util/url.h index 814924662736..b4e5c3845560 100644 --- a/src/ray/util/url.h +++ b/src/ray/util/url.h @@ -15,10 +15,25 @@ #ifndef RAY_UTIL_URL_H #define RAY_UTIL_URL_H -#include +#include +#include -// Parses the endpoint (host + port number) of a URL. -boost::asio::ip::tcp::endpoint parse_ip_tcp_endpoint(const std::string &endpoint, - int default_port = 0); +namespace ray { + +/// Converts the given endpoint (such as TCP or UNIX domain socket address) to a string. +/// \param include_scheme Whether to include the scheme prefix (such as tcp://). +/// This is recommended to avoid later ambiguity when parsing. +std::string endpoint_to_url( + const boost::asio::generic::basic_endpoint &ep, + bool include_scheme = true); + +/// Parses the endpoint socket address of a URL. +/// If a scheme:// prefix is absent, the address family is guessed automatically. +/// For TCP/IP, the endpoint comprises the IP address and port number in the URL. +/// For UNIX domain sockets, the endpoint comprises the socket path. +boost::asio::generic::basic_endpoint +parse_url_endpoint(const std::string &endpoint, int default_port = 0); + +} // namespace ray #endif diff --git a/src/ray/util/url_test.cc b/src/ray/util/url_test.cc index 931bf811fdb0..fbadaad98028 100644 --- a/src/ray/util/url_test.cc +++ b/src/ray/util/url_test.cc @@ -14,40 +14,36 @@ #include "ray/util/url.h" -#include - #include "gtest/gtest.h" namespace ray { template -static std::string to_str(const T &obj) { - std::ostringstream ss; - ss << obj; - return ss.str(); +static std::string to_str(const T &obj, bool include_scheme) { + return endpoint_to_url(obj, include_scheme); } TEST(UrlTest, UrlIpTcpParseTest) { - ASSERT_TRUE(to_str(parse_ip_tcp_endpoint("tcp://[::1]:1/", 0)) == "[::1]:1"); - ASSERT_TRUE(to_str(parse_ip_tcp_endpoint("tcp://[::1]/", 0)) == "[::1]:0"); - ASSERT_TRUE(to_str(parse_ip_tcp_endpoint("tcp://[::1]:1", 0)) == "[::1]:1"); - ASSERT_TRUE(to_str(parse_ip_tcp_endpoint("tcp://[::1]", 0)) == "[::1]:0"); - ASSERT_TRUE(to_str(parse_ip_tcp_endpoint("tcp://::1/", 0)) == "[::1]:0"); - ASSERT_TRUE(to_str(parse_ip_tcp_endpoint("tcp://::1", 0)) == "[::1]:0"); - ASSERT_TRUE(to_str(parse_ip_tcp_endpoint("tcp://127.0.0.1:1/", 0)) == "127.0.0.1:1"); - ASSERT_TRUE(to_str(parse_ip_tcp_endpoint("tcp://127.0.0.1/", 0)) == "127.0.0.1:0"); - ASSERT_TRUE(to_str(parse_ip_tcp_endpoint("tcp://127.0.0.1:1", 0)) == "127.0.0.1:1"); - ASSERT_TRUE(to_str(parse_ip_tcp_endpoint("tcp://127.0.0.1", 0)) == "127.0.0.1:0"); - ASSERT_TRUE(to_str(parse_ip_tcp_endpoint("[::1]:1/", 0)) == "[::1]:1"); - ASSERT_TRUE(to_str(parse_ip_tcp_endpoint("[::1]/", 0)) == "[::1]:0"); - ASSERT_TRUE(to_str(parse_ip_tcp_endpoint("[::1]:1", 0)) == "[::1]:1"); - ASSERT_TRUE(to_str(parse_ip_tcp_endpoint("[::1]", 0)) == "[::1]:0"); - ASSERT_TRUE(to_str(parse_ip_tcp_endpoint("::1/", 0)) == "[::1]:0"); - ASSERT_TRUE(to_str(parse_ip_tcp_endpoint("::1", 0)) == "[::1]:0"); - ASSERT_TRUE(to_str(parse_ip_tcp_endpoint("127.0.0.1:1/", 0)) == "127.0.0.1:1"); - ASSERT_TRUE(to_str(parse_ip_tcp_endpoint("127.0.0.1/", 0)) == "127.0.0.1:0"); - ASSERT_TRUE(to_str(parse_ip_tcp_endpoint("127.0.0.1:1", 0)) == "127.0.0.1:1"); - ASSERT_TRUE(to_str(parse_ip_tcp_endpoint("127.0.0.1", 0)) == "127.0.0.1:0"); + ASSERT_EQ(to_str(parse_url_endpoint("tcp://[::1]:1/", 0), false), "[::1]:1"); + ASSERT_EQ(to_str(parse_url_endpoint("tcp://[::1]/", 0), false), "[::1]:0"); + ASSERT_EQ(to_str(parse_url_endpoint("tcp://[::1]:1", 0), false), "[::1]:1"); + ASSERT_EQ(to_str(parse_url_endpoint("tcp://[::1]", 0), false), "[::1]:0"); + ASSERT_EQ(to_str(parse_url_endpoint("tcp://127.0.0.1:1/", 0), false), "127.0.0.1:1"); + ASSERT_EQ(to_str(parse_url_endpoint("tcp://127.0.0.1/", 0), false), "127.0.0.1:0"); + ASSERT_EQ(to_str(parse_url_endpoint("tcp://127.0.0.1:1", 0), false), "127.0.0.1:1"); + ASSERT_EQ(to_str(parse_url_endpoint("tcp://127.0.0.1", 0), false), "127.0.0.1:0"); + ASSERT_EQ(to_str(parse_url_endpoint("[::1]:1/", 0), false), "[::1]:1"); + ASSERT_EQ(to_str(parse_url_endpoint("[::1]/", 0), false), "[::1]:0"); + ASSERT_EQ(to_str(parse_url_endpoint("[::1]:1", 0), false), "[::1]:1"); + ASSERT_EQ(to_str(parse_url_endpoint("[::1]", 0), false), "[::1]:0"); + ASSERT_EQ(to_str(parse_url_endpoint("127.0.0.1:1/", 0), false), "127.0.0.1:1"); + ASSERT_EQ(to_str(parse_url_endpoint("127.0.0.1/", 0), false), "127.0.0.1:0"); + ASSERT_EQ(to_str(parse_url_endpoint("127.0.0.1:1", 0), false), "127.0.0.1:1"); + ASSERT_EQ(to_str(parse_url_endpoint("127.0.0.1", 0), false), "127.0.0.1:0"); +#ifndef _WIN32 + ASSERT_EQ(to_str(parse_url_endpoint("unix:///tmp/sock"), false), "/tmp/sock"); + ASSERT_EQ(to_str(parse_url_endpoint("/tmp/sock"), false), "/tmp/sock"); +#endif } } // namespace ray