Skip to content

Commit

Permalink
Handle formula assumptions.
Browse files Browse the repository at this point in the history
Fixes J08nY#4.
  • Loading branch information
J08nY committed Dec 17, 2020
1 parent ebf81fd commit e74b0a6
Show file tree
Hide file tree
Showing 8 changed files with 138 additions and 41 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ codestyle:
flake8 --ignore=E501,F405,F403,F401,E126 pyecsca

doc-coverage:
interrogate -vv -nmps pyecsca
interrogate -vv -nmps pyecsca

docs:
$(MAKE) -C docs apidoc
Expand Down
31 changes: 31 additions & 0 deletions pyecsca/ec/error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from public import public


@public
class NonInvertibleError(ArithmeticError):
pass


@public
class NonInvertibleWarning(UserWarning):
pass


@public
class NonResidueError(ArithmeticError):
pass


@public
class NonResidueWarning(UserWarning):
pass


@public
class UnsatisfiedAssumptionError(ValueError):
pass


@public
class UnsatisfiedAssumptionWarning(UserWarning):
pass
73 changes: 61 additions & 12 deletions pyecsca/ec/formula.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from abc import ABC, abstractmethod
from ast import parse, Expression
from astunparse import unparse
from itertools import product
from typing import List, Set, Any, ClassVar, MutableMapping, Tuple, Union
from typing import List, Set, Any, ClassVar, MutableMapping, Tuple, Union, Dict

from pkg_resources import resource_stream
from public import public
from sympy import sympify, FF, symbols, Poly

from .context import ResultAction, getcontext, NullContext
from .error import UnsatisfiedAssumptionError
from .mod import Mod
from .op import CodeOp, OpType

Expand Down Expand Up @@ -38,11 +41,17 @@ def __repr__(self):
class FormulaAction(ResultAction):
"""An execution of a formula, on some input points and parameters, with some outputs."""
formula: "Formula"
"""The formula that was executed."""
inputs: MutableMapping[str, Mod]
"""The input variables (point coordinates and parameters)."""
input_points: List[Any]
"""The input points."""
intermediates: MutableMapping[str, List[OpResult]]
"""Intermediates computed during execution."""
outputs: MutableMapping[str, OpResult]
"""The output variables."""
output_points: List[Any]
"""The output points."""

def __init__(self, formula: "Formula", *points: Any, **inputs: Mod):
super().__init__()
Expand All @@ -62,8 +71,8 @@ def add_operation(self, op: CodeOp, value: Mod):
parents.append(self.intermediates[parent][-1])
elif parent in self.inputs:
parents.append(self.inputs[parent])
l = self.intermediates.setdefault(op.result, list())
l.append(OpResult(op.result, value, op.operator, *parents))
li = self.intermediates.setdefault(op.result, list())
li.append(OpResult(op.result, value, op.operator, *parents))

def add_result(self, point: Any, **outputs: Mod):
if isinstance(getcontext(), NullContext):
Expand All @@ -79,18 +88,29 @@ def __repr__(self):
return f"{self.__class__.__name__}({self.formula}, {self.input_points}) = {self.output_points}"


@public
class Formula(ABC):
"""A formula operating on points."""
name: str
"""Name of the formula."""
shortname: ClassVar[str]
"""A shortname for the type of the formula."""
coordinate_model: Any
"""Coordinate model of the formula."""
meta: MutableMapping[str, Any]
"""Meta information about the formula, such as its source."""
parameters: List[str]
"""Formula parameters (i.e. new parameters introduced by the formula, like `half = 1/2`)."""
assumptions: List[Expression]
"""Assumptions of the formula (e.g. `Z1 == 1` or `2*half == 1`)."""
code: List[CodeOp]
shortname: ClassVar[str]
"""The collection of ops that constitute the code of the formula."""
num_inputs: ClassVar[int]
"""Number of inputs (points) of the formula."""
num_outputs: ClassVar[int]
"""Number of outputs (points) of the formula."""
unified: bool
"""Whether the formula is specifies that it is unified."""

