Skip to content

Commit

Permalink
refactor api client socket
Browse files Browse the repository at this point in the history
  • Loading branch information
jgraef committed Jul 8, 2024
1 parent 5dda01f commit 390682f
Show file tree
Hide file tree
Showing 8 changed files with 315 additions and 226 deletions.
248 changes: 35 additions & 213 deletions skunk-api-client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,96 +12,75 @@ use std::{
use futures_util::{
Future,
FutureExt,
SinkExt,
TryStreamExt,
};
use lazy_static::lazy_static;
use reqwest_websocket::{
Message,
RequestBuilderExt,
};
use serde::{
Deserialize,
Serialize,
};
use skunk_api_protocol::{
socket::{
ClientHello,
ClientMessage,
ServerHello,
ServerMessage,
Version,
},
PROTOCOL_VERSION,
};
use skunk_util::trigger;
use tokio::sync::{
mpsc,
watch,
};
use tokio::sync::watch;
use tracing::Instrument;
use url::Url;

use super::Error;

pub const USER_AGENT: &'static str = std::env!("CARGO_PKG_NAME");
lazy_static! {
pub static ref CLIENT_VERSION: Version = std::env!("CARGO_PKG_VERSION").parse().unwrap();
}
use crate::{
socket::{
Command,
Reactor,
ReactorHandle,
},
Status,
};

#[derive(Clone, Debug)]
pub struct Client {
client: reqwest::Client,
base_url: UrlBuilder,
command_tx: mpsc::Sender<Command>,
reload_rx: trigger::Receiver,
status_rx: watch::Receiver<Status>,
reactor: ReactorHandle,
}

impl Client {
pub fn new(base_url: Url) -> (Self, Connection) {
let client = reqwest::Client::new();
let base_url = UrlBuilder { url: base_url };
let (command_tx, command_rx) = mpsc::channel(4);
let (reload_tx, reload_rx) = trigger::new();
let (status_tx, status_rx) = watch::channel(Default::default());

let (reactor, reactor_handle) =
Reactor::new(client.clone(), base_url.clone().push("ws").finish());
let span = tracing::info_span!("socket");

let connection = Connection {
inner: Box::pin({
let client = client.clone();
let base_url = base_url.clone();
let span = tracing::info_span!("connection");
async move {
let reactor =
Reactor::new(client, base_url, command_rx, reload_tx, status_tx).await?;
reactor.run().await
}
.instrument(span)
}),
inner: Box::pin(reactor.run().instrument(span)),
};

let client = Self {
client,
base_url,
command_tx,
reload_rx,
status_rx,
reactor: reactor_handle,
};

(client, connection)
}

async fn send_command(&mut self, command: Command) {
self.reactor
.command_tx
.send(command)
.await
.expect("Reactor died");
}

pub fn reload_ui(&self) -> trigger::Receiver {
self.reload_rx.clone()
self.reactor.reload_rx.clone()
}

pub fn status(&self) -> watch::Receiver<Status> {
self.status_rx.clone()
self.reactor.status_rx.clone()
}

pub async fn ping(&mut self) {
let mut pong_rx = self.reactor.pong_rx.clone();
self.send_command(Command::Ping).await;
pong_rx.triggered().await;
}
}

#[derive(Clone, Debug)]
struct UrlBuilder {
pub(crate) struct UrlBuilder {
url: Url,
}

Expand All @@ -120,11 +99,11 @@ impl UrlBuilder {
///
/// This must be polled to drive the connection for a [`Client`].
pub struct Connection {
inner: Pin<Box<dyn Future<Output = Result<(), Error>>>>,
inner: Pin<Box<dyn Future<Output = ()>>>,
}

impl Future for Connection {
type Output = Result<(), Error>;
type Output = ();

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.inner.poll_unpin(cx)
Expand All @@ -136,160 +115,3 @@ impl Debug for Connection {
f.debug_struct("Connection").finish_non_exhaustive()
}
}

/// Reactor that handles the websocket connection to the server.
///
/// The client can send commands through the sender half of `command_rx`.
struct Reactor {
socket: WebSocket,
command_rx: mpsc::Receiver<Command>,
reload_tx: trigger::Sender,
status_tx: watch::Sender<Status>,
flows_tx: Option<mpsc::Sender<()>>,
}

impl Reactor {
async fn new(
client: reqwest::Client,
base_url: UrlBuilder,
command_rx: mpsc::Receiver<Command>,
reload_tx: trigger::Sender,
status_tx: watch::Sender<Status>,
) -> Result<Self, Error> {
let websocket = client
.get(base_url.push("ws").finish())
.upgrade()
.send()
.await?
.into_websocket()
.await?
.into();

Ok(Self {
socket: websocket,
command_rx,
reload_tx,
status_tx,
flows_tx: None,
})
}

async fn run(mut self) -> Result<(), Error> {
self.socket
.send(&ClientHello {
user_agent: USER_AGENT.into(),
app_version: CLIENT_VERSION.clone(),
protocol_version: PROTOCOL_VERSION,
})
.await?;

let _server_hello: ServerHello = self
.socket
.receive()
.await?
.ok_or_else(|| Error::Protocol)?;

let _ = self.status_tx.send(Status::Connected);

loop {
tokio::select! {
message_res = self.socket.receive() => {
let Some(message) = message_res? else {
tracing::debug!("Connection closed");
break;
};
self.handle_message(message).await?;
}
command_opt = self.command_rx.recv() => {
let Some(command) = command_opt else {
tracing::debug!("Command sender dropped");
break;
};
self.handle_command(command).await?;
}
}
}

let _ = self.status_tx.send(Status::Disconnected);

Ok(())
}

async fn handle_message(&mut self, message: ServerMessage) -> Result<(), Error> {
tracing::debug!(?message, "received");

match message {
ServerMessage::HotReload => {
let _ = self.reload_tx.trigger();
}
ServerMessage::Interrupt { message_id } => {
// todo: for now we'll just send a Continue back
// eventually we want to send the interrupt to the user with a oneshot channel.
self.socket
.send(&ClientMessage::Continue { message_id })
.await?;
}
ServerMessage::Flow { .. } => {
if let Some(flows_tx) = &mut self.flows_tx {
if let Err(_) = flows_tx.send(()).await {
// the flows receiver has been dropped.
self.flows_tx = None;
}
}
}
}

Ok(())
}

async fn handle_command(&mut self, _command: Command) -> Result<(), Error> {
todo!();
}
}

enum Command {
// todo
}

/// Wrapper around [`reqwest_websocket::WebSocket`] that sends and receives
/// msgpack-encoded messages.
#[derive(Debug)]
struct WebSocket {
inner: reqwest_websocket::WebSocket,
}

impl From<reqwest_websocket::WebSocket> for WebSocket {
fn from(inner: reqwest_websocket::WebSocket) -> Self {
Self { inner }
}
}

impl WebSocket {
async fn receive<T: for<'de> Deserialize<'de>>(&mut self) -> Result<Option<T>, Error> {
while let Some(message) = self.inner.try_next().await? {
match message {
Message::Binary(data) => {
let item: T = rmp_serde::from_slice(&data)?;
return Ok(Some(item));
}
Message::Close { .. } => return Ok(None),
_ => {}
}
}

Ok(None)
}

async fn send<T: Serialize>(&mut self, item: &T) -> Result<(), Error> {
let data = rmp_serde::to_vec(item)?;
self.inner.send(Message::Binary(data)).await?;
Ok(())
}
}

#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
pub enum Status {
#[default]
Disconnected,
Connected,
}
2 changes: 0 additions & 2 deletions skunk-api-client/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ pub enum Error {
Websocket(#[from] reqwest_websocket::Error),
Decode(#[from] rmp_serde::decode::Error),
Encode(#[from] rmp_serde::encode::Error),
#[error("protocol error")]
Protocol,
ApiError {
status_code: StatusCode,
#[source]
Expand Down
9 changes: 8 additions & 1 deletion skunk-api-client/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
mod client;
mod error;
mod socket;
mod util;

pub use self::{
client::{
Client,
Connection,
Status,
},
error::Error,
};

#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
pub enum Status {
#[default]
Disconnected,
Connected,
}
Loading

0 comments on commit 390682f

Please sign in to comment.