Skip to content

Commit

Permalink
performance improvements to the position functions and tree lookup (k…
Browse files Browse the repository at this point in the history
…eybase#19426)

* performance improvements to the position library and tree lookup

* typos and style
  • Loading branch information
AMarcedone authored Sep 10, 2019
1 parent 539960e commit 525491d
Show file tree
Hide file tree
Showing 5 changed files with 295 additions and 118 deletions.
6 changes: 5 additions & 1 deletion go/merkletree2/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ type Config struct {
// bitsPerIndex divides keyByteLength*8
keysByteLength int

// The maximum depth of the tree. Should always equal keysByteLength*8/bitsPerIndex
maxDepth int

// valueConstructor is an interface to construct empty values to be used for
// deserialization.
valueConstructor ValueConstructor
Expand All @@ -65,7 +68,8 @@ func NewConfig(h Encoder, useBlindedValueHashes bool, logChildrenPerNode uint8,
if logChildrenPerNode < 1 {
return Config{}, NewInvalidConfigError(fmt.Sprintf("Need at least 2 children per node, but logChildrenPerNode = %v", logChildrenPerNode))
}
return Config{encoder: h, useBlindedValueHashes: useBlindedValueHashes, childrenPerNode: childrenPerNode, maxValuesPerLeaf: maxValuesPerLeaf, bitsPerIndex: logChildrenPerNode, keysByteLength: keysByteLength, valueConstructor: nil}, nil
maxDepth := keysByteLength * 8 / int(logChildrenPerNode)
return Config{encoder: h, useBlindedValueHashes: useBlindedValueHashes, childrenPerNode: childrenPerNode, maxValuesPerLeaf: maxValuesPerLeaf, bitsPerIndex: logChildrenPerNode, keysByteLength: keysByteLength, maxDepth: maxDepth, valueConstructor: nil}, nil
}

// MasterSecret is a secret used to hide wether a leaf value has changed between
Expand Down
142 changes: 91 additions & 51 deletions go/merkletree2/position.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,18 @@ func (p *Position) getBytes() []byte {
return (*big.Int)(p).Bytes()
}

// Set updates p to the value of q
func (p *Position) Set(q *Position) {
(*big.Int)(p).Set((*big.Int)(q))
}

// Clone returns a pointer to a deep copy of a position
func (p *Position) Clone() *Position {
var q Position
q.Set(p)
return &q
}

func (p *Position) isOnPathToKey(k Key) bool {
// If the Key is shorter than current prefix
if len(k)*8 < (*big.Int)(p).BitLen()-1 {
Expand All @@ -51,44 +63,56 @@ func (t *Config) getParent(p *Position) *Position {
return nil
}

var f big.Int
f.Rsh((*big.Int)(p), uint(t.bitsPerIndex))
f := p.Clone()
t.updateToParent(f)

return (*Position)(&f)
return f
}

// getAllSiblings returns nil,nil if p is the root
func (t *Config) getAllSiblings(p *Position) (siblings []Position, parent *Position) {
func (t *Config) updateToParent(p *Position) {
((*big.Int)(p)).Rsh((*big.Int)(p), uint(t.bitsPerIndex))
}

parent = t.getParent(p)
if parent == nil {
return nil, nil
// Behavior if p has no parent at the requested level is undefined.
func (t *Config) updateToParentAtLevel(p *Position, level uint) {
shift := (*big.Int)(p).BitLen() - 1 - int(t.bitsPerIndex)*int(level)
((*big.Int)(p)).Rsh((*big.Int)(p), uint(shift))
}

// updateToParentAndAllSiblings takes as input p and a slice of size
// t.cfg.childrenPerNode - 1. It populates the slice with the siblings of p, and
// updates p to be its parent.
func (t *Config) updateToParentAndAllSiblings(p *Position, sibs []Position) {
if (*big.Int)(p).BitLen() < 2 {
return
}

// Optimization for binary trees
if t.childrenPerNode == 2 {
var sib big.Int
sib.Xor((*big.Int)(p), big.NewInt(1))
return []Position{Position(sib)}, parent
}
sibs[0].Set(p)
lsBits := &(((*big.Int)(&sibs[0]).Bits())[0])
*lsBits = (*lsBits ^ 1)

} else {

siblings = make([]Position, t.childrenPerNode-1)
pChildIndex := big.Word(t.getDeepestChildIndex(p))

var child0 big.Int
child0.Lsh((*big.Int)(parent), uint(t.bitsPerIndex))
mask := ^((big.Word)((1 << t.bitsPerIndex) - 1))

var buff big.Int
pChildIndex := buff.Xor(&child0, (*big.Int)(p)).Int64()
for i, j := uint(0), big.Word(0); j < big.Word(t.childrenPerNode); j++ {
if j == pChildIndex {
continue
}

for i, j := int64(0), int64(0); j < int64(t.childrenPerNode); j++ {
if j == pChildIndex {
continue
sibs[i].Set(p)
// Set least significant bits to the j-th children
lsBits := &(((*big.Int)(&sibs[i]).Bits())[0])
*lsBits = (*lsBits & mask) | j
i++
}
(*big.Int)(&siblings[i]).Or(&child0, big.NewInt(j))
i++
}

return siblings, parent
t.updateToParent(p)
}

// getDeepestPositionForKey converts the key into the position the key would be
Expand All @@ -103,25 +127,45 @@ func (t *Config) getDeepestPositionForKey(k Key) (*Position, error) {
return &p, nil
}

// getSiblingPositionsOnPathToKey returns a slice of positions, in descending
// order by level (siblings farther from the root come first) and in
// lexicographic order within each level.
func (t *Config) getSiblingPositionsOnPathToKey(k Key) ([]Position, error) {
p, err := t.getDeepestPositionForKey(k)
if err != nil {
return nil, err
// getDeepestPositionAtLevelAndSiblingsOnPathToKey returns a slice of positions,
// in descending order by level (siblings farther from the root come first) and
// in lexicographic order within each level. The first position in the slice is
// the position at level lastLevel on a path from the root to k (or the deepest
// possible position for such key if latLevel is greater than that). The
// following positions are all the siblings of the nodes on the longest possible
// path from the root to the key k with are at levels from lastLevel (excluded)
// to firstLevel (included).
// See TestGetDeepestPositionAtLevelAndSiblingsOnPathToKey for sample outputs.
func (t *Config) getDeepestPositionAtLevelAndSiblingsOnPathToKey(k Key, lastLevel int, firstLevel int) (sibs []Position) {

maxLevel := t.keysByteLength * 8 / int(t.bitsPerIndex)
if lastLevel > maxLevel {
lastLevel = maxLevel
}

// first, shrink the key for efficiency
bytesNecessary := lastLevel * int(t.bitsPerIndex) / 8
if lastLevel*int(t.bitsPerIndex)%8 != 0 {
bytesNecessary++
}
maxPathLength := t.keysByteLength * 8 / int(t.bitsPerIndex)
positions := make([]Position, 0, maxPathLength*(t.childrenPerNode-1))
root := t.getRootPosition()
var sibs []Position
for i := 0; !p.equals(root); {
sibs, p = t.getAllSiblings(p)
positions = append(positions, sibs...)
i++
k = k[:bytesNecessary]

var buf Position
p := &buf
(*big.Int)(p).SetBytes(k)
(*big.Int)(p).SetBit((*big.Int)(p), len(k)*8, 1)

t.updateToParentAtLevel(p, uint(lastLevel))

sibs = make([]Position, (lastLevel-firstLevel+1)*(t.childrenPerNode-1)+1)
sibs[0].Set(p)
for i, j := lastLevel, 0; i >= firstLevel; i-- {
sibsToFill := sibs[1+(t.childrenPerNode-1)*j : 1+(t.childrenPerNode-1)*(j+1)]
t.updateToParentAndAllSiblings(p, sibsToFill)
j++
}

return positions, nil
return sibs
}

// getLevel returns the level of p. The root is at level 0, and each node has
Expand All @@ -138,25 +182,23 @@ func (t *Config) getParentAtLevel(p *Position, level uint) *Position {
return nil
}

var f big.Int
f.Rsh((*big.Int)(p), uint(shift))

return (*Position)(&f)
f := p.Clone()
t.updateToParentAtLevel(f, level)
return f
}

// positionToChildIndexPath returns the list of childIndexes to navigate from the
// root to p (in reverse order).
func (t *Config) positionToChildIndexPath(p *Position) (path []ChildIndex) {
path = make([]ChildIndex, t.getLevel(p))

bitMask := big.NewInt(int64(t.childrenPerNode - 1))
bitMask := big.Word(t.childrenPerNode - 1)

var buff, buff2 big.Int
buff2.Set((*big.Int)(p))
buff := p.Clone()

for i := range path {
path[i] = ChildIndex(buff.And(bitMask, &buff2).Int64())
buff2.Rsh(&buff2, uint(t.bitsPerIndex))
path[i] = ChildIndex(((*big.Int)(buff)).Bits()[0] & bitMask)
((*big.Int)(buff)).Rsh((*big.Int)(buff), uint(t.bitsPerIndex))
}

return path
Expand All @@ -168,7 +210,5 @@ func (t *Config) getDeepestChildIndex(p *Position) ChildIndex {
if (*big.Int)(p).BitLen() < 2 {
return ChildIndex(0)
}
bitMask := big.NewInt(int64(t.childrenPerNode - 1))
var buff big.Int
return ChildIndex(buff.And(bitMask, (*big.Int)(p)).Int64())
return ChildIndex(((*big.Int)(p).Bits())[0] & ((1 << t.bitsPerIndex) - 1))
}
104 changes: 91 additions & 13 deletions go/merkletree2/position_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func TestEncoding(t *testing.T) {
}
}

func TestGetParentAndGetChild(t *testing.T) {
func TestGetAndUpdateParentAndGetChild(t *testing.T) {

config1bit, config2bits, config3bits := getTreeCfgsWith1_2_3BitsPerIndexUnblinded(t)

Expand Down Expand Up @@ -66,6 +66,79 @@ func TestGetParentAndGetChild(t *testing.T) {
require.NoError(t, err)
require.True(t, test.c.getParent(&child).equals(&parent))
require.True(t, test.c.getChild(&parent, test.i).equals(&child))
parentInPlace := child.Clone()
test.c.updateToParent(parentInPlace)
require.True(t, parentInPlace.equals(&parent))
})
}
}

func TestUpdateToParentAtLevel(t *testing.T) {

config1bit, config2bits, config3bits := getTreeCfgsWith1_2_3BitsPerIndexUnblinded(t)

tests := []struct {
c Config
parent string
child string
level uint
}{
{config1bit, "1", "1", 0},
{config1bit, "1", "10", 0},
{config1bit, "1", "1100100", 0},
{config2bits, "1", "1100", 0},
{config3bits, "1", "1111101", 0},
{config1bit, "111", "111", 2},
{config1bit, "110", "11010", 2},
{config2bits, "110", "11001", 1},
{config3bits, "1111", "1111101", 1},
{config3bits, "1111001000", "1111001000001000", 3},
}

for _, test := range tests {
t.Run(fmt.Sprintf("%v bits: %s -(%v)-> %s", test.c.bitsPerIndex, test.child, test.level, test.parent), func(t *testing.T) {
childToUpdate, err := makePositionFromStringForTesting(test.child)
require.NoError(t, err)
parent, err := makePositionFromStringForTesting(test.parent)
require.NoError(t, err)
test.c.updateToParentAtLevel(&childToUpdate, test.level)
require.True(t, parent.equals(&childToUpdate), "expected: %x actual: %x", parent, childToUpdate)
})
}

}

func TestUpdateToParentAndAllSiblings(t *testing.T) {

config1bit, config2bits, config3bits := getTreeCfgsWith1_2_3BitsPerIndexUnblinded(t)

tests := []struct {
c Config
pStr string
expParentStr string
expSiblings []string
}{
{config1bit, "1001111", "100111", []string{"1001110"}},
{config1bit, "1001", "100", []string{"1000"}},
{config2bits, "1001111", "10011", []string{"1001100", "1001101", "1001110"}},
{config3bits, "1001111", "1001", []string{"1001000", "1001001", "1001010", "1001011", "1001100", "1001101", "1001110"}},
}

for _, test := range tests {
t.Run(fmt.Sprintf("%v bits: %v", test.c.bitsPerIndex, test.pStr), func(t *testing.T) {
p, err := makePositionFromStringForTesting(test.pStr)
require.NoError(t, err)
parent, err := makePositionFromStringForTesting(test.expParentStr)
require.NoError(t, err)
siblings := make([]Position, test.c.childrenPerNode-1)

test.c.updateToParentAndAllSiblings(&p, siblings)
require.True(t, p.equals(&parent))
for i, expPosStr := range test.expSiblings {
expPos, err := makePositionFromStringForTesting(expPosStr)
require.NoError(t, err)
require.True(t, expPos.equals(&siblings[i]), "Error at sibling %v, got %v", expPosStr, siblings[i])
}
})
}
}
Expand Down Expand Up @@ -116,34 +189,39 @@ func TestPositionIsOnPathToKey(t *testing.T) {

}

func TestGetSiblingPositionsOnPathToKey(t *testing.T) {
func TestGetDeepestPositionAtLevelAndSiblingsOnPathToKey(t *testing.T) {

config1bit, config2bits, _ := getTreeCfgsWith1_2_3BitsPerIndexUnblinded(t)

tests := []struct {
c Config
lastLevel int
firstLevel int
k Key
expPosOnPath []string
}{
{config1bit, []byte{}, nil},
{config1bit, []byte{0xf0}, []string{"111110001", "11111001", "1111101", "111111", "11110", "1110", "110", "10"}},
{config2bits, []byte{0xf0}, []string{"111110001", "111110010", "111110011", "1111101", "1111110", "1111111", "11100", "11101", "11110", "100", "101", "110"}},
{config1bit, 8, 1, []byte{0xf0}, []string{"111110000", "111110001", "11111001", "1111101", "111111", "11110", "1110", "110", "10"}},
{config1bit, 2, 1, []byte{0xf0}, []string{"111", "110", "10"}},
{config1bit, 3, 1, []byte{0xf0}, []string{"1111", "1110", "110", "10"}},
{config1bit, 3, 2, []byte{0xf0}, []string{"1111", "1110", "110"}},
{config1bit, 4, 2, []byte{0xf0}, []string{"11111", "11110", "1110", "110"}},
{config1bit, 8, 1, []byte{0x00}, []string{"100000000", "100000001", "10000001", "1000001", "100001", "10001", "1001", "101", "11"}},
{config1bit, 2, 1, []byte{0x00}, []string{"100", "101", "11"}},
{config1bit, 3, 1, []byte{0x00}, []string{"1000", "1001", "101", "11"}},
{config1bit, 3, 2, []byte{0x00}, []string{"1000", "1001", "101"}},
{config1bit, 4, 2, []byte{0x00}, []string{"10000", "10001", "1001", "101"}},
{config1bit, 1, 1, []byte{0x00}, []string{"10", "11"}},
{config2bits, 4, 1, []byte{0xf1}, []string{"111110001", "111110000", "111110010", "111110011", "1111101", "1111110", "1111111", "11100", "11101", "11110", "100", "101", "110"}},
}

for _, test := range tests {
t.Run(fmt.Sprintf("%v bits: %v", test.c.bitsPerIndex, test.k), func(t *testing.T) {
posOnPath, err := test.c.getSiblingPositionsOnPathToKey(test.k)
if test.expPosOnPath == nil {
require.Error(t, err)
require.IsType(t, InvalidKeyError{}, err)
return
}
require.NoError(t, err)
posOnPath := test.c.getDeepestPositionAtLevelAndSiblingsOnPathToKey(test.k, test.lastLevel, test.firstLevel)
require.Equal(t, len(test.expPosOnPath), len(posOnPath))
for i, expPosStr := range test.expPosOnPath {
expPos, err := makePositionFromStringForTesting(expPosStr)
require.NoError(t, err)
require.True(t, expPos.equals(&posOnPath[i]), "Error at position %v", expPosStr)
require.True(t, expPos.equals(&posOnPath[i]), "Error at position %v, got %v", expPosStr, posOnPath[i])
}
})
}
Expand Down
Loading

0 comments on commit 525491d

Please sign in to comment.