Skip to content

Commit

Permalink
feat: add montgomery/regular form conversion functions
Browse files Browse the repository at this point in the history
  • Loading branch information
keyvank committed May 7, 2020
1 parent 2e5c4af commit ea3bd90
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 6 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ repository = "https://github.com/filecoin-project/ff-cl-gen"
[dependencies]
ff = { version = "0.2.0", package = "fff" }
itertools = { version = "0.8.0" }
num-bigint = "0.2"

[dev-dependencies]
ocl = { version = "0.19.4", package = "fil-ocl"}
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ FIELD FIELD_sqr(FIELD a); // Modular squaring
FIELD FIELD_double(FIELD a); // Modular doubling
FIELD FIELD_pow(FIELD base, uint exponent); // Modular power
FIELD FIELD_pow_lookup(global FIELD *bases, uint exponent); // Modular power with lookup table for bases
FIELD FIELD_mont(FIELD a); // To montgomery form
FIELD FIELD_unmont(FIELD a); // To regular form
bool FIELD_get_bit(FIELD l, uint i); // Get `i`th bit (From most significant digit)
uint FIELD_get_bits(FIELD l, uint skip, uint window); // Get `window` consecutive bits, (Starting from `skip`th bit from most significant digit)
```
Expand Down
10 changes: 10 additions & 0 deletions src/cl/field.cl
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,16 @@ FIELD FIELD_pow_lookup(__global FIELD *bases, uint exponent) {
return res;
}

FIELD FIELD_mont(FIELD a) {
return FIELD_mul(a, FIELD_R2);
}

FIELD FIELD_unmont(FIELD a) {
FIELD one = FIELD_ZERO;
one.val[0] = 1;
return FIELD_mul(a, one);
}

// Get `i`th bit (From most significant digit) of the field.
bool FIELD_get_bit(FIELD l, uint i) {
return (l.val[FIELD_LIMBS - 1 - i / LIMB_BITS] >> (LIMB_BITS - 1 - (i % LIMB_BITS))) & 1;
Expand Down
3 changes: 3 additions & 0 deletions src/cl/test.cl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ __kernel void test(__global uint *result) {

Fr a = Fr_pow(two, 123456);
Fr b = Fr_pow(eight, 41152);

a = Fr_unmont(a);
a = Fr_mont(a);

*result = Fr_eq(a, b);
}
42 changes: 36 additions & 6 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use ff::PrimeField;
use itertools::join;
use itertools::*;
use num_bigint::BigUint;

static COMMON_SRC: &str = include_str!("cl/common.cl");
static FIELD_SRC: &str = include_str!("cl/field.cl");

/// Divide anything into 64bit chunks
fn limbs_of<T>(value: T) -> Vec<u64> {
fn u64_limbs_of<T>(value: T) -> Vec<u64> {
unsafe {
std::slice::from_raw_parts(
&value as *const T as *const u64,
Expand All @@ -15,6 +16,17 @@ fn limbs_of<T>(value: T) -> Vec<u64> {
}
}

/// Divide anything into 32bit chunks
fn u32_limbs_of<T>(value: T) -> Vec<u32> {
unsafe {
std::slice::from_raw_parts(
&value as *const T as *const u32,
std::mem::size_of::<T>() / 4,
)
.to_vec()
}
}

/// Calculate the `INV` parameter of Montgomery reduction algorithm for 64bit limbs
/// * `a` - Is the first limb of modulus
fn calc_inv(a: u64) -> u64 {
Expand All @@ -34,22 +46,40 @@ fn define_field(postfix: &str, limbs: Vec<u64>) -> String {
)
}

fn calculate_r2<F: PrimeField>() -> Vec<u64> {
// R ^ 2 mod P
let r2 = BigUint::new(u32_limbs_of(F::one()))
.modpow(
&BigUint::from_slice(&[2]), // ^ 2
&BigUint::new(u32_limbs_of(F::char())), // mod P
)
.to_u32_digits();
r2.iter()
.tuples()
.map(|(lo, hi)| ((*hi as u64) << 32) + (*lo as u64))
.collect()
}

fn params<F>() -> String
where
F: PrimeField,
{
let one = limbs_of(F::one()); // Get Montomery form of F::one()
let p = limbs_of(F::char()); // Get regular form of field modulus
let one = u64_limbs_of(F::one()); // Get Montomery form of F::one()
let p = u64_limbs_of(F::char()); // Get regular form of field modulus
let r2 = calculate_r2::<F>();
let limbs = one.len(); // Number of limbs
let inv = calc_inv(p[0]);
let limbs_def = format!("#define FIELD_LIMBS {}", limbs);
let p_def = define_field("P", p);
let r2_def = define_field("R2", r2);
let one_def = define_field("ONE", one);
let zero_def = define_field("ZERO", vec![0u64; limbs]);
let inv_def = format!("#define FIELD_INV {}", inv);
let typedef = format!("typedef struct {{ limb val[FIELD_LIMBS]; }} FIELD;");
join(
&[limbs_def, one_def, p_def, zero_def, inv_def, typedef],
&[
limbs_def, one_def, p_def, r2_def, zero_def, inv_def, typedef,
],
"\n",
)
}
Expand All @@ -61,7 +91,7 @@ where
let mut result = String::new();

for op in &["sub", "add"] {
let len = limbs_of(F::one()).len();
let len = u64_limbs_of(F::one()).len();

let mut src = format!("FIELD FIELD_{}_(FIELD a, FIELD b) {{\n", op);
if len > 1 {
Expand Down

0 comments on commit ea3bd90

Please sign in to comment.