Skip to content

Commit

Permalink
Improve zkp performance (risc0#265)
Browse files Browse the repository at this point in the history
* Add reg_count to taps so verify knows how big to allocate structures
* Change most uses of Vec::new in verify to use Vec::with_capacity to avoid reallocations during filling
* Avoid duplicating stack during proving
* Compute mix powers once instead of over and over again
  • Loading branch information
shkoo authored Aug 31, 2022
1 parent 54f3ff4 commit a4aa78e
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 36 deletions.
24 changes: 18 additions & 6 deletions risc0/zkp/rust/src/adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -335,10 +335,11 @@ impl CircuitStep {
}
CircuitStep::If(cond, block, _loc) => {
if stack[*cond] != Fp::new(0) {
let mut stack = stack.clone();
let stacklen = stack.len();
for op in block.iter() {
op.step(&mut stack, ctx, custom, args)?;
op.step(stack, ctx, custom, args)?;
}
stack.truncate(stacklen);
}
}
CircuitStep::Add(x1, x2, _loc) => {
Expand Down Expand Up @@ -376,10 +377,11 @@ impl CircuitStep {
stack.extend(custom.call(name, extra, &args)?);
}
CircuitStep::Nondet(block, _loc) => {
let mut stack = stack.clone();
let stacklen = stack.len();
for op in block.iter() {
op.step(&mut stack, ctx, custom, args)?;
op.step(stack, ctx, custom, args)?;
}
stack.truncate(stacklen);
}
})
}
Expand Down Expand Up @@ -559,11 +561,21 @@ impl PolyExtStep {

impl PolyExtStepDef {
pub fn step(&self, ctx: &PolyExtContext, u: &[Fp4], args: &[&[Fp]]) -> MixState {
let mut fp_vars = Vec::new();
let mut mix_vars = Vec::new();
let mut fp_vars = Vec::with_capacity(self.block.len() - (self.ret + 1));
let mut mix_vars = Vec::with_capacity(self.ret + 1);
for op in self.block.iter() {
op.step(&mut fp_vars, &mut mix_vars, ctx, u, args);
}
assert_eq!(
fp_vars.len(),
self.block.len() - (self.ret + 1),
"Miscalculated capacity for fp_vars"
);
assert_eq!(
mix_vars.len(),
self.ret + 1,
"Miscalculated capacity for mix_vars"
);
mix_vars[self.ret]
}
}
11 changes: 10 additions & 1 deletion risc0/zkp/rust/src/taps.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ pub struct TapSet<'a> {
pub combo_begin: &'a [u16],
pub group_begin: [usize; REGISTER_GROUPS.len() + 1],
pub combos_count: usize,
pub reg_count: usize,
}

