Skip to content

Commit

Permalink
Merge pull request zkcrypto#83 from zkcrypto/multiexp-exponent-caching
Browse files Browse the repository at this point in the history
Refactor `multiexp` to cache exponent chunks
  • Loading branch information
ebfull authored May 7, 2022
2 parents 2759d93 + 46b5a6e commit d886340
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 38 deletions.
11 changes: 11 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,17 @@ and this project adheres to Rust's notion of
[Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]
### Added
- `bellman::multiexp::Exponent`

### Changed
- `bellman::multiexp::multiexp` now takes exponents as `Arc<Vec<Exponent<_>>>`
instead of `Arc<Vec<FieldBits<_>>>`.

### Fixed
- Migrating from `bitvec 0.22` to `bitvec 1.0` caused a performance regression
in `bellman::multiexp::multiexp`, slowing down proof creation. Some of that
performance has been regained by refactoring `multiexp`.

## [0.12.0] - 2022-05-04
### Changed
Expand Down
4 changes: 2 additions & 2 deletions benches/slow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use bellman::{
};
use bls12_381::{Bls12, Scalar};
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use ff::{Field, PrimeFieldBits};
use ff::Field;
use group::{Curve, Group};
use pairing::Engine;
use rand_core::SeedableRng;
Expand All @@ -20,7 +20,7 @@ fn bench_parts(c: &mut Criterion) {
.map(|_| Scalar::random(&mut rng))
.collect::<Vec<_>>(),
);
let v_bits = Arc::new(v.iter().map(|e| e.to_le_bits()).collect::<Vec<_>>());
let v_bits = Arc::new(v.iter().map(|e| e.into()).collect::<Vec<_>>());
let g = Arc::new(
(0..samples)
.map(|_| <Bls12 as Engine>::G1::random(&mut rng).to_affine())
Expand Down
6 changes: 3 additions & 3 deletions src/groth16/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ where
let a_len = a.len() - 1;
a.truncate(a_len);
// TODO: parallelize if it's even helpful
let a = Arc::new(a.into_iter().map(|s| s.0.to_le_bits()).collect::<Vec<_>>());
let a = Arc::new(a.into_iter().map(|s| s.0.into()).collect::<Vec<_>>());

multiexp(&worker, params.get_h(a.len())?, FullDensity, a)
};
Expand All @@ -240,14 +240,14 @@ where
prover
.input_assignment
.into_iter()
.map(|s| s.to_le_bits())
.map(|s| s.into())
.collect::<Vec<_>>(),
);
let aux_assignment = Arc::new(
prover
.aux_assignment
.into_iter()
.map(|s| s.to_le_bits())
.map(|s| s.into())
.collect::<Vec<_>>(),
);

Expand Down
117 changes: 84 additions & 33 deletions src/multiexp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,61 @@ impl DensityTracker {
}
}

enum ChunkedExponent {
Zero,
One,
Chunks(Vec<u64>),
}

/// An exponent
pub enum Exponent<F: PrimeFieldBits> {
Zero,
One,
Bits(FieldBits<F::ReprBits>),
}

impl<F: PrimeFieldBits> From<&F> for Exponent<F> {
fn from(exp: &F) -> Self {
if exp.is_zero_vartime() {
Exponent::Zero
} else if exp == &F::one() {
Exponent::One
} else {
Exponent::Bits(exp.to_le_bits())
}
}
}

impl<F: PrimeFieldBits> From<F> for Exponent<F> {
fn from(exp: F) -> Self {
(&exp).into()
}
}

impl<F: PrimeFieldBits> Exponent<F> {
fn chunks(&self, c: usize) -> ChunkedExponent {
match self {
Self::Zero => ChunkedExponent::Zero,
Self::One => ChunkedExponent::One,
Self::Bits(exp) => ChunkedExponent::Chunks(
exp.chunks(c)
.map(|chunk| {
chunk
.iter()
.by_vals()
.enumerate()
.fold(0u64, |acc, (i, b)| acc + ((b as u64) << i))
})
.collect(),
),
}
}
}

