Skip to content

Commit

Permalink
Merge pull request #384 from zkFold/eitan-generic-eq
Browse files Browse the repository at this point in the history
Generic Symbolic Eq
  • Loading branch information
TurtlePU authored Dec 1, 2024
2 parents 889c50a + e7224ec commit 39ed260
Show file tree
Hide file tree
Showing 8 changed files with 104 additions and 105 deletions.
67 changes: 33 additions & 34 deletions symbolic-base/src/ZkFold/Symbolic/Data/ByteString.hs
Original file line number Diff line number Diff line change
Expand Up @@ -23,44 +23,43 @@ module ZkFold.Symbolic.Data.ByteString
, toBsBits
) where

import Control.DeepSeq (NFData)
import Control.Monad (replicateM)
import Data.Aeson (FromJSON (..), ToJSON (..))
import qualified Data.Bits as B
import qualified Data.ByteString as Bytes
import Data.Foldable (foldlM)
import Data.Kind (Type)
import Data.List (reverse, unfoldr)
import Data.Maybe (Maybe (..))
import Data.String (IsString (..))
import Data.Traversable (for, mapM)
import GHC.Generics (Generic, Par1 (..))
import GHC.Natural (naturalFromInteger)
import Numeric (readHex, showHex)
import Prelude (Integer, const, drop, fmap, otherwise, pure, return, take,
type (~), ($), (.), (<$>), (<), (<>), (==), (>=))
import qualified Prelude as Haskell
import Test.QuickCheck (Arbitrary (..), chooseInteger)
import Control.DeepSeq (NFData)
import Control.Monad (replicateM)
import Data.Aeson (FromJSON (..), ToJSON (..))
import qualified Data.Bits as B
import qualified Data.ByteString as Bytes
import Data.Foldable (foldlM)
import Data.Kind (Type)
import Data.List (reverse, unfoldr)
import Data.Maybe (Maybe (..))
import Data.String (IsString (..))
import Data.Traversable (for, mapM)
import GHC.Generics (Generic, Par1 (..))
import GHC.Natural (naturalFromInteger)
import Numeric (readHex, showHex)
import Prelude (Integer, const, drop, fmap, otherwise, pure, return, take, type (~),
($), (.), (<$>), (<), (<>), (==), (>=))
import qualified Prelude as Haskell
import Test.QuickCheck (Arbitrary (..), chooseInteger)

import ZkFold.Base.Algebra.Basic.Class
import ZkFold.Base.Algebra.Basic.Field (Zp)
import ZkFold.Base.Algebra.Basic.Field (Zp)
import ZkFold.Base.Algebra.Basic.Number
import ZkFold.Base.Data.HFunctor (HFunctor (..))
import ZkFold.Base.Data.Package (packWith, unpackWith)
import ZkFold.Base.Data.Utils (zipWithM)
import qualified ZkFold.Base.Data.Vector as V
import ZkFold.Base.Data.Vector (Vector (..))
import ZkFold.Prelude (replicateA, (!!))
import ZkFold.Base.Data.HFunctor (HFunctor (..))
import ZkFold.Base.Data.Package (packWith, unpackWith)
import ZkFold.Base.Data.Utils (zipWithM)
import qualified ZkFold.Base.Data.Vector as V
import ZkFold.Base.Data.Vector (Vector (..))
import ZkFold.Prelude (replicateA, (!!))
import ZkFold.Symbolic.Class
import ZkFold.Symbolic.Data.Bool (Bool (..), BoolType (..))
import ZkFold.Symbolic.Data.Class (SymbolicData)
import ZkFold.Symbolic.Data.Bool (Bool (..), BoolType (..))
import ZkFold.Symbolic.Data.Class (SymbolicData)
import ZkFold.Symbolic.Data.Combinators
import ZkFold.Symbolic.Data.Eq (Eq)
import ZkFold.Symbolic.Data.Eq.Structural
import ZkFold.Symbolic.Data.FieldElement (FieldElement)
import ZkFold.Symbolic.Data.Input (SymbolicInput, isValid)
import ZkFold.Symbolic.Interpreter (Interpreter (..))
import ZkFold.Symbolic.MonadCircuit (ClosedPoly, newAssigned)
import ZkFold.Symbolic.Data.Eq (Eq)
import ZkFold.Symbolic.Data.FieldElement (FieldElement)
import ZkFold.Symbolic.Data.Input (SymbolicInput, isValid)
import ZkFold.Symbolic.Interpreter (Interpreter (..))
import ZkFold.Symbolic.MonadCircuit (ClosedPoly, newAssigned)

