diff --git a/skunk-api-client/src/client.rs b/skunk-api-client/src/client.rs index e50a491..c780896 100644 --- a/skunk-api-client/src/client.rs +++ b/skunk-api-client/src/client.rs @@ -12,96 +12,75 @@ use std::{ use futures_util::{ Future, FutureExt, - SinkExt, - TryStreamExt, -}; -use lazy_static::lazy_static; -use reqwest_websocket::{ - Message, - RequestBuilderExt, -}; -use serde::{ - Deserialize, - Serialize, -}; -use skunk_api_protocol::{ - socket::{ - ClientHello, - ClientMessage, - ServerHello, - ServerMessage, - Version, - }, - PROTOCOL_VERSION, }; use skunk_util::trigger; -use tokio::sync::{ - mpsc, - watch, -}; +use tokio::sync::watch; use tracing::Instrument; use url::Url; -use super::Error; - -pub const USER_AGENT: &'static str = std::env!("CARGO_PKG_NAME"); -lazy_static! { - pub static ref CLIENT_VERSION: Version = std::env!("CARGO_PKG_VERSION").parse().unwrap(); -} +use crate::{ + socket::{ + Command, + Reactor, + ReactorHandle, + }, + Status, +}; #[derive(Clone, Debug)] pub struct Client { client: reqwest::Client, base_url: UrlBuilder, - command_tx: mpsc::Sender, - reload_rx: trigger::Receiver, - status_rx: watch::Receiver, + reactor: ReactorHandle, } impl Client { pub fn new(base_url: Url) -> (Self, Connection) { let client = reqwest::Client::new(); let base_url = UrlBuilder { url: base_url }; - let (command_tx, command_rx) = mpsc::channel(4); - let (reload_tx, reload_rx) = trigger::new(); - let (status_tx, status_rx) = watch::channel(Default::default()); + + let (reactor, reactor_handle) = + Reactor::new(client.clone(), base_url.clone().push("ws").finish()); + let span = tracing::info_span!("socket"); let connection = Connection { - inner: Box::pin({ - let client = client.clone(); - let base_url = base_url.clone(); - let span = tracing::info_span!("connection"); - async move { - let reactor = - Reactor::new(client, base_url, command_rx, reload_tx, status_tx).await?; - reactor.run().await - } - .instrument(span) - }), + inner: Box::pin(reactor.run().instrument(span)), }; let client = Self { client, base_url, - command_tx, - reload_rx, - status_rx, + reactor: reactor_handle, }; (client, connection) } + async fn send_command(&mut self, command: Command) { + self.reactor + .command_tx + .send(command) + .await + .expect("Reactor died"); + } + pub fn reload_ui(&self) -> trigger::Receiver { - self.reload_rx.clone() + self.reactor.reload_rx.clone() } pub fn status(&self) -> watch::Receiver { - self.status_rx.clone() + self.reactor.status_rx.clone() + } + + pub async fn ping(&mut self) { + let mut pong_rx = self.reactor.pong_rx.clone(); + self.send_command(Command::Ping).await; + pong_rx.triggered().await; } } #[derive(Clone, Debug)] -struct UrlBuilder { +pub(crate) struct UrlBuilder { url: Url, } @@ -120,11 +99,11 @@ impl UrlBuilder { /// /// This must be polled to drive the connection for a [`Client`]. pub struct Connection { - inner: Pin>>>, + inner: Pin>>, } impl Future for Connection { - type Output = Result<(), Error>; + type Output = (); fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { self.inner.poll_unpin(cx) @@ -136,160 +115,3 @@ impl Debug for Connection { f.debug_struct("Connection").finish_non_exhaustive() } } - -/// Reactor that handles the websocket connection to the server. -/// -/// The client can send commands through the sender half of `command_rx`. -struct Reactor { - socket: WebSocket, - command_rx: mpsc::Receiver, - reload_tx: trigger::Sender, - status_tx: watch::Sender, - flows_tx: Option>, -} - -impl Reactor { - async fn new( - client: reqwest::Client, - base_url: UrlBuilder, - command_rx: mpsc::Receiver, - reload_tx: trigger::Sender, - status_tx: watch::Sender, - ) -> Result { - let websocket = client - .get(base_url.push("ws").finish()) - .upgrade() - .send() - .await? - .into_websocket() - .await? - .into(); - - Ok(Self { - socket: websocket, - command_rx, - reload_tx, - status_tx, - flows_tx: None, - }) - } - - async fn run(mut self) -> Result<(), Error> { - self.socket - .send(&ClientHello { - user_agent: USER_AGENT.into(), - app_version: CLIENT_VERSION.clone(), - protocol_version: PROTOCOL_VERSION, - }) - .await?; - - let _server_hello: ServerHello = self - .socket - .receive() - .await? - .ok_or_else(|| Error::Protocol)?; - - let _ = self.status_tx.send(Status::Connected); - - loop { - tokio::select! { - message_res = self.socket.receive() => { - let Some(message) = message_res? else { - tracing::debug!("Connection closed"); - break; - }; - self.handle_message(message).await?; - } - command_opt = self.command_rx.recv() => { - let Some(command) = command_opt else { - tracing::debug!("Command sender dropped"); - break; - }; - self.handle_command(command).await?; - } - } - } - - let _ = self.status_tx.send(Status::Disconnected); - - Ok(()) - } - - async fn handle_message(&mut self, message: ServerMessage) -> Result<(), Error> { - tracing::debug!(?message, "received"); - - match message { - ServerMessage::HotReload => { - let _ = self.reload_tx.trigger(); - } - ServerMessage::Interrupt { message_id } => { - // todo: for now we'll just send a Continue back - // eventually we want to send the interrupt to the user with a oneshot channel. - self.socket - .send(&ClientMessage::Continue { message_id }) - .await?; - } - ServerMessage::Flow { .. } => { - if let Some(flows_tx) = &mut self.flows_tx { - if let Err(_) = flows_tx.send(()).await { - // the flows receiver has been dropped. - self.flows_tx = None; - } - } - } - } - - Ok(()) - } - - async fn handle_command(&mut self, _command: Command) -> Result<(), Error> { - todo!(); - } -} - -enum Command { - // todo -} - -/// Wrapper around [`reqwest_websocket::WebSocket`] that sends and receives -/// msgpack-encoded messages. -#[derive(Debug)] -struct WebSocket { - inner: reqwest_websocket::WebSocket, -} - -impl From for WebSocket { - fn from(inner: reqwest_websocket::WebSocket) -> Self { - Self { inner } - } -} - -impl WebSocket { - async fn receive Deserialize<'de>>(&mut self) -> Result, Error> { - while let Some(message) = self.inner.try_next().await? { - match message { - Message::Binary(data) => { - let item: T = rmp_serde::from_slice(&data)?; - return Ok(Some(item)); - } - Message::Close { .. } => return Ok(None), - _ => {} - } - } - - Ok(None) - } - - async fn send(&mut self, item: &T) -> Result<(), Error> { - let data = rmp_serde::to_vec(item)?; - self.inner.send(Message::Binary(data)).await?; - Ok(()) - } -} - -#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)] -pub enum Status { - #[default] - Disconnected, - Connected, -} diff --git a/skunk-api-client/src/error.rs b/skunk-api-client/src/error.rs index 4b6350f..0a3b351 100644 --- a/skunk-api-client/src/error.rs +++ b/skunk-api-client/src/error.rs @@ -8,8 +8,6 @@ pub enum Error { Websocket(#[from] reqwest_websocket::Error), Decode(#[from] rmp_serde::decode::Error), Encode(#[from] rmp_serde::encode::Error), - #[error("protocol error")] - Protocol, ApiError { status_code: StatusCode, #[source] diff --git a/skunk-api-client/src/lib.rs b/skunk-api-client/src/lib.rs index 93d621d..6a83647 100644 --- a/skunk-api-client/src/lib.rs +++ b/skunk-api-client/src/lib.rs @@ -1,12 +1,19 @@ mod client; mod error; +mod socket; mod util; pub use self::{ client::{ Client, Connection, - Status, }, error::Error, }; + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)] +pub enum Status { + #[default] + Disconnected, + Connected, +} diff --git a/skunk-api-client/src/socket.rs b/skunk-api-client/src/socket.rs new file mode 100644 index 0000000..96d43dd --- /dev/null +++ b/skunk-api-client/src/socket.rs @@ -0,0 +1,265 @@ +use std::fmt::Debug; + +use futures_util::{ + SinkExt, + TryStreamExt, +}; +use lazy_static::lazy_static; +use reqwest_websocket::{ + Message, + RequestBuilderExt, +}; +use serde::{ + Deserialize, + Serialize, +}; +use skunk_api_protocol::{ + socket::{ + ClientHello, + ClientMessage, + ServerHello, + ServerMessage, + Version, + }, + PROTOCOL_VERSION, +}; +use skunk_util::trigger; +use tokio::sync::{ + mpsc, + watch, +}; +use url::Url; + +use crate::Status; + +pub const USER_AGENT: &'static str = std::env!("CARGO_PKG_NAME"); +lazy_static! { + pub static ref CLIENT_VERSION: Version = std::env!("CARGO_PKG_VERSION").parse().unwrap(); +} + +#[derive(Debug, thiserror::Error)] +#[error("Reactor error")] +enum Error { + #[error("Websocket disconnected")] + Disconnected, + #[error("Handshake failed")] + Handshake, + Reqwest(#[from] reqwest::Error), + Websocket(#[from] reqwest_websocket::Error), + Decode(#[from] rmp_serde::decode::Error), + Encode(#[from] rmp_serde::encode::Error), +} + +#[derive(Clone, Debug)] +pub(crate) struct ReactorHandle { + pub command_tx: mpsc::Sender, + pub reload_rx: trigger::Receiver, + pub status_rx: watch::Receiver, + pub pong_rx: trigger::Receiver, +} + +/// Reactor that handles the websocket connection to the server. +/// +/// The client can send commands through the sender half of `command_rx`. +pub(crate) struct Reactor { + client: reqwest::Client, + url: Url, + command_rx: mpsc::Receiver, + reload_tx: trigger::Sender, + status_tx: watch::Sender, + pong_tx: trigger::Sender, + flows_tx: Option>, +} + +impl Reactor { + pub fn new(client: reqwest::Client, url: Url) -> (Self, ReactorHandle) { + let (command_tx, command_rx) = mpsc::channel(16); + let (reload_tx, reload_rx) = trigger::new(); + let (status_tx, status_rx) = watch::channel(Default::default()); + let (pong_tx, pong_rx) = trigger::new(); + + let this = Self { + client, + url, + command_rx, + reload_tx, + status_tx, + pong_tx, + flows_tx: None, + }; + + let handle = ReactorHandle { + command_tx, + reload_rx, + status_rx, + pong_rx, + }; + + (this, handle) + } + + pub async fn run(mut self) { + loop { + let Ok(connection) = ReactorConnection::connect(&mut self).await + else { + // connection failed + // we should retry after some time, but since we don't have a reliable sleep, + // we'll panic for now + todo!(); + }; + + match connection.run().await { + Ok(()) => { + // ReactorConnection returns Ok(()) when the command sender has been dropped, so + // we should terminate + let _ = self.status_tx.send(Status::Disconnected); + break; + } + Err(Error::Disconnected) => { + // the websocket connection was disconnected for some reason. so, we'll try to + // reconnect todo: wait for some time + let _ = self.status_tx.send(Status::Disconnected); + } + Err(e) => { + tracing::error!("Reactor failed: {e}"); + break; + } + } + } + } +} + +struct ReactorConnection<'a> { + socket: WebSocket, + reactor: &'a mut Reactor, +} + +impl<'a> ReactorConnection<'a> { + async fn connect(reactor: &'a mut Reactor) -> Result { + let mut socket: WebSocket = reactor + .client + .get(reactor.url.clone()) + .upgrade() + .send() + .await? + .into_websocket() + .await? + .into(); + + socket + .send(&ClientHello { + user_agent: USER_AGENT.into(), + app_version: CLIENT_VERSION.clone(), + protocol_version: PROTOCOL_VERSION, + }) + .await?; + + let _server_hello: ServerHello = socket.receive().await?.ok_or_else(|| Error::Handshake)?; + + let _ = reactor.status_tx.send(Status::Connected); + + Ok(Self { socket, reactor }) + } + + async fn run(mut self) -> Result<(), Error> { + loop { + tokio::select! { + message_res = self.socket.receive() => { + let Some(message) = message_res? else { + tracing::debug!("Connection closed"); + break Err(Error::Disconnected); + }; + self.handle_message(message).await?; + } + command_opt = self.reactor.command_rx.recv() => { + let Some(command) = command_opt else { + tracing::debug!("Command sender dropped"); + break Ok(()); + }; + self.handle_command(command).await?; + } + } + } + } + + async fn handle_message(&mut self, message: ServerMessage) -> Result<(), Error> { + tracing::debug!(?message, "received"); + + match message { + ServerMessage::HotReload => { + self.reactor.reload_tx.trigger(); + } + ServerMessage::Pong => { + self.reactor.pong_tx.trigger(); + } + ServerMessage::Interrupt { message_id } => { + // todo: for now we'll just send a Continue back + // eventually we want to send the interrupt to the user with a oneshot channel. + self.socket + .send(&ClientMessage::Continue { message_id }) + .await?; + } + ServerMessage::Flow { .. } => { + if let Some(flows_tx) = &mut self.reactor.flows_tx { + if let Err(_) = flows_tx.send(()).await { + // the flows receiver has been dropped. + self.reactor.flows_tx = None; + } + } + } + } + + Ok(()) + } + + async fn handle_command(&mut self, command: Command) -> Result<(), Error> { + match command { + Command::Ping => { + self.socket.send(&ClientMessage::Ping).await?; + } + } + + Ok(()) + } +} + +#[derive(Debug)] +pub(crate) enum Command { + Ping, +} + +/// Wrapper around [`reqwest_websocket::WebSocket`] that sends and receives +/// msgpack-encoded messages. +#[derive(Debug)] +struct WebSocket { + inner: reqwest_websocket::WebSocket, +} + +impl From for WebSocket { + fn from(inner: reqwest_websocket::WebSocket) -> Self { + Self { inner } + } +} + +impl WebSocket { + async fn receive Deserialize<'de>>(&mut self) -> Result, Error> { + while let Some(message) = self.inner.try_next().await? { + match message { + Message::Binary(data) => { + let item: T = rmp_serde::from_slice(&data)?; + return Ok(Some(item)); + } + Message::Close { .. } => return Ok(None), + _ => {} + } + } + + Ok(None) + } + + async fn send(&mut self, item: &T) -> Result<(), Error> { + let data = rmp_serde::to_vec(item)?; + self.inner.send(Message::Binary(data)).await?; + Ok(()) + } +} diff --git a/skunk-api-protocol/src/socket.rs b/skunk-api-protocol/src/socket.rs index 2e049ab..9a0765d 100644 --- a/skunk-api-protocol/src/socket.rs +++ b/skunk-api-protocol/src/socket.rs @@ -35,6 +35,7 @@ pub struct ServerHello { #[derive(Clone, Debug, Serialize, Deserialize)] pub enum ServerMessage { HotReload, + Pong, // todo Interrupt { message_id: Uuid, @@ -47,6 +48,7 @@ pub enum ServerMessage { #[derive(Clone, Debug, Serialize, Deserialize)] pub enum ClientMessage { + Ping, SubscribeFlows, Start, Stop, diff --git a/skunk-cli/src/api/socket.rs b/skunk-cli/src/api/socket.rs index ba70ae4..eb146e3 100644 --- a/skunk-cli/src/api/socket.rs +++ b/skunk-cli/src/api/socket.rs @@ -129,13 +129,16 @@ impl Reactor { async fn handle_message(&mut self, message: ClientMessage) -> Result<(), Error> { match message { + ClientMessage::Ping => { + self.socket.send(&ServerMessage::Pong).await?; + } ClientMessage::SubscribeFlows => todo!(), ClientMessage::Start => todo!(), ClientMessage::Stop => todo!(), ClientMessage::Continue { .. } => todo!(), } - //Ok(()) + Ok(()) } } diff --git a/skunk-ui/Cargo.toml b/skunk-ui/Cargo.toml index 716d29b..1285aec 100644 --- a/skunk-ui/Cargo.toml +++ b/skunk-ui/Cargo.toml @@ -27,7 +27,6 @@ lipsum = "0.9.1" getrandom = { version = "0.2.15", features = ["js"] } rand = "0.8.5" url = "2.5.2" -gloo-timers = { version = "0.3.0", features = ["futures"] } [package.metadata.stylance] output_file = "../target/app.scss" diff --git a/skunk-ui/src/app/mod.rs b/skunk-ui/src/app/mod.rs index 30d4699..5c4aca4 100644 --- a/skunk-ui/src/app/mod.rs +++ b/skunk-ui/src/app/mod.rs @@ -2,9 +2,6 @@ mod flows; mod home; mod settings; -use std::time::Duration; - -use gloo_timers::future::sleep; use leptos::{ component, create_node_ref, @@ -124,11 +121,7 @@ impl Context { let (client, connection) = Client::new(api_url().expect("Could not determine API url")); // poll the connection in a separate task - leptos::spawn_local(async move { - if let Err(e) = connection.await { - tracing::error!("client connection failed: {e}"); - } - }); + leptos::spawn_local(connection); // reload page on hot-reload signal let mut reload_ui = client.reload_ui();