Skip to content

Commit

Permalink
Add tests for getSessionId helper
Browse files Browse the repository at this point in the history
  • Loading branch information
sosedoff committed Feb 26, 2016
1 parent 86f63ee commit c57b477
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 7 deletions.
4 changes: 2 additions & 2 deletions pkg/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ var (

func DB(c *gin.Context) *client.Client {
if command.Opts.Sessions {
return DbSessions[getSessionId(c)]
return DbSessions[getSessionId(c.Request)]
} else {
return DbClient
}
Expand All @@ -39,7 +39,7 @@ func setClient(c *gin.Context, newClient *client.Client) error {
return nil
}

sessionId := getSessionId(c)
sessionId := getSessionId(c.Request)
if sessionId == "" {
return errors.New("Session ID is required")
}
Expand Down
7 changes: 4 additions & 3 deletions pkg/api/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package api
import (
"fmt"
"mime"
"net/http"
"path/filepath"
"strconv"
"strings"
Expand Down Expand Up @@ -70,10 +71,10 @@ func desanitize64(query string) string {
return query
}

func getSessionId(c *gin.Context) string {
id := c.Request.Header.Get("x-session-id")
func getSessionId(req *http.Request) string {
id := req.Header.Get("x-session-id")
if id == "" {
id = c.Request.URL.Query().Get("_session_id")
id = req.URL.Query().Get("_session_id")
}
return id
}
Expand Down
15 changes: 14 additions & 1 deletion pkg/api/helpers_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package api

import (
"github.com/stretchr/testify/assert"
"net/http"
"net/url"
"testing"

"github.com/stretchr/testify/assert"
)

func Test_desanitize64(t *testing.T) {
Expand All @@ -23,3 +26,13 @@ func Test_cleanQuery(t *testing.T) {
assert.Equal(t, "", cleanQuery("--something"))
assert.Equal(t, "test", cleanQuery("--test\ntest\n -- test\n"))
}

func Test_getSessionId(t *testing.T) {
req := &http.Request{Header: http.Header{}}
req.Header.Add("x-session-id", "token")
assert.Equal(t, "token", getSessionId(req))

req = &http.Request{}
req.URL, _ = url.Parse("http://foobar/?_session_id=token")
assert.Equal(t, "token", getSessionId(req))
}
2 changes: 1 addition & 1 deletion pkg/api/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func dbCheckMiddleware() gin.HandlerFunc {
return
}

sessionId := getSessionId(c)
sessionId := getSessionId(c.Request)
if sessionId == "" {
c.JSON(400, Error{"Session ID is required"})
c.Abort()
Expand Down

0 comments on commit c57b477

Please sign in to comment.