Skip to content

Commit

Permalink
all: use helper funcs for ctx's SessionData (TykTechnologies#734)
Browse files Browse the repository at this point in the history
Also use a pointer type for more arguments and return values, as we need
a pointer for the context value. This is because we need to be able to
return nil (non-existing) in ctxGetSession. Besides, sticking
SessionState into an interface{} will need a dereference anyway.

Updates TykTechnologies#683.
  • Loading branch information
mvdan authored and buger committed May 12, 2017
1 parent 922b4e4 commit ac8ddfc
Show file tree
Hide file tree
Showing 33 changed files with 346 additions and 335 deletions.
61 changes: 37 additions & 24 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func checkAndApplyTrialPeriod(keyName, apiId string, newSession *SessionState) {
}
}

func doAddOrUpdate(keyName string, newSession SessionState, dontReset bool) error {
func doAddOrUpdate(keyName string, newSession *SessionState, dontReset bool) error {
newSession.LastUpdated = strconv.Itoa(int(time.Now().Unix()))

if len(newSession.AccessRights) > 0 {
Expand All @@ -117,7 +117,7 @@ func doAddOrUpdate(keyName string, newSession SessionState, dontReset bool) erro
}).Error("Could not add key for this API ID, API doesn't exist.")
return errors.New("API must be active to add keys")
}
checkAndApplyTrialPeriod(keyName, apiId, &newSession)
checkAndApplyTrialPeriod(keyName, apiId, newSession)

// Lets reset keys if they are edited by admin
if !apiSpec.DontSetQuotasOnCreate {
Expand All @@ -127,7 +127,7 @@ func doAddOrUpdate(keyName string, newSession SessionState, dontReset bool) erro
newSession.QuotaRenews = time.Now().Unix() + newSession.QuotaRenewalRate
}

