Skip to content

Commit

Permalink
Improve merkle proof locking (ava-labs#2761)
Browse files Browse the repository at this point in the history
  • Loading branch information
dboehm-avalabs authored Mar 20, 2023
1 parent aec596e commit 98790f4
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 76 deletions.
71 changes: 30 additions & 41 deletions x/merkledb/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ func (db *Database) GetValues(ctx context.Context, keys [][]byte) ([][]byte, []e
errors := make([]error, len(keys))
for i, key := range keys {
path := newPath(key)
values[i], errors[i] = db.getValue(path)
values[i], errors[i] = db.getValue(path, false)
}
return values, errors
}
Expand All @@ -330,14 +330,15 @@ func (db *Database) GetValue(ctx context.Context, key []byte) ([]byte, error) {
_, span := db.tracer.Start(ctx, "MerkleDB.GetValue")
defer span.End()

db.lock.RLock()
defer db.lock.RUnlock()

return db.getValue(newPath(key))
return db.getValue(newPath(key), true)
}

// Assumes [db.lock] is read locked.
func (db *Database) getValue(key path) ([]byte, error) {
func (db *Database) getValue(key path, lock bool) ([]byte, error) {
if lock {
db.lock.RLock()
defer db.lock.RUnlock()
}

if db.closed {
return nil, database.ErrClosed
}
Expand All @@ -352,17 +353,6 @@ func (db *Database) getValue(key path) ([]byte, error) {
return clonedVal.value, nil
}

// Returns a view of the trie as it was when the merkle root was [rootID].
func (db *Database) GetHistoricalView(ctx context.Context, rootID ids.ID) (ReadOnlyTrie, error) {
_, span := db.tracer.Start(ctx, "MerkleDB.GetHistoricalView")
defer span.End()

db.lock.RLock()
defer db.lock.RUnlock()

return db.getHistoricalViewForRangeProof(rootID, nil, nil)
}

// Returns the ID of the root node of the merkle trie.
func (db *Database) GetMerkleRoot(ctx context.Context) (ids.ID, error) {
_, span := db.tracer.Start(ctx, "MerkleDB.GetMerkleRoot")
Expand All @@ -382,14 +372,14 @@ func (db *Database) getMerkleRoot() ids.ID {

// Returns a proof of the existence/non-existence of [key] in this trie.
func (db *Database) GetProof(ctx context.Context, key []byte) (*Proof, error) {
db.lock.RLock()
defer db.lock.RUnlock()
db.commitLock.RLock()
defer db.commitLock.RUnlock()

return db.getProof(ctx, key)
}

// Returns a proof of the existence/non-existence of [key] in this trie.
// Assumes [db.lock] is read locked.
// Assumes [db.commitLock] is read locked.
func (db *Database) getProof(ctx context.Context, key []byte) (*Proof, error) {
view, err := db.newUntrackedView(defaultPreallocationSize)
if err != nil {
Expand All @@ -407,8 +397,8 @@ func (db *Database) GetRangeProof(
end []byte,
maxLength int,
) (*RangeProof, error) {
db.lock.RLock()
defer db.lock.RUnlock()
db.commitLock.RLock()
defer db.commitLock.RUnlock()

return db.getRangeProofAtRoot(ctx, db.getMerkleRoot(), start, end, maxLength)
}
Expand All @@ -422,13 +412,13 @@ func (db *Database) GetRangeProofAtRoot(
end []byte,
maxLength int,
) (*RangeProof, error) {
db.lock.RLock()
defer db.lock.RUnlock()
db.commitLock.RLock()
defer db.commitLock.RUnlock()

return db.getRangeProofAtRoot(ctx, rootID, start, end, maxLength)
}

// Assumes [db.lock] is read locked.
// Assumes [db.commitLock] is read locked.
func (db *Database) getRangeProofAtRoot(
ctx context.Context,
rootID ids.ID,
Expand All @@ -440,11 +430,11 @@ func (db *Database) getRangeProofAtRoot(
return nil, fmt.Errorf("%w but was %d", ErrInvalidMaxLength, maxLength)
}

historicalView, err := db.getHistoricalViewForRangeProof(rootID, start, end)
historicalView, err := db.getHistoricalViewForRange(rootID, start, end)
if err != nil {
return nil, err
}
return historicalView.getRangeProof(ctx, start, end, maxLength)
return historicalView.GetRangeProof(ctx, start, end, maxLength)
}

// Returns a proof for a subset of the key/value changes in key range
Expand All @@ -465,8 +455,8 @@ func (db *Database) GetChangeProof(
return nil, errSameRoot
}

db.lock.RLock()
defer db.lock.RUnlock()
db.commitLock.RLock()
defer db.commitLock.RUnlock()

result := &ChangeProof{
HadRootsInHistory: true,
Expand Down Expand Up @@ -507,9 +497,9 @@ func (db *Database) GetChangeProof(
}
largestKey := result.getLargestKey(end)

// Since we hold [db.lock] we must still have sufficient
// Since we hold [db.commitlock] we must still have sufficient
// history to recreate the trie at [endRootID].
historicalView, err := db.getHistoricalViewForRangeProof(endRootID, start, largestKey)
historicalView, err := db.getHistoricalViewForRange(endRootID, start, largestKey)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -588,7 +578,7 @@ func (db *Database) Has(k []byte) (bool, error) {
return false, database.ErrClosed
}

_, err := db.getValue(newPath(k))
_, err := db.getValue(newPath(k), true)
if err == database.ErrNotFound {
return false, nil
}
Expand Down Expand Up @@ -901,8 +891,8 @@ func (db *Database) initializeRootIfNeeded() (ids.ID, error) {
}

// Returns a view of the trie as it was when it had root [rootID] for keys within range [start, end].
// Assumes [db.lock] is read locked.
func (db *Database) getHistoricalViewForRangeProof(
// Assumes [db.commitLock] is read locked.
func (db *Database) getHistoricalViewForRange(
rootID ids.ID,
start []byte,
end []byte,
Expand All @@ -924,8 +914,10 @@ func (db *Database) getHistoricalViewForRangeProof(
// Returns all of the keys in range [start, end] that aren't in [keySet].
// If [start] is nil, then the range has no lower bound.
// If [end] is nil, then the range has no upper bound.
// Assumes [db.lock] is read locked.
func (db *Database) getKeysNotInSet(start, end []byte, keySet set.Set[string]) ([][]byte, error) {
db.lock.RLock()
defer db.lock.RUnlock()

it := db.NewIteratorWithStart(start)
defer it.Release()

Expand Down Expand Up @@ -1093,11 +1085,13 @@ func (db *Database) prepareChangeProofView(proof *ChangeProof) (*trieView, error

// Returns a new view atop [db] with the key/value pairs in [proof.KeyValues] added and
// any existing key-value pairs in the proof's range but not in the proof removed.
// assumes [db.commitLock] is held
func (db *Database) prepareRangeProofView(start []byte, proof *RangeProof) (*trieView, error) {
// Don't need to lock [view] because nobody else has a reference to it.
db.lock.RLock()
view, err := db.newUntrackedView(len(proof.KeyValues))
db.lock.RUnlock()

if err != nil {
return nil, err
}
Expand All @@ -1113,7 +1107,6 @@ func (db *Database) prepareRangeProofView(start []byte, proof *RangeProof) (*tri
if len(proof.KeyValues) > 0 {
largestKey = proof.KeyValues[len(proof.KeyValues)-1].Key
}

keysToDelete, err := db.getKeysNotInSet(start, largestKey, keys)
if err != nil {
return nil, err
Expand All @@ -1126,10 +1119,6 @@ func (db *Database) prepareRangeProofView(start []byte, proof *RangeProof) (*tri
return view, nil
}

// Assumes [db.lock] is read locked.
// This is required because putting a node in [db.nodeCache] can cause an eviction,
// which puts a node in [db.nodeDB], and we don't want to put anything in [db.nodeDB]
// after [db] is closed.
// Non-nil error is fatal -- [db] will close.
func (db *Database) putNodeInCache(key path, n *node) error {
// TODO Cache metrics
Expand Down
2 changes: 1 addition & 1 deletion x/merkledb/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ func (n *node) asProofNode() ProofNode {
pn := ProofNode{
KeyPath: n.key.Serialize(),
Children: make(map[byte]ids.ID, len(n.children)),
ValueOrHash: n.valueDigest,
ValueOrHash: Clone(n.valueDigest),
}
for index, entry := range n.children {
pn.Children[index] = entry.id
Expand Down
2 changes: 1 addition & 1 deletion x/merkledb/trie.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type ReadOnlyTrie interface {

// get the value associated with the key in path form
// database.ErrNotFound if the key is not present
getValue(key path) ([]byte, error)
getValue(key path, lock bool) ([]byte, error)

// GetMerkleRoot returns the merkle root of the Trie
GetMerkleRoot(ctx context.Context) (ids.ID, error)
Expand Down
69 changes: 36 additions & 33 deletions x/merkledb/trieview.go
Original file line number Diff line number Diff line change
Expand Up @@ -315,12 +315,21 @@ func (t *trieView) GetProof(ctx context.Context, key []byte) (*Proof, error) {
_, span := t.db.tracer.Start(ctx, "MerkleDB.trieview.GetProof")
defer span.End()

t.lock.Lock()
defer t.lock.Unlock()
t.lock.RLock()
defer t.lock.RUnlock()

if err := t.calculateNodeIDs(ctx); err != nil {
return nil, err
// only need full lock if nodes ids need to be calculated
// looped to ensure that the value didn't change after the lock was released
for t.needsRecalculation {
t.lock.RUnlock()
t.lock.Lock()
if err := t.calculateNodeIDs(ctx); err != nil {
return nil, err
}
t.lock.Unlock()
t.lock.RLock()
}

return t.getProof(ctx, key)
}

Expand Down Expand Up @@ -388,24 +397,6 @@ func (t *trieView) GetRangeProof(
ctx, span := t.db.tracer.Start(ctx, "MerkleDB.trieview.GetRangeProof")
defer span.End()

t.lock.Lock()
defer t.lock.Unlock()

return t.getRangeProof(ctx, start, end, maxLength)
}

// Returns a range proof for (at least part of) the key range [start, end].
// The returned proof's [KeyValues] has at most [maxLength] values.
// [maxLength] must be > 0.
// Assumes [t.lock] is held.
func (t *trieView) getRangeProof(
ctx context.Context,
start, end []byte,
maxLength int,
) (*RangeProof, error) {
ctx, span := t.db.tracer.Start(ctx, "MerkleDB.trieview.getRangeProof")
defer span.End()

if len(end) > 0 && bytes.Compare(start, end) == 1 {
return nil, ErrStartAfterEnd
}
Expand All @@ -414,8 +405,19 @@ func (t *trieView) getRangeProof(
return nil, fmt.Errorf("%w but was %d", ErrInvalidMaxLength, maxLength)
}

if err := t.calculateNodeIDs(ctx); err != nil {
return nil, err
t.lock.RLock()
defer t.lock.RUnlock()

// only need full lock if nodes ids need to be calculated
// looped to ensure that the value didn't change after the lock was released
for t.needsRecalculation {
t.lock.RUnlock()
t.lock.Lock()
if err := t.calculateNodeIDs(ctx); err != nil {
return nil, err
}
t.lock.Unlock()
t.lock.RLock()
}

var (
Expand Down Expand Up @@ -857,22 +859,23 @@ func (t *trieView) GetValues(_ context.Context, keys [][]byte) ([][]byte, []erro
valueErrors := make([]error, len(keys))

for i, key := range keys {
results[i], valueErrors[i] = t.getValue(newPath(key))
results[i], valueErrors[i] = t.getValue(newPath(key), false)
}
return results, valueErrors
}

// GetValue returns the value for the given [key].
// Returns database.ErrNotFound if it doesn't exist.
func (t *trieView) GetValue(_ context.Context, key []byte) ([]byte, error) {
t.lock.RLock()
defer t.lock.RUnlock()

return t.getValue(newPath(key))
return t.getValue(newPath(key), true)
}

// Assumes [t.lock] read lock is held
func (t *trieView) getValue(key path) ([]byte, error) {
func (t *trieView) getValue(key path, lock bool) ([]byte, error) {
if lock {
t.lock.RLock()
defer t.lock.RUnlock()
}

if t.isInvalid() {
return nil, ErrInvalid
}
Expand All @@ -887,7 +890,7 @@ func (t *trieView) getValue(key path) ([]byte, error) {
t.db.metrics.ViewValueCacheMiss()

// if we don't have local copy of the key, then grab a copy from the parent trie
value, err := t.getParentTrie().getValue(key)
value, err := t.getParentTrie().getValue(key, true)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1289,7 +1292,7 @@ func (t *trieView) recordValueChange(key path, value Maybe[[]byte]) error {

// grab the before value
var beforeMaybe Maybe[[]byte]
before, err := t.getParentTrie().getValue(key)
before, err := t.getParentTrie().getValue(key, true)
switch err {
case nil:
beforeMaybe = Some(before)
Expand Down

0 comments on commit 98790f4

Please sign in to comment.