Skip to content

Commit

Permalink
update the code
Browse files Browse the repository at this point in the history
  • Loading branch information
weikengchen committed Jan 19, 2024
1 parent cfbd4ce commit 0e609d8
Show file tree
Hide file tree
Showing 11 changed files with 375 additions and 46 deletions.
10 changes: 8 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ crate-type = ["cdylib"]
pyo3 = { version = "0.20.2", features = ["extension-module", "anyhow", "auto-initialize"] }
risc0-binfmt = { git = "https://github.com/l2iterative/risc0/", branch="no-rust-runtime-for-host" }
risc0-zkvm-platform = { git = "https://github.com/l2iterative/risc0/", branch="no-rust-runtime-for-host" }
risc0-zkvm = { git = "https://github.com/l2iterative/risc0/", branch="no-rust-runtime-for-host", features = ["prove", "metal"] }
risc0-zkvm = { git = "https://github.com/l2iterative/risc0/", branch="no-rust-runtime-for-host", features = ["prove"] }
anyhow = "1.0.79"
serde = "1.0"
bincode = "1.3.3"

[profile.dev]
opt-level = 3
Expand All @@ -22,7 +24,11 @@ opt-level = 3

[profile.release]
lto = true
debug = true

[profile.release.build-override]
opt-level = 3

[features]
default = []
metal = ["risc0-zkvm/metal"]
cuda = ["risc0-zkvm/cuda"]
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# r0prover-python
Python bindings for RISC0 prover
## Python/Ray wrapper for RISC Zero prover

5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ build-backend = "maturin"

