Skip to content

Commit

Permalink
Merge pull request lestrrat-go#99 from lestrrat-go/topic/coverage
Browse files Browse the repository at this point in the history
coverage for jwk package
  • Loading branch information
lestrrat authored Apr 12, 2019
2 parents 985f89f + 6894ef5 commit 88d7d7d
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 40 deletions.
31 changes: 17 additions & 14 deletions jwk/jwk.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@
package jwk

import (
"bytes"
"crypto/ecdsa"
"crypto/rsa"
"encoding/json"
"io"
"io/ioutil"
"net/http"
"net/url"
"os"
"strings"

"github.com/lestrrat-go/jwx/internal/base64"
"github.com/lestrrat-go/jwx/jwa"
Expand Down Expand Up @@ -71,7 +74,6 @@ func Fetch(urlstring string, options ...Option) (*Set, error) {
return nil, errors.Wrap(err, `failed to parse url`)
}

var src []byte
switch u.Scheme {
case "http", "https":
return FetchHTTP(urlstring, options...)
Expand All @@ -86,12 +88,9 @@ func Fetch(urlstring string, options ...Option) (*Set, error) {
if err != nil {
return nil, errors.Wrap(err, `failed read content from jwk file`)
}
src = buf
default:
return nil, errors.Errorf(`invalid url scheme %s`, u.Scheme)
return ParseBytes(buf)
}

return Parse(src)
return nil, errors.Errorf(`invalid url scheme %s`, u.Scheme)
}

// FetchHTTP fetches the remote JWK and parses its contents
Expand All @@ -114,28 +113,27 @@ func FetchHTTP(jwkurl string, options ...Option) (*Set, error) {
return nil, errors.New("failed to fetch remote JWK (status != 200)")
}

// XXX Check for maximum length to read?
buf, err := ioutil.ReadAll(res.Body)
if err != nil {
return nil, errors.Wrap(err, "failed to read JWK HTTP response body")
}

return Parse(buf)
return ParseBytes(buf)
}

func (set *Set) UnmarshalJSON(data []byte) error {
v, err := Parse(data)
v, err := ParseBytes(data)
if err != nil {
return errors.Wrap(err, `failed to parse jwk.Set`)
}
*set = *v
return nil
}

