Skip to content

Commit

Permalink
Improve eval check performance by precomputing powers of poly_mix (ri…
Browse files Browse the repository at this point in the history
…sc0#1537)

CUDA:
```
fib/100/prove
-------------
base        1.00     672.7±3.76ms  95.1 KElem/sec
changes     1.65    1107.9±3.14ms  57.8 KElem/sec

fib/1000/prove
--------------
base        1.00     703.6±3.33ms  91.0 KElem/sec
changes     1.61    1133.8±4.53ms  56.4 KElem/sec

fib/10000/prove
---------------
changes     1.00        2.1±0.01s  120.9 KElem/sec
base        1.03        2.2±0.01s  117.3 KElem/sec
```
CPU:
```
fib/100/prove
-------------
changes     1.00        2.9±0.01s  22.4 KElem/sec
base        1.46        4.2±0.01s  15.3 KElem/sec

fib/1000/prove
--------------
changes     1.00        2.9±0.01s  22.1 KElem/sec
base        1.46        4.2±0.01s  15.2 KElem/sec

fib/10000/prove
---------------
changes     1.00       11.6±0.03s  22.2 KElem/sec
base        1.47       17.0±0.03s  15.1 KElem/sec

```

Co-authored-by: nils <[email protected]>
  • Loading branch information
shkoo and shkoo authored Mar 13, 2024
1 parent 53e6e71 commit fbf5294
Show file tree
Hide file tree
Showing 28 changed files with 49,273 additions and 49,430 deletions.
2,221 changes: 1,108 additions & 1,113 deletions risc0/circuit/recursion-sys/cxx/poly_fp.cpp

Large diffs are not rendered by default.

8,309 changes: 4,041 additions & 4,268 deletions risc0/circuit/recursion-sys/cxx/step_exec.cpp

Large diffs are not rendered by default.

2,226 changes: 1,110 additions & 1,116 deletions risc0/circuit/recursion-sys/kernels/cuda/eval_check.cu

Large diffs are not rendered by default.

2,223 changes: 1,109 additions & 1,114 deletions risc0/circuit/recursion-sys/kernels/metal/eval_check.metal

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions risc0/circuit/recursion-sys/src/ffi.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2023 RISC Zero, Inc.
// Copyright 2024 RISC Zero, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -103,7 +103,7 @@ extern "C" {
pub fn risc0_circuit_recursion_poly_fp(
cycle: usize,
steps: usize,
poly_mix: *const BabyBearExtElem,
poly_mixs: *const BabyBearExtElem,
args_ptr: *const *const BabyBearElem,
args_len: usize,
) -> BabyBearExtElem;
Expand Down
12 changes: 3 additions & 9 deletions risc0/circuit/recursion/src/cpp.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2023 RISC Zero, Inc.
// Copyright 2024 RISC Zero, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -126,18 +126,12 @@ impl PolyFp<BabyBear> for CircuitImpl {
&self,
cycle: usize,
steps: usize,
mix: &BabyBearExtElem,
mix: &[BabyBearExtElem],
args: &[&[BabyBearElem]],
) -> BabyBearExtElem {
let args: Vec<*const BabyBearElem> = args.iter().map(|x| (*x).as_ptr()).collect();
unsafe {
risc0_circuit_recursion_poly_fp(
cycle,
steps,
mix as *const BabyBearExtElem,
args.as_ptr(),
args.len(),
)
risc0_circuit_recursion_poly_fp(cycle, steps, mix.as_ptr(), args.as_ptr(), args.len())
}
}
}
Expand Down
9 changes: 6 additions & 3 deletions risc0/circuit/recursion/src/cpu.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2023 RISC Zero, Inc.
// Copyright 2024 RISC Zero, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -18,7 +18,7 @@ use risc0_zkp::{
core::log2_ceil,
field::{
baby_bear::{BabyBear, BabyBearElem, BabyBearExtElem},
Elem, ExtElem, RootsOfUnity,
map_pow, Elem, ExtElem, RootsOfUnity,
},
hal::{cpu::CpuBuffer, CircuitHal, Hal},
INV_RATE,
Expand Down Expand Up @@ -59,6 +59,8 @@ where
const EXP_PO2: usize = log2_ceil(INV_RATE);
let domain = steps * INV_RATE;

let poly_mix_pows = map_pow(poly_mix, crate::info::POLY_MIX_POWERS);

// SAFETY: Convert a borrow of a cell into a raw const slice so that we can pass
// it over the thread boundary. This should be safe because the scope of the
// usage is within this function and each thread access will not overlap with
Expand All @@ -76,11 +78,12 @@ where
let out = unsafe { std::slice::from_raw_parts(out.as_ptr(), out.len()) };
let check = check.as_slice();
let check = unsafe { std::slice::from_raw_parts(check.as_ptr(), check.len()) };
let poly_mix_pows = poly_mix_pows.as_slice();

let args: &[&[BabyBearElem]] = &[&code, &out, &data, &mix, &accum];

