Skip to content

Commit

Permalink
add tests for multiscalar
Browse files Browse the repository at this point in the history
  • Loading branch information
dignifiedquire committed Oct 27, 2020
1 parent 4a0e2df commit 0513fd7
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 29 deletions.
146 changes: 122 additions & 24 deletions src/groth16/multiscalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,39 +6,35 @@ use crate::bls::Engine;

pub const WINDOW_SIZE: usize = 8;

#[cfg(target_arch = "x86_64")]
fn prefetch<T>(p: *const T) {
unsafe {
core::arch::x86_64::_mm_prefetch(p as *const _, core::arch::x86_64::_MM_HINT_T0);
}
/// Abstraction over either a slice or a getter to produce a fixed number of scalars.
pub enum ScalarList<'a, E: Engine, F: Fn(usize) -> <E::Fr as PrimeField>::Repr + Sync + Send> {
Slice(&'a [<E::Fr as PrimeField>::Repr]),
Getter(F, usize),
}

#[cfg(target_arch = "aarch64")]
fn prefetch<T>(p: *const T) {
unsafe {
use std::arch::aarch64::*;
_prefetch(p, _PREFETCH_READ, _PREFETCH_LOCALITY3);
impl<'a, E: Engine, F: Fn(usize) -> <E::Fr as PrimeField>::Repr + Sync + Send>
ScalarList<'a, E, F>
{
pub fn len(&self) -> usize {
match self {
ScalarList::Slice(s) => s.len(),
ScalarList::Getter(_, len) => *len,
}
}
}

#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
fn prefetch<T>(p: *const T) {}

pub enum PublicInputs<'a, E: Engine, F: Fn(usize) -> <E::Fr as PrimeField>::Repr + Sync + Send> {
Slice(&'a [<E::Fr as PrimeField>::Repr]),
Getter(F),
}

pub type Getter<E> =
dyn Fn(usize) -> <<E as ff::ScalarEngine>::Fr as PrimeField>::Repr + Sync + Send;

/// Abstraction over owned and referenced multiscalar precomputations.
pub trait MultiscalarPrecomp<E: Engine>: Send + Sync {
fn window_size(&self) -> usize;
fn window_mask(&self) -> u64;
fn tables(&self) -> &[Vec<E::G1Affine>];
fn at_point(&self, idx: usize) -> MultiscalarPrecompRef<'_, E>;
}

/// Owned variant of the multiscalar precomputations.
#[derive(Debug)]
pub struct MultiscalarPrecompOwned<E: Engine> {
num_points: usize,
Expand Down Expand Up @@ -72,6 +68,8 @@ impl<E: Engine> MultiscalarPrecomp<E> for MultiscalarPrecompOwned<E> {
}
}

/// Referenced version of the multiscalar precomputations.
#[derive(Debug)]
pub struct MultiscalarPrecompRef<'a, E: Engine> {
num_points: usize,
window_size: usize,
Expand Down Expand Up @@ -104,7 +102,7 @@ impl<E: Engine> MultiscalarPrecomp<E> for MultiscalarPrecompRef<'_, E> {
}
}

/// Precompute tables for fixed bases.
/// Precompute the tables for fixed bases.
pub fn precompute_fixed_window<E: Engine>(
points: &[E::G1Affine],
window_size: usize,
Expand Down Expand Up @@ -195,14 +193,15 @@ pub fn multiscalar<E: Engine>(

/// Perform a threaded multiscalar multiplication and accumulation.
pub fn par_multiscalar<F, E: Engine>(
k: &PublicInputs<'_, E, F>,
points: &ScalarList<'_, E, F>,
precomp_table: &dyn MultiscalarPrecomp<E>,
num_points: usize,
nbits: usize,
) -> E::G1
where
F: Fn(usize) -> <E::Fr as PrimeField>::Repr + Sync + Send,
{
let num_points = points.len();

// The granularity of work, in points. When a thread gets work it will
// gather chunk_size points, perform muliscalar on them, and accumulate
// the result. This is more efficient than evenly dividing the work among
Expand Down Expand Up @@ -234,9 +233,9 @@ where
}

let subset = precomp_table.at_point(start_idx);
let scalars = match k {
PublicInputs::Slice(ref s) => &s[start_idx..end_idx],
PublicInputs::Getter(ref getter) => {
let scalars = match points {
ScalarList::Slice(ref s) => &s[start_idx..end_idx],
ScalarList::Getter(ref getter, _) => {
for i in start_idx..end_idx {
scalar_storage[i - start_idx] = getter(i);
}
Expand All @@ -254,3 +253,102 @@ where
},
)
}

#[cfg(target_arch = "x86_64")]
fn prefetch<T>(p: *const T) {
unsafe {
core::arch::x86_64::_mm_prefetch(p as *const _, core::arch::x86_64::_MM_HINT_T0);
}
}

#[cfg(target_arch = "aarch64")]
fn prefetch<T>(p: *const T) {
unsafe {
use std::arch::aarch64::*;
_prefetch(p, _PREFETCH_READ, _PREFETCH_LOCALITY3);
}
}

#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
fn prefetch<T>(p: *const T) {}

#[cfg(test)]
mod tests {
use super::*;

use crate::bls::{Bls12, Fr, FrRepr, G1Affine, G1Projective};

use ff::Field;
use rand_core::SeedableRng;
use rand_xorshift::XorShiftRng;

fn multiscalar_naive(points: &[G1Affine], scalars: &[FrRepr]) -> G1Projective {
let mut acc = G1Projective::zero();
for (scalar, point) in scalars.iter().zip(points.iter()) {
acc.add_assign(&point.mul(*scalar));
}
acc
}

#[test]
fn test_multiscalar_single() {
let mut rng = XorShiftRng::from_seed([
0x59, 0x62, 0xbe, 0x5d, 0x76, 0x3d, 0x31, 0x8d, 0x17, 0xdb, 0x37, 0x32, 0x54, 0x06,
0xbc, 0xe5,
]);

for _ in 0..50 {
for (num_inputs, window_size) in &[(8, 4), (12, 1), (10, 1), (20, 2)] {
let points: Vec<G1Affine> = (0..*num_inputs)
.map(|_| G1Projective::random(&mut rng).into_affine())
.collect();

let scalars: Vec<FrRepr> = (0..*num_inputs)
.map(|_| Fr::random(&mut rng).into_repr())
.collect();

let table = precompute_fixed_window::<Bls12>(&points, *window_size);

let naive_result = multiscalar_naive(&points, &scalars);
let fast_result = multiscalar::<Bls12>(
&scalars,
&table,
std::mem::size_of::<<Fr as PrimeField>::Repr>() * 8,
);

assert_eq!(naive_result, fast_result);
}
}
}

#[test]
fn test_multiscalar_par() {
let mut rng = XorShiftRng::from_seed([
0x59, 0x62, 0xbe, 0x5d, 0x76, 0x3d, 0x31, 0x8d, 0x17, 0xdb, 0x37, 0x32, 0x54, 0x06,
0xbc, 0xe5,
]);

for _ in 0..50 {
for (num_inputs, window_size) in &[(8, 4), (12, 1), (10, 1), (20, 2)] {
let points: Vec<G1Affine> = (0..*num_inputs)
.map(|_| G1Projective::random(&mut rng).into_affine())
.collect();

let scalars: Vec<FrRepr> = (0..*num_inputs)
.map(|_| Fr::random(&mut rng).into_repr())
.collect();

let table = precompute_fixed_window::<Bls12>(&points, *window_size);

let naive_result = multiscalar_naive(&points, &scalars);
let fast_result = par_multiscalar::<&Getter<Bls12>, Bls12>(
&ScalarList::Slice(&scalars),
&table,
std::mem::size_of::<<Fr as PrimeField>::Repr>() * 8,
);

assert_eq!(naive_result, fast_result);
}
}
}
}
7 changes: 2 additions & 5 deletions src/groth16/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ pub fn verify_proof<'a, E: Engine>(
if (public_inputs.len() + 1) != pvk.ic.len() {
return Err(SynthesisError::MalformedVerifyingKey);
}
let num_inputs = public_inputs.len();

// The original verification equation is:
// A * B = alpha * beta + inputs * gamma + C * delta
Expand Down Expand Up @@ -74,9 +73,8 @@ pub fn verify_proof<'a, E: Engine>(
public_inputs.iter().map(PrimeField::into_repr).collect();

let mut acc = multiscalar::par_multiscalar::<&multiscalar::Getter<E>, E>(
&multiscalar::PublicInputs::Slice(&public_inputs_repr),
&multiscalar::ScalarList::Slice(&public_inputs_repr),
&subset,
num_inputs,
std::mem::size_of::<<E::Fr as PrimeField>::Repr>() * 8,
);

Expand Down Expand Up @@ -193,9 +191,8 @@ where

// \sum Accum_Gamma
let acc_g_psi = multiscalar::par_multiscalar::<_, E>(
&multiscalar::PublicInputs::Getter(scalar_getter),
&multiscalar::ScalarList::Getter(scalar_getter, num_inputs + 1),
&pvk.multiscalar,
num_inputs + 1,
256,
);

Expand Down

0 comments on commit 0513fd7

Please sign in to comment.