Skip to content

Commit

Permalink
Merge pull request #207 from zkFold/turtlepu-symbolic-ord
Browse files Browse the repository at this point in the history
Migrated Ord to Symbolic API
  • Loading branch information
echatav authored Aug 5, 2024
2 parents 8ee94a7 + 63f2fdc commit 353ab13
Show file tree
Hide file tree
Showing 10 changed files with 75 additions and 85 deletions.
2 changes: 1 addition & 1 deletion examples/Examples/Eq.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ import Prelude hiding (Bool, Eq (.

import ZkFold.Base.Algebra.Basic.Field (Zp)
import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 (BLS12_381_Scalar)
import ZkFold.Symbolic.Class (Symbolic)
import ZkFold.Symbolic.Compiler
import ZkFold.Symbolic.Data.Bool (Bool (..))
import ZkFold.Symbolic.Data.Eq (Eq (..))
import ZkFold.Symbolic.Class (Symbolic)
import ZkFold.Symbolic.Data.FieldElement (FieldElement)

-- | (==) operation
Expand Down
7 changes: 4 additions & 3 deletions examples/Examples/LEQ.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@ import Prelude hiding (Bool, Eq (.

import ZkFold.Base.Algebra.Basic.Field (Zp)
import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 (BLS12_381_Scalar)
import ZkFold.Symbolic.Class (Symbolic)
import ZkFold.Symbolic.Compiler
import ZkFold.Symbolic.Data.Bool (Bool (..))
import ZkFold.Symbolic.Data.Bool (Bool)
import ZkFold.Symbolic.Data.FieldElement (FieldElement)
import ZkFold.Symbolic.Data.Ord (Ord (..))
import ZkFold.Symbolic.Data.Ord ((<=))

-- | (<=) operation
leq :: Ord (Bool c) (FieldElement c) => FieldElement c -> FieldElement c -> Bool c
leq :: Symbolic c => FieldElement c -> FieldElement c -> Bool c
leq x y = x <= y

exampleLEQ :: IO ()
Expand Down
6 changes: 1 addition & 5 deletions src/ZkFold/Base/Protocol/ARK/Plonk.hs
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,7 @@ instance forall n l c1 c2 t plonk f g1.
, KnownNat l
, KnownNat (PlonkPermutationSize n)
, KnownNat (PlonkPolyExtendedLength n)
, Eq (ScalarField c1)
, Scale (ScalarField c1) (ScalarField c1)
, BinaryExpansion (ScalarField c1)
, Bits (ScalarField c1) ~ [ScalarField c1]
, FiniteField (ScalarField c1)
, Arithmetic f
, AdditiveGroup (BaseField c1)
, Pairing c1 c2
, ToTranscript t (ScalarField c1)
Expand Down
6 changes: 3 additions & 3 deletions src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Combinators.hs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ embedVar x = newAssigned $ const (fromConstant x)
embedAll :: forall a n . (Arithmetic a, KnownNat n) => a -> ArithmeticCircuit a (Vector n)
embedAll x = circuitF $ Vector <$> replicateM (fromIntegral $ value @n) (newAssigned $ const (fromConstant x))

expansion :: MonadBlueprint i a m => Natural -> i -> m [i]
expansion :: MonadCircuit i a m => Natural -> i -> m [i]
-- ^ @expansion n k@ computes a binary expansion of @k@ if it fits in @n@ bits.
expansion n k = do
bits <- bitsOf n k
Expand All @@ -88,7 +88,7 @@ splitExpansion n1 n2 k = do
constraint (\x -> x k - x l - scale (2 ^ n1 :: Natural) (x h))
return (l, h)

bitsOf :: MonadBlueprint i a m => Natural -> i -> m [i]
bitsOf :: MonadCircuit i a m => Natural -> i -> m [i]
-- ^ @bitsOf n k@ creates @n@ bits and sets their witnesses equal to @n@ smaller
-- bits of @k@.
bitsOf n k = for [0 .. n -! 1] $ \j ->
Expand All @@ -97,7 +97,7 @@ bitsOf n k = for [0 .. n -! 1] $ \j ->
repr :: forall b . (BinaryExpansion b, Bits b ~ [b], Finite b) => b -> [b]
repr = padBits (numberOfBits @b) . binaryExpansion

horner :: MonadBlueprint i a m => [i] -> m i
horner :: MonadCircuit i a m => [i] -> m i
-- ^ @horner [b0,...,bn]@ computes the sum @b0 + 2 b1 + ... + 2^n bn@ using
-- Horner's scheme.
horner xs = case reverse xs of
Expand Down
2 changes: 1 addition & 1 deletion src/ZkFold/Symbolic/Data/Eq/Structural.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ module ZkFold.Symbolic.Data.Eq.Structural where

import Prelude (type (~))

import ZkFold.Symbolic.Class
import ZkFold.Symbolic.Data.Bool
import ZkFold.Symbolic.Data.Class
import ZkFold.Symbolic.Class
import ZkFold.Symbolic.Data.Eq

newtype Structural a = Structural a
Expand Down
120 changes: 55 additions & 65 deletions src/ZkFold/Symbolic/Data/Ord.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,29 @@
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}

module ZkFold.Symbolic.Data.Ord (Ord (..), Lexicographical (..), blueprintGE, circuitGE, circuitGT, getBitsBE) where

import Control.Monad (foldM)
import qualified Data.Bool as Haskell
import Data.Foldable (Foldable)
import Data.Function ((.))
import qualified Data.Zip as Z
import GHC.Generics (Par1 (..))
import Prelude (type (~), ($))
import qualified Prelude as Haskell
module ZkFold.Symbolic.Data.Ord (Ord (..), Lexicographical (..), blueprintGE, bitwiseGE, bitwiseGT, getBitsBE) where

import Control.Monad (foldM)
import qualified Data.Bool as Haskell
import Data.Foldable (Foldable, toList)
import Data.Function ((.))
import Data.Functor ((<$>))
import qualified Data.Zip as Z
import GHC.Generics (Par1 (..))
import Prelude (type (~), ($))
import qualified Prelude as Haskell

import ZkFold.Base.Algebra.Basic.Class
import ZkFold.Base.Data.HFunctor (hmap)
import qualified ZkFold.Base.Data.Vector as V
import ZkFold.Symbolic.Compiler
import ZkFold.Symbolic.Compiler.ArithmeticCircuit.MonadBlueprint (MonadBlueprint (..), circuit)
import ZkFold.Symbolic.Data.Bool (Bool (..), BoolType (..))
import ZkFold.Base.Data.HFunctor (hmap)
import qualified ZkFold.Base.Data.Vector as V
import ZkFold.Base.Data.Vector (unsafeToVector)
import ZkFold.Symbolic.Class (Symbolic (BaseField, symbolicF), symbolic2F)
import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Combinators (expansion)
import ZkFold.Symbolic.Data.Bool (Bool (..))
import ZkFold.Symbolic.Data.Class
import ZkFold.Symbolic.Data.Conditional (Conditional (..))
import ZkFold.Symbolic.Data.FieldElement (FieldElement (..))
import ZkFold.Symbolic.Interpreter (Interpreter (..))
import ZkFold.Symbolic.MonadCircuit (Arithmetic, newAssigned)
import ZkFold.Symbolic.Data.Conditional (Conditional (..))
import ZkFold.Symbolic.Data.FieldElement (FieldElement (..))
import ZkFold.Symbolic.MonadCircuit (MonadCircuit, newAssigned)

-- TODO (Issue #23): add `compare`
class Ord b a where
Expand Down Expand Up @@ -55,79 +56,68 @@ instance Haskell.Ord a => Ord Haskell.Bool a where

min = Haskell.min

toValue :: Interpreter a Par1 -> a
toValue (Interpreter (Par1 v)) = v

fromValue :: a -> Interpreter a Par1
fromValue = Interpreter Haskell.. Par1

instance (Arithmetic a, Haskell.Ord a) => Ord (Bool (Interpreter a)) (Interpreter a Par1) where
(toValue -> x) <= (toValue -> y) = Haskell.bool false true (x Haskell.<= y)
(toValue -> x) < (toValue -> y) = Haskell.bool false true (x Haskell.< y)
(toValue -> x) >= (toValue -> y) = Haskell.bool false true (x Haskell.>= y)
(toValue -> x) > (toValue -> y) = Haskell.bool false true (x Haskell.> y)
(toValue -> x) `max` (toValue -> y) = fromValue $ Haskell.max x y
(toValue -> x) `min` (toValue -> y) = fromValue $ Haskell.min x y

newtype Lexicographical a = Lexicographical a
-- ^ A newtype wrapper for easy definition of Ord instances
-- (though not necessarily a most effective one)

deriving newtype instance SymbolicData c x => SymbolicData c (Lexicographical x)

deriving via (Lexicographical (ArithmeticCircuit a Par1))
instance Arithmetic a => Ord (Bool (ArithmeticCircuit a)) (ArithmeticCircuit a Par1)
deriving newtype instance SymbolicData c a => SymbolicData c (Lexicographical a)

deriving newtype instance (Arithmetic a, Haskell.Ord a) => Ord (Bool (Interpreter a)) (FieldElement (Interpreter a))
deriving newtype instance Arithmetic a => Ord (Bool (ArithmeticCircuit a)) (FieldElement (ArithmeticCircuit a))
deriving via (Lexicographical (FieldElement c))
instance Symbolic c => Ord (Bool c) (FieldElement c)

-- | Every @SymbolicData@ type can be compared lexicographically.
instance
( Arithmetic a
, SymbolicData (ArithmeticCircuit a) x
, Support (ArithmeticCircuit a) x ~ ()
, TypeSize (ArithmeticCircuit a) x ~ 1
) => Ord (Bool (ArithmeticCircuit a)) (Lexicographical x) where
( Symbolic c
, SymbolicData c x
, Support c x ~ ()
, TypeSize c x ~ 1
) => Ord (Bool c) (Lexicographical x) where

x <= y = y >= x

x < y = y > x

x >= y = circuitGE (getBitsBE x) (getBitsBE y)
x >= y = bitwiseGE (getBitsBE x) (getBitsBE y)

x > y = circuitGT (getBitsBE x) (getBitsBE y)
x > y = bitwiseGT (getBitsBE x) (getBitsBE y)

max x y = bool @(Bool (ArithmeticCircuit a)) x y $ x < y
max x y = bool @(Bool c) x y $ x < y

min x y = bool @(Bool (ArithmeticCircuit a)) x y $ x > y
min x y = bool @(Bool c) x y $ x > y

getBitsBE :: forall c a x . (Arithmetic a, c ~ ArithmeticCircuit a, SymbolicData c x, Support c x ~ (), TypeSize c x ~ 1) => x -> c (V.Vector (NumberOfBits a))
getBitsBE ::
forall c x .
(Symbolic c, SymbolicData c x, Support c x ~ (), TypeSize c x ~ 1) =>
x -> c (V.Vector (NumberOfBits (BaseField c)))
-- ^ @getBitsBE x@ returns a list of circuits computing bits of @x@, eldest to
-- youngest.
getBitsBE x = let expansion = binaryExpansion $ hmap (Par1 . V.item) (pieces @c @x x ())
in expansion { acOutput = V.reverse $ acOutput expansion }
getBitsBE x =
hmap unsafeToVector
$ symbolicF (pieces x ()) (binaryExpansion . V.item)
$ expansion (numberOfBits @(BaseField c)) . V.item

circuitGE :: forall a f . (Arithmetic a, Z.Zip f, Foldable f) => ArithmeticCircuit a f -> ArithmeticCircuit a f -> Bool (ArithmeticCircuit a)
bitwiseGE :: forall c f . (Symbolic c, Z.Zip f, Foldable f) => c f -> c f -> Bool c
-- ^ Given two lists of bits of equal length, compares them lexicographically.
circuitGE xs ys = Bool $ circuit $ do
is <- runCircuit xs
js <- runCircuit ys
blueprintGE is js
bitwiseGE xs ys = Bool $
symbolic2F xs ys
(\us vs -> Par1 $ Haskell.bool zero one (toList us Haskell.>= toList vs))
$ \is js -> Par1 <$> blueprintGE is js

blueprintGE :: (MonadBlueprint i a m, Z.Zip f, Foldable f) => f i -> f i -> m i
blueprintGE :: (MonadCircuit i a m, Z.Zip f, Foldable f) => f i -> f i -> m i
blueprintGE xs ys = do
(_, hasNegOne) <- circuitDelta xs ys
newAssigned $ \p -> one - p hasNegOne

circuitGT :: forall a f . (Arithmetic a, Z.Zip f, Foldable f) => ArithmeticCircuit a f -> ArithmeticCircuit a f -> Bool (ArithmeticCircuit a)
bitwiseGT :: forall c f . (Symbolic c, Z.Zip f, Foldable f) => c f -> c f -> Bool c
-- ^ Given two lists of bits of equal length, compares them lexicographically.
circuitGT xs ys = Bool $ circuit $ do
is <- runCircuit xs
js <- runCircuit ys
(hasOne, hasNegOne) <- circuitDelta is js
newAssigned $ \p -> p hasOne * (one - p hasNegOne)

circuitDelta :: forall i a m f . (MonadBlueprint i a m, Z.Zip f, Foldable f) => f i -> f i -> m (i, i)
bitwiseGT xs ys = Bool $
symbolic2F xs ys
(\us vs -> Par1 $ Haskell.bool zero one (toList us Haskell.> toList vs))
$ \is js -> do
(hasOne, hasNegOne) <- circuitDelta is js
Par1 <$> newAssigned (\p -> p hasOne * (one - p hasNegOne))

circuitDelta :: forall i a m f . (MonadCircuit i a m, Z.Zip f, Foldable f) => f i -> f i -> m (i, i)
circuitDelta l r = do
z1 <- newAssigned (Haskell.const zero)
z2 <- newAssigned (Haskell.const zero)
Expand Down
4 changes: 2 additions & 2 deletions src/ZkFold/Symbolic/Data/UInt.hs
Original file line number Diff line number Diff line change
Expand Up @@ -293,12 +293,12 @@ instance (Arithmetic a, KnownNat n, KnownRegisterSize r, KnownNat (NumberOfRegis
u1 >= u2 =
let ByteString rs1 = from u1 :: ByteString n (ArithmeticCircuit a)
ByteString rs2 = from u2 :: ByteString n (ArithmeticCircuit a)
in circuitGE rs1 rs2
in bitwiseGE rs1 rs2

u1 > u2 =
let ByteString rs1 = from u1 :: ByteString n (ArithmeticCircuit a)
ByteString rs2 = from u2 :: ByteString n (ArithmeticCircuit a)
in circuitGT rs1 rs2
in bitwiseGT rs1 rs2

max x y = bool @(Bool (ArithmeticCircuit a)) x y $ x < y

Expand Down
6 changes: 4 additions & 2 deletions src/ZkFold/Symbolic/MonadCircuit.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import Data.Eq (Eq)
import Data.Function (id)
import Data.Functor (Functor)
import Data.Functor.Identity (Identity (..))
import Data.Ord (Ord)
import Data.Type.Equality (type (~))

import ZkFold.Base.Algebra.Basic.Class
Expand Down Expand Up @@ -105,8 +106,9 @@ class Monad m => MonadCircuit i a m | m -> i, m -> a where
newAssigned :: ClosedPoly i a -> m i
newAssigned p = newConstrained (\x i -> p x - x i) p

-- | Field of witnesses with decidable equality is called an ``arithmetic'' field.
type Arithmetic a = (WitnessField a, Eq a)
-- | Field of witnesses with decidable equality and ordering
-- is called an ``arithmetic'' field.
type Arithmetic a = (WitnessField a, Eq a, Ord a)

-- | An example implementation of a @'MonadCircuit'@ which computes witnesses
-- immediately and drops the constraints.
Expand Down
5 changes: 3 additions & 2 deletions tests/Tests/Arithmetization/Test3.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,17 @@ import Test.Hspec

import ZkFold.Base.Algebra.Basic.Class (fromConstant)
import ZkFold.Base.Algebra.Basic.Field (Zp)
import ZkFold.Symbolic.Class (Symbolic)
import ZkFold.Symbolic.Compiler
import ZkFold.Symbolic.Data.Bool (Bool (..))
import ZkFold.Symbolic.Data.FieldElement (FieldElement)
import ZkFold.Symbolic.Data.Ord (Ord (..))
import ZkFold.Symbolic.Data.Ord ((<=))
import ZkFold.Symbolic.Interpreter (Interpreter (Interpreter))

type R = ArithmeticCircuit (Zp 97)

-- A comparison test
testFunc :: Ord (Bool c) (FieldElement c) => FieldElement c -> FieldElement c -> Bool c
testFunc :: Symbolic c => FieldElement c -> FieldElement c -> Bool c
testFunc x y = x <= y

specArithmetization3 :: Spec
Expand Down
2 changes: 1 addition & 1 deletion tests/Tests/Arithmetization/Test4.hs
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ import ZkFold.Base.Protocol.ARK.Plonk (Plonk (..), PlonkP
plonkVerifierInput)
import ZkFold.Base.Protocol.ARK.Plonk.Internal (getParams)
import ZkFold.Base.Protocol.NonInteractiveProof (NonInteractiveProof (..))
import ZkFold.Symbolic.Class
import ZkFold.Symbolic.Compiler (ArithmeticCircuit (..), acValue, applyArgs, compile)
import ZkFold.Symbolic.Data.Bool (Bool (..))
import ZkFold.Symbolic.Data.Eq (Eq (..))
import ZkFold.Symbolic.Data.FieldElement (FieldElement)
import ZkFold.Symbolic.Class

type N = 1

Expand Down

0 comments on commit 353ab13

Please sign in to comment.