Skip to content

Commit

Permalink
[authority] Authority follower interface (MystenLabs#636)
Browse files Browse the repository at this point in the history
Features:
* Modernize authority server network to use tokio::select
* Added broadcast pair to authority
* Allow AuthorityState users to subscribe to broadcasts
* Moved base structures in base_types
* Modernize network & transport
* Expose streaming interface to authority logic
* Create startup logic for batch subsystem
* Store batch retreival + tests
* Added server logic for handling batch requests and subscriptions

Tests:
* Added server stop test
* Infra to test server

Co-authored-by: George Danezis <[email protected]>
  • Loading branch information
gdanezis and George Danezis authored Mar 8, 2022
1 parent 945dd9d commit 35054cc
Show file tree
Hide file tree
Showing 20 changed files with 1,182 additions and 380 deletions.
2 changes: 2 additions & 0 deletions network_utils/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@ edition = "2021"
[dependencies]
bytes = "1.1.0"
futures = "0.3.21"
async-trait = "0.1.52"
log = "0.4.14"
net2 = "0.2.37"
tokio = { version = "1.17.0", features = ["full"] }
tracing = { version = "0.1.31", features = ["log"] }
tokio-util = { version = "0.7.0", features = ["codec"] }

sui-types = { path = "../sui_types" }

Expand Down
24 changes: 15 additions & 9 deletions network_utils/src/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,21 +41,24 @@ impl NetworkClient {
}
}

