Skip to content

Commit

Permalink
auth refresh: preserve existing scopes when requesting new ones
Browse files Browse the repository at this point in the history
When there was a previously valid token that was granted some scopes,
ensure all those scopes will be re-requested when doing the
authentication flow for the new token.
  • Loading branch information
mislav committed Oct 14, 2021
1 parent 64a19ee commit 89ad870
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 13 deletions.
18 changes: 16 additions & 2 deletions pkg/cmd/auth/refresh/refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package refresh
import (
"errors"
"fmt"
"net/http"
"strings"

"github.com/AlecAivazis/survey/v2"
"github.com/MakeNowJust/heredoc"
Expand All @@ -16,8 +18,9 @@ import (
)

type RefreshOptions struct {
IO *iostreams.IOStreams
Config func() (config.Config, error)
IO *iostreams.IOStreams
Config func() (config.Config, error)
httpClient *http.Client

MainExecutable string

Expand All @@ -36,6 +39,7 @@ func NewCmdRefresh(f *cmdutil.Factory, runF func(*RefreshOptions) error) *cobra.
_, err := authflow.AuthFlowWithConfig(cfg, io, hostname, "", scopes)
return err
},
httpClient: http.DefaultClient,
}

cmd := &cobra.Command{
Expand Down Expand Up @@ -128,6 +132,16 @@ func refreshRun(opts *RefreshOptions) error {
}

var additionalScopes []string
if oldToken, _ := cfg.Get(hostname, "oauth_token"); oldToken != "" {
if oldScopes, err := shared.GetScopes(opts.httpClient, hostname, oldToken); err == nil {
for _, s := range strings.Split(oldScopes, ",") {
s = strings.TrimSpace(s)
if s != "" {
additionalScopes = append(additionalScopes, s)
}
}
}
}

credentialFlow := &shared.GitCredentialFlow{
Executable: opts.MainExecutable,
Expand Down
46 changes: 40 additions & 6 deletions pkg/cmd/auth/refresh/refresh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ package refresh

import (
"bytes"
"io/ioutil"
"net/http"
"strings"
"testing"

"github.com/cli/cli/v2/internal/config"
Expand Down Expand Up @@ -134,6 +137,7 @@ func Test_refreshRun(t *testing.T) {
opts *RefreshOptions
askStubs func(*prompt.AskStubber)
cfgHosts []string
oldScopes string
wantErr string
nontty bool
wantAuthArgs authArgs
Expand Down Expand Up @@ -211,6 +215,20 @@ func Test_refreshRun(t *testing.T) {
scopes: []string{"repo:invite", "public_key:read"},
},
},
{
name: "scopes provided",
cfgHosts: []string{
"github.com",
},
oldScopes: "delete_repo, codespace",
opts: &RefreshOptions{
Scopes: []string{"repo:invite", "public_key:read"},
},
wantAuthArgs: authArgs{
hostname: "github.com",
scopes: []string{"repo:invite", "public_key:read", "delete_repo", "codespace"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand All @@ -234,10 +252,26 @@ func Test_refreshRun(t *testing.T) {
for _, hostname := range tt.cfgHosts {
_ = cfg.Set(hostname, "oauth_token", "abc123")
}
reg := &httpmock.Registry{}
reg.Register(
httpmock.GraphQL(`query UserCurrent\b`),
httpmock.StringResponse(`{"data":{"viewer":{"login":"cybilb"}}}`))

httpReg := &httpmock.Registry{}
httpReg.Register(
httpmock.REST("GET", ""),
func(req *http.Request) (*http.Response, error) {
statusCode := 200
if req.Header.Get("Authorization") != "token abc123" {
statusCode = 400
}
return &http.Response{
Request: req,
StatusCode: statusCode,
Body: ioutil.NopCloser(strings.NewReader(``)),
Header: http.Header{
"X-Oauth-Scopes": {tt.oldScopes},
},
}, nil
},
)
tt.opts.httpClient = &http.Client{Transport: httpReg}

mainBuf := bytes.Buffer{}
hostsBuf := bytes.Buffer{}
Expand All @@ -258,8 +292,8 @@ func Test_refreshRun(t *testing.T) {
assert.NoError(t, err)
}

assert.Equal(t, aa.hostname, tt.wantAuthArgs.hostname)
assert.Equal(t, aa.scopes, tt.wantAuthArgs.scopes)
assert.Equal(t, tt.wantAuthArgs.hostname, aa.hostname)
assert.Equal(t, tt.wantAuthArgs.scopes, aa.scopes)
})
}
}
18 changes: 13 additions & 5 deletions pkg/cmd/auth/shared/oauth_scopes.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,19 @@ type httpClient interface {
Do(*http.Request) (*http.Response, error)
}

func HasMinimumScopes(httpClient httpClient, hostname, authToken string) error {
func GetScopes(httpClient httpClient, hostname, authToken string) (string, error) {
apiEndpoint := ghinstance.RESTPrefix(hostname)

req, err := http.NewRequest("GET", apiEndpoint, nil)
if err != nil {
return err
return "", err
}

req.Header.Set("Authorization", "token "+authToken)

res, err := httpClient.Do(req)
if err != nil {
return err
return "", err
}

defer func() {
Expand All @@ -55,10 +55,18 @@ func HasMinimumScopes(httpClient httpClient, hostname, authToken string) error {
}()

if res.StatusCode != 200 {
return api.HandleHTTPError(res)
return "", api.HandleHTTPError(res)
}

return res.Header.Get("X-Oauth-Scopes"), nil
}

func HasMinimumScopes(httpClient httpClient, hostname, authToken string) error {
scopesHeader, err := GetScopes(httpClient, hostname, authToken)
if err != nil {
return err
}

scopesHeader := res.Header.Get("X-Oauth-Scopes")
if scopesHeader == "" {
// if the token reports no scopes, assume that it's an integration token and give up on
// detecting its capabilities
Expand Down

0 comments on commit 89ad870

Please sign in to comment.