def __call__(self, *points: Any, **params: Mod) -> Tuple[Any, ...]:
"""
Expand All @@ -101,13 +121,42 @@ def __call__(self, *points: Any, **params: Mod) -> Tuple[Any, ...]:
:return: The resulting point(s).
"""
from .point import Point
# Validate number of inputs.
if len(points) != self.num_inputs:
raise ValueError(f"Wrong number of inputs for {self}.")
# Validate input points and unroll them into input params.
for i, point in enumerate(points):
if point.coordinate_model != self.coordinate_model:
raise ValueError(f"Wrong coordinate model of point {point}.")
for coord, value in point.coords.items():
params[coord + str(i + 1)] = value
# Validate assumptions and compute formula parameters.
for assumption in self.assumptions:
assumption_string = unparse(assumption)[1:-2]
lhs, rhs = assumption_string.split(" == ")
if lhs in params:
# Handle an assumption check on value of input points.
alocals: Dict[str, Union[Mod, int]] = {**params}
compiled = compile(assumption, "", mode="eval")
holds = eval(compiled, None, alocals)
if not holds:
raise UnsatisfiedAssumptionError(f"Unsatisfied assumption in the formula ({assumption_string}).")
else:
field = int(params[next(iter(params.keys()))].n) # This is nasty...
k = FF(field)
expr = sympify(f"{rhs} - {lhs}")
for curve_param, value in params.items():
expr = expr.subs(curve_param, k(value))
if len(expr.free_symbols) > 1 or (param := str(expr.free_symbols.pop())) not in self.parameters:
raise ValueError(
f"This formula couldn't be executed due to an unsupported asusmption ({assumption_string}).")
poly = Poly(expr, symbols(param), domain=k)
roots = poly.ground_roots()
for root in roots.keys():
params[param] = Mod(int(root), field)
break
else:
raise UnsatisfiedAssumptionError(f"Unsatisfied assumption in the formula ({assumption_string}).")
with FormulaAction(self, *points, **params) as action:
for op in self.code:
op_result = op(**params)
Expand Down Expand Up @@ -219,7 +268,7 @@ def __read_meta_file(self, path):
self.parameters.append(line[10:])
elif line.startswith("assume"):
self.assumptions.append(
parse(line[7:].replace("=", "==").replace("^", "**"), mode="eval"))
parse(line[7:].replace("=", "==").replace("^", "**"), mode="eval"))
elif line.startswith("unified"):
self.unified = True
line = f.readline().decode("ascii")
Expand Down Expand Up @@ -259,7 +308,7 @@ def __hash__(self):


@public
class AdditionFormula(Formula):
class AdditionFormula(Formula, ABC):
shortname = "add"
num_inputs = 2
num_outputs = 1
Expand All @@ -271,7 +320,7 @@ class AdditionEFDFormula(AdditionFormula, EFDFormula):


@public
class DoublingFormula(Formula):
class DoublingFormula(Formula, ABC):
shortname = "dbl"
num_inputs = 1
num_outputs = 1
Expand All @@ -283,7 +332,7 @@ class DoublingEFDFormula(DoublingFormula, EFDFormula):


@public
class TriplingFormula(Formula):
class TriplingFormula(Formula, ABC):
shortname = "tpl"
num_inputs = 1
num_outputs = 1
Expand All @@ -295,7 +344,7 @@ class TriplingEFDFormula(TriplingFormula, EFDFormula):


@public
class NegationFormula(Formula):
class NegationFormula(Formula, ABC):
shortname = "neg"
num_inputs = 1
num_outputs = 1
Expand All @@ -307,7 +356,7 @@ class NegationEFDFormula(NegationFormula, EFDFormula):


@public
class ScalingFormula(Formula):
class ScalingFormula(Formula, ABC):
shortname = "scl"
num_inputs = 1
num_outputs = 1
Expand All @@ -319,7 +368,7 @@ class ScalingEFDFormula(ScalingFormula, EFDFormula):


@public
class DifferentialAdditionFormula(Formula):
class DifferentialAdditionFormula(Formula, ABC):
shortname = "dadd"
num_inputs = 3
num_outputs = 1
Expand All @@ -331,7 +380,7 @@ class DifferentialAdditionEFDFormula(DifferentialAdditionFormula, EFDFormula):


@public
class LadderFormula(Formula):
class LadderFormula(Formula, ABC):
shortname = "ladd"
num_inputs = 3
num_outputs = 2
Expand Down
21 changes: 6 additions & 15 deletions pyecsca/ec/mod.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
import secrets
from functools import wraps, lru_cache
from abc import ABC, abstractmethod
from typing import Type
from public import public

from .error import NonInvertibleError, NonResidueError
from .context import ResultAction


has_gmp = False
try:
Expand All @@ -12,10 +16,6 @@
except ImportError:
pass

from public import public

from .context import ResultAction


@public
def gcd(a, b):
Expand Down Expand Up @@ -91,16 +91,6 @@ def method(self, other):
return method


@public
class NonInvertibleError(ArithmeticError):
pass


@public
class NonResidueError(ArithmeticError):
pass


@public
class RandomModAction(ResultAction):
"""A random sampling from Z_n."""
Expand Down Expand Up @@ -471,6 +461,7 @@ def __pow__(self, n):
return GMPMod(self.x, self.n)
return GMPMod(gmpy2.powmod(self.x, gmpy2.mpz(n), self.n), self.n)


