Skip to content

Commit

Permalink
feat: add noise codec
Browse files Browse the repository at this point in the history
  • Loading branch information
niklaslong committed Aug 25, 2023
1 parent 23269ff commit a1c230c
Showing 1 changed file with 244 additions and 0 deletions.
244 changes: 244 additions & 0 deletions node/narwhal/src/helpers/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,247 @@ impl<N: Network> Decoder for EventCodec<N> {
}
}
}

// NOISE CODEC //

use bytes::{Buf, Bytes};
use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
use snow::{HandshakeState, StatelessTransportState};

use std::{io, sync::Arc};

// The maximum message size for noise messages. If the data to be encrypted exceedes it, it is
// chunked.
const MAX_MESSAGE_LEN: usize = 65535;

#[repr(u8)]
pub enum MessageType {
Bytes = 0,
Event,
}

impl TryFrom<u8> for MessageType {
type Error = String;

fn try_from(value: u8) -> Result<Self, Self::Error> {
match value {
0 => Ok(MessageType::Bytes),
1 => Ok(MessageType::Event),
_ => Err(format!("u8 value: {value} doesn't correspond to a message variant")),
}
}
}

#[derive(Clone, Debug, PartialEq, Eq)]
pub enum EventOrBytes<N: Network> {
Bytes(Bytes),
Event(Event<N>),
}

impl<N: Network> EventOrBytes<N> {
fn message_type(&self) -> MessageType {
match self {
EventOrBytes::Bytes(_) => MessageType::Bytes,
EventOrBytes::Event(_) => MessageType::Event,
}
}
}

#[derive(Clone)]
pub struct PostHandshakeState {
state: Arc<StatelessTransportState>,
tx_nonce: u64,
rx_nonce: u64,
}

pub enum NoiseState {
Handshake(Box<HandshakeState>),
PostHandshake(PostHandshakeState),
}

impl Clone for NoiseState {
fn clone(&self) -> Self {
match self {
Self::Handshake(..) => unimplemented!(),
Self::PostHandshake(ph_state) => Self::PostHandshake(ph_state.clone()),
}
}
}

impl NoiseState {
pub fn into_post_handshake_state(self) -> Self {
if let Self::Handshake(noise_state) = self {
let noise_state = noise_state.into_stateless_transport_mode().expect("handshake isn't finished");
Self::PostHandshake(PostHandshakeState { state: Arc::new(noise_state), tx_nonce: 0, rx_nonce: 0 })
} else {
panic!()
}
}
}

pub struct NoiseCodec<N: Network> {
codec: LengthDelimitedCodec,
event_codec: EventCodec<N>,
noise_state: NoiseState,
}

impl<N: Network> NoiseCodec<N> {
pub fn new(noise_state: NoiseState) -> Self {
Self { codec: LengthDelimitedCodec::new(), event_codec: EventCodec::default(), noise_state }
}
}