[project]
name = "l2_r0prover"
version = "0.0.1"
authors = [
{ name = "Weikeng Chen", email = "[email protected]" }
]
description = ""
requires-python = ">=3.7"
classifiers = [
"Programming Language :: Rust",
Expand Down
30 changes: 26 additions & 4 deletions src/image.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,46 @@
use crate::serialization::Pickleable;
use anyhow::Result;
use pyo3::prelude::*;
use risc0_binfmt::{MemoryImage, Program};
use risc0_zkvm_platform::memory::GUEST_MAX_MEM;
use risc0_zkvm_platform::PAGE_SIZE;
use serde::{Deserialize, Serialize};

#[pyclass]
#[pyclass(module = "l2_r0prover")]
#[derive(Serialize, Deserialize, Clone)]
pub struct Image {
memory_image: MemoryImage,
memory_image: Option<MemoryImage>,
}

impl Image {
pub fn from_elf(elf: &[u8]) -> Result<Self> {
let program = Program::load_elf(elf, GUEST_MAX_MEM as u32)?;
let image = MemoryImage::new(&program, PAGE_SIZE as u32)?;
Ok(Self {
memory_image: image,
memory_image: Some(image),
})
}

pub fn get_image(&self) -> MemoryImage {
self.memory_image.clone()
self.memory_image.as_ref().unwrap().clone()
}
}

impl Pickleable for Image {}

#[pymethods]
impl Image {
#[new]
fn new_init() -> Self {
Self { memory_image: None }
}

fn __getstate__(&self, py: Python<'_>) -> PyResult<PyObject> {
self.to_bytes(py)
}

fn __setstate__(&mut self, py: Python<'_>, state: PyObject) -> PyResult<()> {
*self = Self::from_bytes(state, py)?;
Ok(())
}
}
82 changes: 74 additions & 8 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,41 @@
mod image;
mod segment;
mod serialization;
mod session;
mod succinct;

use crate::image::Image;
use crate::segment::{Segment, SegmentReceipt};
use crate::session::SessionInfo;
use crate::session::{ExitCode, SessionInfo};
use crate::succinct::SuccinctReceipt;
use pyo3::prelude::*;
use pyo3::types::PyBytes;
use risc0_zkvm::{ExecutorEnv, ExecutorImpl, FileSegmentRef, SimpleSegmentRef, VerifierContext};
use risc0_zkvm::{
get_prover_server, ExecutorEnv, ExecutorImpl, ProverOpts, SimpleSegmentRef, VerifierContext,
};

#[pyfunction]
fn load_image_from_elf(elf: &PyBytes) -> PyResult<Image> {
Ok(Image::from_elf(elf.as_bytes())?)
}

#[pyfunction]
fn execute_with_input(image: &Image, input: &PyBytes) -> PyResult<(Vec<Segment>, SessionInfo)> {
let env = ExecutorEnv::builder()
.write_slice(input.as_bytes())
.build()?;
fn execute_with_input(
image: &Image,
input: &PyBytes,
segment_size_limit: Option<u32>,
) -> PyResult<(Vec<Segment>, SessionInfo)> {
let mut env_builder = ExecutorEnv::builder();
env_builder.write_slice(input.as_bytes());

if let Some(segment_size_limit) = segment_size_limit {
env_builder.segment_limit_po2(segment_size_limit);
}
let env = env_builder.build()?;

let mut exec = ExecutorImpl::new(env, image.get_image())?;

let time = std::time::Instant::now();
let session = exec.run()?;
let session = exec.run_with_callback(|segment| Ok(Box::new(SimpleSegmentRef::new(segment))))?;

let mut segments = vec![];
for segment_ref in session.segments.iter() {
Expand All @@ -41,10 +53,64 @@ fn prove_segment(segment: &Segment) -> PyResult<SegmentReceipt> {
Ok(res)
}

#[pyfunction]
fn lift_segment_receipt(segment_receipt: &SegmentReceipt) -> PyResult<SuccinctReceipt> {
let prover = get_prover_server(&ProverOpts::default())?;
Ok(SuccinctReceipt::new(
prover.lift(segment_receipt.get_segment_receipt_ref())?,
))
}

#[pyfunction]
fn join_succinct_receipts(receipts: Vec<PyRef<SuccinctReceipt>>) -> PyResult<SuccinctReceipt> {
let prover = get_prover_server(&ProverOpts::default())?;
assert!(receipts.len() > 0);

if receipts.len() == 1 {
Ok(receipts[0].clone())
} else {
let mut acc = prover.join(
receipts[0].get_succinct_receipt_ref(),
receipts[1].get_succinct_receipt_ref(),
)?;
for receipt in receipts.iter().skip(2) {
acc = prover.join(&acc, &receipt.get_succinct_receipt_ref())?;
}
Ok(SuccinctReceipt::new(acc))
}
}

#[pyfunction]
fn join_segment_receipts(receipts: Vec<PyRef<SegmentReceipt>>) -> PyResult<SuccinctReceipt> {
let prover = get_prover_server(&ProverOpts::default())?;
assert!(receipts.len() > 0);

if receipts.len() == 1 {
Ok(SuccinctReceipt::new(
prover.lift(receipts[0].get_segment_receipt_ref())?,
))
} else {
let mut acc = prover.lift(receipts[0].get_segment_receipt_ref())?;
for receipt in receipts.iter().skip(1) {
acc = prover.join(&acc, &prover.lift(receipt.get_segment_receipt_ref())?)?;
}
Ok(SuccinctReceipt::new(acc))
}
}

#[pymodule]
fn l2_r0prover(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_class::<Image>()?;
m.add_class::<Segment>()?;
m.add_class::<ExitCode>()?;
m.add_class::<SessionInfo>()?;
m.add_class::<SegmentReceipt>()?;
m.add_class::<SuccinctReceipt>()?;
m.add_function(wrap_pyfunction!(load_image_from_elf, m)?)?;
m.add_function(wrap_pyfunction!(execute_with_input, m)?)?;
m.add_function(wrap_pyfunction!(prove_segment, m)?)?;
m.add_function(wrap_pyfunction!(lift_segment_receipt, m)?)?;
m.add_function(wrap_pyfunction!(join_succinct_receipts, m)?)?;
m.add_function(wrap_pyfunction!(join_segment_receipts, m)?)?;
Ok(())
}
68 changes: 61 additions & 7 deletions src/segment.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,83 @@
use crate::serialization::Pickleable;
use anyhow::Result;
use pyo3::prelude::*;
use risc0_zkvm::VerifierContext;
use serde::{Deserialize, Serialize};

#[pyclass]
#[pyclass(module = "l2_r0prover")]
#[derive(Serialize, Deserialize, Clone)]
pub struct Segment {
segment: risc0_zkvm::Segment,
segment: Option<risc0_zkvm::Segment>,
}

impl Segment {
pub fn new(segment: risc0_zkvm::Segment) -> Self {
Self { segment }
Self {
segment: Some(segment),
}
}

pub fn prove(&self, verifier_context: &VerifierContext) -> Result<SegmentReceipt> {
Ok(SegmentReceipt::new(self.segment.prove(verifier_context)?))
Ok(SegmentReceipt::new(
self.segment.as_ref().unwrap().prove(verifier_context)?,
))
}
}

#[pyclass]
impl Pickleable for Segment {}

#[pymethods]
impl Segment {
#[new]
fn new_init() -> Self {
Self { segment: None }
}

fn __getstate__(&self, py: Python<'_>) -> PyResult<PyObject> {
self.to_bytes(py)
}

fn __setstate__(&mut self, py: Python<'_>, state: PyObject) -> PyResult<()> {
*self = Self::from_bytes(state, py)?;
Ok(())
}
}

#[pyclass(module = "l2_r0prover")]
#[derive(Serialize, Deserialize, Clone)]
pub struct SegmentReceipt {
segment_receipt: risc0_zkvm::SegmentReceipt,
segment_receipt: Option<risc0_zkvm::SegmentReceipt>,
}

impl SegmentReceipt {
pub fn new(segment_receipt: risc0_zkvm::SegmentReceipt) -> Self {
Self { segment_receipt }
Self {
segment_receipt: Some(segment_receipt),
}
}

pub fn get_segment_receipt_ref(&self) -> &risc0_zkvm::SegmentReceipt {
&self.segment_receipt.as_ref().unwrap()
}
}

impl Pickleable for SegmentReceipt {}

#[pymethods]
impl SegmentReceipt {
#[new]
fn new_init() -> Self {
Self {
segment_receipt: None,
}
}

fn __getstate__(&self, py: Python<'_>) -> PyResult<PyObject> {
self.to_bytes(py)
}

fn __setstate__(&mut self, py: Python<'_>, state: PyObject) -> PyResult<()> {
*self = Self::from_bytes(state, py)?;
Ok(())
}
}
30 changes: 30 additions & 0 deletions src/serialization.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
use anyhow::anyhow;
use pyo3::types::PyBytes;
use pyo3::{PyObject, PyResult, Python, ToPyObject};
use serde::de::DeserializeOwned;
use serde::Serialize;

// The code here that implements pickle for PyO3 classes comes from
// https://github.com/rth/vtext
// which is under Apache 2.0.
//
// And it is related to this issue in PyO3:
// https://github.com/PyO3/pyo3/issues/100

pub trait Pickleable: Serialize + DeserializeOwned + Clone {
fn to_bytes(&self, py: Python<'_>) -> PyResult<PyObject> {
let bytes = bincode::serialize(&self).map_err(|e| anyhow!("failed to serialize: {}", e))?;
Ok(PyBytes::new(py, &bytes).to_object(py))
}

fn from_bytes(state: PyObject, py: Python<'_>) -> PyResult<Self> {
match state.extract::<&PyBytes>(py) {
Ok(s) => {
let res: Self = bincode::deserialize(s.as_bytes())
.map_err(|e| anyhow!("failed to deserialize: {}", e))?;
Ok(res)
}
Err(e) => Err(anyhow!("failed to parse the pickled data as bytes: {}", e).into()),
}
}
}
Loading

0 comments on commit 0e609d8

Please sign in to comment.