Skip to content

Commit

Permalink
Add support for round2 packages in dkg state
Browse files Browse the repository at this point in the history
  • Loading branch information
pool2win committed Nov 22, 2024
1 parent d6b81be commit d4ac7ce
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 0 deletions.
115 changes: 115 additions & 0 deletions src/node/protocol/dkg/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ pub(crate) struct State {
pub in_progress: bool,
pub pub_key: Option<frost::keys::PublicKeyPackage>,
pub received_round1_packages: Round1Map,
pub received_round2_packages: Round2Map,
pub round1_secret_package: Option<frost::keys::dkg::round1::SecretPackage>,
pub round2_secret_package: Option<frost::keys::dkg::round2::SecretPackage>,
}
Expand All @@ -41,6 +42,7 @@ impl State {
in_progress: false,
pub_key: None,
received_round1_packages: Round1Map::new(),
received_round2_packages: Round2Map::new(),
round1_secret_package: None,
round2_secret_package: None,
}
Expand All @@ -66,6 +68,16 @@ pub(crate) enum StateMessage {

/// Get a received round2 secret package from state
GetRound2SecretPackage(oneshot::Sender<Option<frost::keys::dkg::round2::SecretPackage>>),

/// Get a received round2 packages from state
GetReceivedRound2Packages(oneshot::Sender<Round2Map>),

/// Add a received round2 package to state
AddRound2Package(
frost::Identifier,
frost::keys::dkg::round2::Package,
oneshot::Sender<()>,
),
}

pub(crate) struct Actor {
Expand Down Expand Up @@ -105,6 +117,13 @@ impl Actor {
StateMessage::GetRound2SecretPackage(respond_to) => {
self.get_round2_secret_package(respond_to);
}
StateMessage::GetReceivedRound2Packages(respond_to) => {
let received_round2_packages = self.state.received_round2_packages.clone();
let _ = respond_to.send(received_round2_packages);
}
StateMessage::AddRound2Package(identifier, package, respond_to) => {
self.add_round2_package(identifier, package, respond_to);
}
}
}
}
Expand Down Expand Up @@ -146,6 +165,23 @@ impl Actor {
let secret_package = self.state.round2_secret_package.clone();
let _ = respond_to.send(secret_package);
}

fn add_round2_package(
&mut self,
identifier: Identifier,
package: frost::keys::dkg::round2::Package,
respond_to: oneshot::Sender<()>,
) {
self.state
.received_round2_packages
.insert(identifier, package);
let _ = respond_to.send(());
}

fn get_received_round2_packages(&self, respond_to: oneshot::Sender<Round2Map>) {
let received_round2_packages = self.state.received_round2_packages.clone();
let _ = respond_to.send(received_round2_packages);
}
}

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -230,11 +266,38 @@ impl StateHandle {
let _ = self.sender.send(message).await;
rx.await
}

/// Get a received round2 packages from state
pub async fn get_received_round2_packages(
&self,
) -> Result<Round2Map, oneshot::error::RecvError> {
let (tx, rx) = oneshot::channel();
let message = StateMessage::GetReceivedRound2Packages(tx);
let _ = self.sender.send(message).await;
rx.await
}

/// Add a received round2 package to state
pub async fn add_round2_package(
&self,
identifier: Identifier,
package: frost::keys::dkg::round2::Package,
) -> Result<(), oneshot::error::RecvError> {
let (tx, rx) = oneshot::channel();
let message = StateMessage::AddRound2Package(identifier, package, tx);
let _ = self.sender.send(message).await;
rx.await
}
}

#[cfg(test)]
mod dkg_state_tests {
use super::*;
use crate::node::protocol::message_id_generator::MessageIdGenerator;
#[mockall_double::double]
use crate::node::reliable_sender::ReliableSenderHandle;
use crate::node::{test_helpers::support::build_round2_state, MembershipHandle};
use futures::FutureExt;
use rand::thread_rng;
use std::collections::BTreeMap;
use tokio::sync::oneshot;
Expand Down Expand Up @@ -292,6 +355,58 @@ mod dkg_state_tests {
actor.add_secret_package(secret_package.clone(), tx1);
assert_eq!(actor.state.round1_secret_package, Some(secret_package));
}

#[tokio::test]
async fn test_actor_add_round2_package() {
let (_tx, rx) = mpsc::channel(1);
let mut actor = Actor::new(rx);

let membership_handle = MembershipHandle::start("localhost".to_string()).await;
for i in 1..3 {
let mut mock_reliable_sender = ReliableSenderHandle::default();
mock_reliable_sender.expect_clone().returning(|| {
let mut mock = ReliableSenderHandle::default();
mock.expect_clone().returning(ReliableSenderHandle::default);
mock.expect_send()
//.times(1)
.returning(|_| futures::future::ok(()).boxed());
mock
});
let _ = membership_handle
.add_member(format!("localhost{}", i), mock_reliable_sender)
.await;
}
let state = crate::node::state::State::new(
membership_handle,
MessageIdGenerator::new("local".to_string()),
);
let (state, round1_packages) = build_round2_state(state).await;

// Generate round2 packages
let (round2_secret, round2_packages) = frost::keys::dkg::part2(
state
.dkg_state
.get_round1_secret_package()
.await
.unwrap()
.unwrap(),
&round1_packages,
)
.unwrap();

// Add each round2 package to state
for (identifier, round2_package) in round2_packages.iter() {
let (tx, _rx) = oneshot::channel();
actor.add_round2_package(*identifier, round2_package.clone(), tx);
}

assert_eq!(actor.state.received_round2_packages.len(), 2);

let (tx, rx) = oneshot::channel();
actor.get_received_round2_packages(tx);
let received_packages = rx.await.unwrap();
assert_eq!(received_packages.len(), 2);
}
}

#[cfg(test)]
Expand Down
51 changes: 51 additions & 0 deletions src/node/test_helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ pub(crate) mod support {
use crate::node::membership::MembershipHandle;
#[mockall_double::double]
use crate::node::reliable_sender::ReliableSenderHandle;
use crate::node::{self, protocol::dkg::state::Round1Map};
use frost_secp256k1 as frost;
use rand::thread_rng;

/// Builds a membership with the given number of nodes
/// Do not add the local node to the membership, therefore it loops from 1 to num
Expand All @@ -39,4 +42,52 @@ pub(crate) mod support {
}
membership_handle
}

pub async fn build_round2_state(state: node::state::State) -> (node::state::State, Round1Map) {
let rng = thread_rng();
let mut round1_packages = Round1Map::new();

// generate our round1 secret and package
let (secret_package, round1_package) = frost::keys::dkg::part1(
frost::Identifier::derive(b"node1").unwrap(),
3,
2,
rng.clone(),
)
.unwrap();
log::debug!("Secret package {:?}", secret_package);

// add our secret package to state
state
.dkg_state
.add_round1_secret_package(secret_package)
.await
.unwrap();

// Add packages for other nodes
let (_, round1_package2) = frost::keys::dkg::part1(
frost::Identifier::derive(b"node2").unwrap(),
3,
2,
rng.clone(),
)
.unwrap();
round1_packages.insert(
frost::Identifier::derive(b"node2").unwrap(),
round1_package2,
);

let (_, round1_package3) = frost::keys::dkg::part1(
frost::Identifier::derive(b"node3").unwrap(),
3,
2,
rng.clone(),
)
.unwrap();
round1_packages.insert(
frost::Identifier::derive(b"node3").unwrap(),
round1_package3,
);
(state, round1_packages)
}
}

0 comments on commit d4ac7ce

Please sign in to comment.