-- | A ByteString which stores @n@ bits and uses elements of @a@ as registers, one element per register.
-- Bit layout is Big-endian.
Expand All @@ -74,7 +73,7 @@ deriving anyclass instance NFData (c (Vector n)) => NFData (ByteString n c)
deriving newtype instance SymbolicData (ByteString n c)


deriving via (Structural (ByteString n c))
deriving newtype
instance (Symbolic c, KnownNat n) => Eq (Bool c) (ByteString n c)

instance
Expand Down
37 changes: 36 additions & 1 deletion symbolic-base/src/ZkFold/Symbolic/Data/Eq.hs
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

module ZkFold.Symbolic.Data.Eq (
Eq(..),
elem
elem,
GEq (..)
) where

import Data.Bool (bool)
import Data.Foldable (Foldable)
import Data.Functor.Rep (Representable, mzipRep, mzipWithRep)
import Data.Traversable (Traversable, for)
import qualified Data.Vector as V
import qualified GHC.Generics as G
import Prelude (return, ($))
import qualified Prelude as Haskell

import ZkFold.Base.Algebra.Basic.Class
import ZkFold.Base.Data.Package
import ZkFold.Base.Data.Vector
import ZkFold.Symbolic.Class
import ZkFold.Symbolic.Data.Bool (Bool (Bool), BoolType (..), all, any)
import ZkFold.Symbolic.Data.Combinators (runInvert)
Expand All @@ -22,9 +28,13 @@ import ZkFold.Symbolic.MonadCircuit
class Eq b a where
infix 4 ==
(==) :: a -> a -> b
default (==) :: (G.Generic a, GEq b (G.Rep a)) => a -> a -> b
x == y = geq (G.from x) (G.from y)

infix 4 /=
(/=) :: a -> a -> b
default (/=) :: (G.Generic a, GEq b (G.Rep a)) => a -> a -> b
x /= y = gneq (G.from x) (G.from y)

elem :: (BoolType b, Eq b a, Foldable t) => a -> t a -> b
elem x = any (== x)
Expand Down Expand Up @@ -62,3 +72,28 @@ instance (Symbolic c, Haskell.Eq (BaseField c), Representable f, Traversable f)
in
any Bool (unpacked result)

instance (BoolType b, Eq b x) => Eq b (Vector n x) where
u == v = V.foldl (&&) true (V.zipWith (==) (toV u) (toV v))
u /= v = V.foldl (||) false (V.zipWith (/=) (toV u) (toV v))

deriving newtype instance Symbolic c => Eq (Bool c) (Bool c)

instance (BoolType b, Eq b x0, Eq b x1) => Eq b (x0,x1)
instance (BoolType b, Eq b x0, Eq b x1, Eq b x2) => Eq b (x0,x1,x2)
instance (BoolType b, Eq b x0, Eq b x1, Eq b x2, Eq b x3) => Eq b (x0,x1,x2,x3)

class GEq b u where
geq :: u x -> u x -> b
gneq :: u x -> u x -> b

instance (BoolType b, GEq b u, GEq b v) => GEq b (u G.:*: v) where
geq (x0 G.:*: x1) (y0 G.:*: y1) = geq x0 y0 && geq x1 y1
gneq (x0 G.:*: x1) (y0 G.:*: y1) = gneq x0 y0 || gneq x1 y1

