diff --git a/api.go b/api.go
index f16d947731e..2ee9c10d0fd 100644
--- a/api.go
+++ b/api.go
@@ -374,6 +374,7 @@ func handleGetDetail(sessionKey, apiID string, byHash bool) (interface{}, int) {
if usedQuota, err := sessionManager.Store().GetRawKey(limQuotaKey); err == nil {
qInt, _ := strconv.Atoi(usedQuota)
remaining := access.Limit.QuotaMax - int64(qInt)
+
if remaining < 0 {
access.Limit.QuotaRemaining = 0
} else {
@@ -381,6 +382,9 @@ func handleGetDetail(sessionKey, apiID string, byHash bool) (interface{}, int) {
}
session.AccessRights[id] = access
} else {
+ access.Limit.QuotaRemaining = access.Limit.QuotaMax
+ session.AccessRights[id] = access
+
log.WithFields(logrus.Fields{
"prefix": "api",
"apiID": id,
diff --git a/auth_manager.go b/auth_manager.go
index 8c96dc06bc6..aae839fcb16 100644
--- a/auth_manager.go
+++ b/auth_manager.go
@@ -202,6 +202,11 @@ func (b *DefaultSessionManager) ResetQuota(keyName string, session *user.Session
// Fix the raw key
go b.store.DeleteRawKey(rawKey)
//go b.store.SetKey(rawKey, "0", session.QuotaRenewalRate)
+
+ for apiID := range session.AccessRights {
+ rawKey = QuotaKeyPrefix + apiID + "-" + keyName
+ go b.store.DeleteRawKey(rawKey)
+ }
}
// UpdateSession updates the session state in the storage engine
diff --git a/looping_test.go b/looping_test.go
index 9e9886de26b..48b9fa82139 100644
--- a/looping_test.go
+++ b/looping_test.go
@@ -4,27 +4,27 @@
package main
import (
- "encoding/json"
- "sync"
- "testing"
+ "encoding/json"
+ "sync"
+ "testing"
- "github.com/TykTechnologies/tyk/test"
- "github.com/TykTechnologies/tyk/user"
+ "github.com/TykTechnologies/tyk/test"
+ "github.com/TykTechnologies/tyk/user"
)
func TestLooping(t *testing.T) {
- ts := newTykTestServer()
- defer ts.Close()
-
- postAction := `data`
- getAction := `data`
-
- t.Run("Using advanced URL rewrite", func(t *testing.T) {
- // We defined internnal advanced rewrite based on body data
- // which rewrites to internal paths (marked as blacklist so they protected from outside world)
- buildAndLoadAPI(func(spec *APISpec) {
- version := spec.VersionData.Versions["v1"]
- json.Unmarshal([]byte(`{
+ ts := newTykTestServer()
+ defer ts.Close()
+
+ postAction := `data`
+ getAction := `data`
+
+ t.Run("Using advanced URL rewrite", func(t *testing.T) {
+ // We defined internnal advanced rewrite based on body data
+ // which rewrites to internal paths (marked as blacklist so they protected from outside world)
+ buildAndLoadAPI(func(spec *APISpec) {
+ version := spec.VersionData.Versions["v1"]
+ json.Unmarshal([]byte(`{
"use_extended_paths": true,
"extended_paths": {
"internal": [{
@@ -67,46 +67,46 @@ func TestLooping(t *testing.T) {
}
}`), &version)
- spec.VersionData.Versions["v1"] = version
+ spec.VersionData.Versions["v1"] = version
- spec.Proxy.ListenPath = "/"
- })
+ spec.Proxy.ListenPath = "/"
+ })
- ts.Run(t, []test.TestCase{
- {Method: "POST", Path: "/xml", Data: postAction, BodyMatch: `"Url":"/post_action`},
+ ts.Run(t, []test.TestCase{
+ {Method: "POST", Path: "/xml", Data: postAction, BodyMatch: `"Url":"/post_action`},
- // Should retain original query params
- {Method: "POST", Path: "/xml?a=b", Data: getAction, BodyMatch: `"Url":"/get_action`},
+ // Should retain original query params
+ {Method: "POST", Path: "/xml?a=b", Data: getAction, BodyMatch: `"Url":"/get_action`},
- // Should rewrite http method, if loop rewrite param passed
- {Method: "POST", Path: "/xml", Data: getAction, BodyMatch: `"Method":"GET"`},
+ // Should rewrite http method, if loop rewrite param passed
+ {Method: "POST", Path: "/xml", Data: getAction, BodyMatch: `"Method":"GET"`},
- // Internal endpoint can be accessed only via looping
- {Method: "GET", Path: "/get_action", Code: 403},
+ // Internal endpoint can be accessed only via looping
+ {Method: "GET", Path: "/get_action", Code: 403},
- {Method: "POST", Path: "/get_action", Code: 403},
- }...)
- })
+ {Method: "POST", Path: "/get_action", Code: 403},
+ }...)
+ })
- t.Run("Loop to another API", func(t *testing.T) {
- buildAndLoadAPI(func(spec *APISpec) {
- spec.APIID = "testid"
- spec.Name = "hidden api"
- spec.Proxy.ListenPath = "/somesecret"
- spec.Internal = true
- version := spec.VersionData.Versions["v1"]
- json.Unmarshal([]byte(`{
+ t.Run("Loop to another API", func(t *testing.T) {
+ buildAndLoadAPI(func(spec *APISpec) {
+ spec.APIID = "testid"
+ spec.Name = "hidden api"
+ spec.Proxy.ListenPath = "/somesecret"
+ spec.Internal = true
+ version := spec.VersionData.Versions["v1"]
+ json.Unmarshal([]byte(`{
"use_extended_paths": true,
"global_headers": {
"X-Name":"internal"
}
}`), &version)
- spec.VersionData.Versions["v1"] = version
- }, func(spec *APISpec) {
- spec.Proxy.ListenPath = "/test"
+ spec.VersionData.Versions["v1"] = version
+ }, func(spec *APISpec) {
+ spec.Proxy.ListenPath = "/test"
- version := spec.VersionData.Versions["v1"]
- json.Unmarshal([]byte(`{
+ version := spec.VersionData.Versions["v1"]
+ json.Unmarshal([]byte(`{
"use_extended_paths": true,
"extended_paths": {
"url_rewrites": [{
@@ -128,19 +128,19 @@ func TestLooping(t *testing.T) {
}
}`), &version)
- spec.VersionData.Versions["v1"] = version
- })
+ spec.VersionData.Versions["v1"] = version
+ })
- ts.Run(t, []test.TestCase{
- {Path: "/somesecret", Code: 404},
- {Path: "/test/by_name", Code: 200, BodyMatch: `"X-Name":"internal"`},
- {Path: "/test/by_id", Code: 200, BodyMatch: `"X-Name":"internal"`},
- {Path: "/test/wrong", Code: 500},
- }...)
- })
+ ts.Run(t, []test.TestCase{
+ {Path: "/somesecret", Code: 404},
+ {Path: "/test/by_name", Code: 200, BodyMatch: `"X-Name":"internal"`},
+ {Path: "/test/by_id", Code: 200, BodyMatch: `"X-Name":"internal"`},
+ {Path: "/test/wrong", Code: 500},
+ }...)
+ })
- t.Run("VirtualEndpoint or plugins", func(t *testing.T) {
- testPrepareVirtualEndpoint(`
+ t.Run("VirtualEndpoint or plugins", func(t *testing.T) {
+ testPrepareVirtualEndpoint(`
function testVirtData(request, session, config) {
var loopLocation = "/default"
@@ -160,21 +160,21 @@ func TestLooping(t *testing.T) {
}
`, "POST", "/virt", true)
- ts.Run(t, []test.TestCase{
- {Method: "POST", Path: "/virt", Data: postAction, BodyMatch: `"Url":"/post_action`},
+ ts.Run(t, []test.TestCase{
+ {Method: "POST", Path: "/virt", Data: postAction, BodyMatch: `"Url":"/post_action`},
- // Should retain original query params
- {Method: "POST", Path: "/virt?a=b", Data: getAction, BodyMatch: `"Url":"/get_action`},
+ // Should retain original query params
+ {Method: "POST", Path: "/virt?a=b", Data: getAction, BodyMatch: `"Url":"/get_action`},
- // Should rewrite http method, if loop rewrite param passed
- {Method: "POST", Path: "/virt", Data: getAction, BodyMatch: `"Method":"GET"`},
- }...)
- })
+ // Should rewrite http method, if loop rewrite param passed
+ {Method: "POST", Path: "/virt", Data: getAction, BodyMatch: `"Method":"GET"`},
+ }...)
+ })
- t.Run("Loop limit", func(t *testing.T) {
- buildAndLoadAPI(func(spec *APISpec) {
- version := spec.VersionData.Versions["v1"]
- json.Unmarshal([]byte(`{
+ t.Run("Loop limit", func(t *testing.T) {
+ buildAndLoadAPI(func(spec *APISpec) {
+ version := spec.VersionData.Versions["v1"]
+ json.Unmarshal([]byte(`{
"use_extended_paths": true,
"extended_paths": {
"url_rewrites": [{
@@ -186,19 +186,19 @@ func TestLooping(t *testing.T) {
}
}`), &version)
- spec.VersionData.Versions["v1"] = version
- spec.Proxy.ListenPath = "/"
- })
+ spec.VersionData.Versions["v1"] = version
+ spec.Proxy.ListenPath = "/"
+ })
- ts.Run(t, []test.TestCase{
- {Method: "GET", Path: "/recursion", Code: 500, BodyMatch: "Loop level too deep. Found more than 2 loops in single request"},
- }...)
- })
+ ts.Run(t, []test.TestCase{
+ {Method: "GET", Path: "/recursion", Code: 500, BodyMatch: "Loop level too deep. Found more than 2 loops in single request"},
+ }...)
+ })
- t.Run("Quota and rate limit calculation", func(t *testing.T) {
- buildAndLoadAPI(func(spec *APISpec) {
- version := spec.VersionData.Versions["v1"]
- json.Unmarshal([]byte(`{
+ t.Run("Quota and rate limit calculation", func(t *testing.T) {
+ buildAndLoadAPI(func(spec *APISpec) {
+ version := spec.VersionData.Versions["v1"]
+ json.Unmarshal([]byte(`{
"use_extended_paths": true,
"extended_paths": {
"url_rewrites": [{
@@ -210,42 +210,42 @@ func TestLooping(t *testing.T) {
}
}`), &version)
- spec.VersionData.Versions["v1"] = version
- spec.Proxy.ListenPath = "/"
- spec.UseKeylessAccess = false
- })
+ spec.VersionData.Versions["v1"] = version
+ spec.Proxy.ListenPath = "/"
+ spec.UseKeylessAccess = false
+ })
- keyID := createSession(func(s *user.SessionState) {
- s.QuotaMax = 2
- })
+ keyID := createSession(func(s *user.SessionState) {
+ s.QuotaMax = 2
+ })
- authHeaders := map[string]string{"authorization": keyID}
+ authHeaders := map[string]string{"authorization": keyID}
- ts.Run(t, []test.TestCase{
- {Method: "GET", Path: "/recursion", Headers: authHeaders, BodyNotMatch: "Quota exceeded"},
- }...)
- })
+ ts.Run(t, []test.TestCase{
+ {Method: "GET", Path: "/recursion", Headers: authHeaders, BodyNotMatch: "Quota exceeded"},
+ }...)
+ })
}
func TestConcurrencyReloads(t *testing.T) {
- var wg sync.WaitGroup
+ var wg sync.WaitGroup
- ts := newTykTestServer()
- defer ts.Close()
+ ts := newTykTestServer()
+ defer ts.Close()
- buildAndLoadAPI()
+ buildAndLoadAPI()
- for i := 0; i < 10; i++ {
- wg.Add(1)
- go func() {
- ts.Run(t, test.TestCase{Path: "/sample", Code: 200})
- wg.Done()
- }()
- }
+ for i := 0; i < 10; i++ {
+ wg.Add(1)
+ go func() {
+ ts.Run(t, test.TestCase{Path: "/sample", Code: 200})
+ wg.Done()
+ }()
+ }
- for j := 0; j < 5; j++ {
- buildAndLoadAPI()
- }
+ for j := 0; j < 5; j++ {
+ buildAndLoadAPI()
+ }
- wg.Wait()
+ wg.Wait()
}
diff --git a/policy_test.go b/policy_test.go
index a37a02b1405..b7b64425afb 100644
--- a/policy_test.go
+++ b/policy_test.go
@@ -628,4 +628,40 @@ func TestApplyPoliciesQuotaAPILimit(t *testing.T) {
},
},
}...)
+
+ // Reset quota
+ ts.Run(t, []test.TestCase{
+ {
+ Method: http.MethodPut,
+ Path: "/tyk/keys/" + key,
+ AdminAuth: true,
+ Code: http.StatusOK,
+ Data: session,
+ },
+ {
+ Method: http.MethodGet,
+ Path: "/tyk/keys/" + key,
+ AdminAuth: true,
+ Code: http.StatusOK,
+ BodyMatchFunc: func(data []byte) bool {
+ sessionData := user.SessionState{}
+ if err := json.Unmarshal(data, &sessionData); err != nil {
+ t.Log(err.Error())
+ return false
+ }
+ api1Limit := sessionData.AccessRights["api1"].Limit
+ if api1Limit == nil {
+ t.Error("api1 limit is not set")
+ return false
+ }
+
+ if api1Limit.QuotaRemaining != 100 {
+ t.Error("Should reset quota:", api1Limit.QuotaRemaining)
+ return false
+ }
+
+ return true
+ },
+ },
+ }...)
}