Skip to content

Commit

Permalink
Working hashes covered with tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vks4git committed Apr 14, 2024
1 parent 0ca6e86 commit 054a8c6
Show file tree
Hide file tree
Showing 9 changed files with 536 additions and 119 deletions.
14 changes: 12 additions & 2 deletions examples/Examples/ByteString.hs
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}

module Examples.ByteString (
exampleByteStringAnd,
exampleByteStringOr
exampleByteStringOr,
exampleByteStringGrow
) where

import Data.Data (Proxy (Proxy))
import Data.Function (($))
import Data.List ((++))
import Data.String (String)
import GHC.TypeNats (KnownNat, natVal)
import GHC.TypeNats (KnownNat, natVal, type (<=))
import System.IO (IO, putStrLn)
import Text.Show (show)

Expand All @@ -26,6 +28,14 @@ exampleByteStringAnd = makeExample @n "*" "and" (&&)
exampleByteStringOr :: forall n . KnownNat n => IO ()
exampleByteStringOr = makeExample @n "+" "or" (||)

exampleByteStringGrow :: forall n k . (KnownNat n, KnownNat k, n <= k) => IO ()
exampleByteStringGrow = do
let n = show $ natVal (Proxy @n)
let k = show $ natVal (Proxy @k)
putStrLn $ "\nExample: Extending a bytestring of length " ++ n ++ " to length " ++ k
let file = "compiled_scripts/bytestring" ++ n ++ "_to_" ++ k ++ ".json"
compileIO @(Zp BLS12_381_Scalar) file $ grow @(ByteString n (ArithmeticCircuit (Zp BLS12_381_Scalar))) @(ByteString k (ArithmeticCircuit (Zp BLS12_381_Scalar)))

type Binary a = a -> a -> a

type UBinary n = Binary (ByteString n (ArithmeticCircuit (Zp BLS12_381_Scalar)))
Expand Down
3 changes: 2 additions & 1 deletion examples/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

module Main where

import Examples.ByteString (exampleByteStringAnd, exampleByteStringOr)
import Examples.ByteString (exampleByteStringAnd, exampleByteStringOr, exampleByteStringGrow)
import Examples.Conditional (exampleConditional)
import Examples.Eq (exampleEq)
import Examples.Fibonacci (exampleFibonacci)
Expand Down Expand Up @@ -35,3 +35,4 @@ main = do
exampleByteStringAnd @500
exampleByteStringOr @32
exampleByteStringOr @500
exampleByteStringGrow @1 @512
272 changes: 199 additions & 73 deletions src/ZkFold/Symbolic/Algorithms/Hash/SHA2.hs

Large diffs are not rendered by default.

20 changes: 18 additions & 2 deletions src/ZkFold/Symbolic/Algorithms/Hash/SHA2/Constants.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ module ZkFold.Symbolic.Algorithms.Hash.SHA2.Constants
, word32RoundConstants
, sha512InitialHashes
, sha384InitialHashes
, sha512_224InitialHashes
, sha512_256InitialHashes
, word64RoundConstants
) where

Expand All @@ -15,6 +17,11 @@ import Prelude (($), (<$>))

import ZkFold.Base.Algebra.Basic.Class (FromConstant (..))

-- | SHA2 family algorithms differ in constants and parameters used in the mostly identical internal loop.
-- This module stores initial hashes and round constants.
-- They were taken from the official SHA2 description:
-- https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf, pages 11 to 17

sha256InitialHashes :: FromConstant Natural a => V.Vector a
sha256InitialHashes = V.fromList $ fromConstant @Natural <$>
[0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19]
Expand All @@ -39,14 +46,23 @@ word32RoundConstants = V.fromList $ fromConstant @Natural <$>
sha512InitialHashes :: FromConstant Natural a => V.Vector a
sha512InitialHashes = V.fromList $ fromConstant @Natural <$>
[ 0x6a09e667f3bcc908, 0xbb67ae8584caa73b, 0x3c6ef372fe94f82b, 0xa54ff53a5f1d36f1,
0x510e527fade682d1, 0x9b05688c2b3e6c1f, 0x1f83d9abfb41bd6b, 0x5be0cd19137e2179
]
0x510e527fade682d1, 0x9b05688c2b3e6c1f, 0x1f83d9abfb41bd6b, 0x5be0cd19137e2179 ]

sha384InitialHashes :: FromConstant Natural a => V.Vector a
sha384InitialHashes = V.fromList $ fromConstant @Natural <$>
[ 0xcbbb9d5dc1059ed8, 0x629a292a367cd507, 0x9159015a3070dd17, 0x152fecd8f70e5939,
0x67332667ffc00b31, 0x8eb44a8768581511, 0xdb0c2e0d64f98fa7, 0x47b5481dbefa4fa4 ]

