Skip to content

Commit

Permalink
Use channel to signal round end
Browse files Browse the repository at this point in the history
  • Loading branch information
pool2win committed Nov 28, 2024
1 parent 3d84c12 commit 57e3775
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 14 deletions.
13 changes: 11 additions & 2 deletions src/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,19 @@ impl Node {
log::debug!("Starting... {}", self.bind_address);
let node_id = self.get_node_id().clone();
let state = self.state.clone();
let (round_tx, round_rx) = mpsc::channel::<()>(1);
self.state.round_tx = Some(round_tx.clone());
let echo_broadcast_handle = self.echo_broadcast_handle.clone();
// let interval = tokio::time::interval(tokio::time::Duration::from_secs(15));
tokio::spawn(async move {
dkg::trigger::run_dkg_trigger(15000, node_id, state, echo_broadcast_handle, None).await;
dkg::trigger::run_dkg_trigger(
15000,
node_id,
state,
echo_broadcast_handle,
None,
round_rx,
)
.await;
});

if self.connect_to_seeds().await.is_err() {
Expand Down
6 changes: 5 additions & 1 deletion src/node/protocol/dkg/round_one.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,15 @@ impl Service<Message> for Package {
message
);
let identifier = frost::Identifier::derive(from_sender_id.as_bytes()).unwrap();
state
let finished = state
.dkg_state
.add_round1_package(identifier, message)
.await
.unwrap();
if finished {
log::debug!("Round one finished, sending signal");
let _ = state.round_tx.unwrap().send(()).await;
}
Ok(None)
}
_ => {
Expand Down
7 changes: 6 additions & 1 deletion src/node/protocol/dkg/round_two.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ impl Service<Message> for Package {
})) => {
match build_round2_packages(sender_id, state.clone()).await {
Ok((round2_secret_package, round2_packages)) => {
log::debug!("Building round2 packages succeeded");
// Store the round2 secret package
if let Err(e) = state
.dkg_state
Expand Down Expand Up @@ -155,11 +156,15 @@ impl Service<Message> for Package {
})) => {
// Received round2 message and save it in state
let identifier = frost::Identifier::derive(from_sender_id.as_bytes()).unwrap();
state
let finished = state
.dkg_state
.add_round2_package(identifier, message)
.await
.unwrap();
if finished {
log::debug!("Round two finished, sending signal");
let _ = state.round_tx.unwrap().send(()).await;
}
Ok(None)
}
_ => {
Expand Down
22 changes: 13 additions & 9 deletions src/node/protocol/dkg/trigger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use crate::node::reliable_sender::ReliableSenderHandle;
use crate::node::State;
use crate::node::{echo_broadcast::service::EchoBroadcast, protocol::Message};
use frost_secp256k1 as frost;
use frost_secp256k1::keys::{KeyPackage, PublicKeyPackage};
use tokio::sync::mpsc;
use tokio::time::{Duration, Instant};
use tower::{BoxError, ServiceExt};

Expand All @@ -39,6 +39,7 @@ pub async fn run_dkg_trigger(
mut state: State,
echo_broadcast_handle: EchoBroadcastHandle,
reliable_sender_handle: Option<ReliableSenderHandle>,
round_rx: mpsc::Receiver<()>,
) {
let period = Duration::from_millis(duration_millis);
let start = Instant::now() + period;
Expand All @@ -52,6 +53,7 @@ pub async fn run_dkg_trigger(
state.clone(),
echo_broadcast_handle.clone(),
reliable_sender_handle.clone(),
round_rx,
)
.await;

Expand Down Expand Up @@ -109,6 +111,7 @@ pub(crate) async fn trigger_dkg(
state: State,
echo_broadcast_handle: EchoBroadcastHandle,
reliable_sender_handle: Option<ReliableSenderHandle>,
mut round_rx: mpsc::Receiver<()>,
) -> Result<(), BoxError> {
let protocol_service: Protocol =
Protocol::new(node_id.clone(), state.clone(), reliable_sender_handle);
Expand All @@ -122,17 +125,15 @@ pub(crate) async fn trigger_dkg(

let round2_future = build_round2_future(node_id.clone(), protocol_service.clone());

let sleep_future = tokio::time::sleep(Duration::from_secs(5));

// TODO Improve this to allow round1 to finish as soon as all other parties have sent their round1 message
// This will mean moving the timeout into round1 service

// Wait for round1 to finish, give it 5 seconds
let (_, round1_result) = tokio::join!(sleep_future, round1_future);
if round1_result.is_err() {
if round1_future.await.is_err() {
log::error!("Error running round 1");
round1_result?
return Err("Error running round 1".into());
}
round_rx.recv().await.unwrap();
log::info!("Round 1 finished");

log::debug!(
Expand All @@ -144,12 +145,12 @@ pub(crate) async fn trigger_dkg(
.unwrap()
);

let sleep_future = tokio::time::sleep(Duration::from_secs(5));
// start round2
let (_, round2_result) = tokio::join!(sleep_future, round2_future);
if round2_result.is_err() {
if round2_future.await.is_err() {
log::error!("Error running round 2");
return Err("Error running round 2".into());
}
round_rx.recv().await.unwrap();
log::info!("Round 2 finished");

// Get packages required to run part3
Expand Down Expand Up @@ -246,6 +247,8 @@ mod dkg_trigger_tests {
mock
});

let (round_tx, round_rx) = mpsc::channel::<()>(1);

// Wait for just over one interval to ensure we get at least one trigger
let result: Result<(), time::error::Elapsed> = timeout(
Duration::from_millis(10),
Expand All @@ -255,6 +258,7 @@ mod dkg_trigger_tests {
state,
mock_echo_broadcast_handle,
Some(mock_reliable_sender_handle),
round_rx,
),
)
.await;
Expand Down
4 changes: 3 additions & 1 deletion src/node/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@
use crate::node::membership::MembershipHandle;
use crate::node::protocol::dkg;
use crate::node::protocol::message_id_generator::MessageIdGenerator;

use tokio::sync::mpsc;
/// Handlers to query/update node state
#[derive(Clone)]
pub(crate) struct State {
pub(crate) membership_handle: MembershipHandle,
pub(crate) message_id_generator: MessageIdGenerator,
pub(crate) dkg_state: dkg::state::StateHandle,
pub(crate) round_tx: Option<mpsc::Sender<()>>,
}

impl State {
Expand All @@ -38,6 +39,7 @@ impl State {
membership_handle,
message_id_generator,
dkg_state: dkg::state::StateHandle::new(Some(expected_members)),
round_tx: None,
}
}

Expand Down

0 comments on commit 57e3775

Please sign in to comment.