Skip to content

Commit

Permalink
Merge pull request #376 from zkFold/TurtlePU/Payload
Browse files Browse the repository at this point in the history
Added `SymbolicData.Payload` associated type
  • Loading branch information
TurtlePU authored Nov 26, 2024
2 parents 5139374 + b871c7e commit 04c6fbf
Show file tree
Hide file tree
Showing 31 changed files with 319 additions and 278 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ protostar :: forall a n k p i o ctx f pi m c .
, Support [f] ~ Proxy ctx
, Ring (PolyVec f 3)
, HomomorphicCommit m c
, p ~ Payload (IVCInstanceProof pi f c m) :*: U1
, i ~ Layout (IVCInstanceProof pi f c m) :*: U1
, o ~ (((Par1 :*: Layout [FieldElement ctx]) :*: (Par1 :*: (Par1 :*: Par1))) :*: (Par1 :*: Par1))
) => (forall ctx' . Symbolic ctx' => Vector n (FieldElement ctx') -> Vector k (FieldElement ctx') -> Vector n (FieldElement ctx'))
Expand All @@ -62,6 +63,7 @@ protostar func =
stepCircuit' =
hlmap (\(x :*: u :*: y) -> Comp1 (Par1 <$> x) :*: Comp1 (Par1 <$> u) :*: Comp1 (Par1 <$> y) :*: U1)
$ hmap (\(Comp1 x') -> unPar1 <$> x')
$ hpmap (\_ -> Comp1 (tabulate $ const U1) :*: Comp1 (tabulate $ const U1) :*: Comp1 (tabulate $ const U1) :*: U1)
$ compile @a stepFunction

-- The circuit for one step of the recursion with extra witness
Expand Down
2 changes: 1 addition & 1 deletion symbolic-base/src/ZkFold/Symbolic/Algorithms/Hash/MiMC.hs
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,4 @@ hash :: forall context x a .
, Support x ~ Proxy context
, Foldable (Layout x)
) => x -> FieldElement context
hash = mimcHashN mimcConstants (zero :: a) . fmap FieldElement . unpacked . hmap toList . flip pieces Proxy
hash = mimcHashN mimcConstants (zero :: a) . fmap FieldElement . unpacked . hmap toList . flip arithmetize Proxy
4 changes: 3 additions & 1 deletion symbolic-base/src/ZkFold/Symbolic/Class.hs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ type family FunBody (fs :: [Type -> Type]) (g :: Type -> Type) (i :: Type) (m ::

-- | A Symbolic DSL for performant pure computations with arithmetic circuits.
-- @c@ is a generic context in which computations are performed.
class (HApplicative c, Package c, Arithmetic (BaseField c)) => Symbolic c where
class ( HApplicative c, Package c, Arithmetic (BaseField c)
, ResidueField (Const (WitnessField c)) (WitnessField c)
) => Symbolic c where
-- | Base algebraic field over which computations are performed.
type BaseField c :: Type
-- | Type of witnesses usable inside circuit construction
Expand Down
59 changes: 38 additions & 21 deletions symbolic-base/src/ZkFold/Symbolic/Compiler.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@ module ZkFold.Symbolic.Compiler (
import Data.Aeson (FromJSON, ToJSON, ToJSONKey)
import Data.Binary (Binary)
import Data.Function (const, id, (.))
import Data.Functor.Rep (Rep)
import Data.Functor.Rep (Rep, Representable)
import Data.Ord (Ord)
import Data.Proxy (Proxy (..))
import GHC.Generics (Par1 (Par1))
import Prelude (FilePath, IO, Show (..), putStrLn, return, type (~), ($),
(++))
import Data.Tuple (fst, snd)
import GHC.Generics (Par1 (Par1), U1 (..))
import Prelude (FilePath, IO, Show (..), Traversable, putStrLn, return,
type (~), ($), (++))

import ZkFold.Base.Algebra.Basic.Class
import ZkFold.Prelude (writeFileJSON)
Expand Down Expand Up @@ -44,35 +46,49 @@ type CompilesWith c s f =

-- | A constraint defining what it means
-- for data of type @y@ to be properly restorable.
type RestoresFrom c y = (SymbolicData y, Context y ~ c, Support y ~ Proxy c)
type RestoresFrom c y =
(SymbolicData y, Context y ~ c, Support y ~ Proxy c, Payload y ~ U1)

-- | @compileWith opts sLayout f@ compiles a function @f@ into an optimized
compileInternal ::
(CompilesWith c0 s f, RestoresFrom c1 y, c1 ~ ArithmeticCircuit a p i) =>
(c0 (Layout f) -> c1 (Layout y)) ->
c0 (Layout s) -> Payload s (WitnessField c0) -> f -> y
compileInternal opts sLayout sPayload f =
restore . const . (,U1) . optimize . opts $
fromCircuit2F (arithmetize f input) b $
\r (Par1 i) -> do
constraint (\x -> one - x i)
return r
where
Bool b = isValid input
input = restore $ const (sLayout, sPayload)

-- | @compileWith opts inputT@ compiles a function @f@ into an optimized
-- arithmetic circuit packed inside a suitable 'SymbolicData'.
compileWith ::
forall a y p i s f c0 c1.
(CompilesWith c0 s f, RestoresFrom c1 y, c1 ~ ArithmeticCircuit a p i) =>
forall a y p i q j s f c0 c1.
( CompilesWith c0 s f, c0 ~ ArithmeticCircuit a p i
, Representable p, Representable i, Traversable (Layout s)
, RestoresFrom c1 y, c1 ~ ArithmeticCircuit a q j
, Binary a, Binary (Rep p), Binary (Rep i), Ord (Rep i)) =>
-- | Circuit transformation to apply before optimization.
(c0 (Layout f) -> c1 (Layout y)) ->
-- | Basic "input" circuit used to solder @f@.
c0 (Layout s) ->
-- | An algorithm to prepare support argument from the circuit input.
(forall x. p x -> i x -> (Payload s x, Layout s x)) ->
-- | Function to compile.
f -> y
compileWith opts sLayout f =
restore . const . optimize . opts $ fromCircuit2F (pieces f input) b $
\r (Par1 i) -> do
constraint (\x -> one - x i)
return r
where
Bool b = isValid input
input = restore (const sLayout)
compileWith outputTransform inputTransform =
compileInternal outputTransform
(naturalCircuit $ \p i -> snd (inputTransform p i))
(inputPayload $ \p i -> fst (inputTransform p i))

-- | @compile f@ compiles a function @f@ into an optimized arithmetic circuit
-- packed inside a suitable 'SymbolicData'.
compile :: forall a y f c s p.
compile :: forall a y f c s.
( CompilesWith c s f, RestoresFrom c y, Layout y ~ Layout f
, c ~ ArithmeticCircuit a p (Layout s))
, c ~ ArithmeticCircuit a (Payload s) (Layout s))
=> f -> y
compile = compileWith id idCircuit
compile = compileInternal id idCircuit (inputPayload const)

-- | Compiles a function `f` into an arithmetic circuit. Writes the result to a file.
compileIO ::
Expand All @@ -88,6 +104,7 @@ compileIO ::
, SymbolicInput s
, Context s ~ c
, Layout s ~ l
, Payload s ~ p
, FromJSON (Rep l)
, ToJSON (Rep l)
, Arithmetic a, Binary a, Binary (Rep p)
Expand Down
35 changes: 26 additions & 9 deletions symbolic-base/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ module ZkFold.Symbolic.Compiler.ArithmeticCircuit (
desugarRanges,
emptyCircuit,
idCircuit,
payloadCircuit,
naturalCircuit,
inputPayload,
guessOutput,
-- low-level functions
eval,
Expand All @@ -26,6 +27,7 @@ module ZkFold.Symbolic.Compiler.ArithmeticCircuit (
acPrint,
-- Variable mapping functions
hlmap,
hpmap,
mapVarArithmeticCircuit,
-- Arithmetization type fields
acWitness,
Expand Down Expand Up @@ -66,10 +68,11 @@ import ZkFold.Prelude (length)
import ZkFold.Symbolic.Class (fromCircuit2F)
import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Instance ()
import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (Arithmetic, ArithmeticCircuit (..), Constraint,
SysVar (..), Var (..), WitVar (WExVar), acInput,
crown, eval, eval1, exec, exec1, hlmap,
SysVar (..), Var (..), WitVar (..), acInput,
crown, eval, eval1, exec, exec1, hlmap, hpmap,
witnessGenerator)
import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Map
import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Witness (WitnessF)
import ZkFold.Symbolic.Data.Combinators (expansion)
import ZkFold.Symbolic.MonadCircuit (MonadCircuit (..))

Expand Down Expand Up @@ -108,15 +111,29 @@ desugarRanges c =
emptyCircuit :: ArithmeticCircuit a p i U1
emptyCircuit = ArithmeticCircuit empty M.empty empty U1

-- | Given a natural transformation
-- from payload @p@ and input @i@ to output @o@,
-- returns a corresponding arithmetic circuit
-- where outputs computing the payload are unconstrained.
naturalCircuit ::
( Arithmetic a, Representable p, Representable i, Traversable o
, Binary a, Binary (Rep p), Binary (Rep i), Ord (Rep i)) =>
(forall x. p x -> i x -> o x) -> ArithmeticCircuit a p i o
naturalCircuit f = uncurry crown $ swap $ flip runState emptyCircuit $
for (f (tabulate Left) (tabulate Right)) $
either (unconstrained . pure . WExVar) (return . SysVar . InVar)

-- | Identity circuit which returns its input @i@ and doesn't use the payload.
idCircuit :: Representable i => ArithmeticCircuit a p i i
idCircuit = emptyCircuit { acOutput = acInput }

payloadCircuit ::
( Representable p, Traversable p, Arithmetic a, Binary a
, Binary (Rep p), Binary (Rep l), Ord (Rep l)) => ArithmeticCircuit a p l p
payloadCircuit =
uncurry crown $ swap $ flip runState emptyCircuit $
for (tabulate id) $ unconstrained . pure . WExVar
-- | Payload of an input to arithmetic circuit.
-- To be used as an argument to 'compileWith'.
inputPayload ::
(Representable p, Representable i) =>
(forall x. p x -> i x -> o x) -> o (WitnessF a (WitVar p i))
inputPayload f =
f (tabulate $ pure . WExVar) (tabulate $ pure . WSysVar . InVar)

guessOutput ::
(Arithmetic a, Binary a, Binary (Rep p), Binary (Rep i), Binary (Rep o)) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ module ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (
crown,
-- input mapping
hlmap,
hpmap,
-- evaluation functions
witnessGenerator,
eval,
Expand Down Expand Up @@ -119,6 +120,12 @@ imapWitVar ::
imapWitVar _ (WExVar r) = WExVar r
imapWitVar f (WSysVar v) = WSysVar (imapSysVar f v)

pmapWitVar ::
(Representable p, Representable q) =>
(forall x. q x -> p x) -> WitVar p i -> WitVar q i
pmapWitVar f (WExVar r) = index (f (tabulate WExVar)) r
pmapWitVar _ (WSysVar v) = WSysVar v

data Var a i
= SysVar (SysVar i)
| ConstVar a
Expand Down Expand Up @@ -172,6 +179,11 @@ hlmap f (ArithmeticCircuit s r w o) = ArithmeticCircuit
, acOutput = imapVar f <$> o
}

hpmap ::
(Representable p, Representable q) => (forall x. q x -> p x) ->
ArithmeticCircuit a p i o -> ArithmeticCircuit a q i o
hpmap f ac = ac { acWitness = fmap (pmapWitVar f) <$> acWitness ac }

--------------------------- Symbolic compiler context --------------------------

crown :: ArithmeticCircuit a p i g -> f (Var a i) -> ArithmeticCircuit a p i f
Expand Down
79 changes: 59 additions & 20 deletions symbolic-base/src/ZkFold/Symbolic/Data/Class.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ module ZkFold.Symbolic.Data.Class (
) where

import Control.Applicative ((<*>))
import Data.Bifunctor (bimap)
import Data.Function (flip, (.))
import Data.Functor ((<$>))
import Data.Functor.Rep (Representable (..))
import Data.Kind (Type)
import Data.Tuple (fst)
import Data.Type.Equality (type (~))
import Data.Typeable (Proxy (..))
import GHC.Generics (U1 (..), (:*:) (..), (:.:) (..))
Expand All @@ -22,6 +24,7 @@ import ZkFold.Base.Data.HFunctor (hmap)
import ZkFold.Base.Data.Package (Package, pack)
import ZkFold.Base.Data.Product (fstP, sndP)
import ZkFold.Base.Data.Vector (Vector)
import ZkFold.Symbolic.Class (Symbolic (WitnessField))

-- | A class for Symbolic data types.
class SymbolicData x where
Expand All @@ -35,44 +38,63 @@ class SymbolicData x where
type Layout x :: Type -> Type
type Layout x = GLayout (G.Rep x)

type Payload x :: Type -> Type
type Payload x = GPayload (G.Rep x)

-- | Returns the circuit that makes up `x`.
pieces :: x -> Support x -> Context x (Layout x)
default pieces
arithmetize :: x -> Support x -> Context x (Layout x)
default arithmetize
:: ( G.Generic x
, GSymbolicData (G.Rep x)
, Context x ~ GContext (G.Rep x)
, Support x ~ GSupport (G.Rep x)
, Layout x ~ GLayout (G.Rep x)
)
=> x -> Support x -> Context x (Layout x)
pieces x supp = gpieces (G.from x) supp
arithmetize x = garithmetize (G.from x)

-- | Restores `x` from the circuit's outputs.
restore :: (Support x -> Context x (Layout x)) -> x
default restore
payload :: x -> Support x -> Payload x (WitnessField (Context x))
default payload
:: ( G.Generic x
, GSymbolicData (G.Rep x)
, Context x ~ GContext (G.Rep x)
, Support x ~ GSupport (G.Rep x)
, Layout x ~ GLayout (G.Rep x)
, Payload x ~ GPayload (G.Rep x)
)
=> (Support x -> Context x (Layout x)) -> x
=> x -> Support x -> Payload x (WitnessField (Context x))
payload x = gpayload (G.from x)

-- | Restores `x` from the circuit's outputs.
restore ::
Context x ~ c =>
(Support x -> (c (Layout x), Payload x (WitnessField c))) -> x
default restore ::
( Context x ~ c, G.Generic x, GSymbolicData (G.Rep x)
, Context x ~ GContext (G.Rep x)
, Support x ~ GSupport (G.Rep x)
, Layout x ~ GLayout (G.Rep x)
, Payload x ~ GPayload (G.Rep x)) =>
(Support x -> (c (Layout x), Payload x (WitnessField c))) -> x
restore f = G.to (grestore f)

instance SymbolicData (c (f :: Type -> Type)) where
type Context (c f) = c
type Support (c f) = Proxy c
type Layout (c f) = f
type Payload (c f) = U1

pieces x _ = x
restore f = f Proxy
arithmetize x _ = x
payload _ _ = U1
restore f = fst (f Proxy)

instance HApplicative c => SymbolicData (Proxy (c :: (Type -> Type) -> Type)) where
type Context (Proxy c) = c
type Support (Proxy c) = Proxy c
type Layout (Proxy c) = U1
type Payload (Proxy c) = U1

pieces _ _ = hpure U1
arithmetize _ _ = hpure U1
payload _ _ = U1
restore _ = Proxy

instance
Expand Down Expand Up @@ -134,25 +156,34 @@ instance
type Context (Vector n x) = Context x
type Support (Vector n x) = Support x
type Layout (Vector n x) = Vector n :.: Layout x
type Payload (Vector n x) = Vector n :.: Payload x

pieces xs i = pack (flip pieces i <$> xs)
restore f = tabulate (\i -> restore (hmap (flip index i . unComp1) . f))
arithmetize xs s = pack (flip arithmetize s <$> xs)
payload xs s = Comp1 (flip payload s <$> xs)
restore f = tabulate (\i -> restore (bimap (hmap (ix i)) (ix i) . f))
where ix i = flip index i . unComp1

instance SymbolicData f => SymbolicData (x -> f) where
type Context (x -> f) = Context f
type Support (x -> f) = (x, Support f)
type Layout (x -> f) = Layout f
type Payload (x -> f) = Payload f

pieces f (x, i) = pieces (f x) i
arithmetize f (x, i) = arithmetize (f x) i
payload f (x, i) = payload (f x) i
restore f x = restore (f . (x,))

class GSymbolicData u where
type GContext u :: (Type -> Type) -> Type
type GSupport u :: Type
type GLayout u :: Type -> Type
type GPayload u :: Type -> Type

gpieces :: u x -> GSupport u -> GContext u (GLayout u)
grestore :: (GSupport u -> GContext u (GLayout u)) -> u x
garithmetize :: u x -> GSupport u -> GContext u (GLayout u)
gpayload :: u x -> GSupport u -> GPayload u (WitnessField (GContext u))
grestore ::
GContext u ~ c =>
(GSupport u -> (c (GLayout u), GPayload u (WitnessField c))) -> u x

instance
( GSymbolicData u
Expand All @@ -165,20 +196,28 @@ instance
type GContext (u :*: v) = GContext u
type GSupport (u :*: v) = GSupport u
type GLayout (u :*: v) = GLayout u :*: GLayout v
type GPayload (u :*: v) = GPayload u :*: GPayload v

gpieces (a :*: b) = hliftA2 (:*:) <$> gpieces a <*> gpieces b
grestore f = grestore (hmap fstP . f) :*: grestore (hmap sndP . f)
garithmetize (a :*: b) = hliftA2 (:*:) <$> garithmetize a <*> garithmetize b
gpayload (a :*: b) = (:*:) <$> gpayload a <*> gpayload b
grestore f =
grestore (bimap (hmap fstP) fstP . f)
:*: grestore (bimap (hmap sndP) sndP . f)

instance GSymbolicData f => GSymbolicData (G.M1 i c f) where
type GContext (G.M1 i c f) = GContext f
type GSupport (G.M1 i c f) = GSupport f
type GLayout (G.M1 i c f) = GLayout f
gpieces (G.M1 a) = gpieces a
type GPayload (G.M1 i c f) = GPayload f
garithmetize (G.M1 a) = garithmetize a
gpayload (G.M1 a) = gpayload a
grestore f = G.M1 (grestore f)

instance SymbolicData x => GSymbolicData (G.Rec0 x) where
type GContext (G.Rec0 x) = Context x
type GSupport (G.Rec0 x) = Support x
type GLayout (G.Rec0 x) = Layout x
gpieces (G.K1 x) = pieces x
type GPayload (G.Rec0 x) = Payload x
garithmetize (G.K1 x) = arithmetize x
gpayload (G.K1 x) = payload x
grestore f = G.K1 (restore f)
Loading

0 comments on commit 04c6fbf

Please sign in to comment.