async fn send_recv_bytes_internal(&self, buf: Vec<u8>) -> Result<Vec<u8>, io::Error> {
async fn send_recv_bytes_internal(&self, buf: Vec<u8>) -> Result<Option<Vec<u8>>, io::Error> {
let address = format!("{}:{}", self.base_address, self.base_port);
let mut stream = connect(address, self.buffer_size).await?;
// Send message
time::timeout(self.send_timeout, stream.write_data(&buf)).await??;
// Wait for reply
time::timeout(self.recv_timeout, stream.read_data()).await?
time::timeout(self.recv_timeout, async {
stream.read_data().await.transpose()
})
.await?
}

pub async fn send_recv_bytes(&self, buf: Vec<u8>) -> Result<SerializedMessage, SuiError> {
match self.send_recv_bytes_internal(buf).await {
Err(error) => Err(SuiError::ClientIoError {
error: format!("{}", error),
}),
Ok(response) => {
Ok(Some(response)) => {
// Parse reply
match deserialize_message(&response[..]) {
Ok(SerializedMessage::Error(error)) => Err(*error),
Expand All @@ -64,6 +67,9 @@ impl NetworkClient {
// _ => Err(SuiError::UnexpectedMessage),
}
}
Ok(None) => Err(SuiError::ClientIoError {
error: "Empty response from authority.".to_string(),
}),
}
}

Expand Down Expand Up @@ -101,17 +107,17 @@ impl NetworkClient {
info!("In flight {} Remaining {}", in_flight, requests.len());
}
match time::timeout(self.recv_timeout, stream.read_data()).await {
Ok(Ok(buffer)) => {
Ok(Some(Ok(buffer))) => {
in_flight -= 1;
responses.push(Bytes::from(buffer));
}
Ok(Err(error)) => {
if error.kind() == io::ErrorKind::UnexpectedEof {
info!("Socket closed by server");
return Ok(responses);
}
Ok(Some(Err(error))) => {
error!("Received error response: {}", error);
}
Ok(None) => {
info!("Socket closed by server");
return Ok(responses);
}
Err(error) => {
error!(
"Timeout while receiving response: {} (in flight: {})",
Expand Down
191 changes: 68 additions & 123 deletions network_utils/src/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,19 @@
// Copyright (c) 2022, Mysten Labs, Inc.
// SPDX-License-Identifier: Apache-2.0

use futures::future;
use futures::{Sink, SinkExt, Stream, StreamExt};
use std::io::ErrorKind;
use std::{collections::HashMap, convert::TryInto, io, sync::Arc};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use std::sync::Arc;
use tokio::net::TcpSocket;
use tokio::{
io::{AsyncRead, AsyncWrite},
net::{TcpListener, TcpStream},
};
use tokio::net::{TcpListener, TcpStream};

use async_trait::async_trait;

use tracing::*;

use bytes::{Bytes, BytesMut};
use tokio_util::codec::{Framed, LengthDelimitedCodec};

#[cfg(test)]
#[path = "unit_tests/transport_tests.rs"]
mod transport_tests;
Expand All @@ -22,8 +24,28 @@ pub const DEFAULT_MAX_DATAGRAM_SIZE: usize = 65507;
pub const DEFAULT_MAX_DATAGRAM_SIZE_STR: &str = "65507";

/// The handler required to create a service.
pub trait MessageHandler {
fn handle_message<'a>(&'a self, buffer: &'a [u8]) -> future::BoxFuture<'a, Option<Vec<u8>>>;
#[async_trait]
pub trait MessageHandler<A> {
async fn handle_messages(&self, channel: A) -> ();
}

/*
The RwChannel connects the low-level networking code here, that handles
TCP streams, ports, accept/connect, and sockets that provide AsyncRead /
AsyncWrite on byte streams, with the higher level logic in AuthorityServer
that handles sequences of Bytes / BytesMut, as framed messages, through
exposing a standard Stream and Sink trait.
This separation allows us to change the details of the network, transport
and framing, without changing the authority code. It also allows us to test
the authority without using a real network.
*/
pub trait RwChannel<'a> {
type R: 'a + Stream<Item = Result<BytesMut, std::io::Error>> + Unpin + Send;
type W: 'a + Sink<Bytes, Error = std::io::Error> + Unpin + Send;

fn sink(&mut self) -> &mut Self::W;
fn stream(&mut self) -> &mut Self::R;
}

/// The result of spawning a server is oneshot channel to kill it and a handle to track completion.
Expand Down Expand Up @@ -54,19 +76,14 @@ pub async fn connect(
TcpDataStream::connect(address, max_data_size).await
}

/// Create a DataStreamPool for this protocol.
pub async fn make_outgoing_connection_pool() -> Result<TcpDataStreamPool, std::io::Error> {
TcpDataStreamPool::new().await
}

/// Run a server for this protocol and the given message handler.
pub async fn spawn_server<S>(
address: &str,
state: S,
buffer_size: usize,
) -> Result<SpawnedServer, std::io::Error>
where
S: MessageHandler + Send + Sync + 'static,
S: MessageHandler<TcpDataStream> + Send + Sync + 'static,
{
let (complete, receiver) = futures::channel::oneshot::channel();
let handle = {
Expand All @@ -89,8 +106,7 @@ where

/// An implementation of DataStream based on TCP.
pub struct TcpDataStream {
stream: TcpStream,
max_data_size: usize,
framed: Framed<TcpStream, LengthDelimitedCodec>,
}

impl TcpDataStream {
Expand All @@ -103,95 +119,41 @@ impl TcpDataStream {
socket.set_recv_buffer_size(max_data_size as u32)?;

let stream = socket.connect(addr).await?;
Ok(Self {
stream,
max_data_size,
})
Ok(TcpDataStream::from_tcp_stream(stream, max_data_size))
}

async fn tcp_write_data<S>(stream: &mut S, buffer: &[u8]) -> Result<(), std::io::Error>
where
S: AsyncWrite + Unpin,
{
stream
.write_all(&u32::to_le_bytes(
buffer
.len()
.try_into()
.expect("length must not exceed u32::MAX"),
))
.await?;
stream.write_all(buffer).await
}
fn from_tcp_stream(stream: TcpStream, max_data_size: usize) -> TcpDataStream {
let framed = Framed::new(
stream,
LengthDelimitedCodec::builder()
.max_frame_length(max_data_size)
.new_codec(),
);

async fn tcp_read_data<S>(stream: &mut S, max_size: usize) -> Result<Vec<u8>, std::io::Error>
where
S: AsyncRead + Unpin,
{
let mut size_buf = [0u8; 4];
stream.read_exact(&mut size_buf).await?;
let size = u32::from_le_bytes(size_buf);
if size as usize > max_size {
return Err(io::Error::new(
io::ErrorKind::Other,
"Message size exceeds buffer size",
));
}
let mut buf = vec![0u8; size as usize];
stream.read_exact(&mut buf).await?;
Ok(buf)
Self { framed }
}
}

impl TcpDataStream {
// TODO: Eliminate vecs and use Byte, ByteBuf

pub async fn write_data<'a>(&'a mut self, buffer: &'a [u8]) -> Result<(), std::io::Error> {
Self::tcp_write_data(&mut self.stream, buffer).await
self.framed.send(buffer.to_vec().into()).await
}

pub async fn read_data(&mut self) -> Result<Vec<u8>, std::io::Error> {
Self::tcp_read_data(&mut self.stream, self.max_data_size).await
pub async fn read_data(&mut self) -> Option<Result<Vec<u8>, std::io::Error>> {
let result = self.framed.next().await;
result.map(|v| v.map(|w| w.to_vec()))
}
}

/// An implementation of DataStreamPool based on TCP.
pub struct TcpDataStreamPool {
streams: HashMap<String, TcpStream>,
}

impl TcpDataStreamPool {
async fn new() -> Result<Self, std::io::Error> {
let streams = HashMap::new();
Ok(Self { streams })
}
impl<'a> RwChannel<'a> for TcpDataStream {
type W = Framed<TcpStream, LengthDelimitedCodec>;
type R = Framed<TcpStream, LengthDelimitedCodec>;

async fn get_stream(&mut self, address: &str) -> Result<&mut TcpStream, io::Error> {
if !self.streams.contains_key(address) {
match TcpStream::connect(address).await {
Ok(s) => {
self.streams.insert(address.to_string(), s);
}
Err(error) => {
error!("Failed to open connection to {}: {}", address, error);
return Err(error);
}
};
};
Ok(self.streams.get_mut(address).unwrap())
fn sink(&mut self) -> &mut Self::W {
&mut self.framed
}
}

impl TcpDataStreamPool {
pub async fn send_data_to<'a>(
&'a mut self,
buffer: &'a [u8],
address: &'a str,
) -> Result<(), std::io::Error> {
let stream = self.get_stream(address).await?;
let result = TcpDataStream::tcp_write_data(stream, buffer).await;
if result.is_err() {
self.streams.remove(address);
}
result
fn stream(&mut self) -> &mut Self::R {
&mut self.framed
}
}

Expand All @@ -200,44 +162,27 @@ async fn run_tcp_server<S>(
listener: TcpListener,
state: S,
mut exit_future: futures::channel::oneshot::Receiver<()>,
buffer_size: usize,
_buffer_size: usize,
) -> Result<(), std::io::Error>
where
S: MessageHandler + Send + Sync + 'static,
S: MessageHandler<TcpDataStream> + Send + Sync + 'static,
{
let guarded_state = Arc::new(Box::new(state));
let guarded_state = Arc::new(state);
loop {
let (mut stream, _) = match future::select(exit_future, Box::pin(listener.accept())).await {
future::Either::Left(_) => break,
future::Either::Right((value, new_exit_future)) => {
exit_future = new_exit_future;
value?
let stream;

tokio::select! {
_ = &mut exit_future => { break },
result = listener.accept() => {
let (value, _addr) = result?;
stream = value;
}
};
}

let guarded_state = guarded_state.clone();
tokio::spawn(async move {
loop {
let buffer = match TcpDataStream::tcp_read_data(&mut stream, buffer_size).await {
Ok(buffer) => buffer,
Err(err) => {
// We expect some EOF or disconnect error at the end.
if err.kind() != io::ErrorKind::UnexpectedEof
&& err.kind() != io::ErrorKind::ConnectionReset
{
error!("Error while reading TCP stream: {}", err);
}
break;
}
};

if let Some(reply) = guarded_state.handle_message(&buffer[..]).await {
let status = TcpDataStream::tcp_write_data(&mut stream, &reply[..]).await;
if let Err(error) = status {
error!("Failed to send query response: {}", error);
}
};
}
let framed = TcpDataStream::from_tcp_stream(stream, _buffer_size);
guarded_state.handle_messages(framed).await
});
}
Ok(())
Expand Down
Loading

0 comments on commit 35054cc

Please sign in to comment.