impl<'a> TapSet<'a> {
Expand Down Expand Up @@ -149,6 +150,10 @@ impl<'a> TapSet<'a> {
self.combos_count
}

pub fn reg_count(&self) -> usize {
self.reg_count
}

pub fn combos(&self) -> ComboIter {
ComboIter {
data: ComboData {
Expand Down Expand Up @@ -177,6 +182,7 @@ pub struct TapSetOwned {
combo_begin: Vec<u16>,
group_begin: [usize; REGISTER_GROUPS.len() + 1],
combos_count: usize,
reg_count: usize,
}

impl TapSetOwned {
Expand All @@ -198,7 +204,7 @@ impl TapSetOwned {
let mut combo_begin = Vec::new();
let mut combo_taps = Vec::new();
let mut taps = Vec::new();

let mut tot_reg_count = 0;
// Pre-insert the 'only self' combo
let myself = BTreeSet::from([0_usize]);
let mut combos = vec![&myself];
Expand All @@ -211,6 +217,7 @@ impl TapSetOwned {
group_begin[group_id] = taps.len();
let regs = all.get(group).unwrap();
let reg_count = regs.keys().last().unwrap() + 1;
tot_reg_count += reg_count;
for reg in 0..reg_count {
// Make sure all registers have at least one tap
assert!(regs.contains_key(&reg));
Expand Down Expand Up @@ -247,6 +254,7 @@ impl TapSetOwned {
combo_begin,
group_begin,
combos_count: combos.len(),
reg_count: tot_reg_count,
}
}
}
Expand All @@ -259,6 +267,7 @@ impl<'a> From<&'a TapSetOwned> for TapSet<'a> {
combo_begin: owned.combo_begin.as_slice(),
group_begin: owned.group_begin,
combos_count: owned.combos_count,
reg_count: owned.reg_count,
}
}
}
Expand Down
36 changes: 24 additions & 12 deletions risc0/zkp/rust/src/verify/fri.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use alloc::vec;
use alloc::vec::Vec;

use rand::RngCore;

Expand All @@ -27,7 +27,7 @@ use crate::{
},
field::{baby_bear::BabyBear, Elem},
verify::{merkle::MerkleTreeVerifier, read_iop::ReadIOP, VerificationError},
FRI_FOLD, FRI_MIN_DEGREE, INV_RATE, QUERIES,
FRI_FOLD, FRI_FOLD_PO2, FRI_MIN_DEGREE, INV_RATE, QUERIES,
};

/// VerifyRoundInfo contains the data against which the queries for a particular
Expand Down Expand Up @@ -75,15 +75,16 @@ impl<'a, S: Sha> VerifyRoundInfo<'a, S> {
let group = *pos % self.domain;
// Get the column data
let data = self.merkle.verify::<BabyBear>(iop, group)?;
let mut data4 = vec![];
for i in 0..FRI_FOLD {
data4.push(Fp4::new(
data[0 * FRI_FOLD + i],
data[1 * FRI_FOLD + i],
data[2 * FRI_FOLD + i],
data[3 * FRI_FOLD + i],
));
}
let mut data4: Vec<_> = (0..FRI_FOLD)
.map(|i| {
Fp4::new(
data[0 * FRI_FOLD + i],
data[1 * FRI_FOLD + i],
data[2 * FRI_FOLD + i],
data[3 * FRI_FOLD + i],
)
})
.collect();
// Check the existing goal
if data4[quot] != *goal {
return Err(VerificationError::InvalidProof);
Expand All @@ -106,12 +107,23 @@ where
let orig_domain = INV_RATE * degree;
let mut domain = orig_domain;
// Prep the folding verfiers
let mut rounds = vec![];
let rounds_capacity =
(log2_ceil((degree + FRI_FOLD - 1) / FRI_FOLD) + FRI_FOLD_PO2 - 1) / FRI_FOLD_PO2;
let mut rounds = Vec::with_capacity(rounds_capacity);
while degree > FRI_MIN_DEGREE {
rounds.push(VerifyRoundInfo::new(iop, domain));
domain /= FRI_FOLD;
degree /= FRI_FOLD;
}
// We want to minimize reallocation in verify, so make sure we
// didn't have to reallocate.
assert!(
rounds.len() < rounds_capacity,
"Did not allocate enough rounds; needed {} for degree {} but only allocated {}",
rounds.len(),
degree,
rounds_capacity
);
// Grab the final coeffs + commit
let final_coeffs = iop.read_pod_slice(EXT_SIZE * degree);
let final_digest = iop.get_sha().hash_raw_pod_slice(final_coeffs);
Expand Down
55 changes: 38 additions & 17 deletions risc0/zkp/rust/src/verify/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ mod fri;
pub(crate) mod merkle;
pub mod read_iop;

use alloc::vec;
use alloc::{vec, vec::Vec};
use core::fmt;
use core::iter::zip;
// use log::debug;

use crate::{
Expand Down Expand Up @@ -137,7 +138,7 @@ where

// Now, convert to evaluated values
let mut cur_pos = 0;
let mut eval_u = vec![];
let mut eval_u = Vec::with_capacity(num_taps);
for reg in taps.regs() {
for i in 0..reg.size() {
let x = back_one.pow(reg.back(i)) * z;
Expand All @@ -146,6 +147,7 @@ where
}
cur_pos += reg.size();
}
assert_eq!(eval_u.len(), num_taps, "Miscalculated capacity for eval_us");

// Compute the core polynomial
let result = circuit.compute_polynomial(&eval_u, poly_mix);
Expand Down Expand Up @@ -174,27 +176,48 @@ where
// debug!("mix = {mix:?}");

// Make the mixed U polynomials
let mut combo_u = vec![];
for i in 0..combo_count {
combo_u.push(vec![Fp4::ZERO; taps.get_combo(i).size()]);
}
let mut combo_u: Vec<Vec<Fp4>> = Vec::with_capacity(combo_count + 1);
combo_u.extend(
(0..combo_count)
.into_iter()
.map(|i| vec![Fp4::ZERO; taps.get_combo(i).size()]),
);
let mut cur_mix = Fp4::ONE;
cur_pos = 0;
let mut tap_mix_pows = Vec::with_capacity(taps.reg_count());
for reg in taps.regs() {
for i in 0..reg.size() {
combo_u[reg.combo_id()][i] += cur_mix * coeff_u[cur_pos + i];
}
tap_mix_pows.push(cur_mix);
cur_mix *= mix;
cur_pos += reg.size();
}
assert_eq!(
tap_mix_pows.len(),
taps.reg_count(),
"Miscalculated capacity for tap_mix_pows"
);
// debug!("cur_mix: {cur_mix:?}, cur_pos: {cur_pos}");
// Handle check group
combo_u.push(vec![Fp4::ZERO]);
assert_eq!(
combo_u.len(),
combo_count + 1,
"Miscalculated capacity for combo_u"
);
let mut check_mix_pows = Vec::with_capacity(CHECK_SIZE);
for _ in 0..CHECK_SIZE {
combo_u[combo_count][0] += cur_mix * coeff_u[cur_pos];
cur_pos += 1;
check_mix_pows.push(cur_mix);
cur_mix *= mix;
}
assert_eq!(
check_mix_pows.len(),
CHECK_SIZE,
"Miscalculated capacity for check_mix_pows"
);
// debug!("cur_mix: {cur_mix:?}");

let gen = Fp::new(ROU_FWD[log2_ceil(domain)]);
Expand All @@ -204,20 +227,18 @@ where
size,
|iop: &mut ReadIOP<S>, idx: usize| -> Result<Fp4, VerificationError> {
let x = Fp4::from_fp(gen.pow(idx));
let mut rows = vec![];
rows.push(accum_merkle.verify::<BabyBear>(iop, idx)?);
rows.push(code_merkle.verify::<BabyBear>(iop, idx)?);
rows.push(data_merkle.verify::<BabyBear>(iop, idx)?);
let rows = [
accum_merkle.verify::<BabyBear>(iop, idx)?,
code_merkle.verify::<BabyBear>(iop, idx)?,
data_merkle.verify::<BabyBear>(iop, idx)?,
];
let check_row = check_merkle.verify::<BabyBear>(iop, idx)?;
let mut cur = Fp4::ONE;
let mut tot = vec![Fp4::ZERO; combo_count + 1];
for reg in taps.regs() {
tot[reg.combo_id()] += cur * rows[reg.group() as usize][reg.offset()];
cur *= mix;
for (reg, cur) in zip(taps.regs(), tap_mix_pows.iter()) {
tot[reg.combo_id()] += *cur * rows[reg.group() as usize][reg.offset()];
}
for i in 0..CHECK_SIZE {
tot[combo_count] += cur * check_row[i];
cur *= mix;
for (i, cur) in zip(0..CHECK_SIZE, check_mix_pows.iter()) {
tot[combo_count] += *cur * check_row[i];
}
let mut ret = Fp4::ZERO;
for i in 0..combo_count {
Expand Down
1 change: 1 addition & 0 deletions risc0/zkvm/sdk/rust/circuit/src/taps.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5949,4 +5949,5 @@ pub(crate) const TAPSET: &'static TapSet = &TapSet::<'static> {
combo_begin: &[0, 1, 3, 9, 16, 18],
group_begin: [0, 18, 34, 742],
combos_count: 5,
reg_count: 188,
};

0 comments on commit a4aa78e

Please sign in to comment.