Skip to content

Commit

Permalink
Implement jwe.Encrypt/jwe.Decrypt
Browse files Browse the repository at this point in the history
  • Loading branch information
lestrrat committed Nov 15, 2015
1 parent 983ac9c commit 42b4c03
Show file tree
Hide file tree
Showing 10 changed files with 209 additions and 134 deletions.
9 changes: 7 additions & 2 deletions internal/debug/debug_on.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,13 @@

package debug

import "log"
import (
"log"
"os"
)

var logger = log.New(os.Stdout, "|DEBUG| ", 0)

func Printf(f string, args ...interface{}) {
log.Printf(f, args...)
logger.Printf(f, args...)
}
7 changes: 4 additions & 3 deletions jwa/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ type KeyType string
const (
EC KeyType = "EC" // Elliptic Curve
RSA KeyType = "RSA" // RSA
OctetSeq KeyType = "oct" // Octet sequence (used to represent symmetric keys)
OctetSeq KeyType = "oct" // Octet sequence (used to represent symmetric keys)
)

type EllipticCurveAlgorithm string
Expand Down Expand Up @@ -63,6 +63,7 @@ const (
// ContentEncryptionAlgorithm represents the various encryption
// algorithms as described in https://tools.ietf.org/html/rfc7518#section-5
type ContentEncryptionAlgorithm string

const (
A128CBC_HS256 ContentEncryptionAlgorithm = "A128CBC-HS256" // AES-CBC + HMAC-SHA256 (128)
A192CBC_HS384 ContentEncryptionAlgorithm = "A192CBC-HS384" // AES-CBC + HMAC-SHA384 (192)
Expand All @@ -77,6 +78,6 @@ const (
type CompressionAlgorithm string

const (
NoCompression CompressionAlgorithm = "" // No compression
Deflate CompressionAlgorithm = "DEF" // DEFLATE (RFC 1951)
NoCompress CompressionAlgorithm = "" // No compression
Deflate CompressionAlgorithm = "DEF" // DEFLATE (RFC 1951)
)
6 changes: 3 additions & 3 deletions jwe/aescbc/aescbc.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ func New(key []byte, f BlockCipherFunc) (*AesCbcHmac, error) {
ikey := key[:keysize]
ekey := key[keysize:]

debug.Printf("New: cek (key) = %x\n", key)
debug.Printf("New: ikey = %x\n", ikey)
debug.Printf("New: ekey = %x\n", ekey)
debug.Printf("New: cek (key) = %x (%d)\n", key, len(key))
debug.Printf("New: ikey = %x (%d)\n", ikey, len(ikey))
debug.Printf("New: ekey = %x (%d)\n", ekey, len(ekey))

bc, err := f(ekey)
if err != nil {
Expand Down
1 change: 1 addition & 0 deletions jwe/cipher.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ var GcmAeadFetch = AeadFetchFunc(func(key []byte) (cipher.AEAD, error) {
aescipher, err := aes.NewCipher(key)
if err != nil {
debug.Printf("GcmAeadFetch: failed to create cipher")
panic(err)
return nil, err
}

Expand Down
8 changes: 6 additions & 2 deletions jwe/doc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@ func ExampleEncrypt() {
return
}

k := NewRSAKeyEncrypt(jwa.RSA1_5, &privkey.PublicKey)
k, err := NewRSAPKCSKeyEncrypt(jwa.RSA1_5, &privkey.PublicKey)
if err != nil {
log.Printf("failed to create key encrypter: %s", err)
return
}
kg := NewRandomKeyGenerate(c.KeySize())

e := NewMultiEncrypt(c, kg, k)
Expand All @@ -31,7 +35,7 @@ func ExampleEncrypt() {
return
}

decrypted, err := DecryptMessage(msg, privkey)
decrypted, err := DecryptMessage(msg, jwa.RSA1_5, privkey)
if err != nil {
log.Printf("failed to decrypt: %s", err)
return
Expand Down
6 changes: 6 additions & 0 deletions jwe/encrypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ func NewMultiEncrypt(cc ContentEncrypter, kg KeyGenerator, ke ...KeyEncrypter) *
func (e MultiEncrypt) Encrypt(plaintext []byte) (*Message, error) {
cek, err := e.KeyGenerator.KeyGenerate()
if err != nil {
debug.Printf("Failed to generate key: %s", err)
return nil, err
}
debug.Printf("Encrypt: generated cek len = %d", len(cek))
Expand All @@ -37,6 +38,7 @@ func (e MultiEncrypt) Encrypt(plaintext []byte) (*Message, error) {
}
enckey, err := enc.KeyEncrypt(cek)
if err != nil {
debug.Printf("Failed to encrypt key: %s", err)
return nil, err
}
r.EncryptedKey = enckey
Expand All @@ -57,6 +59,10 @@ func (e MultiEncrypt) Encrypt(plaintext []byte) (*Message, error) {

// ...on the other hand, there's only one content cipher.
iv, ciphertext, tag, err := e.ContentEncrypter.Encrypt(cek, plaintext, aad)
if err != nil {
debug.Printf("Failed to encrypt: %s", err)
return nil, err
}

debug.Printf("Encrypt.Encrypt: cek = %x (%d)", cek, len(cek))
debug.Printf("Encrypt.Encrypt: aad = %x", aad)
Expand Down
9 changes: 8 additions & 1 deletion jwe/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,15 @@ type RSAPKCS15KeyDecrypt struct {
generator KeyGenerator
}

type RSAKeyEncrypt struct {
type RSAOAEPKeyEncrypt struct {
alg jwa.KeyEncryptionAlgorithm
pubkey *rsa.PublicKey
KeyID string
}

type RSAPKCSKeyEncrypt struct {
alg jwa.KeyEncryptionAlgorithm
pubkey *rsa.PublicKey
KeyID string
}

91 changes: 74 additions & 17 deletions jwe/jwe.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,45 +16,102 @@ import (
)

func Encrypt(payload []byte, keyalg jwa.KeyEncryptionAlgorithm, key interface{}, contentalg jwa.ContentEncryptionAlgorithm, compressalg jwa.CompressionAlgorithm) ([]byte, error) {
contentcrypt, err := NewAesCrypt(contentalg)
if err != nil {
return nil, err
}

var keyenc KeyEncrypter
var keysize int
switch keyalg {
case jwa.RSA1_5, jwa.RSA_OAEP, jwa.RSA_OAEP_256:
case jwa.RSA1_5:
pubkey, ok := key.(*rsa.PublicKey)
if !ok {
return nil, errors.New("invalid key: *rsa.PublicKey required")
}
keyenc = NewRSAKeyEncrypt(keyalg, pubkey)
keyenc, err = NewRSAPKCSKeyEncrypt(keyalg, pubkey)
if err != nil {
return nil, err
}
keysize = contentcrypt.KeySize()/2
case jwa.RSA_OAEP, jwa.RSA_OAEP_256:
pubkey, ok := key.(*rsa.PublicKey)
if !ok {
return nil, errors.New("invalid key: *rsa.PublicKey required")
}
keyenc, err = NewRSAOAEPKeyEncrypt(keyalg, pubkey)
if err != nil {
return nil, err
}
keysize = contentcrypt.KeySize()/2
case jwa.A128KW, jwa.A192KW, jwa.A256KW:
sharedkey, ok := key.([]byte)
if !ok {
return nil, errors.New("invalid key: []byte required")
}
kwenc, err := NewAesKeyWrap(keyalg, sharedkey)
keyenc, err = NewAesKeyWrap(keyalg, sharedkey)
if err != nil {
return nil, err
}
keyenc = kwenc
keysize = contentcrypt.KeySize()
case jwa.ECDH_ES, jwa.ECDH_ES_A128KW, jwa.ECDH_ES_A192KW, jwa.ECDH_ES_A256KW:
fallthrough
case jwa.A128GCMKW, jwa.A192GCMKW, jwa.A256GCMKW:
fallthrough
case jwa.PBES2_HS256_A128KW, jwa.PBES2_HS384_A192KW, jwa.PBES2_HS512_A256KW:
fallthrough
default:
debug.Printf("Encrypt: unknown key encryption algorithm: %s", keyalg)
return nil, ErrUnsupportedAlgorithm
}

contentcrypt, err := NewAesCrypt(contentalg)
enc := NewMultiEncrypt(contentcrypt, NewRandomKeyGenerate(keysize), keyenc)
msg, err := enc.Encrypt(payload)
if err != nil {
debug.Printf("Encrypt: failed to encrypt: %s", err)
return nil, err
}
enc := NewMultiEncrypt(contentcrypt, NewRandomKeyGenerate(contentcrypt.KeySize()), keyenc)
msg, err := enc.Encrypt(payload)

return CompactSerialize{}.Serialize(msg)
}

func Decrypt(buf []byte, alg jwa.KeyEncryptionAlgorithm, key interface{}) ([]byte, error) {
/*
var keydec KeyEncrypter
switch keyalg {
case jwa.RSA1_5, jwa.RSA_OAEP, jwa.RSA_OAEP_256:
pubkey, ok := key.(*rsa.PrivateKey)
if !ok {
return nil, errors.New("invalid key: *rsa.PrivateKey required")
}
keydec = NewRSAKeyDecrypt(keyalg, pubkey)
case jwa.A128KW, jwa.A192KW, jwa.A256KW:
sharedkey, ok := key.([]byte)
if !ok {
return nil, errors.New("invalid key: []byte required")
}
kwenc, err := NewAesKeyWrap(keyalg, sharedkey)
if err != nil {
return nil, err
}
keydec = kwenc
case jwa.ECDH_ES, jwa.ECDH_ES_A128KW, jwa.ECDH_ES_A192KW, jwa.ECDH_ES_A256KW:
fallthrough
case jwa.A128GCMKW, jwa.A192GCMKW, jwa.A256GCMKW:
fallthrough
case jwa.PBES2_HS256_A128KW, jwa.PBES2_HS384_A192KW, jwa.PBES2_HS512_A256KW:
fallthrough
default:
return nil, ErrUnsupportedAlgorithm
}
*/

msg, err := Parse(buf)
if err != nil {
return nil, err
}

return CompactSerialize{}.Serialize(msg)
return DecryptMessage(msg, alg, key)
}

func Parse(buf []byte) (*Message, error) {
Expand Down Expand Up @@ -187,13 +244,13 @@ func BuildKeyDecrypter(alg jwa.KeyEncryptionAlgorithm, key interface{}, keysize
if !ok {
return nil, errors.New("*rsa.PrivateKey is required as the key to build this key decrypter")
}
return NewRSAPKCS15KeyDecrypt(alg, privkey, keysize), nil
return NewRSAPKCS15KeyDecrypt(alg, privkey, keysize/2), nil
case jwa.RSA_OAEP, jwa.RSA_OAEP_256:
privkey, ok := key.(*rsa.PrivateKey)
if !ok {
return nil, errors.New("*rsa.PrivateKey is required as the key to build this key decrypter")
}
return NewRSAOAEPKeyDecrypt(alg, privkey), nil
return NewRSAOAEPKeyDecrypt(alg, privkey)
case jwa.A128KW, jwa.A192KW, jwa.A256KW:
sharedkey, ok := key.([]byte)
if !ok {
Expand All @@ -214,7 +271,7 @@ func BuildContentCipher(alg jwa.ContentEncryptionAlgorithm) (ContentCipher, erro
return nil, ErrUnsupportedAlgorithm
}

func DecryptMessage(m *Message, key interface{}) ([]byte, error) {
func DecryptMessage(m *Message, alg jwa.KeyEncryptionAlgorithm, key interface{}) ([]byte, error) {
var err error

if len(m.Recipients) == 0 {
Expand Down Expand Up @@ -258,6 +315,10 @@ func DecryptMessage(m *Message, key interface{}) ([]byte, error) {

var plaintext []byte
for _, recipient := range m.Recipients {
if recipient.Header.Algorithm != alg {
continue
}

h2 := NewHeader()
if err := h2.Copy(h); err != nil {
debug.Printf("failed to copy header: %s", err)
Expand All @@ -280,19 +341,15 @@ func DecryptMessage(m *Message, key interface{}) ([]byte, error) {
cek, err := k.KeyDecrypt(recipient.EncryptedKey.Bytes())
if err != nil {
debug.Printf("failed to decrypt key: %s", err)
return nil, errors.New("failed to decrypt key")
continue
}

debug.Printf("DecryptMessage: cek = %x (%d)", cek, len(cek))
debug.Printf("DecryptMessage: iv = %x", iv)
debug.Printf("DecryptMessage: ciphertext = %x", ciphertext)
debug.Printf("DecryptMessage: tag = %x", tag)
debug.Printf("DecryptMessage: aad = %x", aad)
plaintext, err = cipher.decrypt(cek, iv, ciphertext, tag, aad)
if err == nil {
break
}
debug.Printf("DecryptMessage: cipher.decrypt: %s", err)
debug.Printf("DecryptMessage: failed to decrypt using %s: %s", h2.Algorithm, err)
}

if plaintext == nil {
Expand Down
Loading

0 comments on commit 42b4c03

Please sign in to comment.