Skip to content

Commit

Permalink
Improve GPU performance (risc0#2211)
Browse files Browse the repository at this point in the history
  • Loading branch information
flaub authored Aug 13, 2024
1 parent c71a847 commit c0642da
Show file tree
Hide file tree
Showing 46 changed files with 966 additions and 434 deletions.
44 changes: 40 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 16 additions & 4 deletions benchmarks/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 16 additions & 4 deletions examples/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion risc0/build/src/docker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ mod test {
build("../../risc0/zkvm/methods/guest/Cargo.toml");
compare_image_id(
"risc0_zkvm_methods_guest/hello_commit",
"dcec193a7e790dc71b2b4b89ff6a173eb963b9b8d0c7b3b0bef9a9aa8dad86fa",
"e1ea2ba980dd6d82ba5047b39821a7403d63ee9e5bd30da54d705b2097ce2444",
);
}
}
6 changes: 5 additions & 1 deletion risc0/build_kernel/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,10 @@ impl KernelBuild {
}

fn compile_cpp(&mut self, output: &str) {
if env::var("RISC0_SKIP_BUILD_KERNELS").is_ok() {
return;
}

// It's *highly* recommended to install `sccache` and use this combined with
// `RUSTC_WRAPPER=/path/to/sccache` to speed up rebuilds of C++ kernels
cc::Build::new()
Expand Down Expand Up @@ -341,7 +345,7 @@ impl KernelBuild {
let out_path = out_dir.join(output).with_extension(extension);
let sys_inc_dir = out_dir.join("_sys_");

if env::var("RISC0_SKIP_BUILD_KERNELS").is_ok() || env::var("RISC0_SKIP_BUILD").is_ok() {
if env::var("RISC0_SKIP_BUILD_KERNELS").is_ok() {
fs::OpenOptions::new()
.create(true)
.truncate(true)
Expand Down
3 changes: 1 addition & 2 deletions risc0/circuit/recursion/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ metal = { workspace = true }
bytemuck = "1.12"
cust = { version = "0.3", optional = true }
lazy-regex = { version = "3.2", optional = true }
nvtx = { version = "1.3", optional = true }
rand = { version = "0.8", optional = true }
rayon = { version = "1.5", optional = true }
risc0-circuit-recursion-sys = { workspace = true, optional = true }
Expand Down Expand Up @@ -66,13 +65,13 @@ prove = [
"dep:cfg-if",
"dep:downloader",
"dep:lazy-regex",
"dep:nvtx",
"dep:rand",
"dep:rayon",
"dep:risc0-sys",
"dep:serde",
"dep:sha2",
"dep:zip",
"risc0-core/perf",
"risc0-zkp/prove",
"risc0-circuit-recursion-sys",
"std",
Expand Down
7 changes: 4 additions & 3 deletions risc0/circuit/recursion/src/cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
use std::sync::Mutex;

use rayon::prelude::*;
use risc0_core::scope;
use risc0_zkp::{
adapter::{CircuitStep, CircuitStepContext, PolyFp},
core::log2_ceil,
Expand Down Expand Up @@ -122,7 +123,7 @@ where
];

let accumulator: Mutex<Accum<BabyBearExtElem>> = Mutex::new(Accum::new(steps));
tracing::info_span!("step_compute_accum").in_scope(|| {
scope!("step_compute_accum", {
(0..steps - ZK_CYCLES).into_par_iter().for_each_init(
|| Handler::<BabyBear>::new(&accumulator),
|handler, cycle| {
Expand All @@ -136,10 +137,10 @@ where
},
);
});
tracing::info_span!("calc_prefix_products").in_scope(|| {
scope!("calc_prefix_products", {
accumulator.lock().unwrap().calc_prefix_products();
});
tracing::info_span!("step_verify_accum").in_scope(|| {
scope!("step_verify_accum", {
(0..steps - ZK_CYCLES).into_par_iter().for_each_init(
|| Handler::<BabyBear>::new(&accumulator),
|handler, cycle| {
Expand Down
9 changes: 5 additions & 4 deletions risc0/circuit/recursion/src/cuda.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
use std::rc::Rc;

use cust::{memory::GpuBuffer as _, prelude::*};
use risc0_core::scope;
use risc0_sys::{cuda::SpparkError, CppError};
use risc0_zkp::{
core::log2_ceil,
Expand Down Expand Up @@ -138,7 +139,7 @@ impl<CH: CudaHash> CircuitHal<CudaHal<CH>> for CudaCircuitHal<CH> {
let wom = vec![DeviceExtElem(BabyBearExtElem::ONE); steps];
let wom = DeviceBuffer::from_slice(&wom).unwrap();

tracing::info_span!("step_compute_accum").in_scope(|| {
scope!("step_compute_accum", {
extern "C" {
fn risc0_circuit_recursion_cuda_step_compute_accum(
ctrl: DevicePointer<u8>,
Expand All @@ -163,7 +164,7 @@ impl<CH: CudaHash> CircuitHal<CudaHal<CH>> for CudaCircuitHal<CH> {
}
});

tracing::info_span!("prefix_products").in_scope(|| {
scope!("prefix_products", {
extern "C" {
fn sppark_calc_prefix_operation(
d_elems: DevicePointer<DeviceExtElem>,
Expand All @@ -184,7 +185,7 @@ impl<CH: CudaHash> CircuitHal<CudaHal<CH>> for CudaCircuitHal<CH> {
}
});

tracing::info_span!("step_verify_accum").in_scope(|| {
scope!("step_verify_accum", {
extern "C" {
fn risc0_circuit_recursion_cuda_step_verify_accum(
ctrl: DevicePointer<u8>,
Expand All @@ -211,7 +212,7 @@ impl<CH: CudaHash> CircuitHal<CudaHal<CH>> for CudaCircuitHal<CH> {
}
});

tracing::info_span!("zeroize").in_scope(|| {
scope!("zeroize", {
self.hal.eltwise_zeroize_elem(accum);
self.hal.eltwise_zeroize_elem(io);
});
Expand Down
4 changes: 3 additions & 1 deletion risc0/circuit/recursion/src/metal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
use std::{collections::HashMap, rc::Rc};

use metal::ComputePipelineDescriptor;
use risc0_core::scope;
use risc0_zkp::{
core::log2_ceil,
field::{
Expand Down Expand Up @@ -57,7 +58,6 @@ impl<MH: MetalHash> MetalCircuitHal<MH> {
}

impl<MH: MetalHash> CircuitHal<MetalHal<MH>> for MetalCircuitHal<MH> {
#[tracing::instrument(skip_all)]
fn eval_check(
&self,
check: &MetalBuffer<BabyBearElem>,
Expand All @@ -67,6 +67,8 @@ impl<MH: MetalHash> CircuitHal<MetalHal<MH>> for MetalCircuitHal<MH> {
po2: usize,
steps: usize,
) {
scope!("eval_check");

const EXP_PO2: usize = log2_ceil(INV_RATE);
let domain = steps * INV_RATE;
let rou = BabyBearElem::ROU_FWD[po2 + EXP_PO2];
Expand Down
Loading

0 comments on commit c0642da

Please sign in to comment.