Skip to content

Commit

Permalink
refactor: implement WebSocket handler and proxy initialization logic
Browse files Browse the repository at this point in the history
  • Loading branch information
Xerxes-2 committed Nov 23, 2024
1 parent 2f528ee commit f3fd723
Show file tree
Hide file tree
Showing 3 changed files with 226 additions and 181 deletions.
130 changes: 130 additions & 0 deletions src/handler.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
use anyhow::Result;
use bytes::Bytes;
use hudsucker::{
futures::{Sink, SinkExt, Stream, StreamExt},
tokio_tungstenite::tungstenite::{self, Message},
*,
};
use std::sync::Arc;
use tokio::sync::{mpsc::Sender, RwLock};
use tracing::*;

use crate::{
modder::Modder,
parser::{LiqiMessage, Parser},
settings::Settings,
};

#[derive(Clone)]
pub struct Handler {
sender: Option<Sender<(LiqiMessage, char)>>,
modder: Option<Arc<Modder>>,
inject_msg: Option<Message>,
parser: Arc<RwLock<Parser>>,
}

impl Handler {
pub fn new(
sender: Option<Sender<(LiqiMessage, char)>>,
modder: Option<Arc<Modder>>,
settings: &'static Settings,
) -> Self {
Self {
sender,
modder,
inject_msg: None,
parser: Arc::new(RwLock::new(Parser::new(
&settings.proto_json,
&settings.desc,
))),
}
}
}

impl WebSocketHandler for Handler {
async fn handle_websocket(
mut self,
ctx: WebSocketContext,
mut stream: impl Stream<Item = Result<Message, tungstenite::Error>> + Unpin + Send + 'static,
mut sink: impl Sink<Message, Error = tungstenite::Error> + Unpin + Send + 'static,
) {
if let WebSocketContext::ServerToClient { .. } = ctx {
if let Some(msg) = self.inject_msg.take() {
if let Err(e) = sink.send(msg).await {
error!("Failed to send injected message: {e}");
}
}
}
while let Some(message) = stream.next().await {
match message {
Ok(message) => {
let Some(message) = self.handle_message(&ctx, message).await else {
continue;
};

match sink.send(message).await {
Err(tungstenite::Error::ConnectionClosed) => (),
Err(e) => error!("WebSocket send error: {e}"),
_ => (),
}
}
Err(e) => {
error!("WebSocket message error: {e}");

match sink.send(Message::Close(None)).await {
Err(tungstenite::Error::ConnectionClosed) => (),
Err(e) => error!("WebSocket close error: {e}"),
_ => (),
};

break;
}
}
}
}

async fn handle_message(&mut self, _ctx: &WebSocketContext, msg: Message) -> Option<Message> {
let (direction_char, uri) = match _ctx {
WebSocketContext::ServerToClient { src, .. } => ('\u{2193}', src),
WebSocketContext::ClientToServer { dst, .. } => ('\u{2191}', dst),
};

if uri.path() == "/ob" {
// ignore ob messages
return Some(msg);
}

debug!("{direction_char} {uri}");

let Message::Binary(buf) = msg else {
return Some(msg);
};

let buf: Bytes = buf.into();
let mut parser = self.parser.write().await;
let Ok(parsed) = parser.parse(buf.clone()) else {
error!("Failed to parse message");
return Some(Message::Binary(buf.into()));
};
drop(parser);

let method_name = parsed.method_name.clone();
if let Some(tx) = &self.sender {
if let Err(e) = tx.send((parsed, direction_char)).await {
error!("Failed to send message to channel: {e}");
}
}
let Some(ref modder) = self.modder else {
return Some(Message::Binary(buf.into()));
};
let parser = self.parser.read().await;
let res = modder
.modify(buf, direction_char == '\u{2191}', method_name)
.await;
drop(parser);
if let Some(inj) = res.inject_msg {
self.inject_msg = Some(Message::Binary(inj.into()));
}
res.msg.map(|msg| Message::Binary(msg.into()))
}
}
81 changes: 81 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
use anyhow::{Context, Result};
use clap::Parser;
use handler::Handler;
use helper::helper_worker;
use hudsucker::{
certificate_authority::RcgenAuthority,
rcgen::{CertificateParams, KeyPair},
rustls, Proxy,
};
use modder::Modder;
use settings::Settings;
use std::{future::Future, net::SocketAddr, str::FromStr, sync::Arc};
use tokio::sync::mpsc::channel;
use tracing::info;

pub mod handler;
pub mod helper;
pub mod modder;
pub mod parser;
Expand All @@ -13,3 +27,70 @@ pub struct Arg {
#[clap(short, long, default_value = "./liqi_config/")]
config_dir: String,
}

pub fn init_trace() {
let timer = tracing_subscriber::fmt::time::ChronoLocal::new("%H:%M:%S%.3f".to_string());
let filter = tracing_subscriber::EnvFilter::builder()
.with_default_directive(tracing_subscriber::filter::LevelFilter::WARN.into())
.from_env()
.unwrap_or_default()
.add_directive("majsoul_max_rs=info".parse().unwrap_or_default());
tracing_subscriber::fmt()
.with_env_filter(filter)
.with_timer(timer)
.compact()
.init();
}

fn generate_ca() -> Result<RcgenAuthority> {
const KEY_PAIR: &str = include_str!("./ca/hudsucker.key");
const CA_CERT: &str = include_str!("./ca/hudsucker.cer");
let key_pair = KeyPair::from_pem(KEY_PAIR).context("Failed to parse key pair")?;
let ca_cert = CertificateParams::from_ca_cert_pem(CA_CERT)
.context("Failed to parse CA certificate")?
.self_signed(&key_pair)
.context("Failed to sign CA certificate")?;

let ca = RcgenAuthority::new(
key_pair,
ca_cert,
1_000,
rustls::crypto::aws_lc_rs::default_provider(),
);
Ok(ca)
}

pub async fn build_and_start_proxy<F>(
settings: &'static Settings,
modder: Option<Modder>,
graceful_shutdown: F,
) -> Result<()>
where
F: Future<Output = ()> + Send + 'static,
{
let ca = generate_ca()?;

let proxy_addr = SocketAddr::from_str(settings.proxy_addr.as_str())
.context("Failed to parse proxy address")?;
let modder = modder.map(Arc::new);

let tx = if settings.helper_on() {
let (tx, rx) = channel(32);
// start helper worker
info!("Helper worker started");
tokio::spawn(helper_worker(rx, settings));
Some(tx)
} else {
None
};
let proxy = Proxy::builder()
.with_addr(proxy_addr)
.with_ca(ca)
.with_rustls_client(rustls::crypto::aws_lc_rs::default_provider())
.with_websocket_handler(Handler::new(tx, modder, settings))
.with_graceful_shutdown(graceful_shutdown)
.build()
.context("Failed to build proxy")?;

proxy.start().await.context("Failed to start proxy")
}
Loading

0 comments on commit f3fd723

Please sign in to comment.