Skip to content

Commit

Permalink
implement prefetch for nightly
Browse files Browse the repository at this point in the history
  • Loading branch information
shamatar committed Jul 13, 2019
1 parent 2d9f552 commit 07ca2e8
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 7 deletions.
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ byteorder = "1"
futures-cpupool = {version = "0.1", optional = true}
num_cpus = {version = "1", optional = true}
crossbeam = {version = "0.7.1", optional = true}
prefetch = {version = "0.2", optional = true}

web-sys = {version = "0.3.17", optional = true, features = ["console", "Performance", "Window"]}

Expand All @@ -33,8 +34,10 @@ blake2-rfc = {version = "0.2.18", optional = true}

[features]
default = ["multicore"]
#default = ["multicore", "nightly"]
#default = ["wasm"]
multicore = ["futures-cpupool", "num_cpus", "crossbeam"]
sonic = ["tiny-keccak", "blake2-rfc"]
gm17 = []
wasm = ["web-sys"]
nightly = ["prefetch"]
105 changes: 98 additions & 7 deletions src/multiexp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ use super::worker::Worker;

use super::SynthesisError;

use cfg_if;

/// This genious piece of code works in the following way:
/// - choose `c` - the bit length of the region that one thread works on
/// - make `2^c - 1` buckets and initialize them with `G = infinity` (that's equivalent of zero)
Expand Down Expand Up @@ -47,6 +49,7 @@ use super::SynthesisError;
/// - accumulators over each set of buckets will have an implicit factor of `(2^c)^i`, so before summing thme up
/// "higher" accumulators must be doubled `c` times
///
#[cfg(not(feature = "nightly"))]
fn multiexp_inner<Q, D, G, S>(
pool: &Worker,
bases: S,
Expand All @@ -56,7 +59,7 @@ fn multiexp_inner<Q, D, G, S>(
mut skip: u32,
c: u32,
handle_trivial: bool
) -> Box<Future<Item=<G as CurveAffine>::Projective, Error=SynthesisError>>
) -> Box<dyn Future<Item=<G as CurveAffine>::Projective, Error=SynthesisError>>
where for<'a> &'a Q: QueryDensity,
D: Send + Sync + 'static + Clone + AsRef<Q>,
G: CurveAffine,
Expand Down Expand Up @@ -153,6 +156,53 @@ fn multiexp_inner<Q, D, G, S>(
}
}


cfg_if! {
if #[cfg(feature = "nightly")] {
#[inline(always)]
fn multiexp_inner_impl<Q, D, G, S>(
pool: &Worker,
bases: S,
density_map: D,
exponents: Arc<Vec<<G::Scalar as PrimeField>::Repr>>,
skip: u32,
c: u32,
handle_trivial: bool
) -> Box<dyn Future<Item=<G as CurveAffine>::Projective, Error=SynthesisError>>
where for<'a> &'a Q: QueryDensity,
D: Send + Sync + 'static + Clone + AsRef<Q>,
G: CurveAffine,
S: SourceBuilder<G>
{
multiexp_inner_with_prefetch(pool, bases, density_map, exponents, skip, c, handle_trivial)
}
} else {
#[inline(always)]
fn multiexp_inner_impl<Q, D, G, S>(
pool: &Worker,
bases: S,
density_map: D,
exponents: Arc<Vec<<G::Scalar as PrimeField>::Repr>>,
skip: u32,
c: u32,
handle_trivial: bool
) -> Box<dyn Future<Item=<G as CurveAffine>::Projective, Error=SynthesisError>>
where for<'a> &'a Q: QueryDensity,
D: Send + Sync + 'static + Clone + AsRef<Q>,
G: CurveAffine,
S: SourceBuilder<G>
{
multiexp_inner(pool, bases, density_map, exponents, skip, c, handle_trivial)
}
}
}



#[cfg(feature = "nightly")]
extern crate prefetch;

