Skip to content

Commit

Permalink
Merge pull request lestrrat-go#190 from lestrrat-go/topic/coverage
Browse files Browse the repository at this point in the history
Play the coverage game, take 3
  • Loading branch information
lestrrat authored May 7, 2020
2 parents 45a269f + 35a33c4 commit 4aef210
Show file tree
Hide file tree
Showing 25 changed files with 1,046 additions and 322 deletions.
5 changes: 5 additions & 0 deletions Changes
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
Changes
=======

v1.0.2
* Add jwk.AssignKeyID to automatically assign a `kid` field to a JWK
* Fix jwe.Encrypt / jwe.Decrypt to properly look at the `zip` field
* Change jwe.Message accessors to return []byte, not buffer.Buffer

v1.0.1 - 04 May 2020
* Normalize all JWK serialization to use padding-less base64 encoding (#185)
* Fix edge case unmarshaling openid.AddressClaim within a openid.Token
Expand Down
2 changes: 1 addition & 1 deletion internal/base64/base64.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
)

func EncodeToStringStd(src []byte) string {
return base64.StdEncoding.EncodeToString(src)
return base64.RawStdEncoding.EncodeToString(src)
}

func EncodeToString(src []byte) string {
Expand Down
35 changes: 35 additions & 0 deletions jwe/compress.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package jwe

import (
"bytes"
"compress/flate"
"io/ioutil"

"github.com/lestrrat-go/jwx/jwa"
"github.com/pkg/errors"
)

func uncompress(plaintext []byte) ([]byte, error) {
return ioutil.ReadAll(flate.NewReader(bytes.NewReader(plaintext)))
}

func compress(plaintext []byte, alg jwa.CompressionAlgorithm) ([]byte, error) {
if alg == jwa.NoCompress {
return plaintext, nil
}

var output bytes.Buffer
w, _ := flate.NewWriter(&output, 1)
in := plaintext
for len(in) > 0 {
n, err := w.Write(in)
if err != nil {
return nil, errors.Wrap(err, `failed to write to compression writer`)
}
in = in[n:]
}
if err := w.Close(); err != nil {
return nil, errors.Wrap(err, "failed to close compression writer")
}
return output.Bytes(), nil
}
44 changes: 37 additions & 7 deletions jwe/encrypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"context"
"sync"

"github.com/lestrrat-go/jwx/buffer"
"github.com/lestrrat-go/jwx/jwa"
"github.com/lestrrat-go/pdebug"
"github.com/pkg/errors"
)
Expand All @@ -22,6 +24,7 @@ func releaseEncryptCtx(ctx *encryptCtx) {
ctx.contentEncrypter = nil
ctx.generator = nil
ctx.keyEncrypters = nil
ctx.compress = jwa.NoCompress
encryptCtxPool.Put(ctx)
}

