Skip to content

Commit

Permalink
Convert client peer resolving errors to service transient errors (cad…
Browse files Browse the repository at this point in the history
  • Loading branch information
Shaddoll authored Sep 9, 2022
1 parent 60f7b13 commit 5bb3bd7
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 13 deletions.
14 changes: 8 additions & 6 deletions client/history/peerResolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,32 +67,34 @@ func (pr PeerResolver) FromShardID(shardID int) (string, error) {
shardIDString := string(rune(shardID))
host, err := pr.resolver.Lookup(service.History, shardIDString)
if err != nil {
return "", err
return "", common.ToServiceTransientError(err)
}
return host.GetNamedAddress(pr.namedPort)
peer, err := host.GetNamedAddress(pr.namedPort)
return peer, common.ToServiceTransientError(err)
}

// FromHostAddress resolves the final history peer responsible for the given host address.
// The address is formed by adding port for specified transport
func (pr PeerResolver) FromHostAddress(hostAddress string) (string, error) {
host, err := pr.resolver.LookupByAddress(service.History, hostAddress)
if err != nil {
return "", err
return "", common.ToServiceTransientError(err)
}
return host.GetNamedAddress(pr.namedPort)
peer, err := host.GetNamedAddress(pr.namedPort)
return peer, common.ToServiceTransientError(err)
}

// GetAllPeers returns all history service peers in the cluster ring.
func (pr PeerResolver) GetAllPeers() ([]string, error) {
hosts, err := pr.resolver.Members(service.History)
if err != nil {
return nil, err
return nil, common.ToServiceTransientError(err)
}
peers := make([]string, 0, len(hosts))
for _, host := range hosts {
peer, err := host.GetNamedAddress(pr.namedPort)
if err != nil {
return nil, err
return nil, common.ToServiceTransientError(err)
}
peers = append(peers, peer)
}
Expand Down
16 changes: 9 additions & 7 deletions client/matching/peerResolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
package matching

import (
"github.com/uber/cadence/common"
"github.com/uber/cadence/common/membership"
"github.com/uber/cadence/common/service"
)
Expand All @@ -47,23 +48,24 @@ func NewPeerResolver(membership membership.Resolver, namedPort string) PeerResol
func (pr PeerResolver) FromTaskList(taskListName string) (string, error) {
host, err := pr.resolver.Lookup(service.Matching, taskListName)
if err != nil {
return "", err
return "", common.ToServiceTransientError(err)
}

return pr.FromHostAddress(host.GetAddress())
peer, err := host.GetNamedAddress(pr.namedPort)
return peer, common.ToServiceTransientError(err)
}

// GetAllPeers returns all matching service peers in the cluster ring.
func (pr PeerResolver) GetAllPeers() ([]string, error) {
hosts, err := pr.resolver.Members(service.Matching)
if err != nil {
return nil, err
return nil, common.ToServiceTransientError(err)
}
peers := make([]string, 0, len(hosts))
for _, host := range hosts {
peer, err := pr.FromHostAddress(host.GetAddress())
if err != nil {
return nil, err
return nil, common.ToServiceTransientError(err)
}
peers = append(peers, peer)
}
Expand All @@ -76,9 +78,9 @@ func (pr PeerResolver) GetAllPeers() ([]string, error) {
func (pr PeerResolver) FromHostAddress(hostAddress string) (string, error) {
host, err := pr.resolver.LookupByAddress(service.Matching, hostAddress)
if err != nil {
return "", err
return "", common.ToServiceTransientError(err)
}

return host.GetNamedAddress(pr.namedPort)

peer, err := host.GetNamedAddress(pr.namedPort)
return peer, common.ToServiceTransientError(err)
}
8 changes: 8 additions & 0 deletions common/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,14 @@ func CheckDecisionResultLimit(
return nil
}

// ToServiceTransientError converts an error to ServiceTransientError
func ToServiceTransientError(err error) error {
if err == nil || IsServiceTransientError(err) {
return err
}
return yarpcerrors.Newf(yarpcerrors.CodeUnavailable, err.Error())
}

// IsServiceTransientError checks if the error is a transient error.
func IsServiceTransientError(err error) bool {
switch err.(type) {
Expand Down
18 changes: 18 additions & 0 deletions common/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"time"

"github.com/pborman/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/yarpc/yarpcerrors"

Expand Down Expand Up @@ -341,3 +342,20 @@ func TestConvertErrToGetTaskFailedCause(t *testing.T) {
require.Equal(t, tc.expectedFailedCause, ConvertErrToGetTaskFailedCause(tc.err))
}
}

func TestToServiceTransientError(t *testing.T) {
t.Run("it converts nil", func(t *testing.T) {
assert.NoError(t, ToServiceTransientError(nil))
})

t.Run("it keeps transient errors", func(t *testing.T) {
err := &types.InternalServiceError{}
assert.Equal(t, err, ToServiceTransientError(err))
assert.True(t, IsServiceTransientError(ToServiceTransientError(err)))
})

t.Run("it converts errors to transient errors", func(t *testing.T) {
err := fmt.Errorf("error")
assert.True(t, IsServiceTransientError(ToServiceTransientError(err)))
})
}

0 comments on commit 5bb3bd7

Please sign in to comment.