// Parse parses JWK from the incoming byte buffer.
func Parse(buf []byte) (*Set, error) {
// Parse parses JWK from the incoming io.Reader.
func Parse(in io.Reader) (*Set, error) {
m := make(map[string]interface{})
if err := json.Unmarshal(buf, &m); err != nil {
if err := json.NewDecoder(in).Decode(&m); err != nil {
return nil, errors.Wrap(err, "failed to unmarshal JWK")
}

Expand All @@ -158,9 +156,14 @@ func Parse(buf []byte) (*Set, error) {
return &Set{Keys: []Key{k}}, nil
}

// ParseBytes parses JWK from the incoming byte buffer.
func ParseBytes(buf []byte) (*Set, error) {
return Parse(bytes.NewReader(buf))
}

// ParseString parses JWK from the incoming string.
func ParseString(s string) (*Set, error) {
return Parse([]byte(s))
return Parse(strings.NewReader(s))
}

// LookupKeyID looks for keys matching the given key id. Note that the
Expand Down Expand Up @@ -205,7 +208,7 @@ func (s *Set) ExtractMap(m map[string]interface{}) error {
}

func constructKey(m map[string]interface{}) (Key, error) {
kty, ok := m["kty"].(string)
kty, ok := m[KeyTypeKey].(string)
if !ok {
return nil, errors.Errorf(`unsupported kty type %T`, m[KeyTypeKey])
}
Expand Down
128 changes: 102 additions & 26 deletions jwk/jwk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ import (
"crypto/rsa"
"encoding/json"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"testing"

"github.com/lestrrat-go/jwx/internal/base64"
Expand All @@ -18,6 +20,16 @@ import (
"github.com/stretchr/testify/assert"
)

func TestNew(t *testing.T) {
k, err := jwk.New(nil)
if !assert.Nil(t, k, "key should be nil") {
return
}
if !assert.Error(t, err, "nil key should cause an error") {
return
}
}

func TestParse(t *testing.T) {
verify := func(t *testing.T, src string, expected interface{}) {
t.Run("json.Unmarshal", func(t *testing.T) {
Expand All @@ -36,7 +48,7 @@ func TestParse(t *testing.T) {
}
})
t.Run("jwk.Parse", func(t *testing.T) {
set, err := jwk.Parse([]byte(src))
set, err := jwk.ParseBytes([]byte(src))
if !assert.NoError(t, err, `jwk.Parse should succeed`) {
return
}
Expand All @@ -48,6 +60,18 @@ func TestParse(t *testing.T) {
if !assert.IsType(t, expected, key, "key should be a jwk.RSAPublicKey") {
return
}

switch key := key.(type) {
case *jwk.RSAPrivateKey, *jwk.ECDSAPrivateKey:
realKey, err := key.(jwk.Key).Materialize()
if !assert.NoError(t, err, "failed to get underlying private key") {
return
}

if _, err := jwk.GetPublicKey(realKey); !assert.NoError(t, err, `failed to get public key from underlying private key`) {
return
}
}
}
})
}
Expand Down Expand Up @@ -76,6 +100,28 @@ func TestParse(t *testing.T) {
}`
verify(t, src, &jwk.RSAPrivateKey{})
})
t.Run("ECDSA Private Key", func(t *testing.T) {
const src = `{
"kty" : "EC",
"crv" : "P-256",
"x" : "SVqB4JcUD6lsfvqMr-OKUNUphdNn64Eay60978ZlL74",
"y" : "lf0u0pMj4lGAzZix5u4Cm5CMQIgMNpkwy163wtKYVKI",
"d" : "0g5vAEKzugrXaRbgKG0Tj2qJ5lMP4Bezds1_sTybkfk"
}`
verify(t, src, &jwk.ECDSAPrivateKey{})
})
t.Run("Invalid ECDSA Private Key", func(t *testing.T) {
const src = `{
"kty" : "EC",
"crv" : "P-256",
"y" : "lf0u0pMj4lGAzZix5u4Cm5CMQIgMNpkwy163wtKYVKI",
"d" : "0g5vAEKzugrXaRbgKG0Tj2qJ5lMP4Bezds1_sTybkfk"
}`
_, err := jwk.ParseString(src)
if !assert.Error(t, err, `jwk.ParseString should fail`) {
return
}
})
}

func TestRoundtrip(t *testing.T) {
Expand Down Expand Up @@ -192,7 +238,7 @@ func TestRoundtrip(t *testing.T) {
return
}

ks2, err := jwk.Parse(buf)
ks2, err := jwk.ParseBytes(buf)
if !assert.NoError(t, err, "JSON unmarshal succeeded") {
t.Logf("%s", buf)
return
Expand Down Expand Up @@ -306,7 +352,7 @@ func TestAppendix(t *testing.T) {
]
}`)

set, err := jwk.Parse(jwksrc)
set, err := jwk.ParseBytes(jwksrc)
if !assert.NoError(t, err, "Parse should succeed") {
return
}
Expand Down Expand Up @@ -354,7 +400,7 @@ func TestAppendix(t *testing.T) {
"kid":"HMAC key used in JWS spec Appendix A.1 example"}
]
}`)
set, err := jwk.Parse(jwksrc)
set, err := jwk.ParseBytes(jwksrc)
if !assert.NoError(t, err, "Parse should succeed") {
return
}
Expand Down Expand Up @@ -420,7 +466,7 @@ func TestAppendix(t *testing.T) {
]
}]}`)

set, err := jwk.Parse(jwksrc)
set, err := jwk.ParseBytes(jwksrc)
if !assert.NoError(t, err, "Parse should succeed") {
return
}
Expand Down Expand Up @@ -458,30 +504,60 @@ func TestFetch(t *testing.T) {
]
}`

srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/":
w.WriteHeader(http.StatusOK)
io.WriteString(w, jwksrc)
default:
w.WriteHeader(http.StatusNotFound)
verify := func(t *testing.T, set *jwk.Set) {
key, ok := set.Keys[0].(*jwk.ECDSAPublicKey)
if !assert.True(t, ok, "set.Keys[0] should be a EcdsaPublicKey") {
return
}
}))
defer srv.Close()

cl := srv.Client()

set, err := jwk.Fetch(srv.URL, jwk.WithHTTPClient(cl))
if !assert.NoError(t, err, `failed to fetch jwk`) {
return
if !assert.Equal(t, jwa.P256, key.Curve(), "curve is P-256") {
return
}
}
t.Run("HTTP", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/":
w.WriteHeader(http.StatusOK)
io.WriteString(w, jwksrc)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer srv.Close()

key, ok := set.Keys[0].(*jwk.ECDSAPublicKey)
if !assert.True(t, ok, "set.Keys[0] should be a EcdsaPublicKey") {
return
}
cl := srv.Client()

if !assert.Equal(t, jwa.P256, key.Curve(), "curve is P-256") {
return
}
set, err := jwk.Fetch(srv.URL, jwk.WithHTTPClient(cl))
if !assert.NoError(t, err, `failed to fetch jwk`) {
return
}
verify(t, set)
})
t.Run("Local File", func(t *testing.T) {
f, err := ioutil.TempFile("", "jwk-fetch-test")
if !assert.NoError(t, err, `failed to generate temporary file`) {
return
}
defer f.Close()
defer os.Remove(f.Name())

io.WriteString(f, jwksrc)
f.Sync()

set, err := jwk.Fetch("file://" + f.Name())
if !assert.NoError(t, err, `failed to fetch jwk`) {
return
}
verify(t, set)
})
t.Run("Invalid Scheme", func(t *testing.T) {
set, err := jwk.Fetch("gopher://foo/bar")
if !assert.Nil(t, set, `set should be nil`) {
return
}
if !assert.Error(t, err, `invalid sche,e should be an error`) {
return
}
})
}

0 comments on commit 88d7d7d

Please sign in to comment.