Skip to content

Commit

Permalink
Use a flags field on address a mark used
Browse files Browse the repository at this point in the history
  • Loading branch information
tuxcanfly committed May 25, 2015
1 parent fbf744b commit 0b5290b
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 152 deletions.
10 changes: 5 additions & 5 deletions waddrmgr/address.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ type ManagedAddress interface {
Compressed() bool

// Used returns true if the backing address has been used in a transaction.
Used() (bool, error)
Used() bool
}

// ManagedPubKeyAddress extends ManagedAddress and additionally provides the
Expand Down Expand Up @@ -191,8 +191,8 @@ func (a *managedAddress) Compressed() bool {
// Used returns true if the address has been used in a transaction.
//
// This is part of the ManagedAddress interface implementation.
func (a *managedAddress) Used() (bool, error) {
return a.manager.fetchUsed(a.AddrHash())
func (a *managedAddress) Used() bool {
return a.used
}

// PubKey returns the public key associated with the address.
Expand Down Expand Up @@ -456,8 +456,8 @@ func (a *scriptAddress) Compressed() bool {
// Used returns true if the address has been used in a transaction.
//
// This is part of the ManagedAddress interface implementation.
func (a *scriptAddress) Used() (bool, error) {
return a.manager.fetchUsed(a.AddrHash())
func (a *scriptAddress) Used() bool {
return a.used
}

// Script returns the script associated with the address.
Expand Down
148 changes: 104 additions & 44 deletions waddrmgr/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import (

const (
// LatestMgrVersion is the most recent manager version.
LatestMgrVersion = 4
LatestMgrVersion = 5
)

var (
Expand Down Expand Up @@ -62,6 +62,15 @@ func maybeConvertDbError(err error) error {
// database.
type syncStatus uint8

// addressFlags holds flags associated with a address stored in the database.
type addressFlags uint32

// These constants define the various supported address flags.
const (
addrNone addressFlags = 0
addrUsed addressFlags = addressFlags(byte(1) << 0)
)

// These constants define the various supported sync status types.
//
// NOTE: These are currently unused but are being defined for the possibility of
Expand Down Expand Up @@ -114,6 +123,7 @@ type dbAddressRow struct {
account uint32
addTime uint64
syncStatus syncStatus
addrFlags addressFlags
rawData []byte // Varies based on address type field.
}

Expand Down Expand Up @@ -267,6 +277,35 @@ func putManagerVersion(tx walletdb.Tx, version uint32) error {
return nil
}

// markAddressUsed flags the provided address id as used in the database.
func markAddressUsed(tx walletdb.Tx, addressID []byte) error {
bucket := tx.RootBucket().Bucket(addrBucketName)

addrHash := fastsha256.Sum256(addressID)
val := bucket.Get(addrHash[:])
if len(val) < 22 {
str := "malformed serialized address"
return managerError(ErrDatabase, str, nil)
}
// Check existing flag and return early if already flagged
addrFlags := addressFlags(binary.LittleEndian.Uint32(val[14:18]))
if addrFlags&addrUsed == 0x1 {
return nil
}
// Update flag field with used flag
addrFlags |= addrUsed
row := make([]byte, len(val))
copy(row[:14], val[:14])
binary.LittleEndian.PutUint32(row[14:18], uint32(addrFlags))
copy(row[18:], val[18:])
err := bucket.Put(addrHash[:], row)
if err != nil {
str := fmt.Sprintf("failed to mark address used %x", addressID)
return managerError(ErrDatabase, str, err)
}
return nil
}

// fetchMasterKeyParams loads the master key parameters needed to derive them
// (when given the correct user-supplied passphrase) from the database. Either
// returned value can be nil, but in practice only the private key params will
Expand Down Expand Up @@ -799,14 +838,14 @@ func putLastAccount(tx walletdb.Tx, account uint32) error {
// the common parts.
func deserializeAddressRow(serializedAddress []byte) (*dbAddressRow, error) {
// The serialized address format is:
// <addrType><account><addedTime><syncStatus><rawdata>
// <addrType><account><addedTime><syncStatus><addrFlags><rawdata>
//
// 1 byte addrType + 4 bytes account + 8 bytes addTime + 1 byte
// syncStatus + 4 bytes raw data length + raw data
// syncStatus + 4 bytes flags + 4 bytes raw data length + raw data

// Given the above, the length of the entry must be at a minimum
// the constant value sizes.
if len(serializedAddress) < 18 {
if len(serializedAddress) < 22 {
str := "malformed serialized address"
return nil, managerError(ErrDatabase, str, nil)
}
Expand All @@ -816,29 +855,30 @@ func deserializeAddressRow(serializedAddress []byte) (*dbAddressRow, error) {
row.account = binary.LittleEndian.Uint32(serializedAddress[1:5])
row.addTime = binary.LittleEndian.Uint64(serializedAddress[5:13])
row.syncStatus = syncStatus(serializedAddress[13])
rdlen := binary.LittleEndian.Uint32(serializedAddress[14:18])
row.addrFlags = addressFlags(binary.LittleEndian.Uint32(serializedAddress[14:18]))
rdlen := binary.LittleEndian.Uint32(serializedAddress[18:22])
row.rawData = make([]byte, rdlen)
copy(row.rawData, serializedAddress[18:18+rdlen])
copy(row.rawData, serializedAddress[22:22+rdlen])

return &row, nil
}

// serializeAddressRow returns the serialization of the passed address row.
func serializeAddressRow(row *dbAddressRow) []byte {
// The serialized address format is:
// <addrType><account><addedTime><syncStatus><commentlen><comment>
// <rawdata>
// <addrType><account><addedTime><syncStatus><addrFlags><rawdata>
//
// 1 byte addrType + 4 bytes account + 8 bytes addTime + 1 byte
// syncStatus + 4 bytes raw data length + raw data
// syncStatus + 4 bytes flags + 4 bytes raw data length + raw data
rdlen := len(row.rawData)
buf := make([]byte, 18+rdlen)
buf := make([]byte, 22+rdlen)
buf[0] = byte(row.addrType)
binary.LittleEndian.PutUint32(buf[1:5], row.account)
binary.LittleEndian.PutUint64(buf[5:13], row.addTime)
buf[13] = byte(row.syncStatus)
binary.LittleEndian.PutUint32(buf[14:18], uint32(rdlen))
copy(buf[18:18+rdlen], row.rawData)
binary.LittleEndian.PutUint32(buf[14:18], uint32(row.addrFlags))
binary.LittleEndian.PutUint32(buf[18:22], uint32(rdlen))
copy(buf[22:22+rdlen], row.rawData)
return buf
}

Expand Down Expand Up @@ -1014,31 +1054,6 @@ func fetchAddressByHash(tx walletdb.Tx, addrHash []byte) (interface{}, error) {
return nil, managerError(ErrDatabase, str, nil)
}

// fetchAddressUsed returns true if the provided address id was flagged as used.
func fetchAddressUsed(tx walletdb.Tx, addressID []byte) bool {
bucket := tx.RootBucket().Bucket(usedAddrBucketName)

addrHash := fastsha256.Sum256(addressID)
return bucket.Get(addrHash[:]) != nil
}

// markAddressUsed flags the provided address id as used in the database.
func markAddressUsed(tx walletdb.Tx, addressID []byte) error {
bucket := tx.RootBucket().Bucket(usedAddrBucketName)

addrHash := fastsha256.Sum256(addressID)
val := bucket.Get(addrHash[:])
if val != nil {
return nil
}
err := bucket.Put(addrHash[:], []byte{0})
if err != nil {
str := fmt.Sprintf("failed to mark address used %x", addressID)
return managerError(ErrDatabase, str, err)
}
return nil
}

// fetchAddress loads address information for the provided address id from the
// database. The returned value is one of the address rows for the specific
// address type. The caller should use type assertions to ascertain the type.
Expand Down Expand Up @@ -1568,13 +1583,6 @@ func createManagerNS(namespace walletdb.Namespace) error {
return managerError(ErrDatabase, str, err)
}

// usedAddrBucketName bucket was added after manager version 1 release
_, err = rootBucket.CreateBucket(usedAddrBucketName)
if err != nil {
str := "failed to create used addresses bucket"
return managerError(ErrDatabase, str, err)
}

_, err = rootBucket.CreateBucket(acctNameIdxBucketName)
if err != nil {
str := "failed to create an account name index bucket"
Expand Down Expand Up @@ -1732,6 +1740,16 @@ func upgradeManager(namespace walletdb.Namespace, pubPassPhrase []byte, chainPar
version = 4
}

if version < 5 {
// Upgrade from version 4 to 5.
if err := upgradeToVersion5(namespace); err != nil {
return err
}

// The manager is now at version 5.
version = 5
}

// Ensure the manager is upraded to the latest version. This check is
// to intentionally cause a failure if the manager version is updated
// without writing code to handle the upgrade.
Expand Down Expand Up @@ -1956,3 +1974,45 @@ func upgradeToVersion4(namespace walletdb.Namespace, pubPassPhrase []byte) error
}
return nil
}

// upgradeToVersion5 upgrades the database from version 4 to version 5.
// Instead of a bucket storing keys of used addrs, a flag field addrFlags was
// added to the address row in the address bucket.
func upgradeToVersion5(namespace walletdb.Namespace) error {
err := namespace.Update(func(tx walletdb.Tx) error {
// Write new manager version.
err := putManagerVersion(tx, 5)
if err != nil {
return err
}

addrBucket := tx.RootBucket().Bucket(addrBucketName)
usedAddrBucket := tx.RootBucket().Bucket(usedAddrBucketName)

// Iterate addresses and check for used, then update the flags field
err = addrBucket.ForEach(func(k, v []byte) error {
if len(v) < 18 {
str := "malformed serialized address"
return managerError(ErrDatabase, str, nil)
}
row := make([]byte, len(v)+4)
copy(row[:14], v[:14])
if usedAddrBucket.Get(k) != nil {
binary.LittleEndian.PutUint32(row[14:18], uint32(addrUsed))
} else {
binary.LittleEndian.PutUint32(row[14:18], uint32(addrNone))
}
copy(row[18:], v[14:])
return addrBucket.Put(k, row)
})
if err != nil {
return err
}
// Delete old used addr bucket
return tx.RootBucket().DeleteBucket(usedAddrBucketName)
})
if err != nil {
return maybeConvertDbError(err)
}
return nil
}
19 changes: 5 additions & 14 deletions waddrmgr/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ func (m *Manager) Close() error {
// The passed derivedKey is zeroed after the new address is created.
//
// This function MUST be called with the manager lock held for writes.
func (m *Manager) keyToManaged(derivedKey *hdkeychain.ExtendedKey, account, branch, index uint32) (ManagedAddress, error) {
func (m *Manager) keyToManaged(derivedKey *hdkeychain.ExtendedKey, account, branch, index uint32, flags addressFlags) (ManagedAddress, error) {
// Create a new managed address based on the public or private key
// depending on whether the passed key is private. Also, zero the
// key after creating the managed address from it.
Expand All @@ -407,6 +407,7 @@ func (m *Manager) keyToManaged(derivedKey *hdkeychain.ExtendedKey, account, bran
if branch == internalBranch {
ma.internal = true
}
ma.used = (flags&addrUsed == 0x1)

return ma, nil
}
Expand Down Expand Up @@ -521,7 +522,7 @@ func (m *Manager) loadAccountInfo(account uint32) (*accountInfo, error) {
if err != nil {
return nil, err
}
lastExtAddr, err := m.keyToManaged(lastExtKey, account, branch, index)
lastExtAddr, err := m.keyToManaged(lastExtKey, account, branch, index, addrNone)
if err != nil {
return nil, err
}
Expand All @@ -536,7 +537,7 @@ func (m *Manager) loadAccountInfo(account uint32) (*accountInfo, error) {
if err != nil {
return nil, err
}
lastIntAddr, err := m.keyToManaged(lastIntKey, account, branch, index)
lastIntAddr, err := m.keyToManaged(lastIntKey, account, branch, index, addrNone)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -572,7 +573,7 @@ func (m *Manager) chainAddressRowToManaged(row *dbChainAddressRow) (ManagedAddre
return nil, err
}

return m.keyToManaged(addressKey, row.account, row.branch, row.index)
return m.keyToManaged(addressKey, row.account, row.branch, row.index, row.addrFlags)
}

// importedAddressRowToManaged returns a new managed address based on imported
Expand Down Expand Up @@ -1371,16 +1372,6 @@ func (m *Manager) Unlock(passphrase []byte) error {
return nil
}

// fetchUsed returns true if the provided address id was flagged used.
func (m *Manager) fetchUsed(addressID []byte) (bool, error) {
var used bool
err := m.namespace.View(func(tx walletdb.Tx) error {
used = fetchAddressUsed(tx, addressID)
return nil
})
return used, err
}

// MarkUsed updates the used flag for the provided address.
func (m *Manager) MarkUsed(address btcutil.Address) error {
addressID := address.ScriptAddress()
Expand Down
Loading

0 comments on commit 0b5290b

Please sign in to comment.