Mod = GMPMod
else:
Mod = RawMod
Expand Down
1 change: 1 addition & 0 deletions pyecsca/ec/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .coordinates import EFDCoordinateModel, CoordinateModel


@public
class CurveModel(object):
"""A model(form) of an elliptic curve."""
name: str
Expand Down
30 changes: 18 additions & 12 deletions pyecsca/ec/params.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from sympy import Poly, PythonFiniteField, symbols, sympify
from sympy import Poly, FF, symbols, sympify
from astunparse import unparse
from io import RawIOBase, BufferedIOBase
from os.path import join
Expand All @@ -11,6 +11,7 @@

from .coordinates import AffineCoordinateModel, CoordinateModel
from .curve import EllipticCurve
from .error import UnsatisfiedAssumptionError
from .mod import Mod
from .model import (CurveModel, ShortWeierstrassModel, MontgomeryModel, EdwardsModel,
TwistedEdwardsModel)
Expand Down Expand Up @@ -133,25 +134,25 @@ def _create_params(curve, coords, infty):
exec(compiled, None, alocals)
for param, value in alocals.items():
if params[param] != value:
raise ValueError(
raise UnsatisfiedAssumptionError(
f"Coordinate model {coord_model} has an unsatisifed assumption on the {param} parameter (= {value}).")
except NameError:
k = PythonFiniteField(field)
k = FF(field)
assumption_string = unparse(assumption)
lhs, rhs = assumption_string.split(" = ")
expr = sympify(f"{rhs} - {lhs}")
for curve_param, value in params.items():
expr = expr.subs(curve_param, k(value))
if len(expr.free_symbols) > 1 or (param := str(expr.free_symbols.pop())) not in coord_model.parameters:
raise ValueError(f"This coordinate model couldn't be loaded due to unsupported asusmption ({assumption_string}).")
raise ValueError(f"This coordinate model couldn't be loaded due to an unsupported assumption ({assumption_string}).")
poly = Poly(expr, symbols(param), domain=k)
roots = poly.ground_roots()
for root in roots.keys():
if root >= 0:
params[param] = Mod(int(root), field)
break
else:
raise ValueError(f"Coordinate model {coord_model} has an unsatisifed assumption on the {param} parameter (0 = {expr}).")
raise UnsatisfiedAssumptionError(f"Coordinate model {coord_model} has an unsatisifed assumption on the {param} parameter (0 = {expr}).")

# Construct the point at infinity
infinity: Point
Expand Down Expand Up @@ -238,20 +239,25 @@ def load_params(file: Union[str, Path, BinaryIO], coords: str, infty: bool = Tru

return _create_params(curve, coords, infty)


@public
def get_category(category: str, coords: Union[str, Callable[[str], str]],
infty: Union[bool, Callable[[str], bool]] = True) -> DomainParameterCategory:
"""
Retrieve a category from the std-curves database at https://github.com/J08nY/std-curves.
:param category:
:param coords:
:param infty:
:return:
:param category: The category to retrieve.
:param coords: The name of the coordinate system to use. Can be a callable that takes
as argument the name of the curve and produces the coordinate system to use for that curve.
:param infty: Whether to use the special :py:class:InfinityPoint (`True`) or try to use the
point at infinity of the coordinate system. Can be a callable that takes
as argument the name of the curve and returns the infinity option to use for that curve.
:return: The category.
"""
listing = resource_listdir(__name__, "std")
categories = list(entry for entry in listing if resource_isdir(__name__, join("std", entry)))
if category not in categories:
raise ValueError("Category {} not found.".format(category))
raise ValueError(f"Category {category} not found.")
json_path = join("std", category, "curves.json")
with resource_stream(__name__, json_path) as f:
return load_category(f, coords, infty)
Expand All @@ -273,14 +279,14 @@ def get_params(category: str, name: str, coords: str, infty: bool = True) -> Dom
listing = resource_listdir(__name__, "std")
categories = list(entry for entry in listing if resource_isdir(__name__, join("std", entry)))
if category not in categories:
raise ValueError("Category {} not found.".format(category))
raise ValueError(f"Category {category} not found.")
json_path = join("std", category, "curves.json")
with resource_stream(__name__, json_path) as f:
category_json = json.load(f)
for curve in category_json["curves"]:
if curve["name"] == name:
break
else:
raise ValueError("Curve {} not found in category {}.".format(name, category))
raise ValueError(f"Curve {name} not found in category {category}.")

return _create_params(curve, coords, infty)
Loading

0 comments on commit e74b0a6

Please sign in to comment.