From d4ac7ced110f61a55e2440b46fcc49a8d16f2321 Mon Sep 17 00:00:00 2001 From: pool2win Date: Fri, 22 Nov 2024 11:57:03 +0100 Subject: [PATCH] Add support for round2 packages in dkg state --- src/node/protocol/dkg/state.rs | 115 +++++++++++++++++++++++++++++++++ src/node/test_helpers.rs | 51 +++++++++++++++ 2 files changed, 166 insertions(+) diff --git a/src/node/protocol/dkg/state.rs b/src/node/protocol/dkg/state.rs index 9a3220b..d425714 100644 --- a/src/node/protocol/dkg/state.rs +++ b/src/node/protocol/dkg/state.rs @@ -27,6 +27,7 @@ pub(crate) struct State { pub in_progress: bool, pub pub_key: Option, pub received_round1_packages: Round1Map, + pub received_round2_packages: Round2Map, pub round1_secret_package: Option, pub round2_secret_package: Option, } @@ -41,6 +42,7 @@ impl State { in_progress: false, pub_key: None, received_round1_packages: Round1Map::new(), + received_round2_packages: Round2Map::new(), round1_secret_package: None, round2_secret_package: None, } @@ -66,6 +68,16 @@ pub(crate) enum StateMessage { /// Get a received round2 secret package from state GetRound2SecretPackage(oneshot::Sender>), + + /// Get a received round2 packages from state + GetReceivedRound2Packages(oneshot::Sender), + + /// Add a received round2 package to state + AddRound2Package( + frost::Identifier, + frost::keys::dkg::round2::Package, + oneshot::Sender<()>, + ), } pub(crate) struct Actor { @@ -105,6 +117,13 @@ impl Actor { StateMessage::GetRound2SecretPackage(respond_to) => { self.get_round2_secret_package(respond_to); } + StateMessage::GetReceivedRound2Packages(respond_to) => { + let received_round2_packages = self.state.received_round2_packages.clone(); + let _ = respond_to.send(received_round2_packages); + } + StateMessage::AddRound2Package(identifier, package, respond_to) => { + self.add_round2_package(identifier, package, respond_to); + } } } } @@ -146,6 +165,23 @@ impl Actor { let secret_package = self.state.round2_secret_package.clone(); let _ = respond_to.send(secret_package); } + + fn add_round2_package( + &mut self, + identifier: Identifier, + package: frost::keys::dkg::round2::Package, + respond_to: oneshot::Sender<()>, + ) { + self.state + .received_round2_packages + .insert(identifier, package); + let _ = respond_to.send(()); + } + + fn get_received_round2_packages(&self, respond_to: oneshot::Sender) { + let received_round2_packages = self.state.received_round2_packages.clone(); + let _ = respond_to.send(received_round2_packages); + } } #[derive(Clone, Debug)] @@ -230,11 +266,38 @@ impl StateHandle { let _ = self.sender.send(message).await; rx.await } + + /// Get a received round2 packages from state + pub async fn get_received_round2_packages( + &self, + ) -> Result { + let (tx, rx) = oneshot::channel(); + let message = StateMessage::GetReceivedRound2Packages(tx); + let _ = self.sender.send(message).await; + rx.await + } + + /// Add a received round2 package to state + pub async fn add_round2_package( + &self, + identifier: Identifier, + package: frost::keys::dkg::round2::Package, + ) -> Result<(), oneshot::error::RecvError> { + let (tx, rx) = oneshot::channel(); + let message = StateMessage::AddRound2Package(identifier, package, tx); + let _ = self.sender.send(message).await; + rx.await + } } #[cfg(test)] mod dkg_state_tests { use super::*; + use crate::node::protocol::message_id_generator::MessageIdGenerator; + #[mockall_double::double] + use crate::node::reliable_sender::ReliableSenderHandle; + use crate::node::{test_helpers::support::build_round2_state, MembershipHandle}; + use futures::FutureExt; use rand::thread_rng; use std::collections::BTreeMap; use tokio::sync::oneshot; @@ -292,6 +355,58 @@ mod dkg_state_tests { actor.add_secret_package(secret_package.clone(), tx1); assert_eq!(actor.state.round1_secret_package, Some(secret_package)); } + + #[tokio::test] + async fn test_actor_add_round2_package() { + let (_tx, rx) = mpsc::channel(1); + let mut actor = Actor::new(rx); + + let membership_handle = MembershipHandle::start("localhost".to_string()).await; + for i in 1..3 { + let mut mock_reliable_sender = ReliableSenderHandle::default(); + mock_reliable_sender.expect_clone().returning(|| { + let mut mock = ReliableSenderHandle::default(); + mock.expect_clone().returning(ReliableSenderHandle::default); + mock.expect_send() + //.times(1) + .returning(|_| futures::future::ok(()).boxed()); + mock + }); + let _ = membership_handle + .add_member(format!("localhost{}", i), mock_reliable_sender) + .await; + } + let state = crate::node::state::State::new( + membership_handle, + MessageIdGenerator::new("local".to_string()), + ); + let (state, round1_packages) = build_round2_state(state).await; + + // Generate round2 packages + let (round2_secret, round2_packages) = frost::keys::dkg::part2( + state + .dkg_state + .get_round1_secret_package() + .await + .unwrap() + .unwrap(), + &round1_packages, + ) + .unwrap(); + + // Add each round2 package to state + for (identifier, round2_package) in round2_packages.iter() { + let (tx, _rx) = oneshot::channel(); + actor.add_round2_package(*identifier, round2_package.clone(), tx); + } + + assert_eq!(actor.state.received_round2_packages.len(), 2); + + let (tx, rx) = oneshot::channel(); + actor.get_received_round2_packages(tx); + let received_packages = rx.await.unwrap(); + assert_eq!(received_packages.len(), 2); + } } #[cfg(test)] diff --git a/src/node/test_helpers.rs b/src/node/test_helpers.rs index 6379efd..7d3bbb3 100644 --- a/src/node/test_helpers.rs +++ b/src/node/test_helpers.rs @@ -21,6 +21,9 @@ pub(crate) mod support { use crate::node::membership::MembershipHandle; #[mockall_double::double] use crate::node::reliable_sender::ReliableSenderHandle; + use crate::node::{self, protocol::dkg::state::Round1Map}; + use frost_secp256k1 as frost; + use rand::thread_rng; /// Builds a membership with the given number of nodes /// Do not add the local node to the membership, therefore it loops from 1 to num @@ -39,4 +42,52 @@ pub(crate) mod support { } membership_handle } + + pub async fn build_round2_state(state: node::state::State) -> (node::state::State, Round1Map) { + let rng = thread_rng(); + let mut round1_packages = Round1Map::new(); + + // generate our round1 secret and package + let (secret_package, round1_package) = frost::keys::dkg::part1( + frost::Identifier::derive(b"node1").unwrap(), + 3, + 2, + rng.clone(), + ) + .unwrap(); + log::debug!("Secret package {:?}", secret_package); + + // add our secret package to state + state + .dkg_state + .add_round1_secret_package(secret_package) + .await + .unwrap(); + + // Add packages for other nodes + let (_, round1_package2) = frost::keys::dkg::part1( + frost::Identifier::derive(b"node2").unwrap(), + 3, + 2, + rng.clone(), + ) + .unwrap(); + round1_packages.insert( + frost::Identifier::derive(b"node2").unwrap(), + round1_package2, + ); + + let (_, round1_package3) = frost::keys::dkg::part1( + frost::Identifier::derive(b"node3").unwrap(), + 3, + 2, + rng.clone(), + ) + .unwrap(); + round1_packages.insert( + frost::Identifier::derive(b"node3").unwrap(), + round1_package3, + ); + (state, round1_packages) + } }