Skip to content

Commit

Permalink
Allow custom key types and address formats (cosmos#4232)
Browse files Browse the repository at this point in the history
Add additional parameter to NewAnteHandler for custom SignatureVerificationGasConsumer (the existing one is now called DefaultSigVerificationGasConsumer).

Add addressVerifier field to sdk.Config which allows for custom address verification (to override the current fixed 20 byte address format).

DefaultSigVerificationGasConsumer now uses type switching as opposed to string comparison.
Other zones like Ethermint can now concretely specify which key types they accept.

Closes: cosmos#3685
  • Loading branch information
aaronc authored and Alessio Treglia committed May 2, 2019
1 parent 67f1e12 commit 114de63
Show file tree
Hide file tree
Showing 11 changed files with 154 additions and 40 deletions.
1 change: 1 addition & 0 deletions .pending/breaking/sdk/The-default-signatur
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#3685 The default signature verification gas logic (`DefaultSigVerificationGasConsumer`) now specifies explicit key types rather than string pattern matching. This means that zones that depended on string matching to allow other keys will need to write a custom `SignatureVerificationGasConsumer` function.
1 change: 1 addition & 0 deletions .pending/improvements/sdk/Add-SetAddressVerifi
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#3685 Add `SetAddressVerifier` and `GetAddressVerifier` to `sdk.Config` to allow SDK users to configure custom address format verification logic (to override the default limitation of 20-byte addresses).
1 change: 1 addition & 0 deletions .pending/improvements/sdk/Add-an-additional-pa
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#3685 Add an additional parameter to NewAnteHandler for a custom `SignatureVerificationGasConsumer` (the default logic is now in `DefaultSigVerificationGasConsumer). This allows SDK users to configure their own logic for which key types are accepted and how those key types consume gas.
2 changes: 1 addition & 1 deletion cmd/gaia/app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ func NewGaiaApp(logger log.Logger, db dbm.DB, traceStore io.Writer, loadLatest b
)
app.SetInitChainer(app.initChainer)
app.SetBeginBlocker(app.BeginBlocker)
app.SetAnteHandler(auth.NewAnteHandler(app.accountKeeper, app.feeCollectionKeeper))
app.SetAnteHandler(auth.NewAnteHandler(app.accountKeeper, app.feeCollectionKeeper, auth.DefaultSigVerificationGasConsumer))
app.SetEndBlocker(app.EndBlocker)

if loadLatest {
Expand Down
2 changes: 1 addition & 1 deletion cmd/gaia/cmd/gaiadebug/hack.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ func NewGaiaApp(logger log.Logger, db dbm.DB, baseAppOptions ...func(*bam.BaseAp
app.SetInitChainer(app.initChainer)
app.SetBeginBlocker(app.BeginBlocker)
app.SetEndBlocker(app.EndBlocker)
app.SetAnteHandler(auth.NewAnteHandler(app.accountKeeper, app.feeCollectionKeeper))
app.SetAnteHandler(auth.NewAnteHandler(app.accountKeeper, app.feeCollectionKeeper, auth.DefaultSigVerificationGasConsumer))
app.MountStores(app.keyMain, app.keyAccount, app.keyStaking, app.keySlashing, app.keyParams)
app.MountStore(app.tkeyParams, sdk.StoreTypeTransient)
err := app.LoadLatestVersion(app.keyMain)
Expand Down
30 changes: 24 additions & 6 deletions types/address.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,21 @@ func AccAddressFromHex(address string) (addr AccAddress, err error) {
return AccAddress(bz), nil
}

// VerifyAddressFormat verifies that the provided bytes form a valid address
// according to the default address rules or a custom address verifier set by
// GetConfig().SetAddressVerifier()
func VerifyAddressFormat(bz []byte) error {
verifier := GetConfig().GetAddressVerifier()
if verifier != nil {
return verifier(bz)
} else {
if len(bz) != AddrLen {
return errors.New("Incorrect address length")
}
}
return nil
}

// AccAddressFromBech32 creates an AccAddress from a Bech32 string.
func AccAddressFromBech32(address string) (addr AccAddress, err error) {
if len(strings.TrimSpace(address)) == 0 {
Expand All @@ -99,8 +114,9 @@ func AccAddressFromBech32(address string) (addr AccAddress, err error) {
return nil, err
}

if len(bz) != AddrLen {
return nil, errors.New("Incorrect address length")
err = VerifyAddressFormat(bz)
if err != nil {
return nil, err
}

return AccAddress(bz), nil
Expand Down Expand Up @@ -229,8 +245,9 @@ func ValAddressFromBech32(address string) (addr ValAddress, err error) {
return nil, err
}

if len(bz) != AddrLen {
return nil, errors.New("Incorrect address length")
err = VerifyAddressFormat(bz)
if err != nil {
return nil, err
}

return ValAddress(bz), nil
Expand Down Expand Up @@ -360,8 +377,9 @@ func ConsAddressFromBech32(address string) (addr ConsAddress, err error) {
return nil, err
}

if len(bz) != AddrLen {
return nil, errors.New("Incorrect address length")
err = VerifyAddressFormat(bz)
if err != nil {
return nil, err
}

return ConsAddress(bz), nil
Expand Down
37 changes: 37 additions & 0 deletions types/address_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package types_test

import (
"encoding/hex"
"fmt"
"math/rand"
"testing"

Expand Down Expand Up @@ -290,3 +291,39 @@ func TestAddressInterface(t *testing.T) {
}

}

func TestCustomAddressVerifier(t *testing.T) {
// Create a 10 byte address
addr := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
accBech := types.AccAddress(addr).String()
valBech := types.ValAddress(addr).String()
consBech := types.ConsAddress(addr).String()
// Verifiy that the default logic rejects this 10 byte address
err := types.VerifyAddressFormat(addr)
require.NotNil(t, err)
_, err = types.AccAddressFromBech32(accBech)
require.NotNil(t, err)
_, err = types.ValAddressFromBech32(valBech)
require.NotNil(t, err)
_, err = types.ConsAddressFromBech32(consBech)
require.NotNil(t, err)

// Set a custom address verifier that accepts 10 or 20 byte addresses
types.GetConfig().SetAddressVerifier(func(bz []byte) error {
n := len(bz)
if n == 10 || n == types.AddrLen {
return nil
}
return fmt.Errorf("incorrect address length %d", n)
})

// Verifiy that the custom logic accepts this 10 byte address
err = types.VerifyAddressFormat(addr)
require.Nil(t, err)
_, err = types.AccAddressFromBech32(accBech)
require.Nil(t, err)
_, err = types.ValAddressFromBech32(valBech)
require.Nil(t, err)
_, err = types.ConsAddressFromBech32(consBech)
require.Nil(t, err)
}
13 changes: 13 additions & 0 deletions types/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ type Config struct {
sealed bool
bech32AddressPrefix map[string]string
txEncoder TxEncoder
addressVerifier func([]byte) error
}

var (
Expand Down Expand Up @@ -73,6 +74,13 @@ func (config *Config) SetTxEncoder(encoder TxEncoder) {
config.txEncoder = encoder
}

// SetAddressVerifier builds the Config with the provided function for verifying that addresses
// have the correct format
func (config *Config) SetAddressVerifier(addressVerifier func([]byte) error) {
config.assertNotSealed()
config.addressVerifier = addressVerifier
}

// Seal seals the config such that the config state could not be modified further
func (config *Config) Seal() *Config {
config.mtx.Lock()
Expand Down Expand Up @@ -116,3 +124,8 @@ func (config *Config) GetBech32ConsensusPubPrefix() string {
func (config *Config) GetTxEncoder() TxEncoder {
return config.txEncoder
}

// GetAddressVerifier returns the function to verify that addresses have the correct format
func (config *Config) GetAddressVerifier() func([]byte) error {
return config.addressVerifier
}
39 changes: 19 additions & 20 deletions x/auth/ante.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"bytes"
"encoding/hex"
"fmt"
"strings"
"github.com/tendermint/tendermint/crypto/ed25519"
"time"

"github.com/tendermint/tendermint/crypto"
Expand All @@ -27,10 +27,14 @@ func init() {
copy(simSecp256k1Pubkey[:], bz)
}

// SignatureVerificationGasConsumer is the type of function that is used to both consume gas when verifying signatures
// and also to accept or reject different types of PubKey's. This is where apps can define their own PubKey types.
type SignatureVerificationGasConsumer = func(meter sdk.GasMeter, sig []byte, pubkey crypto.PubKey, params Params) sdk.Result

// NewAnteHandler returns an AnteHandler that checks and increments sequence
// numbers, checks signatures & account numbers, and deducts fees from the first
// signer.
func NewAnteHandler(ak AccountKeeper, fck FeeCollectionKeeper) sdk.AnteHandler {
func NewAnteHandler(ak AccountKeeper, fck FeeCollectionKeeper, sigGasConsumer SignatureVerificationGasConsumer) sdk.AnteHandler {
return func(
ctx sdk.Context, tx sdk.Tx, simulate bool,
) (newCtx sdk.Context, res sdk.Result, abort bool) {
Expand Down Expand Up @@ -127,7 +131,7 @@ func NewAnteHandler(ak AccountKeeper, fck FeeCollectionKeeper) sdk.AnteHandler {

// check signature, return account with incremented nonce
signBytes := GetSignBytes(newCtx.ChainID(), stdTx, signerAccs[i], isGenesis)
signerAccs[i], res = processSig(newCtx, signerAccs[i], stdSigs[i], signBytes, simulate, params)
signerAccs[i], res = processSig(newCtx, signerAccs[i], stdSigs[i], signBytes, simulate, params, sigGasConsumer)
if !res.IsOK() {
return newCtx, res, true
}
Expand Down Expand Up @@ -168,6 +172,7 @@ func ValidateMemo(stdTx StdTx, params Params) sdk.Result {
// a pubkey, set it.
func processSig(
ctx sdk.Context, acc Account, sig StdSignature, signBytes []byte, simulate bool, params Params,
sigGasConsumer SignatureVerificationGasConsumer,
) (updatedAcc Account, res sdk.Result) {

pubKey, res := ProcessPubKey(acc, sig, simulate)
Expand All @@ -188,7 +193,7 @@ func processSig(
consumeSimSigGas(ctx.GasMeter(), pubKey, sig, params)
}

if res := consumeSigVerificationGas(ctx.GasMeter(), sig.Signature, pubKey, params); !res.IsOK() {
if res := sigGasConsumer(ctx.GasMeter(), sig.Signature, pubKey, params); !res.IsOK() {
return nil, res
}

Expand Down Expand Up @@ -254,36 +259,30 @@ func ProcessPubKey(acc Account, sig StdSignature, simulate bool) (crypto.PubKey,
return pubKey, sdk.Result{}
}

// consumeSigVerificationGas consumes gas for signature verification based upon
// the public key type. The cost is fetched from the given params and is matched
// DefaultSigVerificationGasConsumer is the default implementation of SignatureVerificationGasConsumer. It consumes gas
// for signature verification based upon the public key type. The cost is fetched from the given params and is matched
// by the concrete type.
//
// TODO: Design a cleaner and flexible way to match concrete public key types.
func consumeSigVerificationGas(
func DefaultSigVerificationGasConsumer(
meter sdk.GasMeter, sig []byte, pubkey crypto.PubKey, params Params,
) sdk.Result {

pubkeyType := strings.ToLower(fmt.Sprintf("%T", pubkey))

switch {
case strings.Contains(pubkeyType, "ed25519"):
switch pubkey := pubkey.(type) {
case ed25519.PubKeyEd25519:
meter.ConsumeGas(params.SigVerifyCostED25519, "ante verify: ed25519")
return sdk.ErrInvalidPubKey("ED25519 public keys are unsupported").Result()

case strings.Contains(pubkeyType, "secp256k1"):
case secp256k1.PubKeySecp256k1:
meter.ConsumeGas(params.SigVerifyCostSecp256k1, "ante verify: secp256k1")
return sdk.Result{}

case strings.Contains(pubkeyType, "multisigthreshold"):
case multisig.PubKeyMultisigThreshold:
var multisignature multisig.Multisignature
codec.Cdc.MustUnmarshalBinaryBare(sig, &multisignature)

multisigPubKey := pubkey.(multisig.PubKeyMultisigThreshold)
consumeMultisignatureVerificationGas(meter, multisignature, multisigPubKey, params)
consumeMultisignatureVerificationGas(meter, multisignature, pubkey, params)
return sdk.Result{}

default:
return sdk.ErrInvalidPubKey(fmt.Sprintf("unrecognized public key type: %s", pubkeyType)).Result()
return sdk.ErrInvalidPubKey(fmt.Sprintf("unrecognized public key type: %T", pubkey)).Result()
}
}

Expand All @@ -295,7 +294,7 @@ func consumeMultisignatureVerificationGas(meter sdk.GasMeter,
sigIndex := 0
for i := 0; i < size; i++ {
if sig.BitArray.GetIndex(i) {
consumeSigVerificationGas(meter, sig.Sigs[sigIndex], pubkey.PubKeys[i], params)
DefaultSigVerificationGasConsumer(meter, sig.Sigs[sigIndex], pubkey.PubKeys[i], params)
sigIndex++
}
}
Expand Down
Loading

0 comments on commit 114de63

Please sign in to comment.