fn multiexp_inner<Q, D, G, S>(
bases: S,
density_map: D,
exponents: Arc<Vec<FieldBits<<G::Scalar as PrimeFieldBits>::ReprBits>>>,
exponents: Arc<Vec<Exponent<G::Scalar>>>,
c: u32,
) -> Result<G, SynthesisError>
where
Expand All @@ -172,8 +223,8 @@ where
// Perform this region of the multiexp
let this = move |bases: S,
density_map: D,
exponents: Arc<Vec<FieldBits<<G::Scalar as PrimeFieldBits>::ReprBits>>>,
skip: u32|
exponents: Arc<Vec<ChunkedExponent>>,
chunk: usize|
-> Result<_, SynthesisError> {
// Accumulate the result
let mut acc = G::identity();
Expand All @@ -185,38 +236,29 @@ where
let mut buckets = vec![G::identity(); (1 << c) - 1];

// only the first round uses this
let handle_trivial = skip == 0;
let handle_trivial = chunk == 0;

// Sort the bases into buckets
for (exp, density) in exponents.iter().zip(density_map.as_ref().iter()) {
if density {
let (exp_is_zero, exp_is_one) = {
let (first, rest) = exp.split_first().unwrap();
let rest_unset = rest.not_any();
(!*first && rest_unset, *first && rest_unset)
};

if exp_is_zero {
bases.skip(1)?;
} else if exp_is_one {
if handle_trivial {
acc.add_assign_from_source(&mut bases)?;
} else {
bases.skip(1)?;
match exp {
ChunkedExponent::Zero => bases.skip(1)?,
ChunkedExponent::One => {
if handle_trivial {
acc.add_assign_from_source(&mut bases)?;
} else {
bases.skip(1)?;
}
}
} else {
let exp = exp
.into_iter()
.by_vals()
.skip(skip as usize)
.take(c as usize)
.enumerate()
.fold(0u64, |acc, (i, b)| acc + ((b as u64) << i));

if exp != 0 {
(&mut buckets[(exp - 1) as usize]).add_assign_from_source(&mut bases)?;
} else {
bases.skip(1)?;
ChunkedExponent::Chunks(chunks) => {
let exp = chunks[chunk];

if exp != 0 {
(&mut buckets[(exp - 1) as usize])
.add_assign_from_source(&mut bases)?;
} else {
bases.skip(1)?;
}
}
}
}
Expand All @@ -235,10 +277,19 @@ where
Ok(acc)
};

// Split the exponents into chunks.
let exponents = Arc::new(
exponents
.iter()
.map(|exp| exp.chunks(c as usize))
.collect::<Vec<_>>(),
);

let parts = (0..G::Scalar::NUM_BITS)
.into_par_iter()
.step_by(c as usize)
.map(|skip| this(bases.clone(), density_map.clone(), exponents.clone(), skip))
.enumerate()
.map(|(chunk, _)| this(bases.clone(), density_map.clone(), exponents.clone(), chunk))
.collect::<Vec<Result<_, _>>>();

parts
Expand All @@ -255,7 +306,7 @@ pub fn multiexp<Q, D, G, S>(
pool: &Worker,
bases: S,
density_map: D,
exponents: Arc<Vec<FieldBits<<G::Scalar as PrimeFieldBits>::ReprBits>>>,
exponents: Arc<Vec<Exponent<G::Scalar>>>,
) -> Waiter<Result<G, SynthesisError>>
where
for<'a> &'a Q: QueryDensity,
Expand Down Expand Up @@ -311,7 +362,7 @@ fn test_with_bls12() {
.map(|_| Scalar::random(&mut rng))
.collect::<Vec<_>>(),
);
let v_bits = Arc::new(v.iter().map(|e| e.to_le_bits()).collect::<Vec<_>>());
let v_bits = Arc::new(v.iter().map(|e| e.into()).collect::<Vec<_>>());
let g = Arc::new(
(0..SAMPLES)
.map(|_| <Bls12 as Engine>::G1::random(&mut rng).to_affine())
Expand Down

0 comments on commit d886340

Please sign in to comment.