Skip to content

Commit

Permalink
Add SyncSlice for safe(r) multithreading use of slices (risc0#387)
Browse files Browse the repository at this point in the history
* Add SyncSlice to CPU buffer

This lets us share a writable CPU buffer between threads safely

* Add SyncSlice for safe(r) multithreading use of slices

* CpuBuffer now uses SyncSlice underneath
* SyncSlice now gets used for FFI instead of passing &mut[&mut [...]] around

---------

Co-authored-by: nils <[email protected]>
  • Loading branch information
shkoo and shkoo authored Feb 17, 2023
1 parent 754b89e commit 13cabf8
Show file tree
Hide file tree
Showing 15 changed files with 15,850 additions and 15,747 deletions.
19 changes: 11 additions & 8 deletions risc0/circuit/rv32im/src/cpp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ use risc0_sys::ffi::{
risc0_circuit_rv32im_step_verify_bytes, risc0_circuit_rv32im_step_verify_mem,
risc0_circuit_string_free, risc0_circuit_string_ptr, Callback, RawError,
};
use risc0_zkp::adapter::{CircuitDef, CircuitStep, CircuitStepContext, CircuitStepHandler, PolyFp};
use risc0_zkp::{
adapter::{CircuitDef, CircuitStep, CircuitStepContext, CircuitStepHandler, PolyFp},
hal::cpu::SyncSlice,
};

use crate::CircuitImpl;

Expand All @@ -31,7 +34,7 @@ impl CircuitStep<BabyBearElem> for CircuitImpl {
&self,
ctx: &CircuitStepContext,
handler: &mut S,
args: &mut [&mut [BabyBearElem]],
args: &[SyncSlice<BabyBearElem>],
) -> Result<BabyBearElem> {
call_step(
ctx,
Expand All @@ -49,7 +52,7 @@ impl CircuitStep<BabyBearElem> for CircuitImpl {
&self,
ctx: &CircuitStepContext,
handler: &mut S,
args: &mut [&mut [BabyBearElem]],
args: &[SyncSlice<BabyBearElem>],
) -> Result<BabyBearElem> {
call_step(
ctx,
Expand All @@ -67,7 +70,7 @@ impl CircuitStep<BabyBearElem> for CircuitImpl {
&self,
ctx: &CircuitStepContext,
handler: &mut S,
args: &mut [&mut [BabyBearElem]],
args: &[SyncSlice<BabyBearElem>],
) -> Result<BabyBearElem> {
call_step(
ctx,
Expand All @@ -85,7 +88,7 @@ impl CircuitStep<BabyBearElem> for CircuitImpl {
&self,
ctx: &CircuitStepContext,
handler: &mut S,
args: &mut [&mut [BabyBearElem]],
args: &[SyncSlice<BabyBearElem>],
) -> Result<BabyBearElem> {
call_step(
ctx,
Expand All @@ -103,7 +106,7 @@ impl CircuitStep<BabyBearElem> for CircuitImpl {
&self,
ctx: &CircuitStepContext,
handler: &mut S,
args: &mut [&mut [BabyBearElem]],
args: &[SyncSlice<BabyBearElem>],
) -> Result<BabyBearElem> {
call_step(
ctx,
Expand Down Expand Up @@ -144,7 +147,7 @@ impl<'a> CircuitDef<BabyBear> for CircuitImpl {}
pub(crate) fn call_step<S, F>(
ctx: &CircuitStepContext,
handler: &mut S,
args: &mut [&mut [BabyBearElem]],
args: &[SyncSlice<BabyBearElem>],
inner: F,
) -> Result<BabyBearElem>
where
Expand Down Expand Up @@ -172,7 +175,7 @@ where
};
let trampoline = get_trampoline(&call);
let mut err = RawError::default();
let args: Vec<*mut BabyBearElem> = args.iter_mut().map(|x| (*x).as_mut_ptr()).collect();
let args: Vec<*mut BabyBearElem> = args.iter().map(SyncSlice::get_ptr).collect();
let result = inner(
&mut err,
&mut call as *mut _ as *mut c_void,
Expand Down
17 changes: 11 additions & 6 deletions risc0/circuit/rv32im/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ impl TapsProvider for CircuitImpl {
#[cfg(test)]
mod tests {
use risc0_core::field::baby_bear::BabyBearElem;
use risc0_zkp::adapter::{CircuitStep, CircuitStepContext, CircuitStepHandler};
use risc0_zkp::{
adapter::{CircuitStep, CircuitStepContext, CircuitStepHandler},
hal::cpu::CpuBuffer,
};

use crate::CircuitImpl;

Expand Down Expand Up @@ -87,11 +90,13 @@ mod tests {
let circuit = CircuitImpl::new();
let mut custom = CustomStepMock {};
let ctx = CircuitStepContext { size: 0, cycle: 0 };
let mut args0 = vec![BabyBearElem::default(); 20];
let mut args2 = vec![BabyBearElem::default(); 20];
let args: &mut [&mut [BabyBearElem]] =
&mut [&mut args0, &mut [], &mut args2, &mut [], &mut []];
circuit.step_exec(&ctx, &mut custom, args).unwrap();
let args0 = CpuBuffer::from_fn(20, |_| BabyBearElem::default());
let args1 = CpuBuffer::from_fn(20, |_| BabyBearElem::default());
let args2 = CpuBuffer::from_fn(20, |_| BabyBearElem::default());
let args = [&args0, &args1, &args2].map(CpuBuffer::as_slice_sync);
circuit
.step_exec(&ctx, &mut custom, args.as_slice())
.unwrap();
}
}

Expand Down
11,698 changes: 5,849 additions & 5,849 deletions risc0/circuit/rv32im/src/poly_ext.rs

Large diffs are not rendered by default.

2,736 changes: 1,368 additions & 1,368 deletions risc0/sys/cxx/rv32im/poly_fp.cpp

Large diffs are not rendered by default.

2,900 changes: 1,450 additions & 1,450 deletions risc0/sys/cxx/rv32im/step_compute_accum.cpp

Large diffs are not rendered by default.

8,966 changes: 4,483 additions & 4,483 deletions risc0/sys/cxx/rv32im/step_exec.cpp

Large diffs are not rendered by default.

1,822 changes: 911 additions & 911 deletions risc0/sys/cxx/rv32im/step_verify_accum.cpp

Large diffs are not rendered by default.

314 changes: 157 additions & 157 deletions risc0/sys/cxx/rv32im/step_verify_bytes.cpp

Large diffs are not rendered by default.

2,782 changes: 1,391 additions & 1,391 deletions risc0/sys/cxx/rv32im/step_verify_mem.cpp

Large diffs are not rendered by default.

11 changes: 6 additions & 5 deletions risc0/zkp/src/adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use alloc::vec::Vec;
use anyhow::Result;
use risc0_core::field::{Elem, ExtElem, Field};

use crate::hal::cpu::SyncSlice;
use crate::taps::TapSet;

// TODO: Remove references to these constants so we don't depend on a
Expand Down Expand Up @@ -57,35 +58,35 @@ pub trait CircuitStep<E: Elem> {
&self,
ctx: &CircuitStepContext,
custom: &mut S,
args: &mut [&mut [E]],
args: &[SyncSlice<E>],
) -> Result<E>;

fn step_verify_bytes<S: CircuitStepHandler<E>>(
&self,
ctx: &CircuitStepContext,
custom: &mut S,
args: &mut [&mut [E]],
args: &[SyncSlice<E>],
) -> Result<E>;

fn step_verify_mem<S: CircuitStepHandler<E>>(
&self,
ctx: &CircuitStepContext,
custom: &mut S,
args: &mut [&mut [E]],
args: &[SyncSlice<E>],
) -> Result<E>;

fn step_compute_accum<S: CircuitStepHandler<E>>(
&self,
ctx: &CircuitStepContext,
custom: &mut S,
args: &mut [&mut [E]],
args: &[SyncSlice<E>],
) -> Result<E>;

fn step_verify_accum<S: CircuitStepHandler<E>>(
&self,
ctx: &CircuitStepContext,
custom: &mut S,
args: &mut [&mut [E]],
args: &[SyncSlice<E>],
) -> Result<E>;
}

Expand Down
116 changes: 102 additions & 14 deletions risc0/zkp/src/hal/cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ use core::{
cell::{Ref, RefMut},
marker::PhantomData,
ops::Range,
slice::{from_raw_parts, from_raw_parts_mut},
};
use std::{cell::RefCell, rc::Rc};

Expand Down Expand Up @@ -80,6 +79,71 @@ pub struct CpuBuffer<T> {
region: Region,
}

enum SyncSliceRef<'a, T: Default + Clone + Pod> {
FromBuf(RefMut<'a, [T]>),
FromSlice(&'a SyncSlice<'a, T>),
}

/// A buffer which can be used across multiple threads. Users are
/// responsible for ensuring that no two threads access the same
/// element at the same time.
pub struct SyncSlice<'a, T: Default + Clone + Pod> {
_buf: SyncSliceRef<'a, T>,
ptr: *mut T,
size: usize,
}

// SAFETY: SyncSlice keeps a RefMut to the original CpuBuffer, so
// no other as_slice or as_slice_muts can be active at the same time.
//
// The user of the SyncSlice is responsible for ensuring that no
// two threads access the same elements at the same time.
unsafe impl<'a, T: Default + Clone + Pod> Sync for SyncSlice<'a, T> {}

impl<'a, T: Default + Clone + Pod> SyncSlice<'a, T> {
pub fn new(mut buf: RefMut<'a, [T]>) -> Self {
let ptr = buf.as_mut_ptr();
let size = buf.len();
SyncSlice {
ptr,
size,
_buf: SyncSliceRef::FromBuf(buf),
}
}

pub fn get_ptr(&self) -> *mut T {
self.ptr
}

pub fn get(&self, offset: usize) -> T {
assert!(offset < self.size);
unsafe { self.ptr.add(offset).read() }
}

pub fn set(&self, offset: usize, val: T) {
assert!(offset < self.size);
unsafe { self.ptr.add(offset).write(val) }
}

pub fn slice(&self, offset: usize, size: usize) -> SyncSlice<'_, T> {
assert!(
offset + size <= self.size,
"Attempting to slice [{offset}, {offset} + {size} = {}) from a slice of length {}",
offset + size,
self.size
);
SyncSlice {
_buf: SyncSliceRef::FromSlice(self),
ptr: unsafe { self.ptr.add(offset) },
size: size,
}
}

pub fn size(&self) -> usize {
self.size
}
}

impl<T: Default + Clone + Pod> CpuBuffer<T> {
fn new(size: usize) -> Self {
let buf = vec![T::default(); size];
Expand All @@ -89,6 +153,10 @@ impl<T: Default + Clone + Pod> CpuBuffer<T> {
}
}

pub fn get_ptr(&self) -> *mut T {
self.as_slice_sync().get_ptr()
}

fn copy_from(slice: &[T]) -> Self {
let bytes = bytemuck::cast_slice(slice);
CpuBuffer {
Expand All @@ -97,6 +165,16 @@ impl<T: Default + Clone + Pod> CpuBuffer<T> {
}
}

pub fn from_fn<F>(size: usize, f: F) -> Self
where
F: FnMut(usize) -> T,
{
CpuBuffer {
buf: Rc::new(RefCell::new((0..size).map(f).collect())),
region: Region(0, size),
}
}

pub fn as_slice<'a>(&'a self) -> Ref<'a, [T]> {
let vec = self.buf.borrow();
Ref::map(vec, |vec| {
Expand All @@ -112,6 +190,20 @@ impl<T: Default + Clone + Pod> CpuBuffer<T> {
&mut slice[self.region.range()]
})
}

pub fn as_slice_sync<'a>(&'a self) -> SyncSlice<'a, T> {
SyncSlice::new(self.as_slice_mut())
}
}

impl<T: Default + Clone + Pod> From<Vec<T>> for CpuBuffer<T> {
fn from(vec: Vec<T>) -> CpuBuffer<T> {
let size = vec.len();
CpuBuffer {
buf: Rc::new(RefCell::new(vec)),
region: Region(0, size),
}
}
}

impl<T: Pod> Buffer<T> for CpuBuffer<T> {
Expand Down Expand Up @@ -435,20 +527,16 @@ impl<F: Field, HS: HashSuite<F>> Hal for CpuHal<F, HS> {
}

fn hash_fold(&self, io: &Self::BufferDigest, input_size: usize, output_size: usize) {
assert!(io.size() >= 2 * input_size);
assert_eq!(input_size, 2 * output_size);
let mut io = io.as_slice_mut();
let (output, input) = unsafe {
(
from_raw_parts_mut(io.as_mut_ptr().add(output_size), output_size),
from_raw_parts(io.as_ptr().add(input_size), input_size),
)
};
output
.par_iter_mut()
.zip(input.par_chunks_exact(2))
.for_each(|(output, input)| {
*output = *Self::Hash::hash_pair(&input[0], &input[1]);
});
let io = io.as_slice_sync();
let output = io.slice(output_size, output_size);
let input = io.slice(input_size, input_size);
(0..output.size()).into_par_iter().for_each(|idx| {
let in1 = input.get(2 * idx + 0);
let in2 = input.get(2 * idx + 1);
output.set(idx, *Self::Hash::hash_pair(&in1, &in2));
});
}
}

Expand Down
9 changes: 9 additions & 0 deletions risc0/zkp/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ pub mod prove;
pub mod taps;
pub mod verify;

#[cfg(not(feature = "prove"))]
pub mod hal {
pub mod cpu {
use core::marker::PhantomData;
// TODO: Don't depend on SyncSlice in non-proving code.
pub struct SyncSlice<T>(PhantomData<T>);
}
}

pub use risc0_core::field;

pub const MIN_CYCLES_PO2: usize = 11;
Expand Down
Loading

0 comments on commit 13cabf8

Please sign in to comment.