sha512_256InitialHashes :: FromConstant Natural a => V.Vector a
sha512_256InitialHashes = V.fromList $ fromConstant @Natural <$>
[ 0x22312194fc2bf72c, 0x9f555fa3c84c64c2, 0x2393b86b6f53b151, 0x963877195940eabd
, 0x96283ee2a88effe3, 0xbe5e1e2553863992, 0x2b0199fc2c85b8aa, 0x0eb72ddc81c52ca2 ]

sha512_224InitialHashes :: FromConstant Natural a => V.Vector a
sha512_224InitialHashes = V.fromList $ fromConstant @Natural <$>
[ 0x8c3d37c819544da2, 0x73e1996689dcd4d6, 0x1dfab7ae32ff9c82, 0x679dd514582f9fcf
, 0x0f6d2b697bd44da8, 0x77e36f7304c48942, 0x3f9d85a86a1d36c8, 0x1112e6ad91d692a1 ]

word64RoundConstants :: FromConstant Natural a => V.Vector a
word64RoundConstants = V.fromList $ fromConstant @Natural <$>
[
Expand Down
125 changes: 96 additions & 29 deletions src/ZkFold/Symbolic/Data/ByteString.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ module ZkFold.Symbolic.Data.ByteString
( ByteString(..)
, ShiftBits (..)
, ToWords (..)
, Append (..)
, Concat (..)
, Truncate (..)
, Grow (..)
, Extend (..)
) where

import Control.Monad (forM, mapM, replicateM, zipWithM)
import Data.Bits as B
import Data.List (concat, foldl, reverse, splitAt, unfoldr)
import Data.List (foldl, reverse, splitAt, unfoldr)
import Data.List.Split (chunksOf)
import Data.Maybe (Maybe (..))
import Data.Proxy (Proxy (..))
Expand All @@ -37,12 +37,21 @@ import ZkFold.Symbolic.Compiler.ArithmeticCircuit.MonadBlueprint (Clos
import ZkFold.Symbolic.Data.Bool (BoolType (..))
import ZkFold.Symbolic.Data.Combinators


-- | A ByteString which stores @n@ bits and uses elements of @a@ as registers.
-- Bit layout is Big-endian. @a@ is the higher register defined separately as it may store less bits than the lower registers.
--
data ByteString (n :: Natural) a = ByteString !a ![a]
deriving (Haskell.Show, Haskell.Eq)


-- | A class for data types that support bit shift and bit cyclic shift (rotation) operations.
--
class ShiftBits a where
{-# MINIMAL (shiftBits | (shiftBitsL, shiftBitsR)), (rotateBits | (rotateBitsL, rotateBitsR)) #-}

-- | shiftBits performs a left shift when its agrument is greater than zero and a right shift otherwise.
--
shiftBits :: a -> Integer -> a
shiftBits a s
| s < 0 = shiftBitsR a (Haskell.fromIntegral . negate $ s)
Expand All @@ -54,6 +63,8 @@ class ShiftBits a where
shiftBitsR :: a -> Natural -> a
shiftBitsR a s = shiftBits a (negate . Haskell.fromIntegral $ s)

-- | rotateBits performs a left cyclic shift when its agrument is greater than zero and a right cyclic shift otherwise.
--
rotateBits :: a -> Integer -> a
rotateBits a s
| s < 0 = rotateBitsR a (Haskell.fromIntegral . negate $ s)
Expand All @@ -65,42 +76,64 @@ class ShiftBits a where
rotateBitsR :: a -> Natural -> a
rotateBitsR a s = rotateBits a (negate . Haskell.fromIntegral $ s)


-- | Describes types which can be split into words of equal size.
-- Parameters have to be of different types as ByteString store their lengths on type level and hence after splitting they chage types.
--
class ToWords a b where
toWords :: a -> [b]

class Append a b where
append :: [a] -> b

-- | Describes types which can be made by concatenating several words of equal length.
--
class Concat a b where
concat :: [a] -> b


-- | Describes types that can be truncated by dropping several bits from the end (i.e. stored in the lower registers)
--
class Truncate a b where
truncate :: a -> b

class Grow a b where
grow :: a -> b

-- | Describes types that can increase their capacity by adding zero bits to the beginning (i.e. before the higher register).
--
class Extend a b where
extend :: a -> b


instance (Finite a, ToConstant a Natural, KnownNat n) => ToConstant (ByteString n a) Natural where
toConstant (ByteString x xs) = Haskell.foldl (\y p -> toConstant p + base * y) 0 (x:xs)
where base = 2 Haskell.^ maxBitsPerRegister @a @n
toConstant (ByteString x xs) = Haskell.foldl (\y p -> toConstant p + base * y) 0 (x:xs)
where base = 2 Haskell.^ maxBitsPerRegister @a @n


instance (FromConstant Natural a, Finite a, KnownNat n) => FromConstant Natural (ByteString n a) where
-- | fromConstant discards bits after @n@.
-- If the constant is greater than 2^@n@, only the part modulo 2^@n@ will be converted into ByteString.
fromConstant n = case reverse bits of
[] -> error "FromConstant: unreachable"
(r:rs) -> ByteString r rs
where
base = 2 Haskell.^ maxBitsPerRegister @a @n

availableBits = unfoldr (toBase base) (n `Haskell.mod` (2 Haskell.^ (getNatural @n))) <> Haskell.repeat (fromConstant @Natural 0)
-- | Pack a ByteString as tightly as possible, allocating the largest possible number of bits to each register.
-- @fromConstant@ discards bits after @n@.
-- If the constant is greater than @2^n@, only the part modulo @2^n@ will be converted into a ByteString.
--
fromConstant n = case reverse bits of
[] -> error "FromConstant: unreachable"
(r:rs) -> ByteString r rs
where
base = 2 Haskell.^ maxBitsPerRegister @a @n

availableBits = unfoldr (toBase base) (n `Haskell.mod` (2 Haskell.^ (getNatural @n))) <> Haskell.repeat (fromConstant @Natural 0)

bits = take (Haskell.fromIntegral $ minNumberOfRegisters @a @n) availableBits
bits = take (Haskell.fromIntegral $ minNumberOfRegisters @a @n) availableBits

-- | Convert a number into @base@-ary system.
--
toBase :: FromConstant Natural a => Natural -> Natural -> Maybe (a, Natural)
toBase _ 0 = Nothing
toBase base b = let (d, m) = b `divMod` base in Just (fromConstant m, d)


instance (FromConstant Natural a, Finite a, KnownNat n) => FromConstant Integer (ByteString n a) where
fromConstant = fromConstant . naturalFromInteger . (`Haskell.mod` (2 Haskell.^ getNatural @n))


instance (KnownNat p, KnownNat n) => Arbitrary (ByteString n (Zp p)) where
arbitrary = ByteString
<$> toss (highRegisterBits @(Zp p) @n)
Expand All @@ -112,10 +145,11 @@ instance (KnownNat p, KnownNat n) => Arbitrary (ByteString n (Zp p)) where
instance (KnownNat p, KnownNat n) => AdditiveSemigroup (ByteString n (Zp p)) where
x + y = fromConstant $ toConstant x + (toConstant @_ @Natural) y


instance (KnownNat p, KnownNat n) => ShiftBits (ByteString n (Zp p)) where
shiftBits b s = fromConstant $ shift (toConstant @_ @Natural b) (Haskell.fromIntegral s) `Haskell.mod` (2 Haskell.^ (getNatural @n))

-- @Data.Bits.rotate@ works exactly as @Data.Bits.shift@ for @Natural@, we have to rotate bits manually.
-- | @Data.Bits.rotate@ works exactly as @Data.Bits.shift@ for @Natural@, we have to rotate bits manually.
rotateBitsL b s
| s == 0 = b
-- Rotations by k * n + p bits where n is the length of the ByteString are equivalent to rotations by p bits.
Expand Down Expand Up @@ -172,6 +206,13 @@ instance (KnownNat p, KnownNat n) => BoolType (ByteString n (Zp p)) where
-- | Bitwise xor
xor x y = fromConstant @Natural $ toConstant x `B.xor` toConstant y


-- | A ByteString of length @n@ can only be split into words of length @wordSize@ if all of the following conditions are met:
-- 1. @wordSize@ is not greater than @n@;
-- 2. @wordSize@ is not zero;
-- 3. The bytestring is not empty;
-- 4. @wordSize@ divides @n@.
--
instance
( KnownNat wordSize
, KnownNat n
Expand All @@ -196,18 +237,26 @@ instance
natWords :: [Natural]
natWords = unfoldr (toBase (2 Haskell.^ wordSize)) asNat <> Haskell.repeat (fromConstant @Natural 0)


-- | Unfortunately, Haskell does not support dependent types yet,
-- so we have no possibility to infer the exact type of the result
-- (the list can contain an arbitrary number of words).
-- We can only impose some restrictions on @n@ and @m@.
--
instance
( KnownNat n
, KnownNat m
, m <= n
, Mod n m ~ 0
, KnownNat p
) => Append (ByteString m (Zp p)) (ByteString n (Zp p)) where
) => Concat (ByteString m (Zp p)) (ByteString n (Zp p)) where

append = fromConstant @Natural . foldl (\p y -> toConstant y + p `shift` m) 0
concat = fromConstant @Natural . foldl (\p y -> toConstant y + p `shift` m) 0
where
m = Haskell.fromIntegral $ getNatural @m


-- | Only a bigger ByteString can be truncated into a smaller one.
--
instance
( KnownNat m
, KnownNat n
Expand All @@ -220,17 +269,21 @@ instance
diff :: Haskell.Int
diff = Haskell.fromIntegral $ getNatural @m Haskell.- getNatural @n

-- | Only a smaller ByteString can be extended into a bigger one.
--
instance
( KnownNat m
, KnownNat n
, m <= n
, KnownNat p
) => Grow (ByteString m (Zp p)) (ByteString n (Zp p)) where
) => Extend (ByteString m (Zp p)) (ByteString n (Zp p)) where

grow = fromConstant @Natural . toConstant
extend = fromConstant @Natural . toConstant

--------------------------------------------------------------------------------

-- | Convert an @ArithmeticCircuit@ to bits and return their corresponding variables.
--
toBits
:: forall n a
. Arithmetic a
Expand All @@ -247,6 +300,9 @@ toBits hi lo = do

pure $ bitsHigh <> bitsLow


-- | The inverse of @toBits@.
--
fromBits
:: forall n a
. Arithmetic a
Expand All @@ -262,6 +318,7 @@ fromBits bits = do
pure $ highNew : lowsNew



instance (Arithmetic a, KnownNat n) => Arithmetizable a (ByteString n (ArithmeticCircuit a)) where
arithmetize (ByteString a as) = forM (a:as) runCircuit

Expand All @@ -271,7 +328,9 @@ instance (Arithmetic a, KnownNat n) => Arithmetizable a (ByteString n (Arithmeti

typeSize = minNumberOfRegisters @a @n


-- TODO: I really don't like that summation is implemented here and in UInt. Can we do something about it?
-- Converting ByteStrings to UInt and back will flood the circuit with new constraints because of different bit layouts in these types.
--
instance (Arithmetic a, KnownNat n) => AdditiveSemigroup (ByteString n (ArithmeticCircuit a)) where
ByteString x xs + ByteString y ys =
Expand Down Expand Up @@ -302,6 +361,8 @@ instance (Arithmetic a, KnownNat n) => AdditiveSemigroup (ByteString n (Arithmet
(r, c') <- f c a b
(r:) <$> zipWithCarryM f c' as bs

-- | Perform some operation on a list of bits.
--
moveBits
:: forall n a
. Arithmetic a
Expand Down Expand Up @@ -353,6 +414,9 @@ instance (Arithmetic a, KnownNat n) => ShiftBits (ByteString n (ArithmeticCircui
rotateList lst = drop (Haskell.fromIntegral s) lst <> take (Haskell.fromIntegral s) lst


-- | A generic bitwise operation on two ByteStrings.
-- TODO: Shall we expose it to users? Can they do something malicious having such function? AFAIK there are checks that constrain each bit to 0 or 1.
--
bitwiseOperation
:: forall a n
. Arithmetic a
Expand Down Expand Up @@ -387,6 +451,7 @@ bitwiseOperation (ByteString x xs) (ByteString y ys) cons =
False -> maxBitsPerRegister @a @n
True -> highRegisterBits @a @n


instance (Arithmetic a, KnownNat n) => BoolType (ByteString n (ArithmeticCircuit a)) where
false = ByteString zero (replicate (minNumberOfRegisters @a @n -! 1) zero)

Expand Down Expand Up @@ -433,18 +498,20 @@ instance
instance
( KnownNat m
, KnownNat n
, m <= n
, Mod n m ~ 0
, Arithmetic a
) => Append (ByteString m (ArithmeticCircuit a)) (ByteString n (ArithmeticCircuit a)) where
) => Concat (ByteString m (ArithmeticCircuit a)) (ByteString n (ArithmeticCircuit a)) where

append bs =
concat bs =
case circuits solve of
[] -> error "<+> :: Unreachable"
(r:rs) -> ByteString r rs
where
solve :: forall i m'. MonadBlueprint i a m' => m' [i]
solve = do
bits <- mapM (\(ByteString x xs) -> toBits @m @a x xs) bs
fromBits @n @a $ concat bits
fromBits @n @a $ Haskell.concat bits

instance
( KnownNat m
Expand All @@ -468,9 +535,9 @@ instance
, KnownNat n
, m <= n
, Arithmetic a
) => Grow (ByteString m (ArithmeticCircuit a)) (ByteString n (ArithmeticCircuit a)) where
) => Extend (ByteString m (ArithmeticCircuit a)) (ByteString n (ArithmeticCircuit a)) where

grow (ByteString x xs) =
extend (ByteString x xs) =
case circuits solve of
[] -> error "truncate :: Unreachable"
(r:rs) -> ByteString r rs
Expand Down
Loading

0 comments on commit 054a8c6

Please sign in to comment.