Skip to content

Commit

Permalink
Add fee payer to protobuf definition (cosmos#7384)
Browse files Browse the repository at this point in the history
* Add fee payer to protobuf definition

* Compile new tx type

* Use FeePayer from Tx, add it to required signers

* Add unit tests on proper handling of FeePayer field

* Use string address for fee payer field

* Update logic for string feePayer
  • Loading branch information
ethanfrey authored Sep 30, 2020
1 parent ddaa3c5 commit d917520
Show file tree
Hide file tree
Showing 5 changed files with 202 additions and 53 deletions.
5 changes: 5 additions & 0 deletions proto/cosmos/tx/v1beta1/tx.proto
Original file line number Diff line number Diff line change
Expand Up @@ -169,4 +169,9 @@ message Fee {
// gas_limit is the maximum gas that can be used in transaction processing
// before an out of gas error occurs
uint64 gas_limit = 2;

// if unset, the first signer is responsible for paying the fees. If set, the specified account must pay the fees.
// the payer must be a tx signer (and thus have signed this field in AuthInfo).
// setting this field does *not* change the ordering of required signers for the transaction.
string payer = 3;
}
160 changes: 107 additions & 53 deletions types/tx/tx.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 20 additions & 0 deletions types/tx/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,13 @@ func (t *Tx) ValidateBasic() error {
)
}

if fee.Payer != "" {
_, err := sdk.AccAddressFromBech32(fee.Payer)
if err != nil {
return sdkerrors.Wrapf(sdkerrors.ErrInvalidAddress, "Invalid fee payer address (%s)", err)
}
}

sigs := t.Signatures

if len(sigs) == 0 {
Expand All @@ -84,6 +91,8 @@ func (t *Tx) ValidateBasic() error {
}

// GetSigners retrieves all the signers of a tx.
// This includes all unique signers of the messages (in order),
// as well as the FeePayer (if specified and not already included).
func (t *Tx) GetSigners() []sdk.AccAddress {
var signers []sdk.AccAddress
seen := map[string]bool{}
Expand All @@ -97,6 +106,17 @@ func (t *Tx) GetSigners() []sdk.AccAddress {
}
}

// ensure any specified fee payer is included in the required signers (at the end)
feePayer := t.AuthInfo.Fee.Payer
if feePayer != "" && !seen[feePayer] {
payerAddr, err := sdk.AccAddressFromBech32(feePayer)
if err != nil {
panic(err)
}
signers = append(signers, payerAddr)
seen[feePayer] = true
}

return signers
}

Expand Down
20 changes: 20 additions & 0 deletions x/auth/tx/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,15 @@ func (w *wrapper) GetFee() sdk.Coins {
}

func (w *wrapper) FeePayer() sdk.AccAddress {
feePayer := w.tx.AuthInfo.Fee.Payer
if feePayer != "" {
payerAddr, err := sdk.AccAddressFromBech32(feePayer)
if err != nil {
panic(err)
}
return payerAddr
}
// use first signer as default if no payer specified
return w.GetSigners()[0]
}

Expand Down Expand Up @@ -235,6 +244,17 @@ func (w *wrapper) SetFeeAmount(coins sdk.Coins) {
w.authInfoBz = nil
}

func (w *wrapper) SetFeePayer(feePayer sdk.AccAddress) {
if w.tx.AuthInfo.Fee == nil {
w.tx.AuthInfo.Fee = &tx.Fee{}
}

w.tx.AuthInfo.Fee.Payer = feePayer.String()

// set authInfoBz to nil because the cached authInfoBz no longer matches tx.AuthInfo
w.authInfoBz = nil
}

func (w *wrapper) SetSignatures(signatures ...signing.SignatureV2) error {
n := len(signatures)
signerInfos := make([]*tx.SignerInfo, n)
Expand Down
50 changes: 50 additions & 0 deletions x/auth/tx/builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,53 @@ func TestBuilderValidateBasic(t *testing.T) {
err = txBuilder.ValidateBasic()
require.Error(t, err)
}

func TestBuilderFeePayer(t *testing.T) {
// keys and addresses
_, _, addr1 := testdata.KeyTestPubAddr()
_, _, addr2 := testdata.KeyTestPubAddr()
_, _, addr3 := testdata.KeyTestPubAddr()

// msg and signatures
msg1 := testdata.NewTestMsg(addr1, addr2)
feeAmount := testdata.NewTestFeeAmount()
msgs := []sdk.Msg{msg1}

cases := map[string]struct {
txFeePayer sdk.AccAddress
expectedSigners []sdk.AccAddress
expectedPayer sdk.AccAddress
}{
"no fee payer specified": {
expectedSigners: []sdk.AccAddress{addr1, addr2},
expectedPayer: addr1,
},
"secondary signer set as fee payer": {
txFeePayer: addr2,
expectedSigners: []sdk.AccAddress{addr1, addr2},
expectedPayer: addr2,
},
"outside signer set as fee payer": {
txFeePayer: addr3,
expectedSigners: []sdk.AccAddress{addr1, addr2, addr3},
expectedPayer: addr3,
},
}

for name, tc := range cases {
t.Run(name, func(t *testing.T) {
// setup basic tx
txBuilder := newBuilder()
err := txBuilder.SetMsgs(msgs...)
require.NoError(t, err)
txBuilder.SetGasLimit(200000)
txBuilder.SetFeeAmount(feeAmount)

// set fee payer
txBuilder.SetFeePayer(tc.txFeePayer)
// and check it updates fields properly
require.Equal(t, tc.expectedSigners, txBuilder.GetSigners())
require.Equal(t, tc.expectedPayer, txBuilder.FeePayer())
})
}
}

0 comments on commit d917520

Please sign in to comment.