diff --git a/api.go b/api.go index f16d947731e..2ee9c10d0fd 100644 --- a/api.go +++ b/api.go @@ -374,6 +374,7 @@ func handleGetDetail(sessionKey, apiID string, byHash bool) (interface{}, int) { if usedQuota, err := sessionManager.Store().GetRawKey(limQuotaKey); err == nil { qInt, _ := strconv.Atoi(usedQuota) remaining := access.Limit.QuotaMax - int64(qInt) + if remaining < 0 { access.Limit.QuotaRemaining = 0 } else { @@ -381,6 +382,9 @@ func handleGetDetail(sessionKey, apiID string, byHash bool) (interface{}, int) { } session.AccessRights[id] = access } else { + access.Limit.QuotaRemaining = access.Limit.QuotaMax + session.AccessRights[id] = access + log.WithFields(logrus.Fields{ "prefix": "api", "apiID": id, diff --git a/auth_manager.go b/auth_manager.go index 8c96dc06bc6..aae839fcb16 100644 --- a/auth_manager.go +++ b/auth_manager.go @@ -202,6 +202,11 @@ func (b *DefaultSessionManager) ResetQuota(keyName string, session *user.Session // Fix the raw key go b.store.DeleteRawKey(rawKey) //go b.store.SetKey(rawKey, "0", session.QuotaRenewalRate) + + for apiID := range session.AccessRights { + rawKey = QuotaKeyPrefix + apiID + "-" + keyName + go b.store.DeleteRawKey(rawKey) + } } // UpdateSession updates the session state in the storage engine diff --git a/looping_test.go b/looping_test.go index 9e9886de26b..48b9fa82139 100644 --- a/looping_test.go +++ b/looping_test.go @@ -4,27 +4,27 @@ package main import ( - "encoding/json" - "sync" - "testing" + "encoding/json" + "sync" + "testing" - "github.com/TykTechnologies/tyk/test" - "github.com/TykTechnologies/tyk/user" + "github.com/TykTechnologies/tyk/test" + "github.com/TykTechnologies/tyk/user" ) func TestLooping(t *testing.T) { - ts := newTykTestServer() - defer ts.Close() - - postAction := `data` - getAction := `data` - - t.Run("Using advanced URL rewrite", func(t *testing.T) { - // We defined internnal advanced rewrite based on body data - // which rewrites to internal paths (marked as blacklist so they protected from outside world) - buildAndLoadAPI(func(spec *APISpec) { - version := spec.VersionData.Versions["v1"] - json.Unmarshal([]byte(`{ + ts := newTykTestServer() + defer ts.Close() + + postAction := `data` + getAction := `data` + + t.Run("Using advanced URL rewrite", func(t *testing.T) { + // We defined internnal advanced rewrite based on body data + // which rewrites to internal paths (marked as blacklist so they protected from outside world) + buildAndLoadAPI(func(spec *APISpec) { + version := spec.VersionData.Versions["v1"] + json.Unmarshal([]byte(`{ "use_extended_paths": true, "extended_paths": { "internal": [{ @@ -67,46 +67,46 @@ func TestLooping(t *testing.T) { } }`), &version) - spec.VersionData.Versions["v1"] = version + spec.VersionData.Versions["v1"] = version - spec.Proxy.ListenPath = "/" - }) + spec.Proxy.ListenPath = "/" + }) - ts.Run(t, []test.TestCase{ - {Method: "POST", Path: "/xml", Data: postAction, BodyMatch: `"Url":"/post_action`}, + ts.Run(t, []test.TestCase{ + {Method: "POST", Path: "/xml", Data: postAction, BodyMatch: `"Url":"/post_action`}, - // Should retain original query params - {Method: "POST", Path: "/xml?a=b", Data: getAction, BodyMatch: `"Url":"/get_action`}, + // Should retain original query params + {Method: "POST", Path: "/xml?a=b", Data: getAction, BodyMatch: `"Url":"/get_action`}, - // Should rewrite http method, if loop rewrite param passed - {Method: "POST", Path: "/xml", Data: getAction, BodyMatch: `"Method":"GET"`}, + // Should rewrite http method, if loop rewrite param passed + {Method: "POST", Path: "/xml", Data: getAction, BodyMatch: `"Method":"GET"`}, - // Internal endpoint can be accessed only via looping - {Method: "GET", Path: "/get_action", Code: 403}, + // Internal endpoint can be accessed only via looping + {Method: "GET", Path: "/get_action", Code: 403}, - {Method: "POST", Path: "/get_action", Code: 403}, - }...) - }) + {Method: "POST", Path: "/get_action", Code: 403}, + }...) + }) - t.Run("Loop to another API", func(t *testing.T) { - buildAndLoadAPI(func(spec *APISpec) { - spec.APIID = "testid" - spec.Name = "hidden api" - spec.Proxy.ListenPath = "/somesecret" - spec.Internal = true - version := spec.VersionData.Versions["v1"] - json.Unmarshal([]byte(`{ + t.Run("Loop to another API", func(t *testing.T) { + buildAndLoadAPI(func(spec *APISpec) { + spec.APIID = "testid" + spec.Name = "hidden api" + spec.Proxy.ListenPath = "/somesecret" + spec.Internal = true + version := spec.VersionData.Versions["v1"] + json.Unmarshal([]byte(`{ "use_extended_paths": true, "global_headers": { "X-Name":"internal" } }`), &version) - spec.VersionData.Versions["v1"] = version - }, func(spec *APISpec) { - spec.Proxy.ListenPath = "/test" + spec.VersionData.Versions["v1"] = version + }, func(spec *APISpec) { + spec.Proxy.ListenPath = "/test" - version := spec.VersionData.Versions["v1"] - json.Unmarshal([]byte(`{ + version := spec.VersionData.Versions["v1"] + json.Unmarshal([]byte(`{ "use_extended_paths": true, "extended_paths": { "url_rewrites": [{ @@ -128,19 +128,19 @@ func TestLooping(t *testing.T) { } }`), &version) - spec.VersionData.Versions["v1"] = version - }) + spec.VersionData.Versions["v1"] = version + }) - ts.Run(t, []test.TestCase{ - {Path: "/somesecret", Code: 404}, - {Path: "/test/by_name", Code: 200, BodyMatch: `"X-Name":"internal"`}, - {Path: "/test/by_id", Code: 200, BodyMatch: `"X-Name":"internal"`}, - {Path: "/test/wrong", Code: 500}, - }...) - }) + ts.Run(t, []test.TestCase{ + {Path: "/somesecret", Code: 404}, + {Path: "/test/by_name", Code: 200, BodyMatch: `"X-Name":"internal"`}, + {Path: "/test/by_id", Code: 200, BodyMatch: `"X-Name":"internal"`}, + {Path: "/test/wrong", Code: 500}, + }...) + }) - t.Run("VirtualEndpoint or plugins", func(t *testing.T) { - testPrepareVirtualEndpoint(` + t.Run("VirtualEndpoint or plugins", func(t *testing.T) { + testPrepareVirtualEndpoint(` function testVirtData(request, session, config) { var loopLocation = "/default" @@ -160,21 +160,21 @@ func TestLooping(t *testing.T) { } `, "POST", "/virt", true) - ts.Run(t, []test.TestCase{ - {Method: "POST", Path: "/virt", Data: postAction, BodyMatch: `"Url":"/post_action`}, + ts.Run(t, []test.TestCase{ + {Method: "POST", Path: "/virt", Data: postAction, BodyMatch: `"Url":"/post_action`}, - // Should retain original query params - {Method: "POST", Path: "/virt?a=b", Data: getAction, BodyMatch: `"Url":"/get_action`}, + // Should retain original query params + {Method: "POST", Path: "/virt?a=b", Data: getAction, BodyMatch: `"Url":"/get_action`}, - // Should rewrite http method, if loop rewrite param passed - {Method: "POST", Path: "/virt", Data: getAction, BodyMatch: `"Method":"GET"`}, - }...) - }) + // Should rewrite http method, if loop rewrite param passed + {Method: "POST", Path: "/virt", Data: getAction, BodyMatch: `"Method":"GET"`}, + }...) + }) - t.Run("Loop limit", func(t *testing.T) { - buildAndLoadAPI(func(spec *APISpec) { - version := spec.VersionData.Versions["v1"] - json.Unmarshal([]byte(`{ + t.Run("Loop limit", func(t *testing.T) { + buildAndLoadAPI(func(spec *APISpec) { + version := spec.VersionData.Versions["v1"] + json.Unmarshal([]byte(`{ "use_extended_paths": true, "extended_paths": { "url_rewrites": [{ @@ -186,19 +186,19 @@ func TestLooping(t *testing.T) { } }`), &version) - spec.VersionData.Versions["v1"] = version - spec.Proxy.ListenPath = "/" - }) + spec.VersionData.Versions["v1"] = version + spec.Proxy.ListenPath = "/" + }) - ts.Run(t, []test.TestCase{ - {Method: "GET", Path: "/recursion", Code: 500, BodyMatch: "Loop level too deep. Found more than 2 loops in single request"}, - }...) - }) + ts.Run(t, []test.TestCase{ + {Method: "GET", Path: "/recursion", Code: 500, BodyMatch: "Loop level too deep. Found more than 2 loops in single request"}, + }...) + }) - t.Run("Quota and rate limit calculation", func(t *testing.T) { - buildAndLoadAPI(func(spec *APISpec) { - version := spec.VersionData.Versions["v1"] - json.Unmarshal([]byte(`{ + t.Run("Quota and rate limit calculation", func(t *testing.T) { + buildAndLoadAPI(func(spec *APISpec) { + version := spec.VersionData.Versions["v1"] + json.Unmarshal([]byte(`{ "use_extended_paths": true, "extended_paths": { "url_rewrites": [{ @@ -210,42 +210,42 @@ func TestLooping(t *testing.T) { } }`), &version) - spec.VersionData.Versions["v1"] = version - spec.Proxy.ListenPath = "/" - spec.UseKeylessAccess = false - }) + spec.VersionData.Versions["v1"] = version + spec.Proxy.ListenPath = "/" + spec.UseKeylessAccess = false + }) - keyID := createSession(func(s *user.SessionState) { - s.QuotaMax = 2 - }) + keyID := createSession(func(s *user.SessionState) { + s.QuotaMax = 2 + }) - authHeaders := map[string]string{"authorization": keyID} + authHeaders := map[string]string{"authorization": keyID} - ts.Run(t, []test.TestCase{ - {Method: "GET", Path: "/recursion", Headers: authHeaders, BodyNotMatch: "Quota exceeded"}, - }...) - }) + ts.Run(t, []test.TestCase{ + {Method: "GET", Path: "/recursion", Headers: authHeaders, BodyNotMatch: "Quota exceeded"}, + }...) + }) } func TestConcurrencyReloads(t *testing.T) { - var wg sync.WaitGroup + var wg sync.WaitGroup - ts := newTykTestServer() - defer ts.Close() + ts := newTykTestServer() + defer ts.Close() - buildAndLoadAPI() + buildAndLoadAPI() - for i := 0; i < 10; i++ { - wg.Add(1) - go func() { - ts.Run(t, test.TestCase{Path: "/sample", Code: 200}) - wg.Done() - }() - } + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + ts.Run(t, test.TestCase{Path: "/sample", Code: 200}) + wg.Done() + }() + } - for j := 0; j < 5; j++ { - buildAndLoadAPI() - } + for j := 0; j < 5; j++ { + buildAndLoadAPI() + } - wg.Wait() + wg.Wait() } diff --git a/policy_test.go b/policy_test.go index a37a02b1405..b7b64425afb 100644 --- a/policy_test.go +++ b/policy_test.go @@ -628,4 +628,40 @@ func TestApplyPoliciesQuotaAPILimit(t *testing.T) { }, }, }...) + + // Reset quota + ts.Run(t, []test.TestCase{ + { + Method: http.MethodPut, + Path: "/tyk/keys/" + key, + AdminAuth: true, + Code: http.StatusOK, + Data: session, + }, + { + Method: http.MethodGet, + Path: "/tyk/keys/" + key, + AdminAuth: true, + Code: http.StatusOK, + BodyMatchFunc: func(data []byte) bool { + sessionData := user.SessionState{} + if err := json.Unmarshal(data, &sessionData); err != nil { + t.Log(err.Error()) + return false + } + api1Limit := sessionData.AccessRights["api1"].Limit + if api1Limit == nil { + t.Error("api1 limit is not set") + return false + } + + if api1Limit.QuotaRemaining != 100 { + t.Error("Should reset quota:", api1Limit.QuotaRemaining) + return false + } + + return true + }, + }, + }...) }