Skip to content

Commit

Permalink
style: make headers more rusty
Browse files Browse the repository at this point in the history
  • Loading branch information
p4gefau1t committed Mar 5, 2021
1 parent 195dbac commit ee63c8f
Show file tree
Hide file tree
Showing 9 changed files with 163 additions and 259 deletions.
12 changes: 6 additions & 6 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

34 changes: 14 additions & 20 deletions src/protocol/mod.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
use async_trait::async_trait;
use bytes::{Buf, BufMut, BytesMut};
use bytes::{Buf, BufMut};
use fmt::Debug;
use std::{
fmt::{self, Formatter},
io::{self, Cursor},
net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs},
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs},
str::FromStr,
vec,
};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};

use crate::error::Error;

Expand Down Expand Up @@ -77,9 +77,14 @@ impl FromStr for Address {
}
}
impl Address {
pub const ADDR_TYPE_IPV4: u8 = 1;
pub const ADDR_TYPE_DOMAIN_NAME: u8 = 3;
pub const ADDR_TYPE_IPV6: u8 = 4;
const ADDR_TYPE_IPV4: u8 = 1;
const ADDR_TYPE_DOMAIN_NAME: u8 = 3;
const ADDR_TYPE_IPV6: u8 = 4;

#[inline]
fn new_dummy_address() -> Address {
Address::SocketAddress(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0))
}

#[inline]
fn serialized_len(&self) -> usize {
Expand All @@ -90,7 +95,7 @@ impl Address {
}
}

pub async fn read_from_stream<R>(stream: &mut R) -> Result<Address, Error>
async fn read_from_stream<R>(stream: &mut R) -> Result<Address, Error>
where
R: AsyncRead + Unpin,
{
Expand Down Expand Up @@ -165,7 +170,7 @@ impl Address {
}
}

