Skip to content

Commit

Permalink
gcp client auth plugin: persist default cache on unauthorized
Browse files Browse the repository at this point in the history
The default cache for a cachedTokenSource is not always empty. In the
case of commandTokenSource, it contains calling details for the
external command that is used to generate refresh tokens. Persisting
a completely empty cache will thus break ability for the plugin to
obtain refresh tokens. This changes the roundtripper to persist
the default cache instead of assuming an empty map.
  • Loading branch information
j3ffml committed Sep 10, 2018
1 parent 25cbd1c commit 73e5e43
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 23 deletions.
25 changes: 22 additions & 3 deletions staging/src/k8s.io/client-go/plugin/pkg/client/auth/gcp/gcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,13 @@ func parseScopes(gcpConfig map[string]string) []string {
}

func (g *gcpAuthProvider) WrapTransport(rt http.RoundTripper) http.RoundTripper {
return &conditionalTransport{&oauth2.Transport{Source: g.tokenSource, Base: rt}, g.persister}
var resetCache map[string]string
if cts, ok := g.tokenSource.(*cachedTokenSource); ok {
resetCache = cts.baseCache()
} else {
resetCache = make(map[string]string)
}
return &conditionalTransport{&oauth2.Transport{Source: g.tokenSource, Base: rt}, g.persister, resetCache}
}

func (g *gcpAuthProvider) Login() error { return nil }
Expand Down Expand Up @@ -247,6 +253,19 @@ func (t *cachedTokenSource) update(tok *oauth2.Token) map[string]string {
return ret
}

// baseCache is the base configuration value for this TokenSource, without any cached ephemeral tokens.
func (t *cachedTokenSource) baseCache() map[string]string {
t.lk.Lock()
defer t.lk.Unlock()
ret := map[string]string{}
for k, v := range t.cache {
ret[k] = v
}
delete(ret, "access-token")
delete(ret, "expiry")
return ret
}

type commandTokenSource struct {
cmd string
args []string
Expand Down Expand Up @@ -337,6 +356,7 @@ func parseJSONPath(input interface{}, name, template string) (string, error) {
type conditionalTransport struct {
oauthTransport *oauth2.Transport
persister restclient.AuthProviderConfigPersister
resetCache map[string]string
}

var _ net.RoundTripperWrapper = &conditionalTransport{}
Expand All @@ -354,8 +374,7 @@ func (t *conditionalTransport) RoundTrip(req *http.Request) (*http.Response, err

if res.StatusCode == 401 {
glog.V(4).Infof("The credentials that were supplied are invalid for the target cluster")
emptyCache := make(map[string]string)
t.persister.Persist(emptyCache)
t.persister.Persist(t.resetCache)
}

return res, nil
Expand Down
70 changes: 50 additions & 20 deletions staging/src/k8s.io/client-go/plugin/pkg/client/auth/gcp/gcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -442,55 +442,85 @@ func (t *MockTransport) RoundTrip(req *http.Request) (*http.Response, error) {
return t.res, nil
}

func TestClearingCredentials(t *testing.T) {
func Test_cmdTokenSource_roundTrip(t *testing.T) {

accessToken := "fakeToken"
fakeExpiry := time.Now().Add(time.Hour)

cache := map[string]string{
"access-token": "fakeToken",
"expiry": fakeExpiry.String(),
fakeExpiryStr := fakeExpiry.Format(time.RFC3339Nano)
fs := &fakeTokenSource{
token: &oauth2.Token{
AccessToken: accessToken,
Expiry: fakeExpiry,
},
}

cts := cachedTokenSource{
source: nil,
accessToken: cache["access-token"],
expiry: fakeExpiry,
persister: nil,
cache: nil,
cmdCache := map[string]string{
"cmd-path": "/path/to/tokensource/cmd",
"cmd-args": "--output=json",
}
cmdCacheUpdated := map[string]string{
"cmd-path": "/path/to/tokensource/cmd",
"cmd-args": "--output=json",
"access-token": accessToken,
"expiry": fakeExpiryStr,
}
simpleCacheUpdated := map[string]string{
"access-token": accessToken,
"expiry": fakeExpiryStr,
}

tests := []struct {
name string
res http.Response
cache map[string]string
name string
res http.Response
baseCache, expectedCache map[string]string
}{
{
"Unauthorized",
http.Response{StatusCode: 401},
make(map[string]string),
make(map[string]string),
},
{
"Unauthorized, nonempty defaultCache",
http.Response{StatusCode: 401},
cmdCache,
cmdCache,
},
{
"Authorized",
http.Response{StatusCode: 200},
cache,
make(map[string]string),
simpleCacheUpdated,
},
{
"Authorized, nonempty defaultCache",
http.Response{StatusCode: 200},
cmdCache,
cmdCacheUpdated,
},
}

persister := &fakePersister{}
req := http.Request{Header: http.Header{}}

for _, tc := range tests {
authProvider := gcpAuthProvider{&cts, persister}
cts, err := newCachedTokenSource(accessToken, fakeExpiry.String(), persister, fs, tc.baseCache)
if err != nil {
t.Fatalf("unexpected error from newCachedTokenSource: %v", err)
}
authProvider := gcpAuthProvider{cts, persister}

fakeTransport := MockTransport{&tc.res}

transport := (authProvider.WrapTransport(&fakeTransport))
persister.Persist(cache)
// call Token to persist/update cache
if _, err := cts.Token(); err != nil {
t.Fatalf("unexpected error from cachedTokenSource.Token(): %v", err)
}

transport.RoundTrip(&req)

if got := persister.read(); !reflect.DeepEqual(got, tc.cache) {
t.Errorf("got cache %v, want %v", got, tc.cache)
if got := persister.read(); !reflect.DeepEqual(got, tc.expectedCache) {
t.Errorf("got cache %v, want %v", got, tc.expectedCache)
}
}

Expand Down

0 comments on commit 73e5e43

Please sign in to comment.