Skip to content

Commit

Permalink
[network] Change exchange_identity API to take &mut socket
Browse files Browse the repository at this point in the history
There's no need to take owned socket and return it back in a Result
type. This was useful in the past when this socket was a yamux stream,
but is no longer needed.

Closes: aptos-labs#3298
Approved by: davidiw
  • Loading branch information
bothra90 authored and bors-libra committed Apr 7, 2020
1 parent 945e510 commit cf94351
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 24 deletions.
4 changes: 2 additions & 2 deletions network/src/peer_manager/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ pub fn build_test_transport(
) -> BoxedTransport<(Identity, MemorySocket), impl ::std::error::Error + Sync + Send + 'static> {
let memory_transport = MemoryTransport::default();
memory_transport
.and_then(move |socket, _origin| async move {
let (identity, socket) = exchange_identity(&own_identity, socket).await?;
.and_then(move |mut socket, _origin| async move {
let identity = exchange_identity(&own_identity, &mut socket).await?;
Ok((identity, socket))
})
.boxed()
Expand Down
17 changes: 7 additions & 10 deletions network/src/protocols/identity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,7 @@ impl Identity {
}

/// The Identity exchange protocol
pub async fn exchange_identity<T>(
own_identity: &Identity,
mut socket: T,
) -> io::Result<(Identity, T)>
pub async fn exchange_identity<T>(own_identity: &Identity, socket: &mut T) -> io::Result<Identity>
where
T: AsyncRead + AsyncWrite + Unpin,
{
Expand All @@ -58,19 +55,19 @@ where
format!("Failed to serialize identity msg: {}", e),
)
})?;
write_u16frame(&mut socket, &msg).await?;
write_u16frame(socket, &msg).await?;
socket.flush().await?;

// Read an IdentityMsg from the Remote
let mut response = BytesMut::new();
read_u16frame(&mut socket, &mut response).await?;
read_u16frame(socket, &mut response).await?;
let identity = lcs::from_bytes(&response).map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("Failed to parse identity msg: {}", e),
)
})?;
Ok((identity, socket))
Ok(identity)
}

#[cfg(test)]
Expand All @@ -89,7 +86,7 @@ mod tests {

#[test]
fn simple_identify() {
let (outbound, inbound) = build_test_connection();
let (mut outbound, mut inbound) = build_test_connection();
let server_identity = Identity::new(
PeerId::random(),
vec![
Expand All @@ -105,15 +102,15 @@ mod tests {
let client_identity_config = client_identity.clone();

let server = async move {
let (identity, _connection) = exchange_identity(&server_identity_config, inbound)
let identity = exchange_identity(&server_identity_config, &mut inbound)
.await
.expect("Identity exchange fails");

assert_eq!(identity, client_identity);
};

let client = async move {
let (identity, _connection) = exchange_identity(&client_identity_config, outbound)
let identity = exchange_identity(&client_identity_config, &mut outbound)
.await
.expect("Identity exchange fails");

Expand Down
24 changes: 12 additions & 12 deletions network/src/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ pub fn build_memory_noise_transport(
Err(io::Error::new(io::ErrorKind::Other, "Not a trusted peer"))
}
})
.and_then(move |(peer_id, socket), _origin| async move {
let (identity, socket) = exchange_identity(&own_identity, socket).await?;
.and_then(move |(peer_id, mut socket), _origin| async move {
let identity = exchange_identity(&own_identity, &mut socket).await?;
match_peer_id(identity, peer_id).and_then(|identity| Ok((identity, socket)))
})
.with_timeout(TRANSPORT_TIMEOUT)
Expand Down Expand Up @@ -121,8 +121,8 @@ pub fn build_unauthenticated_memory_noise_transport(
Ok((peer_id, socket))
}
})
.and_then(move |(peer_id, socket), _origin| async move {
let (identity, socket) = exchange_identity(&own_identity, socket).await?;
.and_then(move |(peer_id, mut socket), _origin| async move {
let identity = exchange_identity(&own_identity, &mut socket).await?;
match_peer_id(identity, peer_id).and_then(|identity| Ok((identity, socket)))
})
.with_timeout(TRANSPORT_TIMEOUT)
Expand All @@ -134,8 +134,8 @@ pub fn build_memory_transport(
) -> boxed::BoxedTransport<(Identity, impl TSocket), impl ::std::error::Error> {
let memory_transport = memory::MemoryTransport::default();
memory_transport
.and_then(move |socket, _origin| async move {
Ok(exchange_identity(&own_identity, socket).await?)
.and_then(move |mut socket, _origin| async move {
Ok((exchange_identity(&own_identity, &mut socket).await?, socket))
})
.with_timeout(TRANSPORT_TIMEOUT)
.boxed()
Expand Down Expand Up @@ -164,8 +164,8 @@ pub fn build_tcp_noise_transport(
Err(io::Error::new(io::ErrorKind::Other, "Not a trusted peer"))
}
})
.and_then(move |(peer_id, socket), _origin| async move {
let (identity, socket) = exchange_identity(&own_identity, socket).await?;
.and_then(move |(peer_id, mut socket), _origin| async move {
let identity = exchange_identity(&own_identity, &mut socket).await?;
match_peer_id(identity, peer_id).and_then(|identity| Ok((identity, socket)))
})
.with_timeout(TRANSPORT_TIMEOUT)
Expand Down Expand Up @@ -195,8 +195,8 @@ pub fn build_unauthenticated_tcp_noise_transport(
Ok((peer_id, socket))
}
})
.and_then(move |(peer_id, socket), _origin| async move {
let (identity, socket) = exchange_identity(&own_identity, socket).await?;
.and_then(move |(peer_id, mut socket), _origin| async move {
let identity = exchange_identity(&own_identity, &mut socket).await?;
match_peer_id(identity, peer_id).and_then(|identity| Ok((identity, socket)))
})
.with_timeout(TRANSPORT_TIMEOUT)
Expand All @@ -207,8 +207,8 @@ pub fn build_tcp_transport(
own_identity: Identity,
) -> boxed::BoxedTransport<(Identity, impl TSocket), impl ::std::error::Error> {
LIBRA_TCP_TRANSPORT
.and_then(move |socket, _origin| async move {
Ok(exchange_identity(&own_identity, socket).await?)
.and_then(move |mut socket, _origin| async move {
Ok((exchange_identity(&own_identity, &mut socket).await?, socket))
})
.with_timeout(TRANSPORT_TIMEOUT)
.boxed()
Expand Down

0 comments on commit cf94351

Please sign in to comment.