Skip to content

Commit

Permalink
Improve executor perf (risc0#536)
Browse files Browse the repository at this point in the history
Co-authored-by: Parker Thompson <[email protected]>
  • Loading branch information
flaub and mothran authored Apr 27, 2023
1 parent 4341c84 commit da02f36
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 47 deletions.
92 changes: 61 additions & 31 deletions risc0/zkvm/benches/fib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,55 +12,85 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::rc::Rc;

use criterion::{
black_box, criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion, Throughput,
};
use risc0_zkvm::{
prove::{default_prover, Prover},
Executor, ExecutorEnv,
};
use risc0_zkvm::{prove::default_prover, Executor, ExecutorEnv};
use risc0_zkvm_methods::FIB_ELF;

fn setup(iterations: u32) -> Executor<'static> {
let env = ExecutorEnv::builder().add_input(&[iterations]).build();
Executor::from_elf(env, FIB_ELF).unwrap()
}

fn run(prover: Rc<dyn Prover>, exec: &mut Executor, with_seal: bool) {
let session = exec.run().unwrap();
if with_seal {
prover.prove_session(&session).unwrap();
}
enum Scope {
Execute,
Prove,
Total,
}

pub fn bench(c: &mut Criterion) {
let mut group = c.benchmark_group("fib");

for with_seal in [true, false] {
for iterations in [100, 200] {
let prover = default_prover();
let prover = default_prover();
for scope in [Scope::Execute, Scope::Prove, Scope::Total] {
for iterations in [100, 1000, 10_000] {
let mut exec = setup(iterations);
let session = exec.run().unwrap();
let po2 = session.segments[0].po2;
let cycles = 1 << po2;
let (exec_cycles, prove_cycles) =
session
.segments
.iter()
.fold((0, 0), |(exec_cycles, prove_cycles), segment| {
(
exec_cycles + segment.insn_cycles,
prove_cycles + (1 << segment.po2),
)
});
group.sample_size(10);
group.throughput(Throughput::Elements(cycles as u64));
group.bench_with_input(
BenchmarkId::from_parameter(format!(
"{iterations}/{}",
if with_seal { "proof" } else { "run" }
)),
&iterations,
|b, &iterations| {
b.iter_batched(
|| setup(iterations),
|mut exec| black_box(run(prover.clone(), &mut exec, with_seal)),
BatchSize::SmallInput,
)
},
);
match scope {
Scope::Execute => {
let id = BenchmarkId::from_parameter(format!("{iterations}/execute"));
group.throughput(Throughput::Elements(exec_cycles as u64));
group.bench_with_input(id, &iterations, |b, &iterations| {
b.iter_batched(
|| setup(iterations),
|mut exec| black_box(exec.run().unwrap()),
BatchSize::SmallInput,
)
});
}
Scope::Prove => {
let id = BenchmarkId::from_parameter(format!("{iterations}/prove"));
group.throughput(Throughput::Elements(prove_cycles as u64));
group.bench_with_input(id, &iterations, |b, &iterations| {
b.iter_batched(
|| {
let mut exec = setup(iterations);
exec.run().unwrap()
},
|session| black_box(prover.prove_session(&session).unwrap()),
BatchSize::SmallInput,
)
});
}
Scope::Total => {
let id = BenchmarkId::from_parameter(format!("{iterations}/total"));
group.throughput(Throughput::Elements(exec_cycles as u64));
group.bench_with_input(id, &iterations, |b, &iterations| {
b.iter_batched(
|| setup(iterations),
|mut exec| {
black_box({
let session = exec.run().unwrap();
prover.prove_session(&session).unwrap()
})
},
BatchSize::SmallInput,
)
});
}
};
}
}

Expand Down
59 changes: 46 additions & 13 deletions risc0/zkvm/examples/fib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,29 @@ use std::rc::Rc;
use clap::Parser;
use risc0_zkvm::{
prove::{default_prover, Prover},
Executor, ExecutorEnv, Session, SessionReceipt,
Executor, ExecutorEnv,
};
use risc0_zkvm_methods::FIB_ELF;
use tracing_subscriber::{prelude::*, EnvFilter};

#[derive(Parser)]
#[clap()]
#[command()]
struct Args {
/// Number of iterations.
#[clap(long)]
#[arg(short, long)]
iterations: u32,

#[arg(short, long, default_value_t = false)]
skip_prover: bool,
}

#[derive(Debug)]
#[allow(unused)]
struct Metrics {
segments: usize,
insn_cycles: usize,
cycles: usize,
seal: usize,
}

fn main() {
Expand All @@ -38,20 +50,41 @@ fn main() {

let args = Args::parse();
let prover = default_prover();

let (session, receipt) = top(prover, args.iterations);
let po2 = session.segments[0].po2;
let seal = receipt.segments[0].get_seal_bytes().len();
let journal = receipt.journal.len();
let total = seal + journal;
println!("Po2: {po2}, Seal: {seal} bytes, Journal: {journal} bytes, Total: {total} bytes");
let metrics = top(prover, args.iterations, args.skip_prover);
println!("{metrics:?}");
}

#[tracing::instrument(skip_all)]
fn top(prover: Rc<dyn Prover>, iterations: u32) -> (Session, SessionReceipt) {
fn top(prover: Rc<dyn Prover>, iterations: u32, skip_prover: bool) -> Metrics {
let env = ExecutorEnv::builder().add_input(&[iterations]).build();
let mut exec = Executor::from_elf(env, FIB_ELF).unwrap();
let session = exec.run().unwrap();
let receipt = prover.prove_session(&session).unwrap();
(session, receipt)

let (cycles, insn_cycles) =
session
.segments
.iter()
.fold((0, 0), |(cycles, insn_cycles), segment| {
(
cycles + (1 << segment.po2),
insn_cycles + segment.insn_cycles,
)
});

let seal = if skip_prover {
0
} else {
let receipt = prover.prove_session(&session).unwrap();
receipt
.segments
.iter()
.fold(0, |acc, segment| acc + segment.get_seal_bytes().len())
};

Metrics {
segments: session.segments.len(),
insn_cycles,
cycles,
seal,
}
}
1 change: 1 addition & 0 deletions risc0/zkvm/src/exec/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ impl<'a> Executor<'a> {
.len()
.try_into()
.context("Too many segment to fit in u32")?,
self.body_cycles,
));
match exit_code {
ExitCode::SystemSplit(_) => self.split(),
Expand Down
17 changes: 14 additions & 3 deletions risc0/zkvm/src/exec/monitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,10 @@ impl MemoryMonitor {
}

pub fn load_registers<const N: usize>(&mut self, idxs: [usize; N]) -> [u32; N] {
idxs.map(|idx| self.load_register(idx))
idxs.map(|idx| {
let addr = get_register_addr(idx);
u32::from_le_bytes(array::from_fn(|idx| self.image.buf[addr as usize + idx]))
})
}

pub fn load_string(&mut self, mut addr: u32) -> Result<String> {
Expand Down Expand Up @@ -282,8 +285,16 @@ impl PageFaults {
let page_idx = info.get_page_index(addr);
let entry_addr = info.get_page_entry_addr(page_idx);
match dir {
IncludeDir::Read => self.reads.insert(page_idx),
IncludeDir::Write => self.writes.insert(page_idx),
IncludeDir::Read => {
if !self.reads.insert(page_idx) {
break;
}
}
IncludeDir::Write => {
if !self.writes.insert(page_idx) {
break;
}
}
};
if page_idx == info.root_idx {
break;
Expand Down
5 changes: 5 additions & 0 deletions risc0/zkvm/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ pub struct Segment {

/// The index of this [Segment] within the [Session]
pub index: u32,

/// The number of cycles used to execute instructions.
pub insn_cycles: usize,
}

impl Session {
Expand All @@ -112,6 +115,7 @@ impl Segment {
exit_code: ExitCode,
po2: usize,
index: u32,
insn_cycles: usize,
) -> Self {
Self {
pre_image,
Expand All @@ -121,6 +125,7 @@ impl Segment {
exit_code,
po2,
index,
insn_cycles,
}
}
}

0 comments on commit da02f36

Please sign in to comment.