Skip to content

Commit

Permalink
NFC: Rename Fp4 -> FpExt (risc0#1304)
Browse files Browse the repository at this point in the history
  • Loading branch information
flaub authored Jan 9, 2024
1 parent ddedd18 commit 852f37e
Show file tree
Hide file tree
Showing 45 changed files with 361 additions and 422 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#pragma once

/// \file
/// Defines Fp4, a finite field F_p^4, based on Fp via the irreducable polynomial x^4 - 11.
/// Defines FpExt, a finite field F_p^4, based on Fp via the irreducable polynomial x^4 - 11.

#include "fp.h"

Expand All @@ -24,83 +24,83 @@
#define BETA Fp(11)
#define NBETA Fp(Fp::P - 11)

/// Intstances of Fp4 are element of a finite field F_p^4. They are represented as elements of
/// Intstances of FpExt are element of a finite field F_p^4. They are represented as elements of
/// F_p[X] / (X^4 - 11). Basically, this is a 'big' finite field (about 2^128 elements), which is
/// used when the security of various operations depends on the size of the field. It has the field
/// Fp as a subfield, which means operations by the two are compatable, which is important. The
/// irreducible polynomial was choosen to be the simpilest possible one, x^4 - B, where 11 is the
/// smallest B which makes the polynomial irreducable.
struct Fp4 {
/// The elements of Fp4, elems[0] + elems[1]*X + elems[2]*X^2 + elems[3]*x^4
struct FpExt {
/// The elements of FpExt, elems[0] + elems[1]*X + elems[2]*X^2 + elems[3]*x^4
Fp elems[4];

/// Default constructor makes the zero elements
__device__ constexpr Fp4() {}
__device__ constexpr FpExt() {}

/// Initialize from uint32_t
__device__ explicit constexpr Fp4(uint32_t x) {
__device__ explicit constexpr FpExt(uint32_t x) {
elems[0] = x;
elems[1] = 0;
elems[2] = 0;
elems[3] = 0;
}

/// Convert from Fp to Fp4.
__device__ explicit constexpr Fp4(Fp x) {
/// Convert from Fp to FpExt.
__device__ explicit constexpr FpExt(Fp x) {
elems[0] = x;
elems[1] = 0;
elems[2] = 0;
elems[3] = 0;
}

/// Explicitly construct an Fp4 from parts
__device__ constexpr Fp4(Fp a, Fp b, Fp c, Fp d) {
/// Explicitly construct an FpExt from parts
__device__ constexpr FpExt(Fp a, Fp b, Fp c, Fp d) {
elems[0] = a;
elems[1] = b;
elems[2] = c;
elems[3] = d;
}

// Implement the addition/subtraction overloads
__device__ constexpr Fp4 operator+=(Fp4 rhs) {
__device__ constexpr FpExt operator+=(FpExt rhs) {
for (uint32_t i = 0; i < 4; i++) {
elems[i] += rhs.elems[i];
}
return *this;
}

__device__ constexpr Fp4 operator-=(Fp4 rhs) {
__device__ constexpr FpExt operator-=(FpExt rhs) {
for (uint32_t i = 0; i < 4; i++) {
elems[i] -= rhs.elems[i];
}
return *this;
}

__device__ constexpr Fp4 operator+(Fp4 rhs) const {
Fp4 result = *this;
__device__ constexpr FpExt operator+(FpExt rhs) const {
FpExt result = *this;
result += rhs;
return result;
}

__device__ constexpr Fp4 operator-(Fp4 rhs) const {
Fp4 result = *this;
__device__ constexpr FpExt operator-(FpExt rhs) const {
FpExt result = *this;
result -= rhs;
return result;
}

__device__ constexpr Fp4 operator-() const { return Fp4() - *this; }
__device__ constexpr FpExt operator-() const { return FpExt() - *this; }

// Implement the simple multiplication case by the subfield Fp
// Fp * Fp4 is done as a free function due to C++'s operator overloading rules.
__device__ constexpr Fp4 operator*=(Fp rhs) {
// Fp * FpExt is done as a free function due to C++'s operator overloading rules.
__device__ constexpr FpExt operator*=(Fp rhs) {
for (uint32_t i = 0; i < 4; i++) {
elems[i] *= rhs;
}
return *this;
}

__device__ constexpr Fp4 operator*(Fp rhs) const {
Fp4 result = *this;
__device__ constexpr FpExt operator*(Fp rhs) const {
FpExt result = *this;
result *= rhs;
return result;
}
Expand All @@ -109,24 +109,24 @@ struct Fp4 {
// representations, and then reduce module x^4 - B, which means powers >= 4 get shifted back 4 and
// multiplied by -beta. We could write this as a double loops with some if's and hope it gets
// unrolled properly, but it'a small enough to just hand write.
__device__ constexpr Fp4 operator*(Fp4 rhs) const {
__device__ constexpr FpExt operator*(FpExt rhs) const {
// Rename the element arrays to something small for readability
#define a elems
#define b rhs.elems
return Fp4(a[0] * b[0] + NBETA * (a[1] * b[3] + a[2] * b[2] + a[3] * b[1]),
a[0] * b[1] + a[1] * b[0] + NBETA * (a[2] * b[3] + a[3] * b[2]),
a[0] * b[2] + a[1] * b[1] + a[2] * b[0] + NBETA * (a[3] * b[3]),
a[0] * b[3] + a[1] * b[2] + a[2] * b[1] + a[3] * b[0]);
return FpExt(a[0] * b[0] + NBETA * (a[1] * b[3] + a[2] * b[2] + a[3] * b[1]),
a[0] * b[1] + a[1] * b[0] + NBETA * (a[2] * b[3] + a[3] * b[2]),
a[0] * b[2] + a[1] * b[1] + a[2] * b[0] + NBETA * (a[3] * b[3]),
a[0] * b[3] + a[1] * b[2] + a[2] * b[1] + a[3] * b[0]);
#undef a
#undef b
}
__device__ constexpr Fp4 operator*=(Fp4 rhs) {
__device__ constexpr FpExt operator*=(FpExt rhs) {
*this = *this * rhs;
return *this;
}

// Equality
__device__ constexpr bool operator==(Fp4 rhs) const {
__device__ constexpr bool operator==(FpExt rhs) const {
for (uint32_t i = 0; i < 4; i++) {
if (elems[i] != rhs.elems[i]) {
return false;
Expand All @@ -135,23 +135,19 @@ struct Fp4 {
return true;
}

__device__ constexpr bool operator!=(Fp4 rhs) const {
return !(*this == rhs);
}
__device__ constexpr bool operator!=(FpExt rhs) const { return !(*this == rhs); }

__device__ constexpr Fp constPart() const {
return elems[0];
}
__device__ constexpr Fp constPart() const { return elems[0]; }
};

/// Overload for case where LHS is Fp (RHS case is handled as a method)
__device__ constexpr inline Fp4 operator*(Fp a, Fp4 b) {
__device__ constexpr inline FpExt operator*(Fp a, FpExt b) {
return b * a;
}

/// Raise an Fp4 to a power
__device__ constexpr inline Fp4 pow(Fp4 x, size_t n) {
Fp4 tot(1);
/// Raise an FpExt to a power
__device__ constexpr inline FpExt pow(FpExt x, size_t n) {
FpExt tot(1);
while (n != 0) {
if (n % 2 == 1) {
tot *= x;
Expand All @@ -162,11 +158,11 @@ __device__ constexpr inline Fp4 pow(Fp4 x, size_t n) {
return tot;
}

/// Compute the multiplicative inverse of an Fp4.
__device__ constexpr inline Fp4 inv(Fp4 in) {
/// Compute the multiplicative inverse of an FpExt.
__device__ constexpr inline FpExt inv(FpExt in) {
#define a in.elems
// Compute the multiplicative inverse by basicly looking at Fp4 as a composite field and using the
// same basic methods used to invert complex numbers. We imagine that initially we have a
// Compute the multiplicative inverse by basicly looking at FpExt as a composite field and using
// the same basic methods used to invert complex numbers. We imagine that initially we have a
// numerator of 1, and an denominator of a. i.e out = 1 / a; We set a' to be a with the first and
// third components negated. We then multiply the numerator and the denominator by a', producing
// out = a' / (a * a'). By construction (a * a') has 0's in it's first and third elements. We
Expand All @@ -179,15 +175,15 @@ __device__ constexpr inline Fp4 inv(Fp4 in) {
// But we can now invert C direcly, and multiply by a'*b', out = a'*b'*inv(c)
Fp ic = inv(c);
// Note: if c == 0 (really should only happen if in == 0), our 'safe' version of inverse results
// in ic == 0, and thus out = 0, so we have the same 'safe' behavior for Fp4. Oh, and since we
// in ic == 0, and thus out = 0, so we have the same 'safe' behavior for FpExt. Oh, and since we
// want to multiply everything by ic, it's slightly faster to premultiply the two parts of b by ic
// (2 multiplies instead of 4)
b0 *= ic;
b2 *= ic;
return Fp4(a[0] * b0 + BETA * a[2] * b2,
-a[1] * b0 + NBETA * a[3] * b2,
-a[0] * b2 + a[2] * b0,
a[1] * b2 - a[3] * b0);
return FpExt(a[0] * b0 + BETA * a[2] * b2,
-a[1] * b0 + NBETA * a[3] * b2,
-a[0] * b2 + a[2] * b0,
a[1] * b2 - a[3] * b0);
#undef a
}

Expand Down
Loading

0 comments on commit 852f37e

Please sign in to comment.