Skip to content

Commit

Permalink
Merge pull request cli#9797 from cli/phillmv/retry-getting-attestations
Browse files Browse the repository at this point in the history
`gh at verify` retries fetching attestations if it receives a 5xx
  • Loading branch information
phillmv authored Oct 23, 2024
2 parents e390874 + de4c05f commit afa4272
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 10 deletions.
56 changes: 51 additions & 5 deletions pkg/cmd/attestation/api/client.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package api

import (
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"

"github.com/cenkalti/backoff/v4"
"github.com/cli/cli/v2/api"
ioconfig "github.com/cli/cli/v2/pkg/cmd/attestation/io"
)
Expand Down Expand Up @@ -69,6 +72,9 @@ func (c *LiveClient) GetTrustDomain() (string, error) {
return c.getTrustDomain(MetaPath)
}

// Allow injecting backoff interval in tests.
var getAttestationRetryInterval = time.Millisecond * 200

func (c *LiveClient) getAttestations(url, name, digest string, limit int) ([]*Attestation, error) {
c.logger.VerbosePrintf("Fetching attestations for artifact digest %s\n\n", digest)

Expand All @@ -86,15 +92,31 @@ func (c *LiveClient) getAttestations(url, name, digest string, limit int) ([]*At

var attestations []*Attestation
var resp AttestationsResponse
var err error
bo := backoff.NewConstantBackOff(getAttestationRetryInterval)

// if no attestation or less than limit, then keep fetching
for url != "" && len(attestations) < limit {
url, err = c.api.RESTWithNext(c.host, http.MethodGet, url, nil, &resp)
err := backoff.Retry(func() error {
newURL, restErr := c.api.RESTWithNext(c.host, http.MethodGet, url, nil, &resp)

if restErr != nil {
if shouldRetry(restErr) {
return restErr
} else {
return backoff.Permanent(restErr)
}
}

url = newURL
attestations = append(attestations, resp.Attestations...)

return nil
}, backoff.WithMaxRetries(bo, 3))

// bail if RESTWithNext errored out
if err != nil {
return nil, err
}

attestations = append(attestations, resp.Attestations...)
}

if len(attestations) == 0 {
Expand All @@ -108,10 +130,34 @@ func (c *LiveClient) getAttestations(url, name, digest string, limit int) ([]*At
return attestations, nil
}

func shouldRetry(err error) bool {
var httpError api.HTTPError
if errors.As(err, &httpError) {
if httpError.StatusCode >= 500 && httpError.StatusCode <= 599 {
return true
}
}

return false
}

func (c *LiveClient) getTrustDomain(url string) (string, error) {
var resp MetaResponse

err := c.api.REST(c.host, http.MethodGet, url, nil, &resp)
bo := backoff.NewConstantBackOff(getAttestationRetryInterval)
err := backoff.Retry(func() error {
restErr := c.api.REST(c.host, http.MethodGet, url, nil, &resp)
if restErr != nil {
if shouldRetry(restErr) {
return restErr
} else {
return backoff.Permanent(restErr)
}
}

return nil
}, backoff.WithMaxRetries(bo, 3))

if err != nil {
return "", err
}
Expand Down
59 changes: 59 additions & 0 deletions pkg/cmd/attestation/api/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,62 @@ func TestGetTrustDomain(t *testing.T) {
})

}

func TestGetAttestationsRetries(t *testing.T) {
getAttestationRetryInterval = 0

fetcher := mockDataGenerator{
NumAttestations: 5,
}

c := &LiveClient{
api: mockAPIClient{
OnRESTWithNext: fetcher.FlakyOnRESTSuccessWithNextPageHandler(),
},
logger: io.NewTestHandler(),
}

attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, DefaultLimit)
require.NoError(t, err)

// assert the error path was executed; because this is a paged
// request, it should have errored twice
fetcher.AssertNumberOfCalls(t, "FlakyOnRESTSuccessWithNextPage:error", 2)

// but we still successfully got the right data
require.Equal(t, len(attestations), 10)
bundle := (attestations)[0].Bundle
require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json")

// same test as above, but for GetByOwnerAndDigest:
attestations, err = c.GetByOwnerAndDigest(testOwner, testDigest, DefaultLimit)
require.NoError(t, err)

// because we haven't reset the mock, we have added 2 more failed requests
fetcher.AssertNumberOfCalls(t, "FlakyOnRESTSuccessWithNextPage:error", 4)

require.Equal(t, len(attestations), 10)
bundle = (attestations)[0].Bundle
require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json")
}

// test total retries
func TestGetAttestationsMaxRetries(t *testing.T) {
getAttestationRetryInterval = 0

fetcher := mockDataGenerator{
NumAttestations: 5,
}

c := &LiveClient{
api: mockAPIClient{
OnRESTWithNext: fetcher.OnREST500ErrorHandler(),
},
logger: io.NewTestHandler(),
}

_, err := c.GetByRepoAndDigest(testRepo, testDigest, DefaultLimit)
require.Error(t, err)

fetcher.AssertNumberOfCalls(t, "OnREST500Error", 4)
}
45 changes: 40 additions & 5 deletions pkg/cmd/attestation/api/mock_apiClient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ import (
"fmt"
"io"
"strings"

cliAPI "github.com/cli/cli/v2/api"
ghAPI "github.com/cli/go-gh/v2/pkg/api"
"github.com/stretchr/testify/mock"
)

type mockAPIClient struct {
Expand All @@ -22,14 +26,15 @@ func (m mockAPIClient) REST(hostname, method, p string, body io.Reader, data int
}

type mockDataGenerator struct {
mock.Mock
NumAttestations int
}

func (m mockDataGenerator) OnRESTSuccess(hostname, method, p string, body io.Reader, data interface{}) (string, error) {
func (m *mockDataGenerator) OnRESTSuccess(hostname, method, p string, body io.Reader, data interface{}) (string, error) {
return m.OnRESTWithNextSuccessHelper(hostname, method, p, body, data, false)
}

func (m mockDataGenerator) OnRESTSuccessWithNextPage(hostname, method, p string, body io.Reader, data interface{}) (string, error) {
func (m *mockDataGenerator) OnRESTSuccessWithNextPage(hostname, method, p string, body io.Reader, data interface{}) (string, error) {
// if path doesn't contain after, it means first time hitting the mock server
// so return the first page and return the link header in the response
if !strings.Contains(p, "after") {
Expand All @@ -40,7 +45,37 @@ func (m mockDataGenerator) OnRESTSuccessWithNextPage(hostname, method, p string,
return m.OnRESTWithNextSuccessHelper(hostname, method, p, body, data, false)
}

func (m mockDataGenerator) OnRESTWithNextSuccessHelper(hostname, method, p string, body io.Reader, data interface{}, hasNext bool) (string, error) {
// Returns a func that just calls OnRESTSuccessWithNextPage but half the time
// it returns a 500 error.
func (m *mockDataGenerator) FlakyOnRESTSuccessWithNextPageHandler() func(hostname, method, p string, body io.Reader, data interface{}) (string, error) {
// set up the flake counter
m.On("FlakyOnRESTSuccessWithNextPage:error").Return()

count := 0
return func(hostname, method, p string, body io.Reader, data interface{}) (string, error) {
if count%2 == 0 {
m.MethodCalled("FlakyOnRESTSuccessWithNextPage:error")

count = count + 1
return "", cliAPI.HTTPError{HTTPError: &ghAPI.HTTPError{StatusCode: 500}}
} else {
count = count + 1
return m.OnRESTSuccessWithNextPage(hostname, method, p, body, data)
}
}
}

// always returns a 500
func (m *mockDataGenerator) OnREST500ErrorHandler() func(hostname, method, p string, body io.Reader, data interface{}) (string, error) {
m.On("OnREST500Error").Return()
return func(hostname, method, p string, body io.Reader, data interface{}) (string, error) {
m.MethodCalled("OnREST500Error")

return "", cliAPI.HTTPError{HTTPError: &ghAPI.HTTPError{StatusCode: 500}}
}
}

func (m *mockDataGenerator) OnRESTWithNextSuccessHelper(hostname, method, p string, body io.Reader, data interface{}, hasNext bool) (string, error) {
atts := make([]*Attestation, m.NumAttestations)
for j := 0; j < m.NumAttestations; j++ {
att := makeTestAttestation()
Expand Down Expand Up @@ -70,7 +105,7 @@ func (m mockDataGenerator) OnRESTWithNextSuccessHelper(hostname, method, p strin
return "", nil
}

func (m mockDataGenerator) OnRESTWithNextNoAttestations(hostname, method, p string, body io.Reader, data interface{}) (string, error) {
func (m *mockDataGenerator) OnRESTWithNextNoAttestations(hostname, method, p string, body io.Reader, data interface{}) (string, error) {
resp := AttestationsResponse{
Attestations: make([]*Attestation, 0),
}
Expand All @@ -89,7 +124,7 @@ func (m mockDataGenerator) OnRESTWithNextNoAttestations(hostname, method, p stri
return "", nil
}

func (m mockDataGenerator) OnRESTWithNextError(hostname, method, p string, body io.Reader, data interface{}) (string, error) {
func (m *mockDataGenerator) OnRESTWithNextError(hostname, method, p string, body io.Reader, data interface{}) (string, error) {
return "", errors.New("failed to get attestations")
}

Expand Down

0 comments on commit afa4272

Please sign in to comment.