Skip to content

Commit

Permalink
Remove dependencies on baby bear from zkp sha code (risc0#255)
Browse files Browse the repository at this point in the history
  • Loading branch information
shkoo authored Aug 25, 2022
1 parent a656dc0 commit 66f687b
Show file tree
Hide file tree
Showing 9 changed files with 62 additions and 49 deletions.
19 changes: 10 additions & 9 deletions risc0/zkp/rust/src/core/sha.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,12 @@ pub trait Sha: Clone + Debug {
/// implementation wants to manage its own memory.
type DigestPtr: Deref<Target = Digest> + Debug;

/// Generate a SHA from a slice of bytes.
/// Generate a SHA from a slice of bytes, padding to block size
/// and adding the SHA trailer.
fn hash_bytes(&self, bytes: &[u8]) -> Self::DigestPtr;

/// Generate a SHA from a slice of words.
/// Generate a SHA from a slice of words, padding to block size
/// and adding the SHA trailer.
fn hash_words(&self, words: &[u32]) -> Self::DigestPtr {
self.hash_bytes(bytemuck::cast_slice(words) as &[u8])
}
Expand All @@ -163,11 +165,10 @@ pub trait Sha: Clone + Debug {
/// Generate a SHA from a pair of [Digests](Digest).
fn hash_pair(&self, a: &Digest, b: &Digest) -> Self::DigestPtr;

/// Generate a SHA from a slice of [Fps](Fp).
fn hash_fps(&self, fps: &[Fp]) -> Self::DigestPtr;

/// Generate a SHA from a slice of [Fp4s](Fp4).
fn hash_fp4s(&self, fp4s: &[Fp4]) -> Self::DigestPtr;
/// Generate a SHA from a slice of anything that can be
/// represented as plain old data. Pads up to the Sha block
/// boundry, but does not add the standard SHA trailer.
fn hash_raw_pod_slice<T: bytemuck::Pod>(&self, fps: &[T]) -> Self::DigestPtr;

/// Generate a new digest by mixing two digests together via XOR,
/// and storing into the first digest.
Expand Down Expand Up @@ -257,7 +258,7 @@ pub mod testutil {

fn hash_fpvec<S: Sha>(sha: &S, len: usize) -> Digest {
let items: Vec<Fp> = (0..len as u32).into_iter().map(|x| Fp::new(x)).collect();
*sha.hash_fps(items.as_slice())
*sha.hash_raw_pod_slice(items.as_slice())
}

fn hash_fp4vec<S: Sha>(sha: &S, len: usize) -> Digest {
Expand All @@ -272,7 +273,7 @@ pub mod testutil {
)
})
.collect();
*sha.hash_fp4s(items.as_slice())
*sha.hash_raw_pod_slice(items.as_slice())
}

fn test_fps<S: Sha>(sha: &S) {
Expand Down
58 changes: 36 additions & 22 deletions risc0/zkp/rust/src/core/sha_cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

//! Simple wrappers for a CPU-based SHA-256 implementation.
use alloc::{boxed::Box, vec, vec::Vec};
use alloc::{boxed::Box, vec::Vec};
use core::slice;

use sha2::{
Expand All @@ -23,11 +23,7 @@ use sha2::{
Digest as ShaDigest, Sha256,
};

use super::{
fp::Fp,
fp4::Fp4,
sha::{Digest, Sha, DIGEST_WORDS},
};
use super::sha::{Digest, Sha, DIGEST_WORDS};

static INIT_256: [u32; DIGEST_WORDS] = [
0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19,
Expand All @@ -42,22 +38,37 @@ fn set_word(buf: &mut [u8], idx: usize, word: u32) {
}

impl Impl {
/// Compute the hash of an array of [Fp]s using the specified stride.
pub fn hash_fps_stride(
/// Compute the hash of a slice of plain-old-data using the
/// specified offset and stride. 'size' specifies the number of
/// elements to hash.
pub fn hash_pod_stride<T: bytemuck::Pod>(
&self,
fps: &[Fp],
pods: &[T],
offset: usize,
size: usize,
stride: usize,
) -> Box<Digest> {
let mut state = INIT_256;
let mut block: GenericArray<u8, U64> = GenericArray::default();

let mut u8s = pods
.iter()
.skip(offset)
.step_by(stride)
.take(size)
.flat_map(|pod| bytemuck::cast_slice(slice::from_ref(pod)) as &[u8])
.cloned()
.fuse();

let mut off = 0;
for i in 0..size {
while let Some(b1) = u8s.next() {
let b2 = u8s.next().unwrap_or(0);
let b3 = u8s.next().unwrap_or(0);
let b4 = u8s.next().unwrap_or(0);
set_word(
block.as_mut_slice(),
off,
fps[offset + i * stride].as_u32_montgomery(),
u32::from_le_bytes([b1, b2, b3, b4]),
);
off += 1;
if off == 16 {
Expand Down Expand Up @@ -112,18 +123,21 @@ impl Sha for Impl {
Box::new(Digest::new(state))
}

fn hash_fps(&self, fps: &[Fp]) -> Self::DigestPtr {
self.hash_fps_stride(fps, 0, fps.len(), 1)
}

fn hash_fp4s(&self, fp4s: &[Fp4]) -> Self::DigestPtr {
let mut flat: Vec<Fp> = vec![];
for i in 0..fp4s.len() {
for j in 0..4 {
flat.push(fp4s[i].elems()[j]);
}
fn hash_raw_pod_slice<T: bytemuck::Pod>(&self, pod: &[T]) -> Self::DigestPtr {
let u8s: &[u8] = bytemuck::cast_slice(pod);
let mut state = INIT_256;
let mut blocks = u8s.chunks_exact(64);
for block in blocks.by_ref() {
compress256(&mut state, slice::from_ref(GenericArray::from_slice(block)));
}
return self.hash_fps(&flat);
let remainder = blocks.remainder();
if remainder.len() > 0 {
let mut last_block: GenericArray<u8, U64> = GenericArray::default();
bytemuck::cast_slice_mut(last_block.as_mut_slice())[..remainder.len()]
.clone_from_slice(remainder);
compress256(&mut state, slice::from_ref(&last_block));
}
Box::new(Digest::new(state))
}

// Digest two digest into one
Expand Down
2 changes: 1 addition & 1 deletion risc0/zkp/rust/src/hal/cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ impl<F: Field> Hal for CpuHal<F> {
let fp_matrix = CpuHal::<F>::to_baby_bear_fp_slice(matrix.as_slice());
let sha = sha_cpu::Impl {};
output.par_iter_mut().enumerate().for_each(|(idx, output)| {
*output = *sha.hash_fps_stride(fp_matrix, idx, col_size, count);
*output = *sha.hash_pod_stride(fp_matrix, idx, col_size, count);
});
}

Expand Down
2 changes: 1 addition & 1 deletion risc0/zkp/rust/src/prove/fri.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ where
final_coeffs.view(|view| {
let view = H::to_baby_bear_fp_slice(view);
iop.write_fp_slice(view);
let digest = iop.get_sha().hash_fps(view);
let digest = iop.get_sha().hash_raw_pod_slice(view);
iop.commit(&digest);
});
// Do queries
Expand Down
2 changes: 1 addition & 1 deletion risc0/zkp/rust/src/prove/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ pub fn prove<H: Hal, S: Sha, C: Circuit, E: EvalCheck<H>>(

debug!("Size of U = {}", coeff_u.len());
iop.write_fp4_slice(&coeff_u);
let hash_u = sha.hash_fp4s(&coeff_u);
let hash_u = sha.hash_raw_pod_slice(coeff_u.as_slice());
iop.commit(&hash_u);

// Set the mix mix value
Expand Down
2 changes: 1 addition & 1 deletion risc0/zkp/rust/src/verify/fri.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ where
// Grab the final coeffs + commit
let mut final_coeffs = vec![Fp::ZERO; EXT_SIZE * degree];
iop.read_fps(&mut final_coeffs);
let final_digest = iop.get_sha().hash_fps(&final_coeffs); // padding?
let final_digest = iop.get_sha().hash_raw_pod_slice(final_coeffs.as_slice());
iop.commit(&final_digest);
// Get the generator for the final polynomial evaluations
let gen = Fp::new(ROU_FWD[log2_ceil(domain)]);
Expand Down
2 changes: 1 addition & 1 deletion risc0/zkp/rust/src/verify/merkle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ impl MerkleTreeVerifier {
// Read out field elements from IOP.
iop.read_fps(&mut out);
// Get the hash at the leaf of the tree by hashing these field elements.
let mut cur = *iop.get_sha().hash_fps(&out);
let mut cur = *iop.get_sha().hash_raw_pod_slice(out.as_slice());
// Shift idx to start of the row
idx += self.params.row_size;
while idx >= 2 * self.params.top_size {
Expand Down
2 changes: 1 addition & 1 deletion risc0/zkp/rust/src/verify/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ where
let num_taps = taps.tap_size();
let mut coeff_u = vec![Fp4::ZERO; num_taps + CHECK_SIZE];
iop.read_fp4s(&mut coeff_u);
let hash_u = *sha.hash_fp4s(&coeff_u);
let hash_u = *sha.hash_raw_pod_slice(coeff_u.as_slice());
iop.commit(&hash_u);

// Now, convert to evaluated values
Expand Down
22 changes: 10 additions & 12 deletions risc0/zkvm/sdk/rust/guest/src/sha.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,23 +206,21 @@ impl risc0_zkp::core::sha::Sha for Impl {
raw_digest(words)
}

fn hash_fps(&self, fps: &[Fp]) -> Self::DigestPtr {
// Fps do not not include standard sha header.
if fps.len() % CHUNK_SIZE == 0 {
raw_digest(bytemuck::cast_slice(fps))
fn hash_raw_pod_slice<T: bytemuck::Pod>(&self, pod: &[T]) -> Self::DigestPtr {
let u8s: &[u8] = bytemuck::cast_slice(pod);

if u8s.len() % (CHUNK_SIZE * WORD_SIZE) == 0 {
// Already padded; no need to copy it.
raw_digest(bytemuck::cast_slice(pod))
} else {
let size = align_up(fps.len(), CHUNK_SIZE);
let mut buf: Vec<u32> = Vec::with_capacity(size);
buf.extend(bytemuck::cast_slice(fps));
let size = align_up(u8s.len(), CHUNK_SIZE * WORD_SIZE);
let mut buf: Vec<u8> = Vec::with_capacity(size);
buf.extend(bytemuck::cast_slice(pod));
buf.resize(size, 0);
raw_digest(&buf)
raw_digest(bytemuck::cast_slice(buf.as_slice()))
}
}

fn hash_fp4s(&self, fp4s: &[Fp4]) -> Self::DigestPtr {
self.hash_fps(bytemuck::cast_slice(fp4s))
}

// Generate a new digest by mixing two digests together via XOR,
// and storing into the first digest.
fn mix(&self, pool: &mut Self::DigestPtr, val: &Digest) {
Expand Down

0 comments on commit 66f687b

Please sign in to comment.