Skip to content

Commit

Permalink
socks server refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
jgraef committed Jul 10, 2024
1 parent 31f3f74 commit 9caaad7
Show file tree
Hide file tree
Showing 7 changed files with 205 additions and 177 deletions.
58 changes: 47 additions & 11 deletions skunk-cli/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@ use lazy_static::lazy_static;
use semver::Version;
use skunk::{
address::TcpAddress,
connect::{
Connect,
ConnectTcp,
},
protocol::{
http,
tls,
},
proxy::{
fn_proxy,
pcap::{
self,
interface::Interface,
Expand All @@ -25,7 +28,10 @@ use skunk::{
Passthrough,
Proxy,
},
util::CancellationToken,
util::{
error::ResultExt,
CancellationToken,
},
};
use tokio::{
net::TcpStream,
Expand Down Expand Up @@ -151,18 +157,46 @@ impl App {

if args.socks.enabled {
let shutdown = shutdown.clone();

join_set.spawn(async move {
// run the SOCKS server. `proxy` will handle connections. The default
// [`Connect`][skunk::connect::Connect] (i.e.
// [`ConnectTcp`][skunk::connect::ConnectTcp]) is used.
args.socks
.builder()?
.with_graceful_shutdown(shutdown)
.with_proxy(fn_proxy(move |incoming, outgoing| {
proxy(tls.clone(), filter.clone(), incoming, outgoing)
}))
.serve()
.await?;
let mut listener = args.socks.builder()?.listen().await?;

let mut join_set = JoinSet::default();

loop {
let request = tokio::select! {
_ = shutdown.cancelled() => break,
request_res = listener.next() => request_res?,
};

match ConnectTcp.connect(request.destination_address()).await {
Ok(outgoing) => {
let bind_address = outgoing.local_addr().unwrap().into();
let incoming = request.accept(bind_address).await?;
let tls = tls.clone();
let filter = filter.clone();
let shutdown = shutdown.clone();

join_set.spawn(async move {
tokio::select! {
_ = shutdown.cancelled() => {},
result = proxy(tls, filter, incoming, outgoing) => {
let _ = result.log_error();
}
}
});
}
Err(_) => {
request.reject(None);
}
}
}

while let Some(_) = join_set.join_next().await {}

Ok::<(), Error>(())
});
}
Expand Down Expand Up @@ -216,7 +250,9 @@ impl App {
}

// join all tasks
while let Some(()) = join_set.join_next().await.transpose()?.transpose()? {}
while let Some(result) = join_set.join_next().await {
let _ = result.log_error();
}

Ok(())
}
Expand Down
10 changes: 10 additions & 0 deletions skunk/src/address.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use std::{
IpAddr,
Ipv4Addr,
Ipv6Addr,
SocketAddr,
},
ops::RangeInclusive,
str::FromStr,
Expand Down Expand Up @@ -141,6 +142,15 @@ impl<'de> Deserialize<'de> for TcpAddress {
}
}

impl From<SocketAddr> for TcpAddress {
fn from(value: SocketAddr) -> Self {
Self {
host: value.ip().into(),
port: value.port(),
}
}
}

/// A [`HostAddress`] and a port, used for UDP.
#[derive(Clone, Debug, PartialEq, PartialOrd, Eq, Ord, Hash)]
pub struct UdpAddress {
Expand Down
19 changes: 2 additions & 17 deletions skunk/src/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,26 +57,11 @@ pub trait Listen {
fn accept(&self) -> impl Future<Output = Result<Self::Connection, std::io::Error>> + Send;
}

/// Accept connections by listening for connections using Tokio.
pub struct ListenTcp {
listener: TcpListener,
}

impl ListenTcp {
pub fn new(listener: TcpListener) -> Self {
Self { listener }
}

pub fn into_inner(self) -> TcpListener {
self.listener
}
}

impl Listen for ListenTcp {
impl Listen for TcpListener {
type Connection = TcpStream;

async fn accept(&self) -> Result<Self::Connection, std::io::Error> {
let (conn, _) = self.listener.accept().await?;
let (conn, _) = TcpListener::accept(self).await?;
Ok(conn)
}
}
Loading

0 comments on commit 9caaad7

Please sign in to comment.