pub fn read_from_buf(buf: &[u8]) -> io::Result<Self> {
fn read_from_buf(buf: &[u8]) -> io::Result<Self> {
let mut cur = Cursor::new(buf);
if cur.remaining() < 1 + 1 {
return Err(new_error("invalid address buffer"));
Expand Down Expand Up @@ -218,18 +223,7 @@ impl Address {
}
}

#[inline]
pub async fn write_to_stream<W>(&self, writer: &mut W) -> io::Result<()>
where
W: AsyncWrite + Unpin,
{
let mut buf = BytesMut::with_capacity(self.serialized_len());
self.write_to_buf(&mut buf);
writer.write(&buf).await?;
Ok(())
}

pub fn write_to_buf<B: BufMut>(&self, buf: &mut B) {
fn write_to_buf<B: BufMut>(&self, buf: &mut B) {
match self {
Self::SocketAddress(SocketAddr::V4(addr)) => {
buf.put_u8(Self::ADDR_TYPE_IPV4);
Expand Down
10 changes: 5 additions & 5 deletions src/protocol/mux/acceptor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use tokio::{
task::JoinHandle,
};

use super::{Command, MuxHandle, MuxStream, MuxUdpStream, RequestHeader, STREAM_CHANNEL_LEN};
use super::{MuxHandle, MuxStream, MuxUdpStream, RequestHeader, STREAM_CHANNEL_LEN};
use crate::protocol::{AcceptResult, Address, ProxyAcceptor};

#[derive(Deserialize)]
Expand Down Expand Up @@ -79,11 +79,11 @@ impl MuxAcceptor {
let mut stream = mux_handle.accept().await?;
log::debug!("new mux stream {:x} accepted", stream.stream_id());
let header = RequestHeader::read_from(&mut stream).await?;
let result = match header.command {
Command::TcpConnect => {
AcceptResult::Tcp((stream, header.address))
let result = match header {
RequestHeader::TcpConnect(addr) => {
AcceptResult::Tcp((stream, addr))
}
Command::UdpAssociate => {
RequestHeader::UdpAssociate => {
AcceptResult::Udp(MuxUdpStream { inner: stream })
}
};
Expand Down
23 changes: 12 additions & 11 deletions src/protocol/mux/connector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use async_trait::async_trait;
use serde::Deserialize;
use tokio::{sync::Mutex, task::JoinHandle, time::sleep};

use super::{new_key, Command, MuxHandle, MuxStream, MuxUdpStream, RequestHeader};
use super::{new_key, MuxHandle, MuxStream, MuxUdpStream, RequestHeader};
use crate::protocol::{Address, ProxyConnector};

#[derive(Deserialize)]
Expand Down Expand Up @@ -47,6 +47,7 @@ impl<T: ProxyConnector> MuxConnector<T> {
let cleaner_handle = {
let timeout = Duration::from_secs(config.timeout as u64);
let handlers = handlers.clone();
let concurrent = config.concurrent;
tokio::spawn(async move {
loop {
sleep(timeout).await;
Expand All @@ -59,7 +60,12 @@ impl<T: ProxyConnector> MuxConnector<T> {
if num_streams == 0 || closed {
inactive_handle_id.push(*handle_id);
}
log::debug!("handle {:x}: {:x}", *handle_id, num_streams);
log::debug!(
"mux handle {:x}: {}/{}",
*handle_id,
num_streams,
concurrent
);
}
for handle_id in inactive_handle_id.iter() {
handlers.remove(handle_id);
Expand Down Expand Up @@ -111,20 +117,15 @@ impl<T: ProxyConnector> ProxyConnector for MuxConnector<T> {

async fn connect_tcp(&self, addr: &Address) -> io::Result<Self::TS> {
let mut stream = self.spawn_mux_stream().await?;
RequestHeader::new(Command::TcpConnect, addr)
.write_to(&mut stream)
.await?;
let header = RequestHeader::TcpConnect(addr.clone());
header.write_to(&mut stream).await?;
return Ok(stream);
}

async fn connect_udp(&self) -> io::Result<Self::US> {
let mut stream = self.spawn_mux_stream().await?;
RequestHeader::new(
Command::UdpAssociate,
&Address::DomainNameAddress("UDP_CONN".to_string(), 0),
)
.write_to(&mut stream)
.await?;
let header = RequestHeader::UdpAssociate;
header.write_to(&mut stream).await?;
Ok(MuxUdpStream { inner: stream })
}
}
105 changes: 19 additions & 86 deletions src/protocol/mux/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use std::{
task::{Context, Poll},
};

use super::{Address, ProxyTcpStream, ProxyUdpStream, UdpRead, UdpWrite};
use super::{trojan::UdpHeader, Address, ProxyTcpStream, ProxyUdpStream, UdpRead, UdpWrite};
use crate::error::Error;

pub mod acceptor;
Expand All @@ -54,109 +54,42 @@ const STREAM_CHANNEL_LEN: usize = 0x20;
const CMD_TCP_CONNECT: u8 = 0x01;
const CMD_UDP_ASSOCIATE: u8 = 0x03;

#[derive(Clone, Debug, Copy, Eq, PartialEq)]
pub enum Command {
/// CONNECT command (TCP tunnel)
TcpConnect,
/// UDP ASSOCIATE command
enum RequestHeader {
TcpConnect(Address),
UdpAssociate,
}

impl Command {
#[inline]
fn as_u8(self) -> u8 {
match self {
Command::TcpConnect => CMD_TCP_CONNECT,
Command::UdpAssociate => CMD_UDP_ASSOCIATE,
}
}

#[inline]
fn from_u8(code: u8) -> io::Result<Command> {
match code {
CMD_TCP_CONNECT => Ok(Command::TcpConnect),
CMD_UDP_ASSOCIATE => Ok(Command::UdpAssociate),
_ => Err(new_error(format!("invalid request command: {}", code))),
}
}
}

struct RequestHeader {
command: Command,
address: Address,
}

impl RequestHeader {
pub fn new(command: Command, address: &Address) -> Self {
Self {
command,
address: address.clone(),
}
}

pub async fn read_from<R>(stream: &mut R) -> io::Result<Self>
where
R: AsyncRead + Unpin,
{
let mut cmd = [0u8; 1];
stream.read_exact(&mut cmd).await?;
let command = Command::from_u8(cmd[0])?;
let address = Address::read_from_stream(stream).await?;
Ok(Self { command, address })
}

pub async fn write_to<W>(&self, w: &mut W) -> io::Result<()>
where
W: AsyncWrite + Unpin,
{
let cmd = [self.command.as_u8()];
w.write(&cmd).await?;
self.address.write_to_stream(w).await?;
Ok(())
}
}

/// ```plain
/// +------+----------+----------+--------+---------+----------+
/// | ATYP | DST.ADDR | DST.PORT | Length | CRLF | Payload |
/// +------+----------+----------+--------+---------+----------+
/// | 1 | Variable | 2 | 2 | X'0D0A' | Variable |
/// +------+----------+----------+--------+---------+----------+
/// ```
struct UdpHeader {
pub address: Address,
pub payload_len: u16,
}

impl UdpHeader {
pub fn new(address: &Address, payload_len: usize) -> Self {
Self {
address: address.clone(),
payload_len: payload_len as u16,
let addr = Address::read_from_stream(stream).await?;
match cmd[0] {
CMD_TCP_CONNECT => Ok(Self::TcpConnect(addr)),
CMD_UDP_ASSOCIATE => Ok(Self::UdpAssociate),
_ => Err(new_error("invalid cmd")),
}
}

pub async fn read_from<R>(stream: &mut R) -> io::Result<Self>
where
R: AsyncRead + Unpin,
{
let address = Address::read_from_stream(stream).await?;
let mut len_buf = [0u8; 2];
stream.read_exact(&mut len_buf).await?;
let len = ((len_buf[0] as u16) << 8) | (len_buf[1] as u16);
Ok(Self {
address,
payload_len: len,
})
}

pub async fn write_to<W>(&self, w: &mut W) -> io::Result<()>
where
W: AsyncWrite + Unpin,
{
self.address.write_to_stream(w).await?;
let dummy_addr = Address::new_dummy_address();
let (cmd, addr) = match self {
RequestHeader::TcpConnect(addr) => (CMD_TCP_CONNECT, addr),
RequestHeader::UdpAssociate => (CMD_UDP_ASSOCIATE, &dummy_addr),
};
let mut buf = Vec::with_capacity(1 + addr.serialized_len());
let cursor = &mut buf;

cursor.put_u8(cmd);
addr.write_to_buf(cursor);

w.write(&self.payload_len.to_be_bytes()).await?;
w.write(&buf).await?;
Ok(())
}
}
Expand Down
Loading

0 comments on commit ee63c8f

Please sign in to comment.