Skip to content

Commit

Permalink
Use crypto-bigint instead of num-bigint (risc0#547)
Browse files Browse the repository at this point in the history
Use crypto-bigint instead of num-bigint in the executor and in tests.
  • Loading branch information
nategraf authored May 10, 2023
1 parent ca3f02d commit d83c7fe
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 62 deletions.
3 changes: 1 addition & 2 deletions examples/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 5 additions & 4 deletions risc0/zkvm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ getrandom = { version = "0.2", optional = true }
gimli = { version = "0.27", optional = true }
lazy-regex = { version = "2.3", optional = true }
log = "0.4"
num-bigint = { version = "0.4.3", default-features = false, features = ["rand"], optional = true }
crypto-bigint = { version = "0.5", default-features = false, features = ["rand"], optional = true }
num-derive = "0.3"
num-traits = { version = "0.2", default-features = false }
num-traits = { version = "0.2", default-features = false, optional = true }
prost = { version = "0.11", optional = true }
rand = { version = "0.8", optional = true }
rayon = { version = "1.5", optional = true }
Expand Down Expand Up @@ -85,10 +85,11 @@ profiler = [
]
prove = [
"binfmt",
"dep:num-traits",
"dep:generic-array",
"dep:getrandom",
"dep:lazy-regex",
"dep:num-bigint",
"dep:crypto-bigint",
"dep:rand",
"dep:rayon",
"dep:rrs-lib",
Expand All @@ -100,7 +101,7 @@ prove = [
]
std = [
"anyhow/std",
"num-traits/std",
"num-traits?/std",
"risc0-circuit-rv32im/std",
"risc0-zkp/std",
"serde/std",
Expand Down
39 changes: 21 additions & 18 deletions risc0/zkvm/src/exec/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ mod tests;
use std::{array, cell::RefCell, fmt::Debug, io::Write, mem::take, rc::Rc};

use anyhow::{anyhow, bail, Context, Result};
use num_bigint::BigUint;
use num_traits::Zero;
use crypto_bigint::{CheckedMul, Encoding, NonZero, U256, U512};
use risc0_zkp::{
core::{
digest::{DIGEST_BYTES, DIGEST_WORDS},
Expand Down Expand Up @@ -466,36 +465,40 @@ impl<'a> Executor<'a> {
.monitor
.load_registers([REG_A0, REG_A1, REG_A2, REG_A3, REG_A4]);

let mut load_words = |ptr: u32| {
let mut load_bigint_le_bytes = |ptr: u32| -> [u8; bigint::WIDTH_BYTES] {
let mut arr = [0u32; bigint::WIDTH_WORDS];
for i in 0..bigint::WIDTH_WORDS {
arr[i] = self.monitor.load_u32(ptr + (i * WORD_SIZE) as u32);
arr[i] = self.monitor.load_u32(ptr + (i * WORD_SIZE) as u32).to_le();
}
arr
bytemuck::cast(arr)
};

if op != 0 {
anyhow::bail!("ecall_bigint preflight: op must be set to 0");
}

// Load inputs.
let x = BigUint::from_slice(&load_words(x_ptr));
let y = BigUint::from_slice(&load_words(y_ptr));
let n = BigUint::from_slice(&load_words(n_ptr));
let x = U256::from_le_bytes(load_bigint_le_bytes(x_ptr));
let y = U256::from_le_bytes(load_bigint_le_bytes(y_ptr));
let n = U256::from_le_bytes(load_bigint_le_bytes(n_ptr));

// Compute modular multiplication, or simply multiplication if n == 0.
let z = if n.is_zero() { x * y } else { (x * y) % n };

let mut z_vec = z.to_u32_digits();
if z_vec.len() > bigint::WIDTH_WORDS {
anyhow::bail!("ecall_bigint preflight: overflow in bigint multiplication");
}
// Add leading zeros, if necessary, to pad up to the ecall BigInt width.
z_vec.resize(bigint::WIDTH_WORDS, 0);
let z: U256 = if n == U256::ZERO {
x.checked_mul(&y).unwrap()
} else {
let (w_lo, w_hi) = x.mul_wide(&y);
let w = w_hi.concat(&w_lo);
let z = w.rem(&NonZero::<U512>::from_uint(n.resize()));
z.resize()
};

// Store result.
for (i, word) in z_vec.into_iter().enumerate() {
self.monitor.store_u32(z_ptr + (i * WORD_SIZE) as u32, word);
for (i, word) in bytemuck::cast::<_, [u32; bigint::WIDTH_WORDS]>(z.to_le_bytes())
.into_iter()
.enumerate()
{
self.monitor
.store_u32(z_ptr + (i * WORD_SIZE) as u32, word.to_le());
}

Ok(OpCodeResult::new(
Expand Down
93 changes: 55 additions & 38 deletions risc0/zkvm/src/testutils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,67 +14,84 @@

use core::mem;

use num_bigint::BigUint;
use num_traits::{One, Zero};
use rand::{
distributions::{Distribution, Standard, Uniform},
Rng,
use crypto_bigint::{
rand_core::CryptoRngCore, CheckedMul, Encoding, NonZero, Random, RandomMod, U256, U512,
};
use risc0_zkvm_platform::syscall::bigint;

// Convert to little-endian u32 array. Only reinterprettation on LE machines.
fn bigint_to_arr(num: &U256) -> [u32; bigint::WIDTH_WORDS] {
let mut arr: [u32; bigint::WIDTH_WORDS] = bytemuck::cast(num.to_le_bytes());
for x in arr.iter_mut() {
*x = x.to_le();
}
arr
}

// Convert from little-endian u32 array. Only reinterprettation on LE machines.
fn arr_to_bigint(mut arr: [u32; bigint::WIDTH_WORDS]) -> U256 {
for x in arr.iter_mut() {
*x = x.to_le();
}
U256::from_le_bytes(bytemuck::cast(arr))
}

#[derive(Debug)]
pub struct BigIntTestCase {
pub x: [u32; bigint::WIDTH_WORDS],
pub y: [u32; bigint::WIDTH_WORDS],
pub modulus: [u32; bigint::WIDTH_WORDS],
}

// NOTE: Testing here could be significantly improved by creating a less uniform
// test case generator. It is likely more important to test inputs of different
// byte-lengths, with zero and 0xff bytes, and other boundary values than
// testing values in the middle.
impl Distribution<BigIntTestCase> for Standard {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> BigIntTestCase {
let bigint_max = BigUint::one() << bigint::WIDTH_BITS;
impl BigIntTestCase {
pub fn expected(&self) -> [u32; bigint::WIDTH_WORDS] {
// Load inputs.
let x = arr_to_bigint(self.x);
let y = arr_to_bigint(self.y);
let n = arr_to_bigint(self.modulus);

let modulus = Uniform::new(&BigUint::one(), &bigint_max).sample(rng);
let mut x = Uniform::new(&BigUint::zero(), &bigint_max).sample(rng);
let mut y = Uniform::new(&BigUint::zero(), &modulus).sample(rng);
// Compute modular multiplication, or simply multiplication if n == 0.
let z: U256 = if n == U256::ZERO {
x.checked_mul(&y).unwrap()
} else {
let (w_lo, w_hi) = x.mul_wide(&y);
let w = w_hi.concat(&w_lo);
let z = w.rem(&NonZero::<U512>::from_uint(n.resize()));
z.resize()
};

bigint_to_arr(&z)
}

// NOTE: Testing here could be significantly improved by creating a less uniform
// test case generator. It is likely more important to test inputs of different
// byte-lengths, with zero and 0xff bytes, and other boundary values than
// testing values in the middle.
fn sample(rng: &mut impl CryptoRngCore) -> BigIntTestCase {
let modulus = NonZero::<U256>::random(rng);
let mut x = U256::random(rng);
let mut y = U256::random_mod(rng, &modulus);

// x and y come from slightly different ranges because at least one input must
// be less than the modulus, but it doesn't matter which one. Randomly swap.
if rng.gen::<bool>() {
if (rng.next_u32() & 1) == 0 {
mem::swap(&mut x, &mut y);
}

BigIntTestCase {
x: x.to_u32_digits().try_into().unwrap(),
y: y.to_u32_digits().try_into().unwrap(),
modulus: modulus.to_u32_digits().try_into().unwrap(),
}
}
}

impl BigIntTestCase {
pub fn expected(&self) -> [u32; bigint::WIDTH_WORDS] {
let modulus = BigUint::from_slice(&self.modulus);
let z = if modulus.is_zero() {
BigUint::from_slice(&self.x) * BigUint::from_slice(&self.y)
} else {
(BigUint::from_slice(&self.x) * BigUint::from_slice(&self.y)) % modulus
};
let mut vec = z.to_u32_digits();
if vec.len() > bigint::WIDTH_WORDS {
panic!("modular multiplication result larger than input modulus");
x: bigint_to_arr(&x),
y: bigint_to_arr(&y),
modulus: bigint_to_arr(modulus.as_ref()),
}
vec.resize(bigint::WIDTH_WORDS, 0);
vec.try_into().unwrap()
}
}

/// Generate the test cases for the BigInt accelerator circuit that are applied
/// to both the simulator and circuit implementations.
pub fn generate_bigint_test_cases(rng: &mut impl Rng, rand_count: usize) -> Vec<BigIntTestCase> {
pub fn generate_bigint_test_cases(
rng: &mut impl CryptoRngCore,
rand_count: usize,
) -> Vec<BigIntTestCase> {
let zero = [0, 0, 0, 0, 0, 0, 0, 0];
let one = [1, 0, 0, 0, 0, 0, 0, 0];

Expand Down Expand Up @@ -106,6 +123,6 @@ pub fn generate_bigint_test_cases(rng: &mut impl Rng, rand_count: usize) -> Vec<
},
];

cases.extend(Standard.sample_iter(rng).take(rand_count));
cases.extend((0..rand_count).map(|_| BigIntTestCase::sample(rng)));
cases
}

0 comments on commit d83c7fe

Please sign in to comment.