Skip to content

Commit

Permalink
refactor: introduce per-connection write tasks
Browse files Browse the repository at this point in the history
Signed-off-by: ljedrz <[email protected]>
  • Loading branch information
ljedrz committed Apr 28, 2021
1 parent 0e22921 commit dd546f1
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 47 deletions.
37 changes: 26 additions & 11 deletions network/src/inbound/inbound.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use snarkvm_objects::Storage;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{TcpListener, TcpStream},
sync::mpsc::channel,
task,
};

Expand Down Expand Up @@ -104,27 +105,43 @@ impl<S: Storage + Send + Sync + 'static> Node<S> {
.await;

match handshake_result {
Ok(Ok((channel, mut reader))) => {
Ok(Ok((mut writer, mut reader, remote_listener))) => {
// Update the remote address to be the peer's listening address.
let remote_address = channel.addr;
// Save the channel under the provided remote address.
node.outbound.channels.write().insert(remote_address, Arc::new(channel));
let remote_address = writer.addr;

// Create a channel dedicated to sending messages to the connection.
let (sender, receiver) = channel(1024);

// Listen for inbound messages.
let node_clone = node.clone();
let peer_listening_task = tokio::spawn(async move {
let peer_reading_task = tokio::spawn(async move {
node_clone.listen_for_inbound_messages(&mut reader).await;
});

// Listen for outbound messages.
let node_clone = node.clone();
let peer_writing_task = tokio::spawn(async move {
node_clone.listen_for_outbound_messages(receiver, &mut writer).await;
});

// Save the channel under the provided remote address.
node.outbound.channels.write().insert(remote_address, sender);

// Finally, mark the peer as connected.
node.peer_book.set_connected(remote_address, Some(remote_listener));

trace!("Connected to {}", remote_address);

// Immediately send a ping to provide the peer with our block height.
node.send_ping(remote_address).await;

if let Ok(ref peer) = node.peer_book.get_peer(remote_address) {
peer.register_task(peer_listening_task);
peer.register_task(peer_reading_task);
peer.register_task(peer_writing_task);
} else {
// If the related peer is not found, it means it's already been dropped.
peer_listening_task.abort();
peer_reading_task.abort();
peer_writing_task.abort();
}
}
Ok(Err(e)) => {
Expand Down Expand Up @@ -318,7 +335,7 @@ impl<S: Storage + Send + Sync + 'static> Node<S> {
listener_address: SocketAddr,
remote_address: SocketAddr,
stream: TcpStream,
) -> Result<(ConnWriter, ConnReader), NetworkError> {
) -> Result<(ConnWriter, ConnReader, SocketAddr), NetworkError> {
self.peer_book.set_connecting(remote_address)?;

let (mut reader, mut writer) = stream.into_split();
Expand Down Expand Up @@ -366,12 +383,10 @@ impl<S: Storage + Send + Sync + 'static> Node<S> {
// the remote listening address
let remote_listener = SocketAddr::from((remote_address.ip(), peer_version.listening_port));

self.peer_book.set_connected(remote_address, Some(remote_listener))?;

let noise = Arc::new(Mutex::new(noise.into_transport_mode()?));
let reader = ConnReader::new(remote_listener, reader, buffer.clone(), Arc::clone(&noise));
let writer = ConnWriter::new(remote_listener, writer, buffer, noise);

Ok((writer, reader))
Ok((writer, reader, remote_listener))
}
}
58 changes: 38 additions & 20 deletions network/src/outbound/outbound.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,14 @@ use snarkvm_objects::Storage;
use std::{
collections::HashMap,
net::SocketAddr,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
sync::atomic::{AtomicU64, Ordering},
};

use parking_lot::RwLock;
use tokio::sync::mpsc::{error::TrySendError, Receiver, Sender};

/// The map of remote addresses to their active write channels.
type Channels = HashMap<SocketAddr, Arc<ConnWriter>>;
type Channels = HashMap<SocketAddr, Sender<Message>>;

/// A core data structure for handling outbound network traffic.
#[derive(Debug, Default)]
Expand All @@ -54,22 +52,24 @@ impl Outbound {
pub async fn send_request(&self, request: Message) {
let target_addr = request.receiver();
// Fetch the outbound channel.
let channel = match self.outbound_channel(target_addr).await {
Ok(channel) => channel,
match self.outbound_channel(target_addr) {
Ok(channel) => match channel.try_send(request) {
Ok(()) => {}
Err(TrySendError::Full(request)) => {
warn!(
"Couldn't send a {} to {}: the send channel is full",
request, target_addr
);
}
Err(TrySendError::Closed(request)) => {
error!(
"Couldn't send a {} to {}: the send channel is closed",
request, target_addr
);
}
},
Err(_) => {
warn!("Failed to send a {}: peer is disconnected", request);
return;
}
};

// Write the request to the outbound channel.
match channel.write_message(&request.payload).await {
Ok(_) => {
self.send_success_count.fetch_add(1, Ordering::SeqCst);
}
Err(error) => {
warn!("Failed to send a {}: {}", request, error);
self.send_failure_count.fetch_add(1, Ordering::SeqCst);
}
}
}
Expand All @@ -78,7 +78,7 @@ impl Outbound {
/// Establishes an outbound channel to the given remote address, if it does not exist.
///
#[inline]
async fn outbound_channel(&self, remote_address: SocketAddr) -> Result<Arc<ConnWriter>, NetworkError> {
fn outbound_channel(&self, remote_address: SocketAddr) -> Result<Sender<Message>, NetworkError> {
Ok(self
.channels
.read()
Expand Down Expand Up @@ -106,4 +106,22 @@ impl<S: Storage + Send + Sync + 'static> Node<S> {
))
.await;
}

/// This method handles new outbound messages to a single connected node.
pub async fn listen_for_outbound_messages(&self, mut receiver: Receiver<Message>, writer: &mut ConnWriter) {
loop {
// Read the next message queued to be sent.
if let Some(message) = receiver.recv().await {
match writer.write_message(&message.payload).await {
Ok(_) => {
self.outbound.send_success_count.fetch_add(1, Ordering::SeqCst);
}
Err(error) => {
warn!("Failed to send a {}: {}", message, error);
self.outbound.send_failure_count.fetch_add(1, Ordering::SeqCst);
}
}
}
}
}
}
6 changes: 2 additions & 4 deletions network/src/peers/peer_book.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ impl PeerBook {
///
/// Adds the given address to the connected peers in the `PeerBook`.
///
pub fn set_connected(&self, address: SocketAddr, listener: Option<SocketAddr>) -> Result<(), NetworkError> {
pub fn set_connected(&self, address: SocketAddr, listener: Option<SocketAddr>) {
// If listener.is_some(), then it's different than the address; otherwise it's just the address param.
let listener = if let Some(addr) = listener { addr } else { address };

Expand All @@ -205,14 +205,12 @@ impl PeerBook {
self.connecting_peers.write().remove(&address);

// Update the peer info to connected.
peer_info.set_connected()?;
peer_info.set_connected();

// Add the address into the connected peers.
let success = self.connected_peers.write().insert(listener, peer_info).is_none();
// On success, increment the connected peer count.
connected_peers_inc!(success);

Ok(())
}

///
Expand Down
6 changes: 2 additions & 4 deletions network/src/peers/peer_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,17 +174,15 @@ impl PeerInfo {
///
/// Updates the peer to connected.
///
pub(crate) fn set_connected(&mut self) -> Result<(), NetworkError> {
pub(crate) fn set_connected(&mut self) {
if self.status() != PeerStatus::Connected {
// Set the state of this peer to connected.
self.status = PeerStatus::Connected;

self.last_connected = Some(Utc::now());
self.connected_count += 1;

Ok(())
} else {
Err(NetworkError::PeerAlreadyConnected)
trace!("Peer {} was set as connected more than once", self.address);
}
}

Expand Down
28 changes: 20 additions & 8 deletions network/src/peers/peers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use rand::seq::IteratorRandom;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::TcpStream,
sync::mpsc::channel,
task,
};

Expand Down Expand Up @@ -208,25 +209,36 @@ impl<S: Storage + Send + Sync + 'static> Node<S> {
trace!("sent s, se, psk (XX handshake part 3/3) to {}", remote_address);

let noise = Arc::new(Mutex::new(noise.into_transport_mode()?));
let writer = ConnWriter::new(remote_address, writer, buffer.clone(), Arc::clone(&noise));
let mut writer = ConnWriter::new(remote_address, writer, buffer.clone(), Arc::clone(&noise));
let mut reader = ConnReader::new(remote_address, reader, buffer, noise);

// save the outbound channel
node.outbound.channels.write().insert(remote_address, Arc::new(writer));

node.peer_book.set_connected(remote_address, None)?;
// Create a channel dedicated to sending messages to the connection.
let (sender, receiver) = channel(1024);

// spawn the inbound loop
let node_clone = node.clone();
let conn_listening_task = tokio::spawn(async move {
let conn_reading_task = tokio::spawn(async move {
node_clone.listen_for_inbound_messages(&mut reader).await;
});

// Listen for outbound messages.
let node_clone = node.clone();
let conn_writing_task = tokio::spawn(async move {
node_clone.listen_for_outbound_messages(receiver, &mut writer).await;
});

// Save the channel under the provided remote address.
node.outbound.channels.write().insert(remote_address, sender);

node.peer_book.set_connected(remote_address, None);

if let Ok(ref peer) = node.peer_book.get_peer(remote_address) {
peer.register_task(conn_listening_task);
peer.register_task(conn_reading_task);
peer.register_task(conn_writing_task);
} else {
// if the related peer is not found, it means it's already been dropped
conn_listening_task.abort();
conn_reading_task.abort();
conn_writing_task.abort();
}

Ok(())
Expand Down

0 comments on commit dd546f1

Please sign in to comment.