instance GEq b v => GEq b (G.M1 i c v) where
geq (G.M1 x) (G.M1 y) = geq x y
gneq (G.M1 x) (G.M1 y) = gneq x y

instance Eq b x => GEq b (G.Rec0 x) where
geq (G.K1 x) (G.K1 y) = x == y
gneq (G.K1 x) (G.K1 y) = x /= y
36 changes: 0 additions & 36 deletions symbolic-base/src/ZkFold/Symbolic/Data/Eq/Structural.hs

This file was deleted.

53 changes: 26 additions & 27 deletions symbolic-base/src/ZkFold/Symbolic/Data/UInt.hs
Original file line number Diff line number Diff line change
Expand Up @@ -19,42 +19,41 @@ module ZkFold.Symbolic.Data.UInt (
) where

import Control.DeepSeq
import Control.Monad.State (StateT (..))
import Data.Aeson hiding (Bool)
import Data.Foldable (foldlM, foldr, foldrM, for_)
import Data.Functor ((<$>))
import Data.Kind (Type)
import Data.List (unfoldr, zip)
import Data.Map (fromList, (!))
import Data.Traversable (for, traverse)
import Data.Tuple (swap)
import qualified Data.Zip as Z
import GHC.Generics (Generic, Par1 (..))
import GHC.Natural (naturalFromInteger)
import Prelude (Integer, const, error, flip, otherwise, return, type (~), ($),
(++), (.), (<>), (>>=))
import qualified Prelude as Haskell
import Test.QuickCheck (Arbitrary (..), chooseInteger)
import Control.Monad.State (StateT (..))
import Data.Aeson hiding (Bool)
import Data.Foldable (foldlM, foldr, foldrM, for_)
import Data.Functor ((<$>))
import Data.Kind (Type)
import Data.List (unfoldr, zip)
import Data.Map (fromList, (!))
import Data.Traversable (for, traverse)
import Data.Tuple (swap)
import qualified Data.Zip as Z
import GHC.Generics (Generic, Par1 (..))
import GHC.Natural (naturalFromInteger)
import Prelude (Integer, const, error, flip, otherwise, return, type (~), ($), (++),
(.), (<>), (>>=))
import qualified Prelude as Haskell
import Test.QuickCheck (Arbitrary (..), chooseInteger)

import ZkFold.Base.Algebra.Basic.Class
import ZkFold.Base.Algebra.Basic.Field (Zp)
import ZkFold.Base.Algebra.Basic.Field (Zp)
import ZkFold.Base.Algebra.Basic.Number
import qualified ZkFold.Base.Data.Vector as V
import ZkFold.Base.Data.Vector (Vector (..))
import ZkFold.Prelude (length, replicate, replicateA)
import qualified ZkFold.Base.Data.Vector as V
import ZkFold.Base.Data.Vector (Vector (..))
import ZkFold.Prelude (length, replicate, replicateA)
import ZkFold.Symbolic.Class
import ZkFold.Symbolic.Data.Bool
import ZkFold.Symbolic.Data.ByteString
import ZkFold.Symbolic.Data.Class (SymbolicData)
import ZkFold.Symbolic.Data.Class (SymbolicData)
import ZkFold.Symbolic.Data.Combinators
import ZkFold.Symbolic.Data.Conditional
import ZkFold.Symbolic.Data.Eq
import ZkFold.Symbolic.Data.Eq.Structural
import ZkFold.Symbolic.Data.FieldElement (FieldElement)
import ZkFold.Symbolic.Data.Input (SymbolicInput, isValid)
import ZkFold.Symbolic.Data.FieldElement (FieldElement)
import ZkFold.Symbolic.Data.Input (SymbolicInput, isValid)
import ZkFold.Symbolic.Data.Ord
import ZkFold.Symbolic.Interpreter (Interpreter (..))
import ZkFold.Symbolic.MonadCircuit (MonadCircuit, constraint, newAssigned)
import ZkFold.Symbolic.Interpreter (Interpreter (..))
import ZkFold.Symbolic.MonadCircuit (MonadCircuit, constraint, newAssigned)


