Skip to content

Commit

Permalink
Improve stability by exiting immediately on common errors (ekzhang#2)
Browse files Browse the repository at this point in the history
* Kill connections immediately on missing or close

* Add timeout to initial protocol messages

* Add low-level tracing for JSON messages

* Add timeout to initial TCP connections
  • Loading branch information
ekzhang authored Apr 8, 2022
1 parent c1efefe commit 2d0dcf9
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 29 deletions.
6 changes: 3 additions & 3 deletions src/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use sha2::{Digest, Sha256};
use tokio::io::{AsyncBufRead, AsyncWrite};
use uuid::Uuid;

use crate::shared::{recv_json, send_json, ClientMessage, ServerMessage};
use crate::shared::{recv_json_timeout, send_json, ClientMessage, ServerMessage};

/// Wrapper around a MAC used for authenticating clients that have a secret.
pub struct Authenticator(Hmac<Sha256>);
Expand Down Expand Up @@ -54,7 +54,7 @@ impl Authenticator {
) -> Result<()> {
let challenge = Uuid::new_v4();
send_json(stream, ServerMessage::Challenge(challenge)).await?;
match recv_json(stream, &mut Vec::new()).await? {
match recv_json_timeout(stream).await? {
Some(ClientMessage::Authenticate(tag)) => {
ensure!(self.validate(&challenge, &tag), "invalid secret");
Ok(())
Expand All @@ -68,7 +68,7 @@ impl Authenticator {
&self,
stream: &mut (impl AsyncBufRead + AsyncWrite + Unpin),
) -> Result<()> {
let challenge = match recv_json(stream, &mut Vec::new()).await? {
let challenge = match recv_json_timeout(stream).await? {
Some(ServerMessage::Challenge(challenge)) => challenge,
_ => bail!("expected authentication challenge, but no secret was required"),
};
Expand Down
36 changes: 19 additions & 17 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@
use std::sync::Arc;

use anyhow::{bail, Context, Result};
use tokio::{io::BufReader, net::TcpStream};
use tokio::{io::BufReader, net::TcpStream, time::timeout};
use tracing::{error, info, info_span, warn, Instrument};
use uuid::Uuid;

use crate::auth::Authenticator;
use crate::shared::{proxy, recv_json, send_json, ClientMessage, ServerMessage, CONTROL_PORT};
use crate::shared::{
proxy, recv_json, recv_json_timeout, send_json, ClientMessage, ServerMessage, CONTROL_PORT,
NETWORK_TIMEOUT,
};

/// State structure for the client.
pub struct Client {
Expand All @@ -31,18 +34,15 @@ pub struct Client {
impl Client {
/// Create a new client.
pub async fn new(local_port: u16, to: &str, port: u16, secret: Option<&str>) -> Result<Self> {
let stream = TcpStream::connect((to, CONTROL_PORT))
.await
.with_context(|| format!("could not connect to {to}:{CONTROL_PORT}"))?;
let mut stream = BufReader::new(stream);
let mut stream = BufReader::new(connect_with_timeout(to, CONTROL_PORT).await?);

let auth = secret.map(Authenticator::new);
if let Some(auth) = &auth {
auth.client_handshake(&mut stream).await?;
}

send_json(&mut stream, ClientMessage::Hello(port)).await?;
let remote_port = match recv_json(&mut stream, &mut Vec::new()).await? {
let remote_port = match recv_json_timeout(&mut stream).await? {
Some(ServerMessage::Hello(remote_port)) => remote_port,
Some(ServerMessage::Error(message)) => bail!("server error: {message}"),
Some(ServerMessage::Challenge(_)) => {
Expand Down Expand Up @@ -99,21 +99,23 @@ impl Client {
}

async fn handle_connection(&self, id: Uuid) -> Result<()> {
let local_conn = TcpStream::connect(("localhost", self.local_port))
.await
.context("failed TCP connection to local port")?;
let mut remote_conn = BufReader::new(
TcpStream::connect((&self.to[..], CONTROL_PORT))
.await
.context("failed TCP connection to remote port")?,
);

let mut remote_conn =
BufReader::new(connect_with_timeout(&self.to[..], CONTROL_PORT).await?);
if let Some(auth) = &self.auth {
auth.client_handshake(&mut remote_conn).await?;
}

send_json(&mut remote_conn, ClientMessage::Accept(id)).await?;

let local_conn = connect_with_timeout("localhost", self.local_port).await?;
proxy(local_conn, remote_conn).await?;
Ok(())
}
}

async fn connect_with_timeout(to: &str, port: u16) -> Result<TcpStream> {
match timeout(NETWORK_TIMEOUT, TcpStream::connect((to, port))).await {
Ok(res) => res,
Err(err) => Err(err.into()),
}
.with_context(|| format!("could not connect to {to}:{port}"))
}
9 changes: 4 additions & 5 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ use tracing::{info, info_span, warn, Instrument};
use uuid::Uuid;

use crate::auth::Authenticator;
use crate::shared::{proxy, recv_json, send_json, ClientMessage, ServerMessage, CONTROL_PORT};
use crate::shared::{
proxy, recv_json_timeout, send_json, ClientMessage, ServerMessage, CONTROL_PORT,
};

/// State structure for the server.
pub struct Server {
Expand Down Expand Up @@ -71,10 +73,7 @@ impl Server {
}
}

let mut buf = Vec::new();
let msg = recv_json(&mut stream, &mut buf).await?;

match msg {
match recv_json_timeout(&mut stream).await? {
Some(ClientMessage::Authenticate(_)) => {
warn!("unexpected authenticate");
Ok(())
Expand Down
29 changes: 25 additions & 4 deletions src/shared.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
//! Shared data structures, utilities, and protocol definitions.
use std::time::Duration;

use anyhow::{Context, Result};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use tokio::io::{self, AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::time::timeout;
use tracing::trace;
use uuid::Uuid;

/// TCP port used for control connections with the server.
pub const CONTROL_PORT: u16 = 7835;

/// Timeout for network connections and initial protocol messages.
pub const NETWORK_TIMEOUT: Duration = Duration::from_secs(3);

/// A message from the client on the control connection.
#[derive(Debug, Serialize, Deserialize)]
pub enum ClientMessage {
Expand Down Expand Up @@ -49,10 +56,10 @@ where
{
let (mut s1_read, mut s1_write) = io::split(stream1);
let (mut s2_read, mut s2_write) = io::split(stream2);
tokio::try_join!(
io::copy(&mut s1_read, &mut s2_write),
io::copy(&mut s2_read, &mut s1_write),
)?;
tokio::select! {
res = io::copy(&mut s1_read, &mut s2_write) => res,
res = io::copy(&mut s2_read, &mut s1_write) => res,
}?;
Ok(())
}

Expand All @@ -61,6 +68,7 @@ pub async fn recv_json<T: DeserializeOwned>(
reader: &mut (impl AsyncBufRead + Unpin),
buf: &mut Vec<u8>,
) -> Result<Option<T>> {
trace!("waiting to receive json message");
buf.clear();
reader.read_until(0, buf).await?;
if buf.is_empty() {
Expand All @@ -72,8 +80,21 @@ pub async fn recv_json<T: DeserializeOwned>(
Ok(serde_json::from_slice(buf).context("failed to parse JSON")?)
}

/// Read the next null-delimited JSON instruction, with a default timeout.
///
/// This is useful for parsing the initial message of a stream for handshake or
/// other protocol purposes, where we do not want to wait indefinitely.
pub async fn recv_json_timeout<T: DeserializeOwned>(
reader: &mut (impl AsyncBufRead + Unpin),
) -> Result<Option<T>> {
timeout(NETWORK_TIMEOUT, recv_json(reader, &mut Vec::new()))
.await
.context("timed out waiting for initial message")?
}

/// Send a null-terminated JSON instruction on a stream.
pub async fn send_json<T: Serialize>(writer: &mut (impl AsyncWrite + Unpin), msg: T) -> Result<()> {
trace!("sending json message");
let msg = serde_json::to_vec(&msg)?;
writer.write_all(&msg).await?;
writer.write_all(&[0]).await?;
Expand Down

0 comments on commit 2d0dcf9

Please sign in to comment.