Skip to content

Commit

Permalink
keccak: Reset CUDA stack limit (risc0#2671)
Browse files Browse the repository at this point in the history
  • Loading branch information
flaub authored Dec 19, 2024
1 parent 0231629 commit 3190365
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 2 deletions.
11 changes: 11 additions & 0 deletions risc0/circuit/keccak-sys/kernels/cuda/ffi.cu
Original file line number Diff line number Diff line change
Expand Up @@ -233,4 +233,15 @@ const char* risc0_circuit_keccak_cuda_scatter(Fp* into,
return nullptr;
}

const char* risc0_circuit_keccak_cuda_reset() {
try {
CUDA_OK(cudaDeviceSetLimit(cudaLimit::cudaLimitStackSize, 0));
} catch (const std::exception& err) {
return strdup(err.what());
} catch (...) {
return strdup("Generic exception");
}
return nullptr;
}

} // extern "C"
2 changes: 2 additions & 0 deletions risc0/circuit/keccak-sys/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,6 @@ extern "C" {
domain: u32,
poly_mix_pows: *const u32,
) -> *const std::os::raw::c_char;

pub fn risc0_circuit_keccak_cuda_reset() -> *const std::os::raw::c_char;
}
15 changes: 13 additions & 2 deletions risc0/circuit/keccak/src/prove/hal/cuda.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ use std::rc::Rc;

use anyhow::Result;
use risc0_circuit_keccak_sys::{
risc0_circuit_keccak_cuda_eval_check, risc0_circuit_keccak_cuda_scatter,
risc0_circuit_keccak_cuda_witgen, RawBuffer, RawExecBuffers, RawPreflightTrace, ScatterInfo,
risc0_circuit_keccak_cuda_eval_check, risc0_circuit_keccak_cuda_reset,
risc0_circuit_keccak_cuda_scatter, risc0_circuit_keccak_cuda_witgen, RawBuffer, RawExecBuffers,
RawPreflightTrace, ScatterInfo,
};
use risc0_core::{
field::{
Expand Down Expand Up @@ -58,6 +59,12 @@ impl<CH: CudaHash> CudaCircuitHal<CH> {
}
}

impl<CH: CudaHash> Drop for CudaCircuitHal<CH> {
fn drop(&mut self) {
cuda_reset();
}
}

pub(crate) struct CudaPreflightOrder;

// Reorder our processing so that we process each mux arm together.
Expand Down Expand Up @@ -222,6 +229,10 @@ pub fn keccak_prover() -> Result<Box<dyn KeccakProver>> {
Ok(Box::new(KeccakProverImpl { hal, circuit_hal }))
}

fn cuda_reset() {
ffi_wrap(|| unsafe { risc0_circuit_keccak_cuda_reset() }).unwrap();
}

#[cfg(test)]
mod tests {
use std::rc::Rc;
Expand Down

0 comments on commit 3190365

Please sign in to comment.