Skip to content

Commit

Permalink
cli: refactor file watcher
Browse files Browse the repository at this point in the history
  • Loading branch information
jgraef committed Jul 7, 2024
1 parent 26c26dd commit bf69284
Show file tree
Hide file tree
Showing 8 changed files with 152 additions and 46 deletions.
1 change: 1 addition & 0 deletions skunk-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ mod args;
mod config;
mod proxy;
mod serve_ui;
mod util;

use clap::Parser;
use color_eyre::eyre::Error;
Expand Down
45 changes: 12 additions & 33 deletions skunk-cli/src/serve_ui.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ use std::{
Path,
PathBuf,
},
sync::Arc,
task::{
Context,
Poll,
},
time::Duration,
};

use axum::{
Expand All @@ -18,11 +18,6 @@ use axum::{
},
http::Request,
};
use notify::{
RecommendedWatcher,
RecursiveMode,
Watcher as _,
};
use tokio::sync::watch;
use tower_http::services::{
ServeDir,
Expand All @@ -33,33 +28,35 @@ use tower_service::Service;
use crate::{
api,
config::Config,
util::watch::watch_modified,
};

#[derive(Clone, Debug)]
pub struct ServeUi {
inner: ServeDir<ServeFile>,
watcher: Option<Arc<RecommendedWatcher>>,
}

impl ServeUi {
pub fn new(path: impl AsRef<Path>, hot_reload: Option<api::HotReload>) -> Self {
let path = path.as_ref();

let watcher = if let Some(hot_reload) = hot_reload {
Some(Arc::new(
setup_hot_reload(path, hot_reload).expect("Failed to setup hot-reload"),
))
if let Some(hot_reload) = hot_reload {
let mut watch = watch_modified(path, Duration::from_secs(2))
.expect("Failed to watch for file changes");
tokio::spawn(async move {
while let Ok(()) = watch.wait().await {
tracing::info!("UI modified. Triggering reload");
hot_reload.trigger();
}
});
}
else {
None
};

let inner = ServeDir::new(path).fallback(ServeFile::new_with_mime(
path.join("index.html"),
&mime::TEXT_HTML_UTF_8,
));

Self { inner, watcher }
Self { inner }
}

pub fn from_config(config: &Config, api_builder: &mut api::Builder) -> Self {
Expand Down Expand Up @@ -102,24 +99,6 @@ impl Service<Request<Body>> for ServeUi {
}
}

fn setup_hot_reload(
path: &Path,
hot_reload: api::HotReload,
) -> Result<RecommendedWatcher, notify::Error> {
// note: the watcher is shutdown when it's dropped.
let mut watcher = notify::recommended_watcher(move |result: notify::Result<notify::Event>| {
if let Ok(event) = result {
if event.kind.is_modify() {
//tracing::debug!("UI modified. Sending reload notification.");
hot_reload.trigger();
}
}
})?;
watcher.watch(path, RecursiveMode::Recursive)?;

Ok(watcher)
}

async fn reload_handler(mut socket: WebSocket, mut reload_rx: watch::Receiver<()>) {
let reload_message = "{\"reload\": true}";

Expand Down
1 change: 1 addition & 0 deletions skunk-cli/src/util/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod watch;
90 changes: 90 additions & 0 deletions skunk-cli/src/util/watch.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
use std::{
path::Path,
time::Duration,
};

pub use notify::{
Error,
Event,
RecursiveMode,
};
use notify::{
RecommendedWatcher,
Watcher,
};
use tokio::sync::mpsc;

#[derive(Debug)]
pub struct FileWatcher {
watcher: RecommendedWatcher,
event_rx: mpsc::Receiver<Event>,
}

impl FileWatcher {
pub fn new() -> notify::Result<Self> {
let (event_tx, event_rx) = mpsc::channel(1);

// note: the watcher is shutdown when it's dropped.
let watcher = notify::recommended_watcher(move |result: Result<Event, Error>| {
if let Ok(event) = result {
let _ = event_tx.blocking_send(event);
}
})?;

Ok(Self { watcher, event_rx })
}

pub fn watch(
&mut self,
path: impl AsRef<Path>,
recursive_mode: RecursiveMode,
) -> Result<(), Error> {
self.watcher.watch(path.as_ref(), recursive_mode)?;
Ok(())
}

pub async fn next_event(&mut self) -> Result<Event, Closed> {
self.event_rx.recv().await.ok_or(Closed)
}

pub async fn modified(&mut self) -> Result<(), Closed> {
loop {
if self.next_event().await?.kind.is_modify() {
return Ok(());
}
}
}
}

#[derive(Debug)]
pub struct WatchModified {
watcher: FileWatcher,
debounce: Duration,
}

impl WatchModified {
pub fn new(watcher: FileWatcher, debounce: Duration) -> Result<Self, Error> {
Ok(Self { watcher, debounce })
}

pub async fn wait(&mut self) -> Result<(), Closed> {
self.watcher.modified().await?;

loop {
match tokio::time::timeout(self.debounce, self.watcher.modified()).await {
Ok(Ok(())) => {}
Ok(Err(Closed)) => return Err(Closed),
Err(_) => return Ok(()),
}
}
}
}

pub fn watch_modified(path: impl AsRef<Path>, debounce: Duration) -> Result<WatchModified, Error> {
let mut watcher = FileWatcher::new()?;
watcher.watch(path, RecursiveMode::Recursive)?;
WatchModified::new(watcher, debounce)
}

#[derive(Debug)]
pub struct Closed;
6 changes: 5 additions & 1 deletion skunk-flows-store/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ fn main() {

std::fs::write(&db_file, b"").unwrap();

Command::new("sqlx")
let exit_status = Command::new("sqlx")
.arg("migrate")
.arg("run")
.arg("--source")
Expand All @@ -31,5 +31,9 @@ fn main() {
.wait()
.unwrap();

if !exit_status.success() {
panic!("sqlx failed: {exit_status}");
}

println!("cargo::rustc-env=DATABASE_URL={}", db_url,);
}
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
DROP TABLE artifact_blob;
DROP TABLE artifact;
DROP TABLE message;
DROP TABLE flow;
Expand Down
21 changes: 15 additions & 6 deletions skunk-flows-store/migrations/20240707001024_initial.up.sql
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ CREATE TABLE flow (
destination_address TEXT NOT NULL,
destination_port INT NOT NULL,
protocol SMALLINT NOT NULL,
timestamp DATETIME NOT NULL
timestamp DATETIME NOT NULL,
metadata JSONB
);

CREATE TABLE message (
Expand All @@ -17,20 +18,28 @@ CREATE TABLE message (
kind TINYINT NOT NULL,
timestamp DATETIME NOT NULL,
data JSONB NOT NULL,
metadata JSONB,

FOREIGN KEY(flow_id) REFERENCES flow(flow_id)
);

CREATE TABLE artifact (
artifact_id UUID NOT NULL PRIMARY KEY,
from_message UUID,
from_flow UUID,
message_id UUID,
flow_id UUID,
mime_type TEXT,
file_name TEXT,
timestamp DATETIME NOT NULL,
hash BLOB NOT NULL,

FOREIGN KEY(message_id) REFERENCES flow(message_id),
FOREIGN KEY(flow_id) REFERENCES flow(flow_id),
FOREIGN KEY(hash) REFERENCES artifact_blob(hash)
);

CREATE TABLE artifact_blob (
hash BLOB NOT NULL PRIMARY KEY,
size INT NOT NULL,
data BLOB NOT NULL,

FOREIGN KEY(from_message) REFERENCES flow(message_id),
FOREIGN KEY(from_flow) REFERENCES flow(flow_id)
metadata JSONB
);
33 changes: 27 additions & 6 deletions skunk-flows-store/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,33 @@ impl<'a> Transaction<'a> {

pub async fn create_flow(
&mut self,
_flow_id: Uuid,
_destination_address: &str,
_destination_port: u16,
_protocol: u16,
_timestamp: DateTime<FixedOffset>,
flow_id: Uuid,
destination_address: &str,
destination_port: u16,
protocol: u16,
timestamp: DateTime<FixedOffset>,
metadata: Metadata,
) -> Result<(), Error> {
todo!();
let metadata = serde_json::to_value(metadata)?;
sqlx::query!(
r#"
INSERT INTO flow (flow_id, destination_address, destination_port, protocol, timestamp, metadata)
VALUES (?, ?, ?, ?, ?, ?)
"#,
flow_id,
destination_address,
destination_port,
protocol,
timestamp,
metadata,
)
.execute(&mut *self.transaction)
.await?;
Ok(())
}
}

#[derive(Debug, Serialize, Deserialize)]
pub struct Metadata {
// todo
}

0 comments on commit bf69284

Please sign in to comment.