Skip to content

Commit

Permalink
refact: single generic limbs_of function
Browse files Browse the repository at this point in the history
  • Loading branch information
keyvank committed May 27, 2020
1 parent ea3bd90 commit 56200cc
Showing 1 changed file with 10 additions and 21 deletions.
31 changes: 10 additions & 21 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,12 @@ use num_bigint::BigUint;
static COMMON_SRC: &str = include_str!("cl/common.cl");
static FIELD_SRC: &str = include_str!("cl/field.cl");

/// Divide anything into 64bit chunks
fn u64_limbs_of<T>(value: T) -> Vec<u64> {
/// Divide anything into limbs of type `E`
fn limbs_of<T, E: Clone>(value: T) -> Vec<E> {
unsafe {
std::slice::from_raw_parts(
&value as *const T as *const u64,
std::mem::size_of::<T>() / 8,
)
.to_vec()
}
}

/// Divide anything into 32bit chunks
fn u32_limbs_of<T>(value: T) -> Vec<u32> {
unsafe {
std::slice::from_raw_parts(
&value as *const T as *const u32,
std::mem::size_of::<T>() / 4,
&value as *const T as *const E,
std::mem::size_of::<T>() / std::mem::size_of::<E>(),
)
.to_vec()
}
Expand All @@ -48,10 +37,10 @@ fn define_field(postfix: &str, limbs: Vec<u64>) -> String {

fn calculate_r2<F: PrimeField>() -> Vec<u64> {
// R ^ 2 mod P
let r2 = BigUint::new(u32_limbs_of(F::one()))
let r2 = BigUint::new(limbs_of::<_, u32>(F::one()))
.modpow(
&BigUint::from_slice(&[2]), // ^ 2
&BigUint::new(u32_limbs_of(F::char())), // mod P
&BigUint::from_slice(&[2]), // ^ 2
&BigUint::new(limbs_of::<_, u32>(F::char())), // mod P
)
.to_u32_digits();
r2.iter()
Expand All @@ -64,8 +53,8 @@ fn params<F>() -> String
where
F: PrimeField,
{
let one = u64_limbs_of(F::one()); // Get Montomery form of F::one()
let p = u64_limbs_of(F::char()); // Get regular form of field modulus
let one = limbs_of::<_, u64>(F::one()); // Get Montomery form of F::one()
let p = limbs_of::<_, u64>(F::char()); // Get regular form of field modulus
let r2 = calculate_r2::<F>();
let limbs = one.len(); // Number of limbs
let inv = calc_inv(p[0]);
Expand All @@ -91,7 +80,7 @@ where
let mut result = String::new();

for op in &["sub", "add"] {
let len = u64_limbs_of(F::one()).len();
let len = limbs_of::<_, u64>(F::one()).len();

let mut src = format!("FIELD FIELD_{}_(FIELD a, FIELD b) {{\n", op);
if len > 1 {
Expand Down

0 comments on commit 56200cc

Please sign in to comment.