#[cfg(feature = "nightly")]
fn multiexp_inner_with_prefetch<Q, D, G, S>(
pool: &Worker,
bases: S,
Expand All @@ -161,12 +211,13 @@ fn multiexp_inner_with_prefetch<Q, D, G, S>(
mut skip: u32,
c: u32,
handle_trivial: bool
) -> Box<Future<Item=<G as CurveAffine>::Projective, Error=SynthesisError>>
) -> Box<dyn Future<Item=<G as CurveAffine>::Projective, Error=SynthesisError>>
where for<'a> &'a Q: QueryDensity,
D: Send + Sync + 'static + Clone + AsRef<Q>,
G: CurveAffine,
S: SourceBuilder<G>
{
use prefetch::prefetch::*;
// Perform this region of the multiexp
let this = {
let bases = bases.clone();
Expand All @@ -191,12 +242,23 @@ fn multiexp_inner_with_prefetch<Q, D, G, S>(
let one = <G::Engine as ScalarEngine>::Fr::one().into_repr();
let padding = Arc::new(vec![zero]);

let mask = 1 << c;

// Sort the bases into buckets
for ((&exp, &next_exp), density) in exponents.iter()
.zip(exponents.iter().skip(1).chain(padding.iter()))
.zip(density_map.as_ref().iter()) {
// no matter what happens - prefetch next bucket

if next_exp != zero && next_exp != one {
let mut next_exp = next_exp;
next_exp.shr(skip);
let next_exp = next_exp.as_ref()[0] % mask;
if next_exp != 0 {
let p: *const <G as CurveAffine>::Projective = &buckets[(next_exp - 1) as usize];
prefetch::<Write, High, Data, _>(p);
}

}
// Go over density and exponents
if density {
if exp == zero {
Expand All @@ -215,7 +277,7 @@ fn multiexp_inner_with_prefetch<Q, D, G, S>(
// then add with (s mod 2^c) P parts
let mut exp = exp;
exp.shr(skip);
let exp = exp.as_ref()[0] % (1 << c);
let exp = exp.as_ref()[0] % mask;

if exp != 0 {
bases.add_assign_mixed(&mut buckets[(exp - 1) as usize])?;
Expand Down Expand Up @@ -249,7 +311,7 @@ fn multiexp_inner_with_prefetch<Q, D, G, S>(
// There's another region more significant. Calculate and join it with
// this region recursively.
Box::new(
this.join(multiexp_inner(pool, bases, density_map, exponents, skip, c, false))
this.join(multiexp_inner_with_prefetch(pool, bases, density_map, exponents, skip, c, false))
.map(move |(this, mut higher)| {
for _ in 0..c {
higher.double();
Expand All @@ -270,7 +332,7 @@ pub fn multiexp<Q, D, G, S>(
bases: S,
density_map: D,
exponents: Arc<Vec<<<G::Engine as ScalarEngine>::Fr as PrimeField>::Repr>>
) -> Box<Future<Item=<G as CurveAffine>::Projective, Error=SynthesisError>>
) -> Box<dyn Future<Item=<G as CurveAffine>::Projective, Error=SynthesisError>>
where for<'a> &'a Q: QueryDensity,
D: Send + Sync + 'static + Clone + AsRef<Q>,
G: CurveAffine,
Expand All @@ -289,7 +351,7 @@ pub fn multiexp<Q, D, G, S>(
assert!(query_size == exponents.len());
}

multiexp_inner(pool, bases, density_map, exponents, 0, c, true)
multiexp_inner_impl(pool, bases, density_map, exponents, 0, c, true)
}


Expand Down Expand Up @@ -525,4 +587,33 @@ fn test_dense_multiexp() {
println!("{} ns for sparse for {} samples", duration_ns, SAMPLES);

assert_eq!(dense, sparse);
}


#[test]
fn test_bench_sparse_multiexp() {
use rand::{XorShiftRng, SeedableRng, Rand, Rng};
use crate::pairing::bn256::Bn256;
use num_cpus;

const SAMPLES: usize = 1 << 22;
let rng = &mut XorShiftRng::from_seed([0x3dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]);

let v = (0..SAMPLES).map(|_| <Bn256 as ScalarEngine>::Fr::rand(rng).into_repr()).collect::<Vec<_>>();
let g = (0..SAMPLES).map(|_| <Bn256 as Engine>::G1::rand(rng).into_affine()).collect::<Vec<_>>();

println!("Done generating test points and scalars");

let pool = Worker::new();
let start = std::time::Instant::now();

let _sparse = multiexp(
&pool,
(Arc::new(g), 0),
FullDensity,
Arc::new(v)
).wait().unwrap();

let duration_ns = start.elapsed().as_nanos() as f64;
println!("{} ms for sparse for {} samples", duration_ns/1000.0f64, SAMPLES);
}

0 comments on commit 07ca2e8

Please sign in to comment.