Skip to content

Commit

Permalink
Use private buffer instead of global buffer (dmlc#511)
Browse files Browse the repository at this point in the history
  • Loading branch information
aksnzhy authored Apr 26, 2019
1 parent 8c79885 commit 99bc3ab
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 17 deletions.
34 changes: 17 additions & 17 deletions src/graph/network.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@ using dgl::runtime::NDArray;
namespace dgl {
namespace network {

static char* SEND_BUFFER = nullptr;
static char* RECV_BUFFER = nullptr;

// Wrapper for Send api
static void SendData(network::Sender* sender,
const char* data,
Expand All @@ -46,12 +43,13 @@ static void RecvData(network::Receiver* receiver,

DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderCreate")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
network::Sender* sender = new network::SocketSender();
try {
SEND_BUFFER = new char[kMaxBufferSize];
char* buffer = new char[kMaxBufferSize];
sender->SetBuffer(buffer);
} catch (const std::bad_alloc&) {
LOG(FATAL) << "Not enough memory for sender buffer: " << kMaxBufferSize;
}
network::Sender* sender = new network::SocketSender();
CommunicatorHandle chandle = static_cast<CommunicatorHandle>(sender);
*rv = chandle;
});
Expand All @@ -61,7 +59,6 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLFinalizeSender")
CommunicatorHandle chandle = args[0];
network::Sender* sender = static_cast<network::Sender*>(chandle);
sender->Finalize();
delete [] SEND_BUFFER;
});

DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderAddReceiver")
Expand Down Expand Up @@ -96,10 +93,11 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendSubgraph")
network::Sender* sender = static_cast<network::Sender*>(chandle);
auto csr = ptr->GetInCSR();
// Write control message
*SEND_BUFFER = CONTROL_NODEFLOW;
char* buffer = sender->GetBuffer();
*buffer = CONTROL_NODEFLOW;
// Serialize nodeflow to data buffer
int64_t data_size = network::SerializeSampledSubgraph(
SEND_BUFFER+sizeof(CONTROL_NODEFLOW),
buffer+sizeof(CONTROL_NODEFLOW),
csr,
node_mapping,
edge_mapping,
Expand All @@ -108,27 +106,29 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendSubgraph")
CHECK_GT(data_size, 0);
data_size += sizeof(CONTROL_NODEFLOW);
// Send msg via network
SendData(sender, SEND_BUFFER, data_size, recv_id);
SendData(sender, buffer, data_size, recv_id);
});

DGL_REGISTER_GLOBAL("network._CAPI_SenderSendEndSignal")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0];
int recv_id = args[1];
network::Sender* sender = static_cast<network::Sender*>(chandle);
*SEND_BUFFER = CONTROL_END_SIGNAL;
char* buffer = sender->GetBuffer();
*buffer = CONTROL_END_SIGNAL;
// Send msg via network
SendData(sender, SEND_BUFFER, sizeof(CONTROL_END_SIGNAL), recv_id);
SendData(sender, buffer, sizeof(CONTROL_END_SIGNAL), recv_id);
});

DGL_REGISTER_GLOBAL("network._CAPI_DGLReceiverCreate")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
network::Receiver* receiver = new network::SocketReceiver();
try {
RECV_BUFFER = new char[kMaxBufferSize];
char* buffer = new char[kMaxBufferSize];
receiver->SetBuffer(buffer);
} catch (const std::bad_alloc&) {
LOG(FATAL) << "Not enough memory for receiver buffer: " << kMaxBufferSize;
}
network::Receiver* receiver = new network::SocketReceiver();
CommunicatorHandle chandle = static_cast<CommunicatorHandle>(receiver);
*rv = chandle;
});
Expand All @@ -138,7 +138,6 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLFinalizeReceiver")
CommunicatorHandle chandle = args[0];
network::Receiver* receiver = static_cast<network::SocketReceiver*>(chandle);
receiver->Finalize();
delete [] RECV_BUFFER;
});

