Skip to content

Commit

Permalink
jwk.Set is an interface, and has a new API
Browse files Browse the repository at this point in the history
refs lestrrat-go#295

* jwk.Set has an API to mutate and examine its contents.
* jwk.Set is now an interface, forcing users to explicitly initialize it
  properly.
* jwk.Set can now be used concurrently
* jwk.Set marshals correctly into JSON
  • Loading branch information
lestrrat committed Jan 17, 2021
1 parent f7a2cf9 commit 28f8eec
Show file tree
Hide file tree
Showing 18 changed files with 170 additions and 162 deletions.
8 changes: 7 additions & 1 deletion cmd/jwx/jwx.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,14 @@ func doJWK() int {
}

// TODO make it flexible
firstKey, ok := key.Get(0)
if !ok {
log.Printf("empty keyset")
return 0
}

var pubkey interface{}
if err := key.Keys[0].Raw(&pubkey); err != nil {
if err := firstKey.Raw(&pubkey); err != nil {
log.Printf("%s", err)
return 0
}
Expand Down
8 changes: 5 additions & 3 deletions examples/jwt_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func ExampleJWT_ParseJWKS() {
// Then jwt.Parse() will automatically find the matching key

var payload []byte
var keyset *jwk.Set
var keyset jwk.Set
{ // Preparation:
// For demonstration purposes, we need to do some preparation
// Create a JWK key to sign the token (and also give a KeyID)
Expand Down Expand Up @@ -76,7 +76,9 @@ func ExampleJWT_ParseJWKS() {
bogusKey := jwk.NewSymmetricKey()

// This key set contains two keys, the first one is the correct one
keyset = &jwk.Set{Keys: []jwk.Key{pubKey, bogusKey}}
keyset = jwk.NewSet()
keyset.Add(pubKey)
keyset.Add(bogusKey)
}

{ // Actual verification:
Expand Down Expand Up @@ -104,7 +106,7 @@ func ExampleJWT_ParseJWKS() {
// key set. It would be an error if you have multiple keys in the KeySet.

var payload []byte
var keyset *jwk.Set
var keyset jwk.Set
{ // Preparation:
// Unlike our previous example, we DO NOT want to sign the payload.
// Therefore we do NOT set the "kid" value
Expand Down
12 changes: 6 additions & 6 deletions jwe/jwe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -466,11 +466,11 @@ func Test_GHIssue207(t *testing.T) {
for _, tc := range testcases {
tc := tc
t.Run(tc.Name, func(t *testing.T) {
webKeys, err := jwk.ParseString(tc.Key)
if !assert.NoError(t, err, `jwk.ParseString should succeed`) {
webKey, err := jwk.ParseKey([]byte(tc.Key))
if !assert.NoError(t, err, `jwk.ParseKey should succeed`) {
return
}
webKey := webKeys.Keys[0]

thumbprint, err := webKey.Thumbprint(crypto.SHA1)
if !assert.NoError(t, err, `jwk.Thumbprint should succeed`) {
return
Expand Down Expand Up @@ -596,11 +596,11 @@ func TestDecodePredefined_Direct(t *testing.T) {
for _, tc := range testcases {
tc := tc
t.Run(tc.Algorithm.String(), func(t *testing.T) {
webKeys, err := jwk.ParseString(tc.Key)
if !assert.NoError(t, err, `jwk.ParseString should succeed`) {
webKey, err := jwk.ParseKey([]byte(tc.Key))
if !assert.NoError(t, err, `jwk.ParseKey should succeed`) {
return
}
webKey := webKeys.Keys[0]

thumbprint, err := webKey.Thumbprint(crypto.SHA1)
if !assert.NoError(t, err, `jwk.Thumbprint should succeed`) {
return
Expand Down
44 changes: 31 additions & 13 deletions jwk/ecdsa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,16 @@ func TestECDSA(t *testing.T) {
return
}

if !assert.Len(t, set.Keys, 1, `should be 1 key`) {
if !assert.Equal(t, set.Len(), 1, `should be 1 key`) {
return
}

privKey, ok := set.Keys[0].(jwk.ECDSAPrivateKey)
akey, ok := set.Get(0)
if !assert.True(t, ok, `set.Get(0) should succeed`) {
return
}

privKey, ok := akey.(jwk.ECDSAPrivateKey)
if !assert.True(t, ok, `should be jwk.ECDSAPrivateKey`) {
return
}
Expand Down Expand Up @@ -172,26 +177,31 @@ func TestECDSA(t *testing.T) {
}`
expectedPublicKey := `{"crv":"P-256","kty":"EC","x":"MKBCTNIcKUSDii11ySs3526iDZ8AiTo7Tu6KPAqv7D4","y":"4Etl6SRW2YiLUrN5vfvVHuhp7x8PxltmWWlbbM4IFyM"}`

var set jwk.Set
set := jwk.NewSet()
if !assert.NoError(t, json.Unmarshal([]byte(s), &set), "unmarshal(set) should be successful") {
return
}

if _, ok := set.Keys[0].(jwk.ECDSAPrivateKey); !assert.True(t, ok, "first key should be ECDSAPrivateKey") {
akey, ok := set.Get(0)
if !assert.True(t, ok, `set.Get(0) should succeed`) {
return
}

privKey, ok := akey.(jwk.ECDSAPrivateKey)
if !assert.True(t, ok, `should be jwk.ECDSAPrivateKey`) {
return
}
key := set.Keys[0].(jwk.ECDSAPrivateKey)

var rawKey ecdsa.PrivateKey
if !assert.NoError(t, key.Raw(&rawKey), `materialize should succeed`) {
if !assert.NoError(t, privKey.Raw(&rawKey), `materialize should succeed`) {
return
}

if !assert.Equal(t, jwa.P256, key.Crv(), `curve name should match`) {
if !assert.Equal(t, jwa.P256, privKey.Crv(), `curve name should match`) {
return
}

pubKey, err := key.PublicKey()
pubKey, err := privKey.PublicKey()
if !assert.NoError(t, err, `should PublicKey succeed`) {
return
}
Expand Down Expand Up @@ -244,11 +254,19 @@ func TestECDSA(t *testing.T) {
if err != nil {
t.Fatal("Failed to parse JWK ECDSA")
}
ECDSAPrivateKey := set.Keys[0].(jwk.ECDSAPrivateKey)
akey, ok := set.Get(0)
if !assert.True(t, ok, `set.Get(0) should succeed`) {
return
}

privKeyBytes, err := json.Marshal(ECDSAPrivateKey)
if err != nil {
t.Fatal("Failed to marshal ECDSAPrivateKey")
privKey, ok := akey.(jwk.ECDSAPrivateKey)
if !assert.True(t, ok, `should be jwk.ECDSAPrivateKey`) {
return
}

privKeyBytes, err := json.Marshal(privKey)
if !assert.NoError(t, err, `json.Marshal should succeed`) {
return
}
// verify marshal

Expand All @@ -265,7 +283,7 @@ func TestECDSA(t *testing.T) {
return
}

if !assert.Equal(t, expECDSAPrivateKey, ECDSAPrivateKey, "ECDSAPrivate keys should match") {
if !assert.Equal(t, expECDSAPrivateKey, privKey, "ECDSAPrivate keys should match") {
return
}
})
Expand Down
21 changes: 17 additions & 4 deletions jwk/interface.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package jwk

import (
"context"
"crypto/x509"
"sync"

"github.com/lestrrat-go/iter/arrayiter"
"github.com/lestrrat-go/iter/mapiter"
Expand Down Expand Up @@ -38,10 +40,21 @@ const (
KeyOpDeriveBits KeyOperation = "deriveBits" // (derive bits not to be used as a key)
)

// Set is a convenience struct to allow generating and parsing
// JWK sets as opposed to single JWKs
type Set struct {
Keys []Key
// Set represents JWKS object, a collection of jwk.Key objects
type Set interface {
Add(Key) bool
Clear()
Get(int) (Key, bool)
Index(Key) int
Len() int
LookupKeyID(string) (Key, bool)
Remove(Key) bool
Iterate(context.Context) KeyIterator
}

type set struct {
keys []Key
mu sync.RWMutex
}

type HeaderVisitor = iter.MapVisitor
Expand Down
83 changes: 9 additions & 74 deletions jwk/jwk.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import (
"os"
"strings"

"github.com/lestrrat-go/iter/arrayiter"
"github.com/lestrrat-go/jwx/internal/base64"
"github.com/lestrrat-go/jwx/internal/json"
"github.com/lestrrat-go/jwx/jwa"
Expand Down Expand Up @@ -170,7 +169,7 @@ func PublicKeyOf(v interface{}) (interface{}, error) {
}

// Fetch fetches a JWK resource specified by a URL
func Fetch(urlstring string, options ...Option) (*Set, error) {
func Fetch(urlstring string, options ...Option) (Set, error) {
u, err := url.Parse(urlstring)
if err != nil {
return nil, errors.Wrap(err, `failed to parse url`)
Expand All @@ -192,12 +191,12 @@ func Fetch(urlstring string, options ...Option) (*Set, error) {
}

// FetchHTTP wraps FetchHTTPWithContext using the background context.
func FetchHTTP(jwkurl string, options ...Option) (*Set, error) {
func FetchHTTP(jwkurl string, options ...Option) (Set, error) {
return FetchHTTPWithContext(context.Background(), jwkurl, options...)
}

// FetchHTTPWithContext fetches the remote JWK and parses its contents
func FetchHTTPWithContext(ctx context.Context, jwkurl string, options ...Option) (*Set, error) {
func FetchHTTPWithContext(ctx context.Context, jwkurl string, options ...Option) (Set, error) {
httpcl := http.DefaultClient
for _, option := range options {
switch option.Ident() {
Expand Down Expand Up @@ -287,33 +286,6 @@ func ParseKey(data []byte) (Key, error) {
return key, nil
}

func (s *Set) UnmarshalJSON(data []byte) error {
var proxy struct {
Keys []json.RawMessage `json:"keys"`
}

if err := json.Unmarshal(data, &proxy); err != nil {
return errors.Wrap(err, `failed to unmarshal into Key (proxy)`)
}

if len(proxy.Keys) == 0 {
k, err := ParseKey(data)
if err != nil {
return errors.Wrap(err, `failed to unmarshal key from JSON headers`)
}
s.Keys = append(s.Keys, k)
} else {
for i, buf := range proxy.Keys {
k, err := ParseKey([]byte(buf))
if err != nil {
return errors.Wrapf(err, `failed to unmarshal key #%d (total %d) from multi-key JWK set`, i+1, len(proxy.Keys))
}
s.Keys = append(s.Keys, k)
}
}
return nil
}

// Parse parses JWK from the incoming io.Reader. This function can handle
// both single-key and multi-key formats. If you know before hand which
// format the incoming data is in, you might want to consider using
Expand All @@ -323,12 +295,12 @@ func (s *Set) UnmarshalJSON(data []byte) error {
//
// Parse will be removed in v1.1.0.
// v1.1.0 will introduce `Parse([]byte)` and `ParseReader(`io.Reader`)
func Parse(in io.Reader) (*Set, error) {
var s Set
if err := json.NewDecoder(in).Decode(&s); err != nil {
func Parse(in io.Reader) (Set, error) {
s := NewSet()
if err := json.NewDecoder(in).Decode(s); err != nil {
return nil, errors.Wrap(err, "failed to unmarshal JWK")
}
return &s, nil
return s, nil
}

// ParseBytes parses JWK from the incoming byte buffer.
Expand All @@ -337,7 +309,7 @@ func Parse(in io.Reader) (*Set, error) {
//
// ParseBytes will be removed in v1.1.0.
// v1.1.0 will introduce `Parse([]byte)` and `ParseReader(`io.Reader`)
func ParseBytes(buf []byte) (*Set, error) {
func ParseBytes(buf []byte) (Set, error) {
return Parse(bytes.NewReader(buf))
}

Expand All @@ -347,47 +319,10 @@ func ParseBytes(buf []byte) (*Set, error) {
//
// ParseString will be removed in v1.1.0.
// v1.1.0 will introduce `Parse([]byte)` and `ParseReader(`io.Reader`)
func ParseString(s string) (*Set, error) {
func ParseString(s string) (Set, error) {
return Parse(strings.NewReader(s))
}

// LookupKeyID looks for keys matching the given key id. Note that the
// Set *may* contain multiple keys with the same key id
func (s Set) LookupKeyID(kid string) []Key {
var keys []Key
for iter := s.Iterate(context.TODO()); iter.Next(context.TODO()); {
pair := iter.Pair()
key := pair.Value.(Key)
if key.KeyID() == kid {
keys = append(keys, key)
}
}
return keys
}

func (s *Set) Len() int {
return len(s.Keys)
}

func (s *Set) Iterate(ctx context.Context) KeyIterator {
ch := make(chan *KeyPair, s.Len())
go iterate(ctx, s.Keys, ch)
return arrayiter.New(ch)
}

func iterate(ctx context.Context, keys []Key, ch chan *KeyPair) {
defer close(ch)

for i, key := range keys {
pair := &KeyPair{Index: i, Value: key}
select {
case <-ctx.Done():
return
case ch <- pair:
}
}
}

// AssignKeyID is a convenience function to automatically assign the "kid"
// section of the key, if it already doesn't have one. It uses Key.Thumbprint
// method with crypto.SHA256 as the default hashing algorithm
Expand Down
Loading

0 comments on commit 28f8eec

Please sign in to comment.