From 09c3292740beef9eca7464265f58505516333e43 Mon Sep 17 00:00:00 2001 From: shkoo Date: Tue, 16 Aug 2022 20:37:21 -0700 Subject: [PATCH] Progress towards making zkp field-agnostic (#238) * Define "field::Elem" and "field::ExtElem" types * Add test_roots_of_unity and test_field_ops test utils to run basic tests on fields (It sounds like kalen and @tzerell will be working more on tests and documentation) * Rename baby bear field (15*2^27) from {Fp, Fp4} to field::baby_bear::{Elem, ExtElem}; Fp and Fp4 are now aliases * Change a bunch of Fp and Fp4 calls to use trait items. * Remove Fp::invalid, so we don't have to keep this complication for all our fields. (and we weren't using it in rust anyways) --- risc0/zkp/rust/benches/ntt.rs | 3 +- risc0/zkp/rust/src/adapter.rs | 5 +- risc0/zkp/rust/src/core/fp.rs | 259 ------------- risc0/zkp/rust/src/core/fp4.rs | 287 -------------- risc0/zkp/rust/src/core/mod.rs | 17 +- risc0/zkp/rust/src/core/ntt.rs | 44 ++- risc0/zkp/rust/src/core/poly.rs | 13 +- risc0/zkp/rust/src/field/baby_bear.rs | 532 ++++++++++++++++++++++++++ risc0/zkp/rust/src/field/mod.rs | 161 ++++++++ risc0/zkp/rust/src/hal/cpu.rs | 16 +- risc0/zkp/rust/src/lib.rs | 2 + risc0/zkp/rust/src/prove/adapter.rs | 10 +- risc0/zkp/rust/src/prove/executor.rs | 23 +- risc0/zkp/rust/src/prove/fri.rs | 2 +- risc0/zkp/rust/src/prove/mod.rs | 14 +- risc0/zkp/rust/src/verify/adapter.rs | 4 +- risc0/zkp/rust/src/verify/fri.rs | 14 +- risc0/zkp/rust/src/verify/merkle.rs | 3 +- risc0/zkp/rust/src/verify/mod.rs | 22 +- risc0/zkvm/circuit/lib.rs | 3 +- risc0/zkvm/sdk/rust/src/prove/exec.rs | 3 +- 21 files changed, 798 insertions(+), 639 deletions(-) delete mode 100644 risc0/zkp/rust/src/core/fp.rs delete mode 100644 risc0/zkp/rust/src/core/fp4.rs create mode 100644 risc0/zkp/rust/src/field/baby_bear.rs create mode 100644 risc0/zkp/rust/src/field/mod.rs diff --git a/risc0/zkp/rust/benches/ntt.rs b/risc0/zkp/rust/benches/ntt.rs index 78caa8ad8c..90d1354928 100644 --- a/risc0/zkp/rust/benches/ntt.rs +++ b/risc0/zkp/rust/benches/ntt.rs @@ -14,7 +14,8 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use rand::thread_rng; -use risc0_zkp::core::{fp::Fp, ntt::interpolate_ntt, Random}; +use risc0_zkp::core::{fp::Fp, ntt::interpolate_ntt}; +use risc0_zkp::field::Elem; pub fn ntt(c: &mut Criterion) { let mut group = c.benchmark_group("interpolate_ntt"); diff --git a/risc0/zkp/rust/src/adapter.rs b/risc0/zkp/rust/src/adapter.rs index ed44a93b63..c15a60ff1b 100644 --- a/risc0/zkp/rust/src/adapter.rs +++ b/risc0/zkp/rust/src/adapter.rs @@ -18,6 +18,7 @@ use anyhow::{bail, Result}; use crate::{ core::{fp::Fp, fp4::Fp4}, + field::Elem, taps::TapSet, INV_RATE, }; @@ -55,7 +56,7 @@ impl CircuitStepContext { pub fn _set(&self, base: &mut [Fp], value: Fp, offset: usize, _loc: &str) { let reg = &mut base[offset * self.size + self.cycle]; - assert!(*reg == Fp::invalid() || *reg == Fp::new(0) || *reg == value); + assert!(*reg == Fp::new(0) || *reg == value); *reg = value; } @@ -322,7 +323,7 @@ impl CircuitStep { CircuitStep::Set(base, value, offset, _loc) => { let value = stack[*value]; let reg = &mut args[*base][offset * ctx.size + ctx.cycle]; - assert!(*reg == Fp::invalid() || *reg == Fp::new(0) || *reg == value); + assert!(*reg == Fp::ZERO || *reg == value); *reg = value; } CircuitStep::GetGlobal(base, offset, _loc) => { diff --git a/risc0/zkp/rust/src/core/fp.rs b/risc0/zkp/rust/src/core/fp.rs deleted file mode 100644 index 317d79d8e0..0000000000 --- a/risc0/zkp/rust/src/core/fp.rs +++ /dev/null @@ -1,259 +0,0 @@ -// Copyright 2022 Risc0, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//! Support for the base finite field modulo 15*2^27 + 1 - -use core::ops; - -use bytemuck::{Pod, Zeroable}; -use rand::Rng; - -use super::Random; - -/// The modulus of the field. -pub const P: u32 = 15 * (1 << 27) + 1; -/// The modulus of the field as a u64. -pub const P_U64: u64 = P as u64; - -/// The Fp class is an element of the finite field F_p, where P is the prime -/// number 15*2^27 + 1. Put another way, Fp is basically integer arithmetic -/// modulo P. -/// -/// The `Fp` datatype is the core type of all of the operations done within the -/// zero knowledge proofs, and is the smallest 'addressable' datatype, and the -/// base type of which all composite types are built. In many ways, one can -/// imagine it as the word size of a very strange architecture. -/// -/// This specific prime P was chosen to: -/// - Be less than 2^31 so that it fits within a 32 bit word and doesn't -/// overflow on addition. -/// - Otherwise have as large a power of 2 in the factors of P-1 as possible. -/// -/// This last property is useful for number theoretical transforms (the fast -/// fourier transform equivelant on finite fields). See NTT.h for details. -/// -/// The Fp class wraps all the standard arithmetic operations to make the finite -/// field elements look basically like ordinary numbers (which they mostly are). -#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, PartialOrd, Zeroable, Pod)] -#[repr(transparent)] -pub struct Fp(u32); - -impl Fp { - /// Create a new [Fp] from a raw integer. - pub const fn new(x: u32) -> Self { - Self(x % P) - } - - /// Create a new [Fp] with an 'invalid' value. - pub const fn invalid() -> Self { - Self(0xffffffff) - } - - /// Return the maximum value that an [Fp] can take. - pub fn max() -> Self { - Self(P - 1) - } - - /// Raise an `Fp` value to the power of `n`. - pub fn pow(self, n: usize) -> Self { - let mut n = n; - let mut tot = Fp(1); - let mut x = self; - while n != 0 { - if n % 2 == 1 { - tot *= x; - } - n = n / 2; - x *= x; - } - tot - } - - /// Compute the multiplicative inverse of `x`, or `1 / x` in finite field - /// terms. Since `x ^ (P - 1) == 1 % P` for any `x != 0` (as a - /// consequence of Fermat's little theorem), it follows that `x * - /// x ^ (P - 2) == 1 % P` for `x != 0`. That is, `x ^ (P - 2)` is the - /// multiplicative inverse of `x`. Computed this way, the *inverse* of - /// zero comes out as zero, which is convenient in many cases, so we - /// leave it. - pub fn inv(self) -> Self { - self.pow((P - 2) as usize) - } -} - -/// Provides support for multiplying by a factor with an `Fp` type. -pub trait FpMul { - /// Multiply `self` by a factor of `Fp` type. - fn fp_mul(self, x: Fp) -> Self; -} - -impl FpMul for Fp { - fn fp_mul(self, x: Fp) -> Self { - self * x - } -} - -impl Random for Fp { - fn random(rng: &mut R) -> Self { - // Reject the last modulo-P region of possible uint32_t values, since it's - // uneven and will only return random values less than (2^32 % P). - const REJECT_CUTOFF: u32 = (u32::MAX / P) * P; - let mut val: u32 = rng.gen(); - - while val >= REJECT_CUTOFF { - val = rng.gen(); - } - Fp::from(val) - } -} - -impl ops::Add for Fp { - type Output = Self; - fn add(self, rhs: Self) -> Self { - Fp(add(self.0, rhs.0)) - } -} - -impl ops::AddAssign for Fp { - fn add_assign(&mut self, rhs: Self) { - self.0 = add(self.0, rhs.0) - } -} - -impl ops::Sub for Fp { - type Output = Self; - fn sub(self, rhs: Self) -> Self { - Fp(sub(self.0, rhs.0)) - } -} - -impl ops::SubAssign for Fp { - fn sub_assign(&mut self, rhs: Self) { - self.0 = sub(self.0, rhs.0) - } -} - -impl ops::Mul for Fp { - type Output = Self; - fn mul(self, rhs: Self) -> Self { - Fp(mul(self.0, rhs.0)) - } -} - -impl ops::MulAssign for Fp { - fn mul_assign(&mut self, rhs: Self) { - self.0 = mul(self.0, rhs.0) - } -} - -impl ops::Neg for Fp { - type Output = Self; - fn neg(self) -> Self { - Fp(0) - self - } -} - -impl From for u32 { - fn from(x: Fp) -> Self { - x.0 - } -} - -impl From<&Fp> for u32 { - fn from(x: &Fp) -> Self { - x.0 - } -} - -impl From for u64 { - fn from(x: Fp) -> Self { - x.0.into() - } -} - -impl From for Fp { - fn from(value: bool) -> Self { - if value { - Fp(1) - } else { - Fp(0) - } - } -} - -impl From for Fp { - fn from(x: u32) -> Self { - Fp(x % P) - } -} - -impl From for Fp { - fn from(x: u64) -> Self { - Fp((x % P_U64) as u32) - } -} - -fn add(lhs: u32, rhs: u32) -> u32 { - let x = lhs + rhs; - return if x >= P { x - P } else { x }; -} - -fn sub(lhs: u32, rhs: u32) -> u32 { - let x = lhs.wrapping_sub(rhs); - return if x > P { x.wrapping_add(P) } else { x }; -} - -fn mul(lhs: u32, rhs: u32) -> u32 { - (((lhs as u64) * (rhs as u64)) % P_U64) as u32 -} - -#[cfg(test)] -mod tests { - use super::Random; - use super::{Fp, P, P_U64}; - use rand::SeedableRng; - - #[test] - fn inv() { - // Smoke test for inv - assert_eq!(Fp(5).inv() * Fp(5), Fp(1)); - } - - #[test] - fn pow() { - // Smoke tests for pow - assert_eq!(Fp(5).pow(0), Fp(1)); - assert_eq!(Fp(5).pow(1), Fp(5)); - assert_eq!(Fp(5).pow(2), Fp(25)); - // Mathematica says PowerMod[5, 1000, 15*2^27 + 1] == 589699054 - assert_eq!(Fp(5).pow(1000), Fp(589699054)); - assert_eq!(Fp(5).pow((P - 2) as usize) * Fp(5), Fp(1)); - assert_eq!(Fp(5).pow((P - 1) as usize), Fp(1)); - } - - #[test] - fn compare_native() { - // Compare core operations against simple % P implementations - let mut rng = rand::rngs::SmallRng::seed_from_u64(2); - for _ in 0..100_000 { - let fa = Fp::random(&mut rng); - let fb = Fp::random(&mut rng); - let a: u64 = fa.into(); - let b: u64 = fb.into(); - assert_eq!(fa + fb, Fp::from(a + b)); - assert_eq!(fa - fb, Fp::from(a + (P_U64 - b))); - assert_eq!(fa * fb, Fp::from(a * b)); - } - } -} diff --git a/risc0/zkp/rust/src/core/fp4.rs b/risc0/zkp/rust/src/core/fp4.rs deleted file mode 100644 index f8cfe1f5a1..0000000000 --- a/risc0/zkp/rust/src/core/fp4.rs +++ /dev/null @@ -1,287 +0,0 @@ -// Copyright 2022 Risc0, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//! Support for the rank 4 extension field of the base field. - -use core::ops; - -use bytemuck::{Pod, Zeroable}; -use rand::Rng; - -use super::{ - fp::{Fp, FpMul, P}, - Random, -}; - -const BETA: Fp = Fp::new(11); -const NBETA: Fp = Fp::new(P - 11); - -/// The size of the extension field in elements, 4 in this case. -pub const EXT_SIZE: usize = 4; - -/// Instances of `Fp4` are elements of a finite field `F_p^4`. They are -/// represented as elements of `F_p[X] / (X^4 - 11)`. Basically, this is a *big* -/// finite field (about `2^128` elements), which is used when the security of -/// various operations depends on the size of the field. It has the field -/// `Fp` as a subfield, which means operations by the two are compatable, which -/// is important. The irreducible polynomial was choosen to be the most simple -/// possible one, `x^4 - B`, where `11` is the smallest `B` which makes the -/// polynomial irreducable. -#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, PartialOrd, Zeroable, Pod)] -#[repr(transparent)] -pub struct Fp4([Fp; EXT_SIZE]); - -impl Fp4 { - /// Explicitly construct an Fp4 from parts. - pub fn new(x0: Fp, x1: Fp, x2: Fp, x3: Fp) -> Self { - Self([x0, x1, x2, x3]) - } - - /// Create a [Fp4] from a [Fp]. - pub fn from_fp(x: Fp) -> Self { - Self([x, Fp::new(0), Fp::new(0), Fp::new(0)]) - } - - /// Create a [Fp4] from a raw integer. - pub fn from_u32(x0: u32) -> Self { - Self([Fp::new(x0), Fp::new(0), Fp::new(0), Fp::new(0)]) - } - - /// Returns the value zero. - pub fn zero() -> Self { - Self::from_u32(0) - } - - /// Returns the value one. - pub fn one() -> Self { - Self::from_u32(1) - } - - /// Returns the constant portion of a [Fp]. - pub fn const_part(self) -> Fp { - self.0[0] - } - - /// Returns the elements of a [Fp]. - pub fn elems(&self) -> &[Fp] { - &self.0 - } - - /// Raise a [Fp4] to a power of `n`. - pub fn pow(self, n: usize) -> Self { - let mut n = n; - let mut tot = Fp4::from(1); - let mut x = self; - while n != 0 { - if n % 2 == 1 { - tot *= x; - } - n = n / 2; - x *= x; - } - tot - } - - /// Compute the multiplicative inverse of an `Fp4`. - pub fn inv(self) -> Self { - let a = &self.0; - // Compute the multiplicative inverse by looking at `Fp4` as a composite field - // and using the same basic methods used to invert complex numbers. We imagine - // that initially we have a numerator of `1`, and a denominator of `a`. - // `out = 1 / a`; We set `a'` to be a with the first and third components - // negated. We then multiply the numerator and the denominator by `a'`, - // producing `out = a' / (a * a')`. By construction `(a * a')` has `0`s - // in its first and third elements. We call this number, `b` and compute - // it as follows. - let mut b0 = a[0] * a[0] + BETA * (a[1] * (a[3] + a[3]) - a[2] * a[2]); - let mut b2 = a[0] * (a[2] + a[2]) - a[1] * a[1] + BETA * (a[3] * a[3]); - // Now, we make `b'` by inverting `b2`. When we muliply both sizes by `b'`, we - // get `out = (a' * b') / (b * b')`. But by construction `b * b'` is in - // fact an element of `Fp`, call it `c`. - let c = b0 * b0 + BETA * b2 * b2; - // But we can now invert `C` direcly, and multiply by `a' * b'`: - // `out = a' * b' * inv(c)` - let ic = c.inv(); - // Note: if c == 0 (really should only happen if in == 0), our 'safe' version of - // inverse results in ic == 0, and thus out = 0, so we have the same 'safe' - // behavior for Fp4. Oh, and since we want to multiply everything by ic, it's - // slightly faster to pre-multiply the two parts of b by ic (2 multiplies - // instead of 4). - b0 *= ic; - b2 *= ic; - Fp4([ - a[0] * b0 + BETA * a[2] * b2, - -a[1] * b0 + NBETA * a[3] * b2, - -a[0] * b2 + a[2] * b0, - a[1] * b2 - a[3] * b0, - ]) - } -} - -impl FpMul for Fp4 { - fn fp_mul(self, x: Fp) -> Self { - self * x - } -} - -impl Random for Fp4 { - /// Generate a random field element uniformly. - fn random(rng: &mut R) -> Self { - Self([ - Fp::random(rng), - Fp::random(rng), - Fp::random(rng), - Fp::random(rng), - ]) - } -} - -impl ops::Add for Fp4 { - type Output = Self; - fn add(self, rhs: Self) -> Self { - let mut lhs = self; - lhs += rhs; - lhs - } -} - -impl ops::AddAssign for Fp4 { - fn add_assign(&mut self, rhs: Self) { - for i in 0..self.0.len() { - self.0[i] += rhs.0[i]; - } - } -} - -impl ops::Sub for Fp4 { - type Output = Self; - fn sub(self, rhs: Self) -> Self { - let mut lhs = self; - lhs -= rhs; - lhs - } -} - -impl ops::SubAssign for Fp4 { - fn sub_assign(&mut self, rhs: Self) { - for i in 0..self.0.len() { - self.0[i] -= rhs.0[i]; - } - } -} - -/// Implement the simple multiplication case by the subfield Fp. -impl ops::MulAssign for Fp4 { - fn mul_assign(&mut self, rhs: Fp) { - for i in 0..self.0.len() { - self.0[i] *= rhs; - } - } -} - -impl ops::Mul for Fp4 { - type Output = Self; - fn mul(self, rhs: Fp) -> Self { - let mut lhs = self; - lhs *= rhs; - lhs - } -} - -impl ops::Mul for Fp { - type Output = Fp4; - fn mul(self, rhs: Fp4) -> Fp4 { - rhs * self - } -} - -// Now we get to the interesting case of multiplication. Basically, multiply -// out the polynomial representations, and then reduce module `x^4 - B`, which -// means powers >= 4 get shifted back 4 and multiplied by `-beta`. We could -// write this as a double loops with some `if`s and hope it gets unrolled -// properly, but it's small enough to just hand write. -impl ops::MulAssign for Fp4 { - fn mul_assign(&mut self, rhs: Self) { - // Rename the element arrays to something small for readability. - let a = &self.0; - let b = &rhs.0; - self.0 = [ - a[0] * b[0] + NBETA * (a[1] * b[3] + a[2] * b[2] + a[3] * b[1]), - a[0] * b[1] + a[1] * b[0] + NBETA * (a[2] * b[3] + a[3] * b[2]), - a[0] * b[2] + a[1] * b[1] + a[2] * b[0] + NBETA * (a[3] * b[3]), - a[0] * b[3] + a[1] * b[2] + a[2] * b[1] + a[3] * b[0], - ]; - } -} - -impl ops::Mul for Fp4 { - type Output = Fp4; - fn mul(self, rhs: Fp4) -> Fp4 { - let mut lhs = self; - lhs *= rhs; - lhs - } -} - -impl ops::Neg for Fp4 { - type Output = Self; - fn neg(self) -> Self { - Fp4::default() - self - } -} - -impl From for Fp4 { - fn from(x: u32) -> Self { - Self([Fp::from(x), Fp::default(), Fp::default(), Fp::default()]) - } -} - -impl From for Fp4 { - fn from(x: Fp) -> Self { - Self([x, Fp::default(), Fp::default(), Fp::default()]) - } -} - -#[cfg(test)] -mod tests { - use super::Fp4; - use super::Random; - use rand::SeedableRng; - - #[test] - fn isa_field() { - let mut rng = rand::rngs::SmallRng::seed_from_u64(2); - // Pick random sets of 3 elements of Fp4, and verify they meet the requirements - // of a field. - for _ in 0..1_000 { - let a = Fp4::random(&mut rng); - let b = Fp4::random(&mut rng); - let c = Fp4::random(&mut rng); - // Addition + multiplication commute - assert_eq!(a + b, b + a); - assert_eq!(a * b, b * a); - // Addition + multiplication are associative - assert_eq!(a + (b + c), (a + b) + c); - assert_eq!(a * (b * c), (a * b) * c); - // Distributive property - assert_eq!(a * (b + c), a * b + a * c); - // Inverses - if a != Fp4::default() { - assert_eq!(a.inv() * a, Fp4::from(1)); - } - assert_eq!(Fp4::default() - a, -a); - assert_eq!(a + (-a), Fp4::default()); - } - } -} diff --git a/risc0/zkp/rust/src/core/mod.rs b/risc0/zkp/rust/src/core/mod.rs index 54de2150fc..ecd7fa2a88 100644 --- a/risc0/zkp/rust/src/core/mod.rs +++ b/risc0/zkp/rust/src/core/mod.rs @@ -19,8 +19,21 @@ extern crate alloc; use rand::Rng; -pub mod fp; -pub mod fp4; +/// Transitional "fp" module until ZKP has been genericized to work +/// with multiple fields. +pub mod fp { + pub use crate::field::baby_bear::Elem as Fp; +} +/// Transitional "fp4" module until ZKP has been genericized to work +/// with multiple fields. +pub mod fp4 { + pub use crate::field::baby_bear::ExtElem as Fp4; + use crate::field::ExtElem; + + /// Transitional reexport until ZKP has been genericized to work + /// with multiple fields. + pub const EXT_SIZE: usize = Fp4::EXT_SIZE; +} pub mod ntt; pub mod poly; pub mod rou; diff --git a/risc0/zkp/rust/src/core/ntt.rs b/risc0/zkp/rust/src/core/ntt.rs index 196b0438e6..e9198b2c7a 100644 --- a/risc0/zkp/rust/src/core/ntt.rs +++ b/risc0/zkp/rust/src/core/ntt.rs @@ -14,16 +14,18 @@ //! An implementation of a Numeric Theoretic Transform (NTT). -use core::ops::{Add, Sub}; +use core::ops::{Add, Mul, Sub}; use paste::paste; use super::{ - fp::{Fp, FpMul}, + fp::Fp, log2_ceil, rou::{ROU_FWD, ROU_REV}, }; +use crate::field::Elem; + /// Reverses the bits in a 32 bit number /// For example 1011...0100 becomes 0010...1101 pub fn bit_rev_32(mut x: u32) -> u32 { @@ -65,7 +67,7 @@ macro_rules! butterfly { #[inline] fn [](io: &mut [T], expand_bits: usize) where - T: Copy + FpMul + Add + Sub, + T: Copy + Mul + Add + Sub, { if $n == expand_bits { return; @@ -74,10 +76,10 @@ macro_rules! butterfly { [](&mut io[..half], expand_bits); [](&mut io[half..], expand_bits); let step = Fp::new(ROU_FWD[$n]); - let mut cur = Fp::new(1); + let mut cur = Fp::ONE; for i in 0..half { let a = io[i]; - let b = io[i + half].fp_mul(cur); + let b = io[i + half] * cur; io[i] = a + b; io[i + half] = a - b; cur *= step; @@ -87,16 +89,16 @@ macro_rules! butterfly { #[inline] fn [](io: &mut [T]) where - T: Copy + FpMul + Add + Sub, + T: Copy + Mul + Add + Sub, { let half = 1 << ($n - 1); let step = Fp::new(ROU_REV[$n]); - let mut cur = Fp::new(1); + let mut cur = Fp::ONE; for i in 0..half { let a = io[i]; let b = io[i + half]; io[i] = a + b; - io[i + half] = (a - b).fp_mul(cur); + io[i + half] = (a - b) * cur; cur *= step; } [](&mut io[..half]); @@ -197,7 +199,7 @@ butterfly!(1, 0); /// which is 6. So i' in the exponent of the index-3 value is 6. pub fn interpolate_ntt(io: &mut [T]) where - T: Copy + FpMul + Add + Sub, + T: Copy + Mul + Add + Sub, { let size = io.len(); let n = log2_ceil(size); @@ -235,14 +237,14 @@ where } let norm = Fp::new(size as u32).inv(); for i in 0..size { - io[i] = io[i].fp_mul(norm); + io[i] = io[i] * norm; } } /// Perform a forward butterfly transform of a buffer of (1 << n) numbers. pub fn evaluate_ntt(io: &mut [T], expand_bits: usize) where - T: Copy + FpMul + Add + Sub, + T: Copy + Mul + Add + Sub, { // do_ntt::(io, expand_bits); let size = io.len(); @@ -295,13 +297,13 @@ where #[cfg(test)] mod tests { + use crate::field::Elem; use rand::thread_rng; use crate::core::{ fp::Fp, ntt::{bit_reverse, evaluate_ntt, interpolate_ntt}, rou::ROU_FWD, - Random, }; // Compare the complex version to the naive version @@ -313,13 +315,13 @@ mod tests { // Randomly fill input let mut buf = [Fp::random(&mut rng); SIZE]; // Compute the hard way - let mut goal = [Fp::default(); SIZE]; + let mut goal = [Fp::ZERO; SIZE]; // Compute polynomial at each ROU power (starting at 0, i.e. x = 1) - let mut x = Fp::new(1); + let mut x = Fp::ONE; for i in 0..SIZE { // Compute the polynomial - let mut tot = Fp::new(0); - let mut xn = Fp::new(1); + let mut tot = Fp::ZERO; + let mut xn = Fp::ONE; for j in 0..SIZE { tot += buf[j] * xn; xn *= x; @@ -362,7 +364,7 @@ mod tests { const SIZE_OUT: usize = 1 << N; let mut rng = thread_rng(); let mut cmp = [Fp::random(&mut rng); SIZE_IN]; - let mut buf = [Fp::default(); SIZE_OUT]; + let mut buf = [Fp::ZERO; SIZE_OUT]; // Do plain interpolate on cmp interpolate_ntt(&mut cmp); // Expand to buf @@ -372,13 +374,13 @@ mod tests { // Order cmp nicely for the check bit_reverse(&mut cmp); // Now verify by comparing with the slow way - let mut goal = [Fp::default(); SIZE_OUT]; + let mut goal = [Fp::ZERO; SIZE_OUT]; // Compute polynomial at each ROU power (starting at 0, i.e. x = 1) - let mut x = Fp::new(1); + let mut x = Fp::ONE; for i in 0..SIZE_OUT { // Compute the polynomial - let mut tot = Fp::new(0); - let mut xn = Fp::new(1); + let mut tot = Fp::ZERO; + let mut xn = Fp::ONE; for j in 0..SIZE_IN { tot += cmp[j] * xn; xn *= x; diff --git a/risc0/zkp/rust/src/core/poly.rs b/risc0/zkp/rust/src/core/poly.rs index cfbf3ffb07..0df811bc14 100644 --- a/risc0/zkp/rust/src/core/poly.rs +++ b/risc0/zkp/rust/src/core/poly.rs @@ -17,12 +17,13 @@ use alloc::vec; use super::fp4::Fp4; +use crate::field::Elem; /// Evaluate a polynomial whose coeffients are in the extension field at a /// point. pub fn poly_eval(coeffs: &[Fp4], x: Fp4) -> Fp4 { - let mut mul = Fp4::one(); - let mut tot = Fp4::zero(); + let mut mul = Fp4::ONE; + let mut tot = Fp4::ZERO; for i in 0..coeffs.len() { tot += coeffs[i] * mul; mul *= x; @@ -46,8 +47,8 @@ pub fn poly_interpolate(out: &mut [Fp4], x: &[Fp4], fx: &[Fp4], size: usize) { return; } // Compute ft = product of (x - x_i) for all i - let mut ft = vec![Fp4::default(); size + 1]; - ft[0] = Fp4::one(); + let mut ft = vec![Fp4::ZERO; size + 1]; + ft[0] = Fp4::ONE; for i in 0..size { for j in (0..i + 1).rev() { let value = ft[j]; @@ -57,7 +58,7 @@ pub fn poly_interpolate(out: &mut [Fp4], x: &[Fp4], fx: &[Fp4], size: usize) { } // Clear output for i in 0..size { - out[i] = Fp4::default(); + out[i] = Fp4::ZERO; } for i in 0..size { // Compute fr = ft / (x - x_i) @@ -79,7 +80,7 @@ pub fn poly_interpolate(out: &mut [Fp4], x: &[Fp4], fx: &[Fp4], size: usize) { /// Take the coefficients in P, and divide by (X - z) for some z, return the /// remainder. pub fn poly_divide(p: &mut [Fp4], z: Fp4) -> Fp4 { - let mut cur = Fp4::default(); + let mut cur = Fp4::ZERO; for i in (0..p.len()).rev() { let next = z * cur + p[i]; p[i] = cur; diff --git a/risc0/zkp/rust/src/field/baby_bear.rs b/risc0/zkp/rust/src/field/baby_bear.rs new file mode 100644 index 0000000000..73106f6168 --- /dev/null +++ b/risc0/zkp/rust/src/field/baby_bear.rs @@ -0,0 +1,532 @@ +// Copyright 2022 Risc0, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/// ! Baby bear field. +/// ! Support for the base finite field modulo 15*2^27 + 1 +use crate::field::{self, Elem as FieldElem}; + +use core::ops; + +use bytemuck::{Pod, Zeroable}; + +/// The BabyBear class is an element of the finite field F_p, where P is the +/// prime number 15*2^27 + 1. Put another way, Fp is basically integer +/// arithmetic modulo P. +/// +/// The `Fp` datatype is the core type of all of the operations done within the +/// zero knowledge proofs, and is the smallest 'addressable' datatype, and the +/// base type of which all composite types are built. In many ways, one can +/// imagine it as the word size of a very strange architecture. +/// +/// This specific prime P was chosen to: +/// - Be less than 2^31 so that it fits within a 32 bit word and doesn't +/// overflow on addition. +/// - Otherwise have as large a power of 2 in the factors of P-1 as possible. +/// +/// This last property is useful for number theoretical transforms (the fast +/// fourier transform equivelant on finite fields). See NTT.h for details. +/// +/// The Fp class wraps all the standard arithmetic operations to make the finite +/// field elements look basically like ordinary numbers (which they mostly are). +#[derive(Eq, PartialEq, Clone, Copy, Debug, Pod, Zeroable)] +#[repr(transparent)] +pub struct Elem(u32); + +impl Default for Elem { + fn default() -> Self { + Self::ZERO + } +} + +/// The modulus of the field. +const P: u32 = 15 * (1 << 27) + 1; +/// The modulus of the field as a u64. +const P_U64: u64 = P as u64; + +impl field::Elem for Elem { + const ZERO: Self = Elem::new(0); + + const ONE: Self = Elem::new(1); + + /// Compute the multiplicative inverse of `x`, or `1 / x` in finite field + /// terms. Since `x ^ (P - 1) == 1 % P` for any `x != 0` (as a + /// consequence of Fermat's little theorem), it follows that `x * + /// x ^ (P - 2) == 1 % P` for `x != 0`. That is, `x ^ (P - 2)` is the + /// multiplicative inverse of `x`. Computed this way, the *inverse* of + /// zero comes out as zero, which is convenient in many cases, so we + /// leave it. + fn inv(self) -> Self { + self.pow((P - 2) as usize) + } + + fn random(rng: &mut impl rand::Rng) -> Self { + // Reject the last modulo-P region of possible uint32_t values, since it's + // uneven and will only return random values less than (2^32 % P). + const REJECT_CUTOFF: u32 = (u32::MAX / P) * P; + let mut val: u32 = rng.gen(); + + while val >= REJECT_CUTOFF { + val = rng.gen(); + } + Elem::from(val) + } +} + +macro_rules! rou_array { + [$($x:literal),* $(,)?] => { + [$(Elem::new($x)),* ] + } +} + +impl field::RootsOfUnity for Elem { + const MAX_ROU_PO2: usize = 27; + + const ROU_FWD: &'static [Elem] = &rou_array![ + 1, 2013265920, 284861408, 1801542727, 567209306, 740045640, 918899846, 1881002012, + 1453957774, 65325759, 1538055801, 515192888, 483885487, 157393079, 1695124103, 2005211659, + 1540072241, 88064245, 1542985445, 1269900459, 1461624142, 825701067, 682402162, 1311873874, + 1164520853, 352275361, 18769, 137 + ]; + + const ROU_REV: &'static [Elem] = &rou_array![ + 1, 2013265920, 1728404513, 1592366214, 196396260, 1253260071, 72041623, 1091445674, + 145223211, 1446820157, 1030796471, 2010749425, 1827366325, 1239938613, 246299276, + 596347512, 1893145354, 246074437, 1525739923, 1194341128, 1463599021, 704606912, 95395244, + 15672543, 647517488, 584175179, 137728885, 749463956 + ]; +} + +impl Elem { + /// Create a new [BabyBear] from a raw integer. + pub const fn new(x: u32) -> Self { + Self(x % P) + } +} + +impl ops::Add for Elem { + type Output = Self; + fn add(self, rhs: Self) -> Self { + Elem(add(self.0, rhs.0)) + } +} + +impl ops::AddAssign for Elem { + fn add_assign(&mut self, rhs: Self) { + self.0 = add(self.0, rhs.0) + } +} + +impl ops::Sub for Elem { + type Output = Self; + fn sub(self, rhs: Self) -> Self { + Elem(sub(self.0, rhs.0)) + } +} + +impl ops::SubAssign for Elem { + fn sub_assign(&mut self, rhs: Self) { + self.0 = sub(self.0, rhs.0) + } +} + +impl ops::Mul for Elem { + type Output = Self; + fn mul(self, rhs: Self) -> Self { + Elem(mul(self.0, rhs.0)) + } +} + +impl ops::MulAssign for Elem { + fn mul_assign(&mut self, rhs: Self) { + self.0 = mul(self.0, rhs.0) + } +} + +impl ops::Neg for Elem { + type Output = Self; + fn neg(self) -> Self { + Elem(0) - self + } +} + +impl From for u32 { + fn from(x: Elem) -> Self { + x.0 + } +} + +impl From<&Elem> for u32 { + fn from(x: &Elem) -> Self { + x.0 + } +} + +impl From for u64 { + fn from(x: Elem) -> Self { + x.0.into() + } +} + +impl From for Elem { + fn from(x: u32) -> Self { + Elem(x % P) + } +} + +impl From for Elem { + fn from(x: u64) -> Self { + Elem((x % P_U64) as u32) + } +} + +fn add(lhs: u32, rhs: u32) -> u32 { + let x = lhs + rhs; + return if x >= P { x - P } else { x }; +} + +fn sub(lhs: u32, rhs: u32) -> u32 { + let x = lhs.wrapping_sub(rhs); + return if x > P { x.wrapping_add(P) } else { x }; +} + +fn mul(lhs: u32, rhs: u32) -> u32 { + (((lhs as u64) * (rhs as u64)) % P_U64) as u32 +} + +/// The size of the extension field in elements, 4 in this case. +const EXT_SIZE: usize = 4; + +/// Instances of `ExtElem` are elements of a finite field `F_p^4`. They are +/// represented as elements of `F_p[X] / (X^4 - 11)`. Basically, this is a *big* +/// finite field (about `2^128` elements), which is used when the security of +/// various operations depends on the size of the field. It has the field +/// `Elem` as a subfield, which means operations by the two are compatable, +/// which is important. The irreducible polynomial was choosen to be the most +/// simple possible one, `x^4 - B`, where `11` is the smallest `B` which makes +/// the polynomial irreducable. +#[derive(Eq, PartialEq, Clone, Copy, Debug, Pod, Zeroable)] +#[repr(transparent)] +pub struct ExtElem([Elem; EXT_SIZE]); + +impl Default for ExtElem { + fn default() -> Self { + Self::ZERO + } +} + +impl field::Elem for ExtElem { + const ZERO: ExtElem = ExtElem::zero(); + const ONE: ExtElem = ExtElem::one(); + + /// Generate a random field element uniformly. + fn random(rng: &mut impl rand::Rng) -> Self { + Self([ + Elem::random(rng), + Elem::random(rng), + Elem::random(rng), + Elem::random(rng), + ]) + } + + /// Raise a [ExtElem] to a power of `n`. + fn pow(self, n: usize) -> Self { + let mut n = n; + let mut tot = ExtElem::from(1); + let mut x = self; + while n != 0 { + if n % 2 == 1 { + tot *= x; + } + n = n / 2; + x *= x; + } + tot + } + + /// Compute the multiplicative inverse of an `ExtElem`. + fn inv(self) -> Self { + let a = &self.0; + // Compute the multiplicative inverse by looking at `ExtElem` as a composite + // field and using the same basic methods used to invert complex + // numbers. We imagine that initially we have a numerator of `1`, and a + // denominator of `a`. `out = 1 / a`; We set `a'` to be a with the first + // and third components negated. We then multiply the numerator and the + // denominator by `a'`, producing `out = a' / (a * a')`. By construction + // `(a * a')` has `0`s in its first and third elements. We call this + // number, `b` and compute it as follows. + let mut b0 = a[0] * a[0] + BETA * (a[1] * (a[3] + a[3]) - a[2] * a[2]); + let mut b2 = a[0] * (a[2] + a[2]) - a[1] * a[1] + BETA * (a[3] * a[3]); + // Now, we make `b'` by inverting `b2`. When we muliply both sizes by `b'`, we + // get `out = (a' * b') / (b * b')`. But by construction `b * b'` is in + // fact an element of `Elem`, call it `c`. + let c = b0 * b0 + BETA * b2 * b2; + // But we can now invert `C` direcly, and multiply by `a' * b'`: + // `out = a' * b' * inv(c)` + let ic = c.inv(); + // Note: if c == 0 (really should only happen if in == 0), our + // 'safe' version of inverse results in ic == 0, and thus out + // = 0, so we have the same 'safe' behavior for ExtElem. Oh, + // and since we want to multiply everything by ic, it's + // slightly faster to pre-multiply the two parts of b by ic (2 + // multiplies instead of 4). + b0 *= ic; + b2 *= ic; + ExtElem([ + a[0] * b0 + BETA * a[2] * b2, + -a[1] * b0 + NBETA * a[3] * b2, + -a[0] * b2 + a[2] * b0, + a[1] * b2 - a[3] * b0, + ]) + } +} + +impl field::ExtElem for ExtElem { + const EXT_SIZE: usize = EXT_SIZE; + + type SubElem = Elem; + + fn from_subfield(elem: &Elem) -> Self { + Self::from([elem.clone(), Elem::ZERO, Elem::ZERO, Elem::ZERO]) + } +} + +impl From<[Elem; EXT_SIZE]> for ExtElem { + fn from(val: [Elem; EXT_SIZE]) -> Self { + ExtElem(val) + } +} + +const BETA: Elem = Elem::new(11); +const NBETA: Elem = Elem::new(P - 11); + +impl ExtElem { + /// Explicitly construct an ExtElem from parts. + pub fn new(x0: Elem, x1: Elem, x2: Elem, x3: Elem) -> Self { + Self([x0, x1, x2, x3]) + } + + /// Create a [ExtElem] from a [Elem]. + pub fn from_fp(x: Elem) -> Self { + Self([x, Elem::new(0), Elem::new(0), Elem::new(0)]) + } + + /// Create a [ExtElem] from a raw integer. + pub const fn from_u32(x0: u32) -> Self { + Self([Elem::new(x0), Elem::new(0), Elem::new(0), Elem::new(0)]) + } + + /// Returns the value zero. + const fn zero() -> Self { + Self::from_u32(0) + } + + /// Returns the value one. + const fn one() -> Self { + Self::from_u32(1) + } + + /// Returns the constant portion of a [Elem]. + pub fn const_part(self) -> Elem { + self.0[0] + } + + /// Returns the elements of a [Elem]. + pub fn elems(&self) -> &[Elem] { + &self.0 + } +} + +impl ops::Add for ExtElem { + type Output = Self; + fn add(self, rhs: Self) -> Self { + let mut lhs = self; + lhs += rhs; + lhs + } +} + +impl ops::AddAssign for ExtElem { + fn add_assign(&mut self, rhs: Self) { + for i in 0..self.0.len() { + self.0[i] += rhs.0[i]; + } + } +} + +impl ops::Sub for ExtElem { + type Output = Self; + fn sub(self, rhs: Self) -> Self { + let mut lhs = self; + lhs -= rhs; + lhs + } +} + +impl ops::SubAssign for ExtElem { + fn sub_assign(&mut self, rhs: Self) { + for i in 0..self.0.len() { + self.0[i] -= rhs.0[i]; + } + } +} + +/// Implement the simple multiplication case by the subfield Elem. +impl ops::MulAssign for ExtElem { + fn mul_assign(&mut self, rhs: Elem) { + for i in 0..self.0.len() { + self.0[i] *= rhs; + } + } +} + +impl ops::Mul for ExtElem { + type Output = Self; + fn mul(self, rhs: Elem) -> Self { + let mut lhs = self; + lhs *= rhs; + lhs + } +} + +impl ops::Mul for Elem { + type Output = ExtElem; + fn mul(self, rhs: ExtElem) -> ExtElem { + rhs * self + } +} + +// Now we get to the interesting case of multiplication. Basically, +// multiply out the polynomial representations, and then reduce module +// `x^4 - B`, which means powers >= 4 get shifted back 4 and +// multiplied by `-beta`. We could write this as a double loops with +// some `if`s and hope it gets unrolled properly, but it's small +// enough to just hand write. +impl ops::MulAssign for ExtElem { + fn mul_assign(&mut self, rhs: Self) { + // Rename the element arrays to something small for readability. + let a = &self.0; + let b = &rhs.0; + self.0 = [ + a[0] * b[0] + NBETA * (a[1] * b[3] + a[2] * b[2] + a[3] * b[1]), + a[0] * b[1] + a[1] * b[0] + NBETA * (a[2] * b[3] + a[3] * b[2]), + a[0] * b[2] + a[1] * b[1] + a[2] * b[0] + NBETA * (a[3] * b[3]), + a[0] * b[3] + a[1] * b[2] + a[2] * b[1] + a[3] * b[0], + ]; + } +} + +impl ops::Mul for ExtElem { + type Output = ExtElem; + fn mul(self, rhs: ExtElem) -> ExtElem { + let mut lhs = self; + lhs *= rhs; + lhs + } +} + +impl ops::Neg for ExtElem { + type Output = Self; + fn neg(self) -> Self { + ExtElem::ZERO - self + } +} + +impl From for ExtElem { + fn from(x: u32) -> Self { + Self([Elem::from(x), Elem::ZERO, Elem::ZERO, Elem::ZERO]) + } +} + +impl From for ExtElem { + fn from(x: Elem) -> Self { + Self([x, Elem::ZERO, Elem::ZERO, Elem::ZERO]) + } +} + +#[cfg(test)] +mod tests { + use super::field; + use super::{Elem, ExtElem, P, P_U64}; + use crate::field::Elem as FieldElem; + use rand::SeedableRng; + + #[test] + pub fn roots_of_unity() { + field::test::test_roots_of_unity::(); + } + + #[test] + pub fn field_ops() { + field::test::test_field_ops::(P_U64); + } + + #[test] + fn isa_field() { + let mut rng = rand::rngs::SmallRng::seed_from_u64(2); + // Pick random sets of 3 elements of ExtElem, and verify they meet the + // requirements of a field. + for _ in 0..1_000 { + let a = ExtElem::random(&mut rng); + let b = ExtElem::random(&mut rng); + let c = ExtElem::random(&mut rng); + // Addition + multiplication commute + assert_eq!(a + b, b + a); + assert_eq!(a * b, b * a); + // Addition + multiplication are associative + assert_eq!(a + (b + c), (a + b) + c); + assert_eq!(a * (b * c), (a * b) * c); + // Distributive property + assert_eq!(a * (b + c), a * b + a * c); + // Inverses + if a != ExtElem::ZERO { + assert_eq!(a.inv() * a, ExtElem::from(1)); + } + assert_eq!(ExtElem::ZERO - a, -a); + assert_eq!(a + (-a), ExtElem::ZERO); + } + } + + #[test] + fn inv() { + // Smoke test for inv + assert_eq!(Elem(5).inv() * Elem(5), Elem(1)); + } + + #[test] + fn pow() { + // Smoke tests for pow + assert_eq!(Elem(5).pow(0), Elem(1)); + assert_eq!(Elem(5).pow(1), Elem(5)); + assert_eq!(Elem(5).pow(2), Elem(25)); + // Mathematica says PowerMod[5, 1000, 15*2^27 + 1] == 589699054 + assert_eq!(Elem(5).pow(1000), Elem(589699054)); + assert_eq!(Elem(5).pow((P - 2) as usize) * Elem(5), Elem(1)); + assert_eq!(Elem(5).pow((P - 1) as usize), Elem(1)); + } + + #[test] + fn compare_native() { + // Compare core operations against simple % P implementations + let mut rng = rand::rngs::SmallRng::seed_from_u64(2); + for _ in 0..100_000 { + let fa = Elem::random(&mut rng); + let fb = Elem::random(&mut rng); + let a: u64 = fa.into(); + let b: u64 = fb.into(); + assert_eq!(fa + fb, Elem::from(a + b)); + assert_eq!(fa - fb, Elem::from(a + (P_U64 - b))); + assert_eq!(fa * fb, Elem::from(a * b)); + } + } +} diff --git a/risc0/zkp/rust/src/field/mod.rs b/risc0/zkp/rust/src/field/mod.rs new file mode 100644 index 0000000000..2d8bcbd7ed --- /dev/null +++ b/risc0/zkp/rust/src/field/mod.rs @@ -0,0 +1,161 @@ +// Copyright 2022 Risc0, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// TODO: Document better + +use core::{cmp, ops}; + +/// A field with field elements. +pub trait Elem: + ops::Mul + + ops::MulAssign + + ops::Add + + ops::AddAssign + + ops::Sub + + ops::SubAssign + + cmp::PartialEq + + cmp::Eq + + core::clone::Clone + + core::marker::Copy + + Sized +{ + /// Zero, the additive identity. + const ZERO: Self; + + /// One, the multiplicative identity. + const ONE: Self; + + /// Compute the multiplicative inverse of `x`, or `1 / x` in finite field + /// terms. + fn inv(self) -> Self; + + /// Returns this element raised to the given power. + fn pow(self, exp: usize) -> Self { + let mut n = exp; + let mut tot = Self::ONE; + let mut x = self; + while n != 0 { + if n % 2 == 1 { + tot *= x; + } + n = n / 2; + x *= x; + } + tot + } + + /// Returns a random valid field element. + fn random(rng: &mut impl rand::Rng) -> Self; +} + +/// A field extensension. +pub trait ExtElem: Elem + ops::Mul { + type SubElem: Elem; + + const EXT_SIZE: usize; + + fn from_subfield(elem: &Self::SubElem) -> Self; +} + +pub trait RootsOfUnity: Sized + 'static { + /// Maximum root of unity which is a power of 2, i.e. there is + /// 2^MAX_ROU_PO2th root of unity, but no 2^(MAX_ROU_PO2+1)th. + const MAX_ROU_PO2: usize; + + /// For each power of 2, what is the 'forward' root of unity for + /// the po2. That is, this list satisfies ROU_FWD\[i+1\] ^ 2 = + /// ROU_FWD\[i\] in the prime field which implies ROU_FWD\[i\] ^ + /// (2 ^ i) = 1. + const ROU_FWD: &'static [Self]; + + /// For each power of 2, what is the 'reverse' root of unity for + /// the po2. This list satisfies ROU_FWD\[i\] * ROU_REV\[i\] = 1 + /// in the prime field F_2013265921. + const ROU_REV: &'static [Self]; +} + +#[cfg(test)] +pub mod test { + use super::{Elem, RootsOfUnity}; + use core::fmt::Debug; + use rand::Rng; + + pub fn test_roots_of_unity() { + let mut cur: Option = None; + + for &rou in F::ROU_FWD.iter().rev() { + if let Some(ref mut curval) = &mut cur { + *curval *= *curval; + assert_eq!(*curval, rou); + } else { + cur = Some(rou); + } + } + assert_eq!(cur, Some(F::ONE)); + + for (&fwd, &rev) in F::ROU_FWD.iter().zip(F::ROU_REV.iter()) { + assert_eq!(fwd * rev, F::ONE); + } + } + + fn non_zero_rand(r: &mut impl Rng) -> F { + loop { + let val = F::random(r); + if val != F::ZERO { + return val; + } + } + } + + pub fn test_field_ops(p_u64: u64) + where + F: Into + From + Debug, + { + // We do 128-bit arithmetic so we don't have to worry about overflows. + let p: u128 = p_u64 as _; + + assert_eq!(F::from(0), F::ZERO); + assert_eq!(F::from(p_u64), F::ZERO); + assert_eq!(F::from(1), F::ONE); + assert_eq!(F::from(p_u64 - 1) + F::from(1), F::ZERO); + + assert_eq!(F::ZERO.inv(), F::ZERO); + assert_eq!(F::ONE.inv(), F::ONE); + + // Compare against a bunch of numbers to make sure it matches + // with regular modulo arithmetic. + let mut rng = rand::thread_rng(); + + for _ in 0..1000 { + let x: F = non_zero_rand(&mut rng); + let y: F = non_zero_rand(&mut rng); + + let xi: u128 = x.into() as _; + let yi: u128 = y.into() as _; + + assert_eq!((x + y).into() as u128, (&xi + &yi) % p); + assert_eq!((x * y).into() as u128, (&xi * &yi) % p); + assert_eq!((x - y).into() as u128, (&xi + p - &yi) % p); + + let xinv = x.inv(); + if x != F::ONE { + assert!(xinv != x); + } + assert_eq!(xinv * x, F::ONE); + } + } +} + +/// Fields available for use with zkp: +pub mod baby_bear; diff --git a/risc0/zkp/rust/src/hal/cpu.rs b/risc0/zkp/rust/src/hal/cpu.rs index f92cc46ede..c56707f7ad 100644 --- a/risc0/zkp/rust/src/hal/cpu.rs +++ b/risc0/zkp/rust/src/hal/cpu.rs @@ -21,13 +21,14 @@ use std::{cell::RefCell, rc::Rc}; use crate::{ core::{ - fp::{Fp, FpMul}, + fp::Fp, fp4::{Fp4, EXT_SIZE}, log2_ceil, ntt::{bit_rev_32, bit_reverse, evaluate_ntt, expand, interpolate_ntt}, sha::{Digest, Sha}, sha_cpu, }, + field::Elem, FRI_FOLD, }; #[allow(unused_imports)] @@ -202,13 +203,13 @@ impl Hal for CpuHal { (&which[..], &xs[..], &mut out[..]) .into_par_iter() .for_each(|(id, x, out)| { - let mut tot = Fp4::zero(); - let mut cur = Fp4::new(Fp::new(1), Fp::new(0), Fp::new(0), Fp::new(0)); + let mut tot = Fp4::ZERO; + let mut cur = Fp4::new(Fp::ONE, Fp::ZERO, Fp::ZERO, Fp::ZERO); let id = *id as usize; let count = 1 << po2; let local = &coeffs[count * id..count * id + count]; for coeff in local { - tot += cur.fp_mul(*coeff); + tot += cur * *coeff; cur *= *x; } *out = tot; @@ -288,7 +289,7 @@ impl Hal for CpuHal { let input = ArrayView::from_shape((to_add, count), &input).unwrap(); let input = input.axis_iter(Axis(1)).into_par_iter(); output.zip(input).for_each(|(mut output, input)| { - let mut sum = Fp4::zero(); + let mut sum = Fp4::ZERO; for i in input { sum += *i; } @@ -343,7 +344,7 @@ impl Hal for CpuHal { // TODO: parallelize for idx in 0..count { - let mut tot = Fp4::default(); + let mut tot = Fp4::ZERO; let mut cur_mix = Fp4::from_u32(1); for i in 0..FRI_FOLD { let rev_i = bit_rev_32(i as u32) >> (32 - log2_ceil(FRI_FOLD)); @@ -411,7 +412,6 @@ mod test { use rand::thread_rng; use super::*; - use crate::core::Random; #[test] #[should_panic] @@ -438,7 +438,7 @@ mod test { fn test_binary(hal: &H, hal_fn: HF, cpu_fn: CF, count: usize) where - T: 'static + Random + Default + Clone + PartialEq + Debug, + T: Elem + Default + Debug + 'static, H: Hal, HF: Fn(&Buffer, &Buffer, &Buffer), CF: Fn(&T, &T) -> T, diff --git a/risc0/zkp/rust/src/lib.rs b/risc0/zkp/rust/src/lib.rs index 3807bd13a0..e4c13c75c1 100644 --- a/risc0/zkp/rust/src/lib.rs +++ b/risc0/zkp/rust/src/lib.rs @@ -28,6 +28,8 @@ pub mod taps; #[cfg(feature = "verify")] pub mod verify; +pub mod field; + pub const MIN_CYCLES: usize = 512; pub const MAX_CYCLES_PO2: usize = 20; pub const MAX_CYCLES: usize = 1 << MAX_CYCLES_PO2; diff --git a/risc0/zkp/rust/src/prove/adapter.rs b/risc0/zkp/rust/src/prove/adapter.rs index d1e7359758..b6962e0bc5 100644 --- a/risc0/zkp/rust/src/prove/adapter.rs +++ b/risc0/zkp/rust/src/prove/adapter.rs @@ -24,8 +24,8 @@ use crate::{ log2_ceil, rou::ROU_FWD, sha::Sha, - Random, }, + field::Elem, hal::Buffer, prove::{executor::Executor, write_iop::WriteIOP, Circuit}, taps::{RegisterGroup, TapSet}, @@ -73,7 +73,7 @@ impl<'a, C: CircuitDef, CS: CustomStep> Circuit for ProveAdapter<'a, C, CS> .circuit .get_taps() .group_size(RegisterGroup::Accum); - self.accum.resize(self.steps * accum_size, Fp::invalid()); + self.accum.resize(self.steps * accum_size, Fp::ZERO); let args: &mut [&mut [Fp]] = &mut [ &mut self.exec.code, &mut self.exec.output, @@ -101,12 +101,6 @@ impl<'a, C: CircuitDef, CS: CustomStep> Circuit for ProveAdapter<'a, C, CS> self.accum[j * self.steps + i] = Fp::random(&mut rng); } } - // Zero out 'invalid' entries in accum - for x in self.accum.iter_mut() { - if *x == Fp::invalid() { - *x = Fp::new(0); - } - } } fn eval_check( diff --git a/risc0/zkp/rust/src/prove/executor.rs b/risc0/zkp/rust/src/prove/executor.rs index 4bf8170d7a..f3c98a7280 100644 --- a/risc0/zkp/rust/src/prove/executor.rs +++ b/risc0/zkp/rust/src/prove/executor.rs @@ -20,7 +20,8 @@ use rand::thread_rng; use crate::{ adapter::{CircuitDef, CircuitStepContext, CustomStep}, - core::{fp::Fp, Random}, + core::fp::Fp, + field::Elem, taps::RegisterGroup, MIN_PO2, ZK_CYCLES, }; @@ -52,11 +53,11 @@ impl, S: CustomStep> Executor { circuit, custom, // Initialize trace to min_po2 size - code: vec![Fp::default(); steps * code_size], + code: vec![Fp::ZERO; steps * code_size], code_size, - data: vec![Fp::default(); steps * data_size], + data: vec![Fp::ZERO; steps * data_size], data_size, - output: vec![Fp::invalid(); output_size], + output: vec![Fp::ZERO; output_size], po2, steps, halted: false, @@ -95,7 +96,7 @@ impl, S: CustomStep> Executor { ]; let result = self.circuit.step_exec(&ctx, &mut self.custom, args)?; // debug!("result: {:?}", result); - self.halted = self.halted || result == Fp::new(0); + self.halted = self.halted || result == Fp::ZERO; self.cycle += 1; Ok(true) } @@ -105,8 +106,8 @@ impl, S: CustomStep> Executor { if self.steps >= (1 << self.max_po2) { bail!("Cannot expand, max po2 of {} reached.", self.max_po2); } - let mut new_code = vec![Fp::default(); self.code.len() * 2]; - let mut new_data = vec![Fp::default(); self.data.len() * 2]; + let mut new_code = vec![Fp::ZERO; self.code.len() * 2]; + let mut new_data = vec![Fp::ZERO; self.data.len() * 2]; for i in 0..self.code_size { let idx = i * self.steps; let src = &self.code[idx..idx + self.cycle]; @@ -134,7 +135,7 @@ impl, S: CustomStep> Executor { // Make code be all zeros of zk cycles, and data be random for i in self.cycle..self.steps { for j in 0..self.code_size { - self.code[j * self.steps + i] = Fp::new(0); + self.code[j * self.steps + i] = Fp::ZERO; } for j in 0..self.data_size { self.data[j * self.steps + i] = Fp::random(&mut rng); @@ -157,12 +158,6 @@ impl, S: CustomStep> Executor { .step_verify(&ctx, &mut self.custom, args) .unwrap(); } - // Zero out 'invalid' entries in data - for value in self.data.iter_mut() { - if *value == Fp::invalid() { - *value = Fp::new(0); - } - } } pub fn get_code(&self, cycle: usize, offset: usize) -> Fp { diff --git a/risc0/zkp/rust/src/prove/fri.rs b/risc0/zkp/rust/src/prove/fri.rs index 8969d977b5..64069d322f 100644 --- a/risc0/zkp/rust/src/prove/fri.rs +++ b/risc0/zkp/rust/src/prove/fri.rs @@ -23,8 +23,8 @@ use crate::{ fp4::{Fp4, EXT_SIZE}, log2_ceil, sha::Sha, - Random, }, + field::Elem, hal::{Buffer, Hal}, prove::{merkle::MerkleTreeProver, write_iop::WriteIOP}, FRI_FOLD, FRI_MIN_DEGREE, INV_RATE, QUERIES, diff --git a/risc0/zkp/rust/src/prove/mod.rs b/risc0/zkp/rust/src/prove/mod.rs index 7c9da9a117..929f167ec6 100644 --- a/risc0/zkp/rust/src/prove/mod.rs +++ b/risc0/zkp/rust/src/prove/mod.rs @@ -31,8 +31,8 @@ use crate::{ poly::{poly_divide, poly_interpolate}, rou::ROU_REV, sha::Sha, - Random, }, + field::Elem, hal::{Buffer, Hal}, prove::{fri::fri_prove, poly_group::PolyGroup, write_iop::WriteIOP}, taps::{RegisterGroup, TapSet}, @@ -203,7 +203,7 @@ pub fn prove(hal: &H, sha: &S, circuit: &mut C) -> V // Now, convert the values to coefficients via interpolation let mut pos = 0; - let mut coeff_u = vec![Fp4::default(); eval_u.len()]; + let mut coeff_u = vec![Fp4::ZERO; eval_u.len()]; for reg in taps.regs() { poly_interpolate( &mut coeff_u[pos..], @@ -238,9 +238,9 @@ pub fn prove(hal: &H, sha: &S, circuit: &mut C) -> V // Do the coefficent mixing // Begin by making a zeroed output buffer let combo_count = taps.combos_size(); - let combos = vec![Fp4::default(); size * (combo_count + 1)]; + let combos = vec![Fp4::ZERO; size * (combo_count + 1)]; let combos = hal.copy_from(combos.as_slice()); - let mut cur_mix = Fp4::one(); + let mut cur_mix = Fp4::ONE; let mut mix_group = |id: RegisterGroup, pg: &PolyGroup| { let mut which = Vec::new(); @@ -275,7 +275,7 @@ pub fn prove(hal: &H, sha: &S, circuit: &mut C) -> V combos.view_mut(&mut |combos| { // Subtract the U coeffs from the combos let mut cur_pos = 0; - let mut cur = Fp4::one(); + let mut cur = Fp4::ONE; for reg in taps.regs() { for i in 0..reg.size() { combos[size * reg.combo_id() + i] -= cur * coeff_u[cur_pos + i]; @@ -297,7 +297,7 @@ pub fn prove(hal: &H, sha: &S, circuit: &mut C) -> V &mut combos[combo * size..combo * size + size], z * back_one.pow((*back).into()) ), - Fp4::zero() + Fp4::ZERO ); } } @@ -307,7 +307,7 @@ pub fn prove(hal: &H, sha: &S, circuit: &mut C) -> V &mut combos[combo_count * size..combo_count * size + size], z4 ), - Fp4::zero() + Fp4::ZERO ); }); diff --git a/risc0/zkp/rust/src/verify/adapter.rs b/risc0/zkp/rust/src/verify/adapter.rs index acd370f268..79a2f7f195 100644 --- a/risc0/zkp/rust/src/verify/adapter.rs +++ b/risc0/zkp/rust/src/verify/adapter.rs @@ -20,8 +20,8 @@ use crate::{ fp::Fp, fp4::Fp4, sha::{Digest, Sha}, - Random, }, + field::Elem, taps::TapSet, verify::{read_iop::ReadIOP, Circuit, VerificationError}, }; @@ -53,7 +53,7 @@ impl<'a, C: CircuitInfo + PolyExt + TapsProvider> Circuit for VerifyAdapter<'a, fn execute(&mut self, iop: &mut ReadIOP) { // Read the outputs + size - self.out.resize(self.circuit.output_size(), Fp::default()); + self.out.resize(self.circuit.output_size(), Fp::ZERO); iop.read_fps(&mut self.out); let mut slice = [0u32; 1]; iop.read_u32s(&mut slice); diff --git a/risc0/zkp/rust/src/verify/fri.rs b/risc0/zkp/rust/src/verify/fri.rs index a80746dcb4..2b92671bf6 100644 --- a/risc0/zkp/rust/src/verify/fri.rs +++ b/risc0/zkp/rust/src/verify/fri.rs @@ -24,8 +24,8 @@ use crate::{ ntt::{bit_reverse, interpolate_ntt}, rou::{ROU_FWD, ROU_REV}, sha::Sha, - Random, }, + field::Elem, verify::{merkle::MerkleTreeVerifier, read_iop::ReadIOP}, FRI_FOLD, FRI_MIN_DEGREE, INV_RATE, QUERIES, }; @@ -44,9 +44,9 @@ fn fold_eval(values: &mut [Fp4], mix: Fp4, s: usize, j: usize) -> Fp4 { bit_reverse(values); let root_po2 = log2_ceil(FRI_FOLD * s); let inv_wk: Fp = Fp::new(ROU_REV[root_po2]).pow(j); - let mut mul = Fp::new(1); - let mut tot = Fp4::zero(); - let mut mix_pow = Fp4::one(); + let mut mul = Fp::ONE; + let mut tot = Fp4::ZERO; + let mut mix_pow = Fp4::ONE; for i in 0..FRI_FOLD { tot += values[i] * mul * mix_pow; mul *= inv_wk; @@ -101,7 +101,7 @@ where degree /= FRI_FOLD; } // Grab the final coeffs + commit - let mut final_coeffs = vec![Fp::default(); EXT_SIZE * degree]; + let mut final_coeffs = vec![Fp::ZERO; EXT_SIZE * degree]; iop.read_fps(&mut final_coeffs); let final_digest = iop.get_sha().hash_fps(&final_coeffs); // padding? iop.commit(&final_digest); @@ -119,8 +119,8 @@ where } // Do final verification let x = gen.pow(pos); - let mut fx = Fp4::zero(); - let mut cur = Fp::new(1); + let mut fx = Fp4::ZERO; + let mut cur = Fp::ONE; for i in 0..degree { let coeff = Fp4::new( final_coeffs[0 * degree + i], diff --git a/risc0/zkp/rust/src/verify/merkle.rs b/risc0/zkp/rust/src/verify/merkle.rs index e54d720ddb..633b1c0a38 100644 --- a/risc0/zkp/rust/src/verify/merkle.rs +++ b/risc0/zkp/rust/src/verify/merkle.rs @@ -19,6 +19,7 @@ use crate::{ fp::Fp, sha::{Digest, Sha}, }, + field::Elem, merkle::MerkleTreeParams, verify::read_iop::ReadIOP, }; @@ -65,7 +66,7 @@ impl MerkleTreeVerifier { pub fn verify(&self, iop: &mut ReadIOP, mut idx: usize) -> Vec { assert!(idx < self.params.row_size); // Initialize a vector to hold field elements. - let mut out = vec![Fp::new(0); self.params.col_size]; + let mut out = vec![Fp::ZERO; self.params.col_size]; // Read out field elements from IOP. iop.read_fps(&mut out); // Get the hash at the leaf of the tree by hashing these field elements. diff --git a/risc0/zkp/rust/src/verify/mod.rs b/risc0/zkp/rust/src/verify/mod.rs index 04d6fb8f66..48feef1a91 100644 --- a/risc0/zkp/rust/src/verify/mod.rs +++ b/risc0/zkp/rust/src/verify/mod.rs @@ -29,8 +29,8 @@ use crate::{ poly::poly_eval, rou::{ROU_FWD, ROU_REV}, sha::{Digest, Sha}, - Random, }, + field::Elem, taps::{RegisterGroup, TapSet}, verify::{fri::fri_verify, merkle::MerkleTreeVerifier, read_iop::ReadIOP}, INV_RATE, MAX_CYCLES_PO2, QUERIES, @@ -118,7 +118,7 @@ where // Read the U coeffs + commit their hash let num_taps = taps.tap_size(); - let mut coeff_u = vec![Fp4::default(); num_taps + CHECK_SIZE]; + let mut coeff_u = vec![Fp4::ZERO; num_taps + CHECK_SIZE]; iop.read_fp4s(&mut coeff_u); let hash_u = *sha.hash_fp4s(&coeff_u); iop.commit(&hash_u); @@ -140,7 +140,7 @@ where // debug!("Result = {result:?}"); // Now generate the check polynomial - let mut check = Fp4::zero(); + let mut check = Fp4::ZERO; let remap = [0, 2, 1, 3]; let fp0 = Fp::from(0 as u32); let fp1 = Fp::from(1 as u32); @@ -151,7 +151,7 @@ where check += coeff_u[num_taps + rmi + 8] * z.pow(i) * Fp4::new(fp0, fp0, fp1, fp0); check += coeff_u[num_taps + rmi + 12] * z.pow(i) * Fp4::new(fp0, fp0, fp0, fp1); } - check *= (Fp4::from_u32(3) * z).pow(size) - Fp4::one(); + check *= (Fp4::from_u32(3) * z).pow(size) - Fp4::ONE; // debug!("Check = {check:?}"); assert_eq!(check, result); @@ -162,9 +162,9 @@ where // Make the mixed U polynomials let mut combo_u = vec![]; for i in 0..combo_count { - combo_u.push(vec![Fp4::zero(); taps.get_combo(i).size()]); + combo_u.push(vec![Fp4::ZERO; taps.get_combo(i).size()]); } - let mut cur_mix = Fp4::one(); + let mut cur_mix = Fp4::ONE; cur_pos = 0; for reg in taps.regs() { for i in 0..reg.size() { @@ -175,7 +175,7 @@ where } // debug!("cur_mix: {cur_mix:?}, cur_pos: {cur_pos}"); // Handle check group - combo_u.push(vec![Fp4::zero()]); + combo_u.push(vec![Fp4::ZERO]); for _ in 0..CHECK_SIZE { combo_u[combo_count][0] += cur_mix * coeff_u[cur_pos]; cur_pos += 1; @@ -192,8 +192,8 @@ where rows.push(code_merkle.verify(iop, idx)); rows.push(data_merkle.verify(iop, idx)); let check_row = check_merkle.verify(iop, idx); - let mut cur = Fp4::one(); - let mut tot = vec![Fp4::zero(); combo_count + 1]; + let mut cur = Fp4::ONE; + let mut tot = vec![Fp4::ZERO; combo_count + 1]; for reg in taps.regs() { tot[reg.combo_id()] += cur * rows[reg.group() as usize][reg.offset()]; cur *= mix; @@ -202,10 +202,10 @@ where tot[combo_count] += cur * check_row[i]; cur *= mix; } - let mut ret = Fp4::zero(); + let mut ret = Fp4::ZERO; for i in 0..combo_count { let num = tot[i] - poly_eval(&combo_u[i], x); - let mut divisor = Fp4::one(); + let mut divisor = Fp4::ONE; for back in taps.get_combo(i).slice() { divisor *= x - z * back_one.pow(*back as usize); } diff --git a/risc0/zkvm/circuit/lib.rs b/risc0/zkvm/circuit/lib.rs index 59833daa17..91c53e4c1f 100644 --- a/risc0/zkvm/circuit/lib.rs +++ b/risc0/zkvm/circuit/lib.rs @@ -26,7 +26,8 @@ mod ffi { } } -/// Produces a machine generated .h file that implements the RISC-V circuit and writes it to a file. +/// Produces a machine generated .h file that implements the RISC-V circuit and +/// writes it to a file. pub fn make_circuit(path: &str) -> Result<()> { let_cxx_string!(path = path); Ok(ffi::make_circuit(&path)?) diff --git a/risc0/zkvm/sdk/rust/src/prove/exec.rs b/risc0/zkvm/sdk/rust/src/prove/exec.rs index 66a6b2ebc6..116e09fa15 100644 --- a/risc0/zkvm/sdk/rust/src/prove/exec.rs +++ b/risc0/zkvm/sdk/rust/src/prove/exec.rs @@ -25,6 +25,7 @@ use risc0_zkp::core::sha::Sha; use risc0_zkp::{ adapter::{CircuitDef, CustomStep}, core::{fp::Fp, log2_ceil, sha::DIGEST_WORDS}, + field::Elem, prove::executor::Executor, MAX_CYCLES_PO2, ZK_CYCLES, }; @@ -342,7 +343,7 @@ impl<'a, H: IoHandler> MachineContext<'a, H> { ( event.cycle.into(), event.addr.into(), - event.is_write.into(), + if event.is_write { Fp::ONE } else { Fp::ZERO }, parts.0, parts.1, )