err := apiSpec.SessionManager.UpdateSession(keyName, newSession, getLifetime(apiSpec, &newSession))
err := apiSpec.SessionManager.UpdateSession(keyName, newSession, getLifetime(apiSpec, newSession))
if err != nil {
return err
}
Expand All @@ -145,8 +145,8 @@ func doAddOrUpdate(keyName string, newSession SessionState, dontReset bool) erro
spec.SessionManager.ResetQuota(keyName, newSession)
newSession.QuotaRenews = time.Now().Unix() + newSession.QuotaRenewalRate
}
checkAndApplyTrialPeriod(keyName, spec.APIID, &newSession)
err := spec.SessionManager.UpdateSession(keyName, newSession, getLifetime(spec, &newSession))
checkAndApplyTrialPeriod(keyName, spec.APIID, newSession)
err := spec.SessionManager.UpdateSession(keyName, newSession, getLifetime(spec, newSession))
if err != nil {
return err
}
Expand Down Expand Up @@ -248,7 +248,7 @@ func handleAddOrUpdate(keyName string, r *http.Request) ([]byte, int) {

}
suppressReset := r.FormValue("suppress_reset") == "1"
if err := doAddOrUpdate(keyName, newSession, suppressReset); err != nil {
if err := doAddOrUpdate(keyName, &newSession, suppressReset); err != nil {
return createError("Failed to create key, ensure security settings are correct."), 500
}

Expand Down Expand Up @@ -398,7 +398,7 @@ func handleDeleteKey(keyName, apiID string) ([]byte, int) {
// Go through ALL managed API's and delete the key
for _, spec := range ApiSpecRegister {
spec.SessionManager.RemoveSession(keyName)
spec.SessionManager.ResetQuota(keyName, SessionState{})
spec.SessionManager.ResetQuota(keyName, &SessionState{})
}

log.WithFields(logrus.Fields{
Expand All @@ -421,7 +421,7 @@ func handleDeleteKey(keyName, apiID string) ([]byte, int) {
}

sessionManager.RemoveSession(keyName)
sessionManager.ResetQuota(keyName, SessionState{})
sessionManager.ResetQuota(keyName, &SessionState{})

statusObj := APIModifyKeySuccess{keyName, "ok", "deleted"}
responseMessage, err = json.Marshal(&statusObj)
Expand Down Expand Up @@ -920,9 +920,9 @@ func orgHandler(w http.ResponseWriter, r *http.Request) {
}

func handleOrgAddOrUpdate(keyName string, r *http.Request) ([]byte, int) {
var newSession SessionState
newSession := new(SessionState)

if err := json.NewDecoder(r.Body).Decode(&newSession); err != nil {
if err := json.NewDecoder(r.Body).Decode(newSession); err != nil {
log.Error("Couldn't decode new session object: ", err)
return createError("Request malformed"), 400
}
Expand Down Expand Up @@ -1146,8 +1146,8 @@ func createKeyHandler(w http.ResponseWriter, r *http.Request) {
return
}

var newSession SessionState
if err := json.NewDecoder(r.Body).Decode(&newSession); err != nil {
newSession := new(SessionState)
if err := json.NewDecoder(r.Body).Decode(newSession); err != nil {
log.WithFields(logrus.Fields{
"prefix": "api",
"status": "fail",
Expand All @@ -1168,14 +1168,14 @@ func createKeyHandler(w http.ResponseWriter, r *http.Request) {
for apiID := range newSession.AccessRights {
apiSpec := GetSpecForApi(apiID)
if apiSpec != nil {
checkAndApplyTrialPeriod(newKey, apiID, &newSession)
checkAndApplyTrialPeriod(newKey, apiID, newSession)
// If we have enabled HMAC checking for keys, we need to generate a secret for the client to use
if !apiSpec.DontSetQuotasOnCreate {
// Reset quota by default
apiSpec.SessionManager.ResetQuota(newKey, newSession)
newSession.QuotaRenews = time.Now().Unix() + newSession.QuotaRenewalRate
}
err := apiSpec.SessionManager.UpdateSession(newKey, newSession, getLifetime(apiSpec, &newSession))
err := apiSpec.SessionManager.UpdateSession(newKey, newSession, getLifetime(apiSpec, newSession))
if err != nil {
responseMessage := createError("Failed to create key - " + err.Error())
doJSONWrite(w, 403, responseMessage)
Expand Down Expand Up @@ -1209,13 +1209,13 @@ func createKeyHandler(w http.ResponseWriter, r *http.Request) {
}).Warning("No API Access Rights set on key session, adding key to all APIs.")

for _, spec := range ApiSpecRegister {
checkAndApplyTrialPeriod(newKey, spec.APIID, &newSession)
checkAndApplyTrialPeriod(newKey, spec.APIID, newSession)
if !spec.DontSetQuotasOnCreate {
// Reset quote by default
spec.SessionManager.ResetQuota(newKey, newSession)
newSession.QuotaRenews = time.Now().Unix() + newSession.QuotaRenewalRate
}
err := spec.SessionManager.UpdateSession(newKey, newSession, getLifetime(spec, &newSession))
err := spec.SessionManager.UpdateSession(newKey, newSession, getLifetime(spec, newSession))
if err != nil {
responseMessage := createError("Failed to create key - " + err.Error())
doJSONWrite(w, 403, responseMessage)
Expand Down Expand Up @@ -1740,20 +1740,19 @@ func healthCheckhandler(w http.ResponseWriter, r *http.Request) {

func UserRatesCheck() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
sessionState := context.Get(r, SessionData)
if sessionState == nil {
session := ctxGetSession(r)
if session == nil {
responseMessage := createError("Health checks are not enabled for this node")
doJSONWrite(w, 405, responseMessage)
return
}

userSession := sessionState.(SessionState)
returnSession := PublicSessionState{}
returnSession.Quota.QuotaRenews = userSession.QuotaRenews
returnSession.Quota.QuotaRemaining = userSession.QuotaRemaining
returnSession.Quota.QuotaMax = userSession.QuotaMax
returnSession.RateLimit.Rate = userSession.Rate
returnSession.RateLimit.Per = userSession.Per
returnSession.Quota.QuotaRenews = session.QuotaRenews
returnSession.Quota.QuotaRemaining = session.QuotaRemaining
returnSession.Quota.QuotaMax = session.QuotaMax
returnSession.RateLimit.Rate = session.Rate
returnSession.RateLimit.Per = session.Per

responseMessage, err := json.Marshal(returnSession)
if err != nil {
Expand Down Expand Up @@ -1836,3 +1835,17 @@ func ctxSetData(r *http.Request, m map[string]interface{}) {
}
context.Set(r, ContextData, m)
}

func ctxGetSession(r *http.Request) *SessionState {
if v := context.Get(r, SessionData); v != nil {
return v.(*SessionState)
}
return nil
}

func ctxSetSession(r *http.Request, s *SessionState) {
if s == nil {
panic("setting a nil context SessionData")
}
context.Set(r, SessionData, s)
}
33 changes: 25 additions & 8 deletions api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ func TestHealthCheckEndpoint(t *testing.T) {
}
}

func createSampleSession() SessionState {
return SessionState{
func createSampleSession() *SessionState {
return &SessionState{
Rate: 5.0,
Allowance: 5.0,
LastCheck: time.Now().Unix(),
Expand All @@ -119,7 +119,7 @@ func TestApiHandler(t *testing.T) {

for _, uri := range uris {
sampleKey := createSampleSession()
body, _ := json.Marshal(&sampleKey)
body, _ := json.Marshal(sampleKey)

recorder := httptest.NewRecorder()

Expand Down Expand Up @@ -154,7 +154,7 @@ func TestApiHandler(t *testing.T) {
func TestApiHandlerGetSingle(t *testing.T) {
uri := "/tyk/apis/1"
sampleKey := createSampleSession()
body, _ := json.Marshal(&sampleKey)
body, _ := json.Marshal(sampleKey)

recorder := httptest.NewRecorder()

Expand Down Expand Up @@ -272,7 +272,7 @@ func TestKeyHandlerNewKey(t *testing.T) {
for _, api_id := range []string{"1", "none", ""} {
uri := "/tyk/keys/1234"
sampleKey := createSampleSession()
body, _ := json.Marshal(&sampleKey)
body, _ := json.Marshal(sampleKey)

recorder := httptest.NewRecorder()
param := make(url.Values)
Expand Down Expand Up @@ -309,7 +309,7 @@ func TestKeyHandlerUpdateKey(t *testing.T) {
for _, api_id := range []string{"1", "none", ""} {
uri := "/tyk/keys/1234"
sampleKey := createSampleSession()
body, _ := json.Marshal(&sampleKey)
body, _ := json.Marshal(sampleKey)

recorder := httptest.NewRecorder()
param := make(url.Values)
Expand Down Expand Up @@ -378,7 +378,7 @@ func TestKeyHandlerGetKey(t *testing.T) {
func createKey() {
uri := "/tyk/keys/1234"
sampleKey := createSampleSession()
body, _ := json.Marshal(&sampleKey)
body, _ := json.Marshal(sampleKey)

recorder := httptest.NewRecorder()
req, _ := http.NewRequest("POST", uri, bytes.NewReader(body))
Expand Down Expand Up @@ -429,7 +429,7 @@ func TestCreateKeyHandlerCreateNewKey(t *testing.T) {
uri := "/tyk/keys/create"

sampleKey := createSampleSession()
body, _ := json.Marshal(&sampleKey)
body, _ := json.Marshal(sampleKey)

recorder := httptest.NewRecorder()
param := make(url.Values)
Expand Down Expand Up @@ -715,3 +715,20 @@ func TestContextData(t *testing.T) {
}()
ctxSetData(r, nil)
}

func TestContextSession(t *testing.T) {
r := new(http.Request)
if ctxGetSession(r) != nil {
t.Fatal("expected ctxGetSession to return nil")
}
ctxSetSession(r, &SessionState{})
if ctxGetSession(r) == nil {
t.Fatal("expected ctxGetSession to return non-nil")
}
defer func() {
if r := recover(); r == nil {
t.Fatal("expected ctxSetSession of zero val to panic")
}
}()
ctxSetSession(r, nil)
}
8 changes: 4 additions & 4 deletions auth_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ type AuthorisationHandler interface {
// SessionState objects, not identity
type SessionHandler interface {
Init(store StorageHandler)
UpdateSession(keyName string, session SessionState, resetTTLTo int64) error
UpdateSession(keyName string, session *SessionState, resetTTLTo int64) error
RemoveSession(keyName string)
GetSessionDetail(keyName string) (SessionState, bool)
GetSessions(filter string) []string
GetStore() StorageHandler
ResetQuota(string, SessionState)
ResetQuota(string, *SessionState)
}

// DefaultAuthorisationManager implements AuthorisationHandler,
Expand Down Expand Up @@ -87,7 +87,7 @@ func (b *DefaultSessionManager) GetStore() StorageHandler {
return b.Store
}

func (b *DefaultSessionManager) ResetQuota(keyName string, session SessionState) {
func (b *DefaultSessionManager) ResetQuota(keyName string, session *SessionState) {

rawKey := QuotaKeyPrefix + publicHash(keyName)
log.WithFields(logrus.Fields{
Expand All @@ -105,7 +105,7 @@ func (b *DefaultSessionManager) ResetQuota(keyName string, session SessionState)
}

// UpdateSession updates the session state in the storage engine
func (b *DefaultSessionManager) UpdateSession(keyName string, session SessionState, resetTTLTo int64) error {
func (b *DefaultSessionManager) UpdateSession(keyName string, session *SessionState, resetTTLTo int64) error {
if !session.HasChanged() {
log.Debug("Session has not changed, not updating")
return nil
Expand Down
15 changes: 7 additions & 8 deletions coprocess.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,9 @@ func (c *CoProcessor) GetObjectFromRequest(r *http.Request) *coprocess.Object {

// Encode the session object (if not a pre-process & not a custom key check):
if c.HookType != coprocess.HookType_Pre && c.HookType != coprocess.HookType_CustomKeyCheck {
session := context.Get(r, SessionData)
session := ctxGetSession(r)
if session != nil {
sessionState := session.(SessionState)
object.Session = ProtoSessionState(sessionState)
object.Session = ProtoSessionState(session)
}
}

Expand Down Expand Up @@ -294,17 +293,17 @@ func (m *CoProcessMiddleware) ProcessRequest(w http.ResponseWriter, r *http.Requ
return errors.New("Key not authorised"), 403
}

returnedSessionState := TykSessionState(returnObject.Session)
returnedSession := TykSessionState(returnObject.Session)

if extractor == nil {
sessionLifetime := getLifetime(m.Spec, &returnedSessionState)
sessionLifetime := getLifetime(m.Spec, returnedSession)
// This API is not using the ID extractor, but we've got a session:
m.Spec.SessionManager.UpdateSession(authHeaderValue, returnedSessionState, sessionLifetime)
context.Set(r, SessionData, returnedSessionState)
m.Spec.SessionManager.UpdateSession(authHeaderValue, returnedSession, sessionLifetime)
ctxSetSession(r, returnedSession)
context.Set(r, AuthHeaderValue, authHeaderValue)
} else {
// The CP middleware did setup a session, we should pass it to the ID extractor (caching):
extractor.PostProcess(r, returnedSessionState, sessionID)
extractor.PostProcess(r, returnedSession, sessionID)
}
}

Expand Down
Loading

0 comments on commit ac8ddfc

Please sign in to comment.