Expand All @@ -42,7 +45,14 @@ func (e encryptCtx) Encrypt(plaintext []byte) (*Message, error) {

protected := NewHeaders()
if err := protected.Set(ContentEncryptionKey, e.contentEncrypter.Algorithm()); err != nil {
return nil, errors.Wrap(err, "failed to set enc in protected header")
return nil, errors.Wrap(err, `failed to set "enc" in protected header`)
}

compression := e.compress
if compression != jwa.NoCompress {
if err := protected.Set(CompressionKey, compression); err != nil {
return nil, errors.Wrap(err, `failed to set "zip" in protected header`)
}
}

// In JWE, multiple recipients may exist -- they receive an
Expand Down Expand Up @@ -95,6 +105,11 @@ func (e encryptCtx) Encrypt(plaintext []byte) (*Message, error) {
return nil, errors.Wrap(err, "failed to base64 encode protected headers")
}

plaintext, err = compress(plaintext, compression)
if err != nil {
return nil, errors.Wrap(err, `failed to compress payload before encryption`)
}

// ...on the other hand, there's only one content cipher.
iv, ciphertext, tag, err := e.contentEncrypter.Encrypt(cek, plaintext, aad)
if err != nil {
Expand All @@ -113,14 +128,29 @@ func (e encryptCtx) Encrypt(plaintext []byte) (*Message, error) {
}

msg := NewMessage()
if err := msg.authenticatedData.Base64Decode(aad); err != nil {

decodedAad, err := buffer.FromBase64(aad)
if err != nil {
return nil, errors.Wrap(err, "failed to decode base64")
}
msg.cipherText = ciphertext
msg.initializationVector = iv
msg.protectedHeaders = protected
msg.recipients = recipients
msg.tag = tag
if err := msg.Set(AuthenticatedDataKey, decodedAad.Bytes()); err != nil {
return nil, errors.Wrapf(err, `failed to set %s`, AuthenticatedDataKey)
}
if err := msg.Set(CipherTextKey, ciphertext); err != nil {
return nil, errors.Wrapf(err, `failed to set %s`, CipherTextKey)
}
if err := msg.Set(InitializationVectorKey, iv); err != nil {
return nil, errors.Wrapf(err, `failed to set %s`, InitializationVectorKey)
}
if err := msg.Set(ProtectedHeadersKey, protected); err != nil {
return nil, errors.Wrapf(err, `failed to set %s`, ProtectedHeadersKey)
}
if err := msg.Set(RecipientsKey, recipients); err != nil {
return nil, errors.Wrapf(err, `failed to set %s`, RecipientsKey)
}
if err := msg.Set(TagKey, tag); err != nil {
return nil, errors.Wrapf(err, `failed to set %s`, TagKey)
}

return msg, nil
}
84 changes: 62 additions & 22 deletions jwe/headers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,43 +2,75 @@ package jwe_test

import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"testing"

"github.com/lestrrat-go/jwx/buffer"
"github.com/lestrrat-go/jwx/jwa"
"github.com/lestrrat-go/jwx/jwe"
"github.com/lestrrat-go/jwx/jwk"
"github.com/stretchr/testify/assert"
)

func TestHeaders(t *testing.T) {
t.Run("Set/Get", func(t *testing.T) {
h := jwe.NewHeaders()
rawKey, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
if !assert.NoError(t, err, `ecdsa.GenerateKey should succeed`) {
return
}
privKey, err := jwk.New(rawKey)
if !assert.NoError(t, err, `jwk.New should succeed`) {
return
}

pubKey, err := jwk.New(rawKey.PublicKey)
if !assert.NoError(t, err, `jwk.PublicKey should succeed`) {
return
}

data := []struct {
Key string
Value interface{}
Expected interface{}
}{
{Key: jwe.AgreementPartyUInfoKey, Value: []byte("apu foobarbaz"), Expected: buffer.Buffer("apu foobarbaz")},
{Key: jwe.AgreementPartyVInfoKey, Value: []byte("apv foobarbaz"), Expected: buffer.Buffer("apv foobarbaz")},
{Key: jwe.CompressionKey, Value: jwa.Deflate},
{Key: jwe.ContentEncryptionKey, Value: jwa.A128GCM},
{Key: jwe.ContentTypeKey, Value: "application/json"},
{Key: jwe.CriticalKey, Value: []string{"crit blah"}},
{Key: jwe.EphemeralPublicKeyKey, Value: pubKey},
{Key: jwe.JWKKey, Value: privKey},
{Key: jwe.JWKSetURLKey, Value: "http://github.com/lestrrat-go/jwx"},
{Key: jwe.KeyIDKey, Value: "kid blah"},
{Key: jwe.TypeKey, Value: "typ blah"},
{Key: jwe.X509CertThumbprintKey, Value: "x5t blah"},
{Key: jwe.X509CertThumbprintS256Key, Value: "x5t#256 blah"},
{Key: jwe.X509URLKey, Value: "http://github.com/lestrrat-go/jwx"},
}

data := map[string]struct {
Value interface{}
Expected interface{}
}{
"kid": {Value: "kid blah"},
"enc": {Value: jwa.A128GCM},
"cty": {Value: "application/json"},
"typ": {Value: "typ blah"},
"x5t": {Value: "x5t blah"},
"x5t#256": {Value: "x5t#256 blah"},
"crit": {Value: []string{"crit blah"}},
"jku": {Value: "http://github.com/lestrrat-go/jwx"},
"x5u": {Value: "http://github.com/lestrrat-go/jwx"},
base := jwe.NewHeaders()

t.Run("Set values", func(t *testing.T) {
for _, tc := range data {
if !assert.NoError(t, base.Set(tc.Key, tc.Value), "Headers.Set should succeed") {
return
}
}
})

for name, testcase := range data {
h.Set(name, testcase.Value)
got, ok := h.Get(name)
if !assert.True(t, ok, "value should exist") {
t.Run("Set/Get", func(t *testing.T) {
h := base
for _, tc := range data {
got, ok := h.Get(tc.Key)
if !assert.True(t, ok, "value for %s should exist", tc.Key) {
return
}

expected := testcase.Expected
expected := tc.Expected
if expected == nil {
expected = testcase.Value
expected = tc.Value
}
if !assert.Equal(t, expected, got, "value should match") {
return
Expand Down Expand Up @@ -67,7 +99,15 @@ func TestHeaders(t *testing.T) {

t.Run("Iterator", func(t *testing.T) {
expected := map[string]interface{}{}
v := jwe.NewHeaders()
for _, tc := range data {
v := tc.Value
if expected := tc.Expected; expected != nil {
v = expected
}
expected[tc.Key] = v
}

v := base
t.Run("Iterate", func(t *testing.T) {
seen := make(map[string]interface{})
for iter := v.Iterate(context.TODO()); iter.Next(context.TODO()); {
Expand Down
9 changes: 5 additions & 4 deletions jwe/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ type stdRecipient struct {

// Message contains the entire encrypted JWE message
type Message struct {
authenticatedData buffer.Buffer
cipherText buffer.Buffer
initializationVector buffer.Buffer
authenticatedData *buffer.Buffer
cipherText *buffer.Buffer
initializationVector *buffer.Buffer
protectedHeaders Headers
recipients []Recipient
tag buffer.Buffer
tag *buffer.Buffer
unprotectedHeaders Headers
}

Expand All @@ -49,6 +49,7 @@ type encryptCtx struct {
contentEncrypter contentEncrypter
generator keygen.Generator
keyEncrypters []keyenc.Encrypter
compress jwa.CompressionAlgorithm
}

// populater is an interface for things that may modify the
Expand Down
7 changes: 0 additions & 7 deletions jwe/internal/cipher/cipher.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package cipher
import (
"crypto/aes"
"crypto/cipher"
"crypto/rsa"
"fmt"

"github.com/lestrrat-go/jwx/jwa"
Expand Down Expand Up @@ -185,9 +184,3 @@ func (c AesContentCipher) Decrypt(cek, iv, ciphertxt, tag, aad []byte) (plaintex
plaintext, err = aead.Open(nil, iv, combined, aad)
return
}

func NewRsaContentCipher(alg jwa.ContentEncryptionAlgorithm, pubkey *rsa.PublicKey) (*RsaContentCipher, error) {
return &RsaContentCipher{
pubkey: pubkey,
}, nil
}
6 changes: 0 additions & 6 deletions jwe/internal/cipher/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package cipher

import (
"crypto/cipher"
"crypto/rsa"

"github.com/lestrrat-go/jwx/jwe/internal/keygen"
)
Expand Down Expand Up @@ -33,8 +32,3 @@ type AesContentCipher struct {
keysize int
tagsize int
}

// RsaContentCipher represents a cipher based on RSA
type RsaContentCipher struct {
pubkey *rsa.PublicKey
}
26 changes: 20 additions & 6 deletions jwe/jwe.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ func Encrypt(payload []byte, keyalg jwa.KeyEncryptionAlgorithm, key interface{},
encctx.contentEncrypter = contentcrypt
encctx.generator = keygen.NewRandom(keysize)
encctx.keyEncrypters = []keyenc.Encrypter{enc}
encctx.compress = compressalg
msg, err := encctx.Encrypt(payload)
if err != nil {
if pdebug.Enabled {
Expand Down Expand Up @@ -212,16 +213,29 @@ func parseCompact(buf []byte) (*Message, error) {
}

m := NewMessage()
m.authenticatedData.SetBytes(hdrbuf.Bytes())
m.protectedHeaders = protected
m.tag = tagbuf
m.cipherText = ctbuf
m.initializationVector = ivbuf
m.recipients = []Recipient{
if err := m.Set(AuthenticatedDataKey, hdrbuf.Bytes()); err != nil {
return nil, errors.Wrapf(err, `failed to set %s`, AuthenticatedDataKey)
}
if err := m.Set(CipherTextKey, ctbuf); err != nil {
return nil, errors.Wrapf(err, `failed to set %s`, CipherTextKey)
}
if err := m.Set(InitializationVectorKey, ivbuf); err != nil {
return nil, errors.Wrapf(err, `failed to set %s`, InitializationVectorKey)
}
if err := m.Set(ProtectedHeadersKey, protected); err != nil {
return nil, errors.Wrapf(err, `failed to set %s`, ProtectedHeadersKey)
}

if err := m.Set(RecipientsKey, []Recipient{
&stdRecipient{
headers: hdr,
encryptedKey: enckeybuf,
},
}); err != nil {
return nil, errors.Wrapf(err, `failed to set %s`, RecipientsKey)
}
if err := m.Set(TagKey, tagbuf); err != nil {
return nil, errors.Wrapf(err, `failed to set %s`, TagKey)
}
return m, nil
}
Expand Down
Loading

0 comments on commit 4aef210

Please sign in to comment.