-- TODO (Issue #18): hide this constructor
Expand Down Expand Up @@ -461,7 +460,7 @@ instance
, KnownRegisterSize r
) => Ring (UInt n r c)

deriving via (Structural (UInt n rs c))
deriving newtype
instance (Symbolic c, KnownNat (NumberOfRegisters (BaseField c) n rs)) =>
Eq (Bool c) (UInt n rs c)

Expand Down
1 change: 0 additions & 1 deletion symbolic-base/symbolic-base.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,6 @@ library
ZkFold.Symbolic.Data.DiscreteField
ZkFold.Symbolic.Data.Ed25519
ZkFold.Symbolic.Data.Eq
ZkFold.Symbolic.Data.Eq.Structural
ZkFold.Symbolic.Data.FFA
ZkFold.Symbolic.Data.Input
ZkFold.Symbolic.Data.List
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import ZkFold.Symbolic.Cardano.Types.Basic
import ZkFold.Symbolic.Class (Symbolic)
import ZkFold.Symbolic.Data.Class
import ZkFold.Symbolic.Data.Eq (Eq)
import ZkFold.Symbolic.Data.Eq.Structural
import ZkFold.Symbolic.Data.Input

type AddressType context = ByteString 4 context
Expand All @@ -29,8 +28,7 @@ data Address context = Address {
deriving instance (Haskell.Eq (ByteString 4 context), Haskell.Eq (ByteString 224 context))
=> Haskell.Eq (Address context)

deriving via (Structural (Address context))
instance (Symbolic context) => Eq (Bool context) (Address context)
instance Symbolic context => Eq (Bool context) (Address context)

instance HApplicative context => SymbolicData (Address context)
instance Symbolic context => SymbolicInput (Address context) where
Expand Down
4 changes: 1 addition & 3 deletions symbolic-cardano/src/ZkFold/Symbolic/Cardano/Types/Output.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import ZkFold.Symbolic.Class
import ZkFold.Symbolic.Data.Class
import ZkFold.Symbolic.Data.Combinators (NumberOfRegisters, RegisterSize (..))
import ZkFold.Symbolic.Data.Eq (Eq)
import ZkFold.Symbolic.Data.Eq.Structural
import ZkFold.Symbolic.Data.Input (SymbolicInput (..))

data Liability context
Expand Down Expand Up @@ -64,8 +63,7 @@ instance
) => SymbolicInput (Output tokens datum context) where
isValid (Output a t d) = isValid (a, t, d)

deriving via (Structural (Output tokens datum context))
instance
instance
( Symbolic context
, KnownNat tokens
, KnownNat (NumberOfRegisters (BaseField context) 64 'Auto)
Expand Down
7 changes: 7 additions & 0 deletions symbolic-cardano/src/ZkFold/Symbolic/Cardano/Types/Value.hs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -freduction-depth=0 #-} -- Avoid reduction overflow error caused by NumberOfRegisters
{-# OPTIONS_GHC -Wno-orphans #-}
Expand All @@ -16,6 +17,7 @@ import ZkFold.Symbolic.Cardano.Types.Basic
import ZkFold.Symbolic.Class (Symbolic (..))
import ZkFold.Symbolic.Data.Class
import ZkFold.Symbolic.Data.Combinators (NumberOfRegisters, RegisterSize (..))
import ZkFold.Symbolic.Data.Eq
import ZkFold.Symbolic.Data.Input

type PolicyId context = ByteString 224 context
Expand All @@ -32,6 +34,11 @@ deriving instance (Haskell.Ord (ByteString 224 context), Haskell.Ord (ByteString

deriving instance (Symbolic context, KnownNat n) => SymbolicData (Value n context)

deriving newtype instance
( Symbolic context
, KnownNat (NumberOfRegisters (BaseField context) 64 Auto)
) => Eq (Bool context) (Value n context)

instance
( Symbolic context
, KnownNat n
Expand Down

0 comments on commit 39ed260

Please sign in to comment.