Skip to content

Commit

Permalink
Extract EvalCheck trait from Hal (risc0#252)
Browse files Browse the repository at this point in the history
  • Loading branch information
flaub authored Aug 24, 2022
1 parent 0d360c7 commit a34d350
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 89 deletions.
77 changes: 10 additions & 67 deletions risc0/zkp/rust/src/hal/cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,27 +25,23 @@ use rayon::prelude::*;

use super::{Buffer, Hal};
use crate::{
adapter::{PolyFp, PolyFpContext},
core::{
fp::Fp,
fp4::{Fp4, EXT_SIZE},
log2_ceil,
ntt::{bit_rev_32, bit_reverse, evaluate_ntt, expand, interpolate_ntt},
rou::ROU_FWD,
sha::{Digest, Sha},
sha_cpu,
},
field::Elem,
FRI_FOLD, INV_RATE,
FRI_FOLD,
};

pub struct CpuHal<'a, C: PolyFp> {
circuit: &'a C,
}
pub struct CpuHal {}

impl<'a, C: PolyFp> CpuHal<'a, C> {
pub fn new(circuit: &'a C) -> Self {
CpuHal { circuit }
impl CpuHal {
pub fn new() -> Self {
CpuHal {}
}
}

Expand Down Expand Up @@ -92,15 +88,15 @@ impl<T: Default + Clone + Pod> CpuBuffer<T> {
}
}

fn as_slice<'a>(&'a self) -> Ref<'a, [T]> {
pub fn as_slice<'a>(&'a self) -> Ref<'a, [T]> {
let vec = self.buf.borrow();
Ref::map(vec, |vec| {
let slice = bytemuck::cast_slice(vec);
&slice[self.region.range()]
})
}

fn as_slice_mut<'a>(&'a self) -> RefMut<'a, [T]> {
pub fn as_slice_mut<'a>(&'a self) -> RefMut<'a, [T]> {
let vec = self.buf.borrow_mut();
RefMut::map(vec, |vec| {
let slice = bytemuck::cast_slice_mut(vec);
Expand Down Expand Up @@ -136,7 +132,7 @@ impl<T: Pod> Buffer<T> for CpuBuffer<T> {
}
}

impl<'a, E: PolyFp> Hal for CpuHal<'a, E> {
impl Hal for CpuHal {
type BufferFp = CpuBuffer<Fp>;
type BufferFp4 = CpuBuffer<Fp4>;
type BufferDigest = CpuBuffer<Digest>;
Expand Down Expand Up @@ -414,49 +410,6 @@ impl<'a, E: PolyFp> Hal for CpuHal<'a, E> {
*output = *sha.hash_pair(&input[0], &input[1]);
});
}

fn eval_check(
&self,
_circuit: &str,
check: &CpuBuffer<Fp>,
code: &CpuBuffer<Fp>,
data: &CpuBuffer<Fp>,
accum: &CpuBuffer<Fp>,
mix: &CpuBuffer<Fp>,
out: &CpuBuffer<Fp>,
poly_mix: Fp4,
po2: usize,
steps: usize,
) {
const EXP_PO2: usize = log2_ceil(INV_RATE);

let domain = steps * INV_RATE;
let code = code.as_slice();
let data = data.as_slice();
let accum = accum.as_slice();
let mix = mix.as_slice();
let out = out.as_slice();
let mut check = check.as_slice_mut();
// TODO: parallelize
for cycle in 0..domain {
let args: &[&[Fp]] = &[&code, &out, &data, &mix, &accum];
let cond = self.circuit.poly_fp(
&PolyFpContext {
size: domain,
cycle,
mix: poly_mix,
},
args,
);
let x = Fp::new(ROU_FWD[po2 + EXP_PO2]).pow(cycle);
// TODO: what is this magic number 3?
let y = (Fp::new(3) * x).pow(1 << po2);
let ret = cond.tot * (y - Fp::new(1)).inv();
for i in 0..EXT_SIZE {
check[i * domain + cycle] = ret.elems()[i];
}
}
}
}