DGL_REGISTER_GLOBAL("network._CAPI_DGLReceiverWait")
Expand All @@ -156,13 +155,14 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvSubgraph")
CommunicatorHandle chandle = args[0];
network::Receiver* receiver = static_cast<network::SocketReceiver*>(chandle);
// Recv data from network
RecvData(receiver, RECV_BUFFER, kMaxBufferSize);
int control = *RECV_BUFFER;
char* buffer = receiver->GetBuffer();
RecvData(receiver, buffer, kMaxBufferSize);
int control = *buffer;
if (control == CONTROL_NODEFLOW) {
NodeFlow* nf = new NodeFlow();
ImmutableGraph::CSR::Ptr csr;
// Deserialize nodeflow from recv_data_buffer
network::DeserializeSampledSubgraph(RECV_BUFFER+sizeof(CONTROL_NODEFLOW),
network::DeserializeSampledSubgraph(buffer+sizeof(CONTROL_NODEFLOW),
&(csr),
&(nf->node_mapping),
&(nf->edge_mapping),
Expand Down
22 changes: 22 additions & 0 deletions src/graph/network/communicator.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,17 @@ class Sender {
* \brief Finalize Sender
*/
virtual void Finalize() = 0;

/*!
* \brief Get data buffer
* \return buffer pointer
*/
virtual char* GetBuffer() = 0;

/*!
* \brief Set data buffer
*/
virtual void SetBuffer(char* buffer) = 0;
};

/*!
Expand Down Expand Up @@ -90,6 +101,17 @@ class Receiver {
* \brief Finalize Receiver
*/
virtual void Finalize() = 0;

/*!
* \brief Get data buffer
* \return buffer pointer
*/
virtual char* GetBuffer() = 0;

/*!
* \brief Set data buffer
*/
virtual void SetBuffer(char* buffer) = 0;
};

} // namespace network
Expand Down
18 changes: 18 additions & 0 deletions src/graph/network/socket_communicator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,15 @@ void SocketSender::Finalize() {
client = nullptr;
}
}
delete buffer_;
}

char* SocketSender::GetBuffer() {
return buffer_;
}

void SocketSender::SetBuffer(char* buffer) {
buffer_ = buffer;
}

bool SocketReceiver::Wait(const char* ip,
Expand Down Expand Up @@ -190,6 +199,15 @@ void SocketReceiver::Finalize() {
socket_[i] = nullptr;
}
}
delete buffer_;
}

char* SocketReceiver::GetBuffer() {
return buffer_;
}

void SocketReceiver::SetBuffer(char* buffer) {
buffer_ = buffer;
}

} // namespace network
Expand Down
32 changes: 32 additions & 0 deletions src/graph/network/socket_communicator.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,17 @@ class SocketSender : public Sender {
*/
void Finalize();

/*!
* \brief Get data buffer
* \return buffer pointer
*/
char* GetBuffer();

/*!
* \brief Set data buffer
*/
void SetBuffer(char* buffer);

private:
/*!
* \brief socket map
Expand All @@ -81,6 +92,11 @@ class SocketSender : public Sender {
* \brief receiver address map
*/
std::unordered_map<int, Addr> receiver_addr_map_;

/*!
* \brief data buffer
*/
char* buffer_;
};

/*!
Expand Down Expand Up @@ -118,6 +134,17 @@ class SocketReceiver : public Receiver {
*/
void Finalize();

/*!
* \brief Get data buffer
* \return buffer pointer
*/
char* GetBuffer();

/*!
* \brief Set data buffer
*/
void SetBuffer(char* buffer);

private:
/*!
* \brief number of sender
Expand All @@ -144,6 +171,11 @@ class SocketReceiver : public Receiver {
*/
MessageQueue* queue_;

/*!
* \brief data buffer
*/
char* buffer_;

/*!
* \brief Process received message in independent threads
* \param socket new accpeted socket
Expand Down

0 comments on commit 99bc3ab

Please sign in to comment.