Skip to content

Commit

Permalink
Add setter methods to MT (0xPolygonHermez#824)
Browse files Browse the repository at this point in the history
  • Loading branch information
fgimenez authored Jun 29, 2022
1 parent 7511164 commit 4ac22fc
Show file tree
Hide file tree
Showing 3 changed files with 227 additions and 5 deletions.
72 changes: 72 additions & 0 deletions merkletree/key.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package merkletree

import (
"math"
"math/big"

"github.com/ethereum/go-ethereum/common"
Expand Down Expand Up @@ -107,3 +108,74 @@ func KeyContractStorage(ethAddr common.Address, storagePos []byte) ([]byte, erro

return keyEthAddr(ethAddr, leafTypeStorage, hk0)
}

// hashContractBytecode computes the bytecode hash in order to add it to the
// state-tree.
func hashContractBytecode(code []byte) ([]uint64, error) {
const (
bytecodeElementsHash = 8
bytecodeBytesElement = 7

maxBytesToAdd = bytecodeElementsHash * bytecodeBytesElement
)

numHashes := int(math.Ceil(float64(len(code)) / float64(maxBytesToAdd)))

tmpHash := [4]uint64{}
var err error

bytesPointer := 0
for i := 0; i < numHashes; i++ {
elementsToHash := [12]uint64{}

if i != 0 {
for j := 0; j < 4; j++ {
elementsToHash[j] = tmpHash[j]
}
} else {
for j := 0; j < 4; j++ {
elementsToHash[j] = 0
}
}
subsetBytecode := code[bytesPointer : int(math.Min(float64(len(code)-1), float64(bytesPointer+maxBytesToAdd)))+1]
bytesPointer += maxBytesToAdd
tmpElem := [7]byte{}
counter := 0
index := 4
for j := 0; j < maxBytesToAdd; j++ {
byteToAdd := []byte{0}

if j < len(subsetBytecode) {
byteToAdd = subsetBytecode[j : j+1]
}
tmpElem[counter] = byteToAdd[0]
counter++

if counter == bytecodeBytesElement {
elementsToHash[index] = new(big.Int).SetBytes(tmpElem[:]).Uint64()
index++
tmpElem = [7]byte{}
counter = 0
}
}
tmpHash, err = poseidon.Hash([8]uint64{
elementsToHash[4],
elementsToHash[5],
elementsToHash[6],
elementsToHash[7],
elementsToHash[8],
elementsToHash[9],
elementsToHash[10],
elementsToHash[11],
}, [4]uint64{
elementsToHash[0],
elementsToHash[1],
elementsToHash[2],
elementsToHash[3],
})
if err != nil {
return nil, err
}
}
return tmpHash[:], nil
}
151 changes: 146 additions & 5 deletions merkletree/tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ package merkletree

import (
"context"
"fmt"
"math/big"

"github.com/ethereum/go-ethereum/common"
"github.com/hermeznetwork/hermez-core/hex"
"github.com/hermeznetwork/hermez-core/merkletree/pb"
)

Expand All @@ -20,7 +22,7 @@ func NewStateTree(client pb.StateDBServiceClient) *StateTree {
}
}

// GetBalance returns balance
// GetBalance returns balance.
func (tree *StateTree) GetBalance(ctx context.Context, address common.Address, root []byte) (*big.Int, error) {
r := new(big.Int).SetBytes(root)

Expand All @@ -40,7 +42,7 @@ func (tree *StateTree) GetBalance(ctx context.Context, address common.Address, r
return fea2scalar(proof.Value), nil
}

// GetNonce returns nonce
// GetNonce returns nonce.
func (tree *StateTree) GetNonce(ctx context.Context, address common.Address, root []byte) (*big.Int, error) {
r := new(big.Int).SetBytes(root)

Expand All @@ -60,7 +62,7 @@ func (tree *StateTree) GetNonce(ctx context.Context, address common.Address, roo
return fea2scalar(proof.Value), nil
}

// GetCodeHash returns code hash
// GetCodeHash returns code hash.
func (tree *StateTree) GetCodeHash(ctx context.Context, address common.Address, root []byte) ([]byte, error) {
r := new(big.Int).SetBytes(root)

Expand All @@ -82,7 +84,7 @@ func (tree *StateTree) GetCodeHash(ctx context.Context, address common.Address,
return ScalarToFilledByteSlice(valueBi), nil
}

// GetCode returns code
// GetCode returns code.
func (tree *StateTree) GetCode(ctx context.Context, address common.Address, root []byte) ([]byte, error) {
scCodeHash, err := tree.GetCodeHash(ctx, address, root)
if err != nil {
Expand All @@ -98,7 +100,7 @@ func (tree *StateTree) GetCode(ctx context.Context, address common.Address, root
return scCode.Data, nil
}

// GetStorageAt returns Storage Value at specified position
// GetStorageAt returns Storage Value at specified position.
func (tree *StateTree) GetStorageAt(ctx context.Context, address common.Address, position *big.Int, root []byte) (*big.Int, error) {
r := new(big.Int).SetBytes(root)

Expand All @@ -118,6 +120,109 @@ func (tree *StateTree) GetStorageAt(ctx context.Context, address common.Address,
return fea2scalar(proof.Value), nil
}

// SetBalance sets balance.
func (tree *StateTree) SetBalance(ctx context.Context, address common.Address, balance *big.Int, root []byte) (newRoot []byte, proof *UpdateProof, err error) {
if balance.Cmp(big.NewInt(0)) == -1 {
return nil, nil, fmt.Errorf("invalid balance")
}

r := new(big.Int).SetBytes(root)
key, err := KeyEthAddrBalance(address)
if err != nil {
return nil, nil, err
}

k := new(big.Int).SetBytes(key)
balanceH8 := scalar2fea(balance)

updateProof, err := tree.set(ctx, scalarToh4(r), scalarToh4(k), balanceH8)
if err != nil {
return nil, nil, err
}

return h4ToFilledByteSlice(updateProof.NewRoot), updateProof, nil
}

// SetNonce sets nonce.
func (tree *StateTree) SetNonce(ctx context.Context, address common.Address, nonce *big.Int, root []byte) (newRoot []byte, proof *UpdateProof, err error) {
if nonce.Cmp(big.NewInt(0)) == -1 {
return nil, nil, fmt.Errorf("invalid nonce")
}

r := new(big.Int).SetBytes(root)
key, err := KeyEthAddrNonce(address)
if err != nil {
return nil, nil, err
}

k := new(big.Int).SetBytes(key[:])

nonceH8 := scalar2fea(nonce)

updateProof, err := tree.set(ctx, scalarToh4(r), scalarToh4(k), nonceH8)
if err != nil {
return nil, nil, err
}

return h4ToFilledByteSlice(updateProof.NewRoot), updateProof, nil
}

// SetCode sets smart contract code.
func (tree *StateTree) SetCode(ctx context.Context, address common.Address, code []byte, root []byte) (newRoot []byte, proof *UpdateProof, err error) {
// calculating smart contract code hash
scCodeHash4, err := hashContractBytecode(code)
if err != nil {
return nil, nil, err
}

// store smart contract code by its hash
err = tree.setProgram(ctx, scCodeHash4, code, true)
if err != nil {
return nil, nil, err
}

// set smart contract code hash as a leaf value in merkle tree
r := new(big.Int).SetBytes(root)
key, err := KeyContractCode(address)
if err != nil {
return nil, nil, err
}
k := new(big.Int).SetBytes(key[:])

scCodeHash, err := hex.DecodeHex(h4ToString(scCodeHash4))
if err != nil {
return nil, nil, err
}

scCodeHashBI := new(big.Int).SetBytes(scCodeHash[:])
scCodeHashH8 := scalar2fea(scCodeHashBI)

updateProof, err := tree.set(ctx, scalarToh4(r), scalarToh4(k), scCodeHashH8)
if err != nil {
return nil, nil, err
}

return h4ToFilledByteSlice(updateProof.NewRoot), updateProof, nil
}

// SetStorageAt sets storage value at specified position.
func (tree *StateTree) SetStorageAt(ctx context.Context, address common.Address, position *big.Int, value *big.Int, root []byte, txBundleID string) (newRoot []byte, proof *UpdateProof, err error) {
r := new(big.Int).SetBytes(root)
key, err := KeyContractStorage(address, position.Bytes())
if err != nil {
return nil, nil, err
}

k := new(big.Int).SetBytes(key[:])
valueH8 := scalar2fea(value)
updateProof, err := tree.set(ctx, scalarToh4(r), scalarToh4(k), valueH8)
if err != nil {
return nil, nil, err
}

return h4ToFilledByteSlice(updateProof.NewRoot), updateProof, nil
}

func (tree *StateTree) get(ctx context.Context, root, key []uint64) (*Proof, error) {
result, err := tree.grpcClient.Get(ctx, &pb.GetRequest{
Root: &pb.Fea{Fe0: root[0], Fe1: root[1], Fe2: root[2], Fe3: root[3]},
Expand Down Expand Up @@ -151,3 +256,39 @@ func (tree *StateTree) getProgram(ctx context.Context, hash string) (*ProgramPro
Data: result.Data,
}, nil
}

func (tree *StateTree) set(ctx context.Context, oldRoot, key, value []uint64) (*UpdateProof, error) {
h4Value := h4ToString(value)
result, err := tree.grpcClient.Set(ctx, &pb.SetRequest{
OldRoot: &pb.Fea{Fe0: oldRoot[0], Fe1: oldRoot[1], Fe2: oldRoot[2], Fe3: oldRoot[3]},
Key: &pb.Fea{Fe0: key[0], Fe1: key[1], Fe2: key[2], Fe3: key[3]},
Value: h4Value,
Persistent: true,
Details: false,
})
if err != nil {
return nil, err
}

newValue, err := stringToh4(result.NewValue)
if err != nil {
return nil, err
}

return &UpdateProof{
OldRoot: oldRoot,
NewRoot: []uint64{result.NewRoot.Fe0, result.NewRoot.Fe1, result.NewRoot.Fe2, result.NewRoot.Fe3},
Key: key,
NewValue: newValue,
}, nil
}

func (tree *StateTree) setProgram(ctx context.Context, hash []uint64, data []byte, persistent bool) error {
h4Hash := h4ToString(hash)
_, err := tree.grpcClient.SetProgram(ctx, &pb.SetProgramRequest{
Hash: h4Hash,
Data: data,
Persistent: persistent,
})
return err
}
9 changes: 9 additions & 0 deletions merkletree/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@ type Proof struct {
Value []uint64
}

// UpdateProof is a proof generated on Set operation.
type UpdateProof struct {
OldRoot []uint64
NewRoot []uint64
Key []uint64
NewValue []uint64
}

// ProgramProof is a proof generated on GetProgram operation.
type ProgramProof struct {
Data []byte
}

0 comments on commit 4ac22fc

Please sign in to comment.