Skip to content

Commit

Permalink
Merge pull request #203 from zkFold/hov-newranged-uint
Browse files Browse the repository at this point in the history
acRange has been added to the arithmetic operations for UInt arithmetic circuits with a fixed register
  • Loading branch information
vlasin authored Aug 6, 2024
2 parents 97b3ff3 + 7d8cea1 commit a9a4507
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 28 deletions.
18 changes: 10 additions & 8 deletions src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Combinators.hs
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ import Data.Traversable (for)
import qualified Data.Zip as Z
import GHC.Generics (Par1)
import GHC.IsList (IsList (..))
import Prelude hiding (Bool, Eq (..), length, negate,
splitAt, (!!), (*), (+), (-), (^))
import Prelude hiding (Bool, Eq (..), drop, length, negate,
splitAt, take, (!!), (*), (+), (-), (^))

import ZkFold.Base.Algebra.Basic.Class
import ZkFold.Base.Algebra.Basic.Number
import ZkFold.Base.Algebra.Polynomials.Multivariate (variables)
import qualified ZkFold.Base.Data.Vector as V
import ZkFold.Base.Data.Vector (Vector (..))
import ZkFold.Prelude (length, splitAt, (!!))
import ZkFold.Prelude (drop, length, take, (!!))
import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (ArithmeticCircuit (..), acInput)
import ZkFold.Symbolic.Compiler.ArithmeticCircuit.MonadBlueprint
import ZkFold.Symbolic.MonadCircuit
Expand Down Expand Up @@ -78,17 +78,19 @@ expansion n k = do
constraint (\x -> x k - x k')
return bits

splitExpansion :: MonadBlueprint i a m => Natural -> Natural -> i -> m (i, i)
splitExpansion :: (MonadBlueprint i a m, Arithmetic a) => Natural -> Natural -> i -> m (i, i)
-- ^ @splitExpansion n1 n2 k@ computes two values @(l, h)@ such that
-- @k = 2^n1 h + l@, @l@ fits in @n1@ bits and @h@ fits in n2 bits (if such
-- values exist).
splitExpansion n1 n2 k = do
bits <- bitsOf (n1 + n2) k
let (lo, hi) = splitAt n1 bits
l <- horner lo
h <- horner hi
let f x y = x + y + y
l <- newRanged (fromConstant $ (2 :: Natural) ^ n1 -! 1) $ foldr f zero . take n1 . repr . ($ k)
h <- newRanged (fromConstant $ (2 :: Natural) ^ n2 -! 1) $ foldr f zero . take n2 . drop n1 . repr . ($ k)
constraint (\x -> x k - x l - scale (2 ^ n1 :: Natural) (x h))
return (l, h)
where
repr :: forall b . (BinaryExpansion b, Bits b ~ [b]) => b -> [b]
repr = padBits (n1 + n2) . binaryExpansion

bitsOf :: MonadCircuit i a m => Natural -> i -> m [i]
-- ^ @bitsOf n k@ creates @n@ bits and sets their witnesses equal to @n@ smaller
Expand Down
2 changes: 1 addition & 1 deletion src/ZkFold/Symbolic/Data/FFA.hs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ condSub m x = fst <$> condSubOF m x
smallCut :: forall i a m. (Arithmetic a, MonadBlueprint i a m) => Vector Size i -> m (Vector Size i)
smallCut = zipWithM condSub $ coprimes @a

bigSub :: MonadBlueprint i a m => Natural -> i -> m i
bigSub :: (Arithmetic a, MonadBlueprint i a m) => Natural -> i -> m i
bigSub m j = trimPow j >>= trimPow >>= condSub m
where
s = Haskell.ceiling (log2 m) :: Natural
Expand Down
44 changes: 25 additions & 19 deletions src/ZkFold/Symbolic/Data/UInt.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import Data.Functor ((<$>
import Data.Kind (Type)
import Data.List (unfoldr, zip)
import Data.Map (fromList, (!))
import Data.Traversable (Traversable, for, traverse)
import Data.Traversable (for, traverse)
import Data.Tuple (swap)
import qualified Data.Zip as Z
import GHC.Generics (Generic, Par1 (..))
Expand All @@ -39,7 +39,7 @@ import ZkFold.Base.Control.HApplicative (hlif
import qualified ZkFold.Base.Data.Vector as V
import ZkFold.Base.Data.Vector (Vector (..))
import ZkFold.Prelude (drop, length, replicate, replicateA)
import ZkFold.Symbolic.Class
import ZkFold.Symbolic.Class hiding (embed)
import ZkFold.Symbolic.Compiler hiding (forceZero)
import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Combinators (embedV, expansion, splitExpansion)
import ZkFold.Symbolic.Compiler.ArithmeticCircuit.MonadBlueprint
Expand Down Expand Up @@ -376,19 +376,27 @@ instance
s <- newAssigned (\v -> v d + v b + fromConstant t)
splitExpansion (registerSize @a @n @r) 1 s

negate (UInt x) =
let y = 2 ^ registerSize @a @n @r
ys = replicate (numberOfRegisters @a @n @r -! 2) (2 ^ registerSize @a @n @r -! 1)
y' = 2 ^ highRegisterSize @a @n @r -! 1
ns
| numberOfRegisters @a @n @r Haskell.== 1 = V.unsafeToVector [y' + 1]
| otherwise = V.unsafeToVector $ (y : ys) <> [y']
in UInt (negateN ns x)
negate (UInt x) = UInt $ circuitF (V.unsafeToVector <$> solve)
where
solve :: MonadBlueprint i a m => m [i]
solve = do
j <- newAssigned (Haskell.const zero)

xs <- V.fromVector <$> runCircuit x
let y = 2 ^ registerSize @a @n @r
ys = replicate (numberOfRegisters @a @n @r -! 2) (2 ^ registerSize @a @n @r -! 1)
y' = 2 ^ highRegisterSize @a @n @r -! 1
ns
| numberOfRegisters @a @n @r Haskell.== 1 = [y' + 1]
| otherwise = (y : ys) <> [y']
(zs, _) <- flip runStateT j $ traverse StateT (Haskell.zipWith negateN ns xs)
return zs

negateN :: MonadBlueprint i a m => Natural -> i -> i -> m (i, i)
negateN n i b = do
r <- newAssigned (\v -> fromConstant n - v i + v b)
splitExpansion (registerSize @a @n @r) 1 r

negateN :: (Arithmetic a, Z.Zip f, Traversable f) => f Natural -> ArithmeticCircuit a f -> ArithmeticCircuit a f
negateN ns r = circuitF $ do
is <- runCircuit r
for (Z.zip is ns) $ \(i, n) -> newAssigned (\v -> fromConstant n - v i)

instance (Arithmetic a, KnownNat n, KnownRegisterSize rs, r ~ NumberOfRegisters a n rs) => MultiplicativeSemigroup (UInt n rs (ArithmeticCircuit a)) where
UInt x * UInt y = UInt (circuitF $ V.unsafeToVector <$> solve)
Expand Down Expand Up @@ -517,13 +525,11 @@ instance (Finite (Zp p), Prime p, KnownNat n, KnownRegisterSize r) => StrictConv
instance (Arithmetic a, KnownNat n, KnownRegisterSize r, NumberOfBits a <= n) => StrictConv (ArithmeticCircuit a Par1) (UInt n r (ArithmeticCircuit a)) where
strictConv a = UInt (circuitF $ V.unsafeToVector <$> solve)
where
requiredBits :: Natural
requiredBits = (numberOfRegisters @a @n @r -! 1) * (registerSize @a @n @r) + (highRegisterSize @a @n @r)

solve :: MonadBlueprint i a m => m [i]
solve = do
i <- unPar1 <$> runCircuit a
bits <- Haskell.reverse <$> expansion requiredBits i
let len = Haskell.min (getNatural @n) (numberOfBits @a)
bits <- Haskell.reverse <$> expansion len i
fromBits (highRegisterSize @a @n @r) (registerSize @a @n @r) bits


Expand Down Expand Up @@ -645,7 +651,7 @@ instance (Arithmetic a, KnownNat n, KnownRegisterSize r) => StrictNum (UInt n r

--------------------------------------------------------------------------------

fullAdder :: MonadBlueprint i a m => Natural -> i -> i -> i -> m (i, i)
fullAdder :: (Arithmetic a, MonadBlueprint i a m) => Natural -> i -> i -> i -> m (i, i)
fullAdder r xk yk c = fullAdded xk yk c >>= splitExpansion r 1

fullAdded :: MonadBlueprint i a m => i -> i -> i -> m i
Expand Down

0 comments on commit a9a4507

Please sign in to comment.