impl<N: Network> Encoder<EventOrBytes<N>> for NoiseCodec<N> {
type Error = std::io::Error;

fn encode(&mut self, message_or_bytes: EventOrBytes<N>, dst: &mut BytesMut) -> Result<(), Self::Error> {
let message_type = message_or_bytes.message_type();

let ciphertext = match self.noise_state {
NoiseState::Handshake(ref mut noise) => {
match message_or_bytes {
// Don't allow message sending before the noise handshake has completed.
EventOrBytes::Event(_) => unimplemented!(),
EventOrBytes::Bytes(bytes) => {
let mut buffer = [0u8; MAX_MESSAGE_LEN + 1];
let len = noise
.write_message(&bytes, &mut buffer[1..])
.map_err(|e| Self::Error::new(io::ErrorKind::InvalidInput, e))?;

// Set the message type flag.
buffer[0] = message_type as u8;

buffer[..len + 1].into()
}
}
}

NoiseState::PostHandshake(ref mut noise) => {
// Encode the message using the event codec.
let mut bytes = BytesMut::new();
match message_or_bytes {
// Don't allow sending raw bytes after the noise handshake has completed.
EventOrBytes::Bytes(_) => unimplemented!(),
EventOrBytes::Event(event) => self.event_codec.encode(event, &mut bytes)?,
}

// Chunk the payload if necessary.
//
// A Noise transport message is simply an AEAD ciphertext that is less than or
// equal to 65535 bytes in length, and that consists of an encrypted payload plus
// 16 bytes of authentication data.
//
// See: https://noiseprotocol.org/noise.html#the-handshakestate-object
let chunked_plaintext_msg: Vec<_> = bytes.chunks(MAX_MESSAGE_LEN - 16).collect();
let num_chunks = chunked_plaintext_msg.len() as u64;

// Encrypt the resulting bytes with Noise.
let encrypted_chunks: Vec<io::Result<Vec<u8>>> = chunked_plaintext_msg
.into_par_iter()
.enumerate()
.map(|(nonce_offset, plaintext_chunk)| {
let mut buffer = vec![0u8; MAX_MESSAGE_LEN];
let len = noise
.state
.write_message(noise.tx_nonce + nonce_offset as u64, plaintext_chunk, &mut buffer)
.map_err(|e| Self::Error::new(io::ErrorKind::InvalidInput, e))?;

buffer.truncate(len);

Ok(buffer)
})
.collect();

let mut buffer = BytesMut::new();
// Set the message type flag.
buffer.put_u8(message_type as u8);

for chunk in encrypted_chunks {
buffer.extend_from_slice(&chunk?)
}

noise.tx_nonce += num_chunks;

buffer
}
};

// Encode the resulting ciphertext using the length-delimited codec.
self.codec.encode(ciphertext.freeze(), dst)
}
}

impl<N: Network> Decoder for NoiseCodec<N> {
type Error = io::Error;
type Item = EventOrBytes<N>;

fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
// Decode the ciphertext with the length-delimited codec.
let (flag, bytes) = if let Some(mut bytes) = self.codec.decode(src)? {
let flag =
MessageType::try_from(bytes.get_u8()).map_err(|e| Self::Error::new(io::ErrorKind::InvalidData, e))?;
(flag, bytes)
} else {
return Ok(None);
};

let msg = match self.noise_state {
NoiseState::Handshake(ref mut noise) => {
// Ignore any messages before the noise handshake has completed.
if let MessageType::Event = flag {
return Ok(None);
}

// Decrypt the ciphertext in handshake mode.
let mut buffer = [0u8; MAX_MESSAGE_LEN];
let len = noise.read_message(&bytes, &mut buffer).map_err(|_| io::ErrorKind::InvalidData)?;

Some(EventOrBytes::Bytes(Bytes::copy_from_slice(&buffer[..len])))
}

NoiseState::PostHandshake(ref mut noise) => {
// Ignore raw bytes after the noise handshake has completed.
if let MessageType::Bytes = flag {
return Ok(None);
}

// Noise decryption.
let chunked_encrypted_msg: Vec<_> = bytes.chunks(MAX_MESSAGE_LEN).collect();
let num_chunks = chunked_encrypted_msg.len() as u64;

let decrypted_chunks: Vec<io::Result<Vec<u8>>> = chunked_encrypted_msg
.into_par_iter()
.enumerate()
.map(|(nonce_offset, encrypted_chunk)| {
let mut buffer = vec![0u8; MAX_MESSAGE_LEN];

// Decrypt the ciphertext in post-handshake mode.
let len = noise
.state
.read_message(noise.rx_nonce + nonce_offset as u64, encrypted_chunk, &mut buffer)
.map_err(|_| io::ErrorKind::InvalidData)?;

buffer.truncate(len);
Ok(buffer)
})
.collect();

noise.rx_nonce += num_chunks;

// Collect chunks into plaintext to be passed to the message codecs.
let mut plaintext = BytesMut::new();
for chunk in decrypted_chunks {
plaintext.extend_from_slice(&chunk?);
}

// Decode with message codecs.
match flag {
MessageType::Event => self.event_codec.decode(&mut plaintext)?.map(|msg| EventOrBytes::Event(msg)),
_ => unreachable!("bytes variant was handled as an early return"),
}
}
};

Ok(msg)
}
}

0 comments on commit a1c230c

Please sign in to comment.