#[cfg(test)]
Expand All @@ -465,28 +418,18 @@ mod test {

use super::*;

struct PolyFpMock {}

impl PolyFp for PolyFpMock {
fn poly_fp(&self, _ctx: &PolyFpContext, _args: &[&[Fp]]) -> crate::adapter::MixState {
unimplemented!()
}
}

#[test]
#[should_panic]
fn check_req() {
let mock = PolyFpMock {};
let hal = CpuHal::new(&mock);
let hal = CpuHal::new();
let a = hal.alloc_fp(10);
let b = hal.alloc_fp(20);
hal.eltwise_add_fp(&a, &b, &b);
}

#[test]
fn fp() {
let mock = PolyFpMock {};
let hal = CpuHal::new(&mock);
let hal = CpuHal::new();
const COUNT: usize = 1024 * 1024;
test_binary(
&hal,
Expand Down
15 changes: 8 additions & 7 deletions risc0/zkp/rust/src/hal/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,17 +90,18 @@ pub trait Hal {
fn sha_rows(&self, output: &Self::BufferDigest, matrix: &Self::BufferFp);

fn sha_fold(&self, io: &Self::BufferDigest, input_size: usize, output_size: usize);
}

pub trait EvalCheck<H: Hal> {
/// Compute check polynomial.
fn eval_check(
&self,
circuit: &str,
check: &Self::BufferFp,
code: &Self::BufferFp,
data: &Self::BufferFp,
accum: &Self::BufferFp,
mix: &Self::BufferFp,
out: &Self::BufferFp,
check: &H::BufferFp,
code: &H::BufferFp,
data: &H::BufferFp,
accum: &H::BufferFp,
mix: &H::BufferFp,
out: &H::BufferFp,
poly_mix: Fp4,
po2: usize,
steps: usize,
Expand Down
14 changes: 9 additions & 5 deletions risc0/zkp/rust/src/prove/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use crate::{
sha::Sha,
},
field::Elem,
hal::{Buffer, Hal},
hal::{Buffer, EvalCheck, Hal},
prove::{fri::fri_prove, poly_group::PolyGroup, write_iop::WriteIOP},
taps::{RegisterGroup, TapSet},
CHECK_SIZE, INV_RATE, MAX_CYCLES_PO2,
Expand Down Expand Up @@ -64,12 +64,17 @@ pub trait Circuit {
fn get_steps(&self) -> usize;
}

pub fn prove_without_seal<H: Hal, S: Sha, C: Circuit>(_hal: &H, sha: &S, circuit: &mut C) {
pub fn prove_without_seal<S: Sha, C: Circuit>(sha: &S, circuit: &mut C) {
let mut iop = WriteIOP::new(sha);
circuit.execute(&mut iop);
}

pub fn prove<H: Hal, S: Sha, C: Circuit>(hal: &H, sha: &S, circuit: &mut C) -> Vec<u32> {
pub fn prove<H: Hal, S: Sha, C: Circuit, E: EvalCheck<H>>(
hal: &H,
sha: &S,
circuit: &mut C,
eval: &E,
) -> Vec<u32> {
let taps = circuit.get_taps();
let code_size = taps.group_size(RegisterGroup::Code);
let data_size = taps.group_size(RegisterGroup::Data);
Expand Down Expand Up @@ -111,8 +116,7 @@ pub fn prove<H: Hal, S: Sha, C: Circuit>(hal: &H, sha: &S, circuit: &mut C) -> V
let check_poly = hal.alloc_fp(EXT_SIZE * domain);
let mix = hal.copy_fp_from(circuit.get_mix());
let out = hal.copy_fp_from(circuit.get_output());
hal.eval_check(
"rv32im",
eval.eval_check(
&check_poly,
&code_group.evaluated,
&data_group.evaluated,
Expand Down
5 changes: 2 additions & 3 deletions risc0/zkvm/sdk/rust/src/method_id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,14 @@ impl MethodId {

#[cfg(feature = "prove")]
pub fn compute_with_limit(elf_contents: &[u8], limit: u32) -> Result<Self> {
use crate::{elf::Program, prove::CIRCUIT, CODE_SIZE};
use crate::{elf::Program, CODE_SIZE};
use risc0_zkp::{
hal::{cpu::CpuHal, Hal},
prove::poly_group::PolyGroup,
};
use risc0_zkvm_circuit::CircuitImpl;
use risc0_zkvm_platform::memory::MEM_SIZE;

let hal = CpuHal::<CircuitImpl>::new(&CIRCUIT);
let hal = CpuHal::new();
let program = Program::load_elf(elf_contents, MEM_SIZE as u32)?;

// Start with an empty table
Expand Down
83 changes: 83 additions & 0 deletions risc0/zkvm/sdk/rust/src/prove/cpu_eval.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Copyright 2022 Risc0, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use risc0_zkp::{
adapter::{PolyFp, PolyFpContext},
core::{
fp::Fp,
fp4::{Fp4, EXT_SIZE},
log2_ceil,
rou::ROU_FWD,
},
field::Elem,
hal::{
cpu::{CpuBuffer, CpuHal},
EvalCheck,
},
INV_RATE,
};

pub struct CpuEvalCheck<'a, C: PolyFp> {
circuit: &'a C,
}

impl<'a, C: PolyFp> CpuEvalCheck<'a, C> {
pub fn new(circuit: &'a C) -> Self {
Self { circuit }
}
}

impl<'a, C: PolyFp> EvalCheck<CpuHal> for CpuEvalCheck<'a, C> {
fn eval_check(
&self,
check: &CpuBuffer<Fp>,
code: &CpuBuffer<Fp>,
data: &CpuBuffer<Fp>,
accum: &CpuBuffer<Fp>,
mix: &CpuBuffer<Fp>,
out: &CpuBuffer<Fp>,
poly_mix: Fp4,
po2: usize,
steps: usize,
) {
const EXP_PO2: usize = log2_ceil(INV_RATE);

let domain = steps * INV_RATE;
let code = code.as_slice();
let data = data.as_slice();
let accum = accum.as_slice();
let mix = mix.as_slice();
let out = out.as_slice();
let mut check = check.as_slice_mut();
// TODO: parallelize
for cycle in 0..domain {
let args: &[&[Fp]] = &[&code, &out, &data, &mix, &accum];
let cond = self.circuit.poly_fp(
&PolyFpContext {
size: domain,
cycle,
mix: poly_mix,
},
args,
);
let x = Fp::new(ROU_FWD[po2 + EXP_PO2]).pow(cycle);
// TODO: what is this magic number 3?
let y = (Fp::new(3) * x).pow(1 << po2);
let ret = cond.tot * (y - Fp::new(1)).inv();
for i in 0..EXT_SIZE {
check[i * domain + cycle] = ret.elems()[i];
}
}
}
}
2 changes: 0 additions & 2 deletions risc0/zkvm/sdk/rust/src/prove/exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,6 @@ impl MemoryState {
}
}

impl MemoryState {}

fn split_word(value: u32) -> (Fp, Fp) {
(Fp::new(value & 0xffff), Fp::new(value >> 16))
}
Expand Down
22 changes: 17 additions & 5 deletions risc0/zkvm/sdk/rust/src/prove/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.

mod cpu_eval;
pub mod exec;

use std::io::Write;

use anyhow::Result;
use lazy_static::lazy_static;
use risc0_zkp::{
core::sha::default_implementation, hal::cpu::CpuHal, prove::adapter::ProveAdapter,
core::sha::default_implementation,
hal::{cpu::CpuHal, EvalCheck, Hal},
prove::adapter::ProveAdapter,
};
use risc0_zkvm_circuit::CircuitImpl;
use risc0_zkvm_platform::{
Expand All @@ -29,6 +32,8 @@ use risc0_zkvm_platform::{

use crate::{elf::Program, host::ProverOpts, method_id::MethodId, receipt::Receipt};

use self::cpu_eval::CpuEvalCheck;

lazy_static! {
pub static ref CIRCUIT: CircuitImpl = CircuitImpl::new();
}
Expand Down Expand Up @@ -67,20 +72,27 @@ impl<'a> Prover<'a> {
}

pub fn run(&mut self) -> Result<Receipt> {
let hal = CpuHal::new();
let circuit: &CircuitImpl = &CIRCUIT;
let eval = CpuEvalCheck::new(circuit);
self.run_with_hal(&hal, &eval)
}

pub fn run_with_hal<H: Hal, E: EvalCheck<H>>(&mut self, hal: &H, eval: &E) -> Result<Receipt> {
let skip_seal = self.inner.opts.skip_seal;

let mut executor = exec::RV32Executor::new(&CIRCUIT, &self.elf, &mut self.inner);
let circuit: &CircuitImpl = &CIRCUIT;
let mut executor = exec::RV32Executor::new(circuit, &self.elf, &mut self.inner);
executor.run()?;

let mut prover = ProveAdapter::new(&mut executor.executor);
let hal = CpuHal::<CircuitImpl>::new(&CIRCUIT);
let sha = default_implementation();

let seal = if skip_seal {
risc0_zkp::prove::prove_without_seal(&hal, sha, &mut prover);
risc0_zkp::prove::prove_without_seal(sha, &mut prover);
Vec::new()
} else {
risc0_zkp::prove::prove(&hal, sha, &mut prover)
risc0_zkp::prove::prove(hal, sha, &mut prover, eval)
};

// Attach the full version of the output journal & construct receipt object
Expand Down

0 comments on commit a34d350

Please sign in to comment.