(0..domain).into_par_iter().for_each(|cycle| {
let tot = self.circuit.poly_fp(cycle, domain, &poly_mix, args);
let tot = self.circuit.poly_fp(cycle, domain, poly_mix_pows, args);
let x = BabyBearElem::ROU_FWD[po2 + EXP_PO2].pow(cycle);
// TODO: what is this magic number 3?
let y = (BabyBearElem::new(3) * x).pow(1 << po2);
Expand Down
19 changes: 15 additions & 4 deletions risc0/circuit/recursion/src/cuda.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2023 RISC Zero, Inc.
// Copyright 2024 RISC Zero, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -19,7 +19,7 @@ use risc0_zkp::{
core::log2_ceil,
field::{
baby_bear::{BabyBearElem, BabyBearExtElem},
RootsOfUnity,
map_pow, Elem, ExtElem, RootsOfUnity,
},
hal::{
cuda::{
Expand Down Expand Up @@ -84,11 +84,23 @@ impl<'a, CH: CudaHash> CircuitHal<CudaHal<CH>> for CudaCircuitHal<CH> {
let domain = steps * INV_RATE;
let rou = BabyBearElem::ROU_FWD[po2 + EXP_PO2];

let poly_mix = CudaBuffer::copy_from("poly_mix", &[poly_mix]);
let rou = CudaBuffer::copy_from("rou", &[rou]);
let po2 = CudaBuffer::copy_from("po2", &[po2 as u32]);
let size = CudaBuffer::copy_from("size", &[domain as u32]);

let poly_mix_pows = map_pow(poly_mix, crate::info::POLY_MIX_POWERS);
let poly_mix_pows: &[u32; BabyBearExtElem::EXT_SIZE * crate::info::NUM_POLY_MIX_POWERS] =
BabyBearExtElem::as_u32_slice(poly_mix_pows.as_slice())
.try_into()
.unwrap();

let mix_pows_name = std::ffi::CString::new("poly_mix").unwrap();
self.module
.get_global(&mix_pows_name)
.unwrap()
.copy_from(poly_mix_pows)
.unwrap();

let stream = Stream::new(StreamFlags::DEFAULT, None).unwrap();

let kernel = self.module.get_function("eval_check").unwrap();
Expand All @@ -101,7 +113,6 @@ impl<'a, CH: CudaHash> CircuitHal<CudaHal<CH>> for CudaCircuitHal<CH> {
accum.as_device_ptr(),
mix.as_device_ptr(),
out.as_device_ptr(),
poly_mix.as_device_ptr(),
rou.as_device_ptr(),
po2.as_device_ptr(),
size.as_device_ptr()
Expand Down
14 changes: 14 additions & 0 deletions risc0/circuit/recursion/src/info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,17 @@ impl CircuitInfo for CircuitImpl {
#[rustfmt::skip]
const MIX_SIZE: usize = 20;
}

#[allow(dead_code)]
pub const NUM_POLY_MIX_POWERS: usize = 149;

#[allow(dead_code)]
pub const POLY_MIX_POWERS: &[usize] = &[
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49,
50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73,
74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 103,
110, 121, 122, 132, 137, 144, 151, 158, 168, 177, 188, 203, 216, 224, 228, 231, 242, 243, 258,
265, 272, 279, 289, 298, 309, 324, 337, 345, 352, 363, 420, 501, 516, 533, 629, 696, 760, 805,
905, 950, 955, 960, 1005, 1017, 1065, 1077, 1081, 1085, 1097, 1101, 1105, 1141,
];
14 changes: 9 additions & 5 deletions risc0/circuit/recursion/src/metal.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2023 RISC Zero, Inc.
// Copyright 2024 RISC Zero, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -19,7 +19,7 @@ use risc0_zkp::{
core::log2_ceil,
field::{
baby_bear::{BabyBearElem, BabyBearExtElem},
RootsOfUnity,
map_pow, RootsOfUnity,
},
hal::{
metal::{BufferImpl as MetalBuffer, MetalHal, MetalHash},
Expand Down Expand Up @@ -63,10 +63,14 @@ impl<MH: MetalHash> CircuitHal<MetalHal<MH>> for MetalCircuitHal<MH> {
) {
const EXP_PO2: usize = log2_ceil(INV_RATE);
let domain = steps * INV_RATE;
let poly_mix =
MetalBuffer::copy_from(&self.hal.device, self.hal.cmd_queue.clone(), &[poly_mix]);
let rou = BabyBearElem::ROU_FWD[po2 + EXP_PO2];
let rou = MetalBuffer::copy_from(&self.hal.device, self.hal.cmd_queue.clone(), &[rou]);
let poly_mix_pows = map_pow(poly_mix, crate::info::POLY_MIX_POWERS);
let poly_mix_pows = MetalBuffer::copy_from(
&self.hal.device,
self.hal.cmd_queue.clone(),
poly_mix_pows.as_slice(),
);
let po2 =
MetalBuffer::copy_from(&self.hal.device, self.hal.cmd_queue.clone(), &[po2 as u32]);
let size = MetalBuffer::copy_from(
Expand All @@ -81,7 +85,7 @@ impl<MH: MetalHash> CircuitHal<MetalHal<MH>> for MetalCircuitHal<MH> {
groups[REGISTER_GROUP_ACCUM].as_arg(),
globals[GLOBAL_MIX].as_arg(),
globals[GLOBAL_OUT].as_arg(),
poly_mix.as_arg(),
poly_mix_pows.as_arg(),
rou.as_arg(),
po2.as_arg(),
size.as_arg(),
Expand Down
4 changes: 2 additions & 2 deletions risc0/circuit/recursion/src/recursion_zkr.zip
Git LFS file not shown
Loading

0 comments on commit fbf5294

Please sign in to comment.