Skip to content

Commit

Permalink
Merge pull request ProvableHQ#1577 from niklaslong/serde
Browse files Browse the repository at this point in the history
Make message deserialisation generic over `Read`
  • Loading branch information
howardwu authored Jan 27, 2022
2 parents ba33a93 + 35195f0 commit e1c4e22
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 39 deletions.
92 changes: 56 additions & 36 deletions src/network/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@ use snarkvm::{dpc::posw::PoSWProof, prelude::*};
use ::bytes::{Buf, BufMut, Bytes, BytesMut};
use anyhow::{anyhow, Result};
use serde::{de::DeserializeOwned, Serialize};
use std::{io::Write, marker::PhantomData, net::SocketAddr};
use std::{
io::{Cursor, Seek, Write},
marker::PhantomData,
net::SocketAddr,
};
use tokio::task;
use tokio_util::codec::{Decoder, Encoder};

Expand Down Expand Up @@ -223,61 +227,77 @@ impl<N: Network, E: Environment> Message<N, E> {

/// Deserializes the given buffer into a message.
#[inline]
pub fn deserialize(buffer: &[u8]) -> Result<Self> {
// Ensure the buffer contains at least the length of an ID.
if buffer.len() < 2 {
return Err(anyhow!("Invalid message buffer"));
}
pub fn deserialize<R: Read + Seek>(reader: &mut R) -> Result<Self> {
// Read the message ID.
let id: u16 = bincode::deserialize_from(&mut *reader)?;

// Helper function to read all the remaining bytes from a reader.
let read_to_end = |reader: &mut R| -> Result<Bytes> {
let mut data = vec![];
reader.read_to_end(&mut data)?;

// Split the buffer into the ID and data portion.
let (id, data) = (u16::from_le_bytes([buffer[0], buffer[1]]), &buffer[2..]);
Ok(data.into())
};

// Deserialize the data field.
let message = match id {
0 => Self::BlockRequest(bincode::deserialize(&data[0..4])?, bincode::deserialize(&data[4..8])?),
1 => Self::BlockResponse(Data::Buffer(data.to_vec().into())),
0 => Self::BlockRequest(bincode::deserialize_from(&mut *reader)?, bincode::deserialize_from(&mut *reader)?),
1 => Self::BlockResponse(Data::Buffer(read_to_end(&mut *reader)?)),
2 => {
let (version, fork_depth, node_type, status, listener_port, nonce, cumulative_weight) = bincode::deserialize(data)?;
let (version, fork_depth, node_type, status, listener_port, nonce, cumulative_weight) =
bincode::deserialize_from(&mut *reader)?;

Self::ChallengeRequest(version, fork_depth, node_type, status, listener_port, nonce, cumulative_weight)
}
3 => Self::ChallengeResponse(Data::Buffer(data.to_vec().into())),
4 => match data.is_empty() {
true => Self::Disconnect,
false => return Err(anyhow!("Invalid 'Disconnect' message: {:?} {:?}", buffer, data)),
},
5 => match data.is_empty() {
true => Self::PeerRequest,
false => return Err(anyhow!("Invalid 'PeerRequest' message: {:?} {:?}", buffer, data)),
},
6 => Self::PeerResponse(bincode::deserialize(data)?),
3 => Self::ChallengeResponse(Data::Buffer(read_to_end(&mut *reader)?)),
4 => {
let data = read_to_end(&mut *reader)?;

match data.is_empty() {
true => Self::Disconnect,
false => return Err(anyhow!("Invalid 'Disconnect' message: {:?}", data)),
}
}
5 => {
let data = read_to_end(&mut *reader)?;

match data.is_empty() {
true => Self::PeerRequest,
false => return Err(anyhow!("Invalid 'PeerRequest' message: {:?}", data)),
}
}
6 => Self::PeerResponse(bincode::deserialize_from(&mut *reader)?),
7 => {
let (version, fork_depth, node_type, status, block_hash) = bincode::deserialize(&data[0..48])?;
let block_header = Data::Buffer(data[48..].to_vec().into());
let (version, fork_depth, node_type, status, block_hash) = bincode::deserialize_from(&mut *reader)?;
let block_header = Data::Buffer(read_to_end(&mut *reader)?);

Self::Ping(version, fork_depth, node_type, status, block_hash, block_header)
}
8 => {
let is_fork = match data[0] {
let fork_flag: u8 = bincode::deserialize_from(&mut *reader)?;
let data = read_to_end(&mut *reader)?;

let is_fork = match fork_flag {
0 => None,
1 => Some(true),
2 => Some(false),
_ => return Err(anyhow!("Invalid 'Pong' message: {:?} {:?}", buffer, data)),
_ => return Err(anyhow!("Invalid 'Pong' message: {:?}", data)),
};

Self::Pong(is_fork, Data::Buffer(data[1..].to_vec().into()))
Self::Pong(is_fork, Data::Buffer(data))
}
9 => Self::UnconfirmedBlock(
bincode::deserialize(&data[0..4])?,
bincode::deserialize(&data[4..36])?,
Data::Buffer(data[36..].to_vec().into()),
bincode::deserialize_from(&mut *reader)?,
bincode::deserialize_from(&mut *reader)?,
Data::Buffer(read_to_end(&mut *reader)?),
),
10 => Self::UnconfirmedTransaction(Data::Buffer(data.to_vec().into())),
11 => Self::PoolRegister(bincode::deserialize(data)?),
12 => Self::PoolRequest(bincode::deserialize(&data[0..8])?, Data::Buffer(data[8..].to_vec().into())),
10 => Self::UnconfirmedTransaction(Data::Buffer(read_to_end(&mut *reader)?)),
11 => Self::PoolRegister(bincode::deserialize_from(&mut *reader)?),
12 => Self::PoolRequest(bincode::deserialize_from(&mut *reader)?, Data::Buffer(read_to_end(&mut *reader)?)),
13 => Self::PoolResponse(
bincode::deserialize(&data[0..32])?,
bincode::deserialize(&data[32..64])?,
Data::Buffer(data[64..].to_vec().into()),
bincode::deserialize_from(&mut *reader)?,
bincode::deserialize_from(&mut *reader)?,
Data::Buffer(read_to_end(&mut *reader)?),
),
_ => return Err(anyhow!("Invalid message ID {}", id)),
};
Expand Down Expand Up @@ -342,7 +362,7 @@ impl<N: Network, E: Environment> Decoder for Message<N, E> {
}

// Convert the buffer to a message, or fail if it is not valid.
let message = match Message::deserialize(&source[4..][..length]) {
let message = match Message::deserialize(&mut Cursor::new(&source[4..][..length])) {
Ok(message) => Ok(Some(message)),
Err(error) => Err(std::io::Error::new(std::io::ErrorKind::InvalidData, error)),
};
Expand Down
6 changes: 3 additions & 3 deletions testing/src/test_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ impl Handshake for TestNode {
connection.reader().read_exact(&mut buf[..MESSAGE_LENGTH_PREFIX_SIZE]).await?;
let len = u32::from_le_bytes(buf[..MESSAGE_LENGTH_PREFIX_SIZE].try_into().unwrap()) as usize;
connection.reader().read_exact(&mut buf[..len]).await?;
let peer_request = ClientMessage::deserialize(&buf[..len]);
let peer_request = ClientMessage::deserialize(&mut io::Cursor::new(&buf[..len]));

// Register peer's nonce.
let (peer_listening_addr, peer_nonce) = if let Ok(Message::ChallengeRequest(
Expand Down Expand Up @@ -263,7 +263,7 @@ impl Handshake for TestNode {
connection.reader().read_exact(&mut buf[..MESSAGE_LENGTH_PREFIX_SIZE]).await?;
let len = u32::from_le_bytes(buf[..MESSAGE_LENGTH_PREFIX_SIZE].try_into().unwrap()) as usize;
connection.reader().read_exact(&mut buf[..len]).await?;
let peer_response = ClientMessage::deserialize(&buf[..len]);
let peer_response = ClientMessage::deserialize(&mut io::Cursor::new(&buf[..len]));

if let Ok(Message::ChallengeResponse(block_header)) = peer_response {
let block_header = block_header.deserialize().await.unwrap();
Expand Down Expand Up @@ -306,7 +306,7 @@ impl Reading for TestNode {
return Ok(None);
}

match ClientMessage::deserialize(&buf[..len]) {
match ClientMessage::deserialize(&mut io::Cursor::new(&buf[..len])) {
Ok(msg) => {
info!(parent: self.node().span(), "received a {} from {}", msg.name(), source);
Ok(Some(msg))
Expand Down

0 comments on commit e1c4e22

Please sign in to comment.