Skip to content

Commit

Permalink
Merge pull request CZ-NIC#379 from rohe/kmk
Browse files Browse the repository at this point in the history
Kmk
  • Loading branch information
rohe authored Jun 22, 2017
2 parents ccc51cc + 95fee45 commit d974e5a
Show file tree
Hide file tree
Showing 3 changed files with 176 additions and 15 deletions.
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,19 @@ The format is based on the [KeepAChangeLog] project.
- [#374]: Made the to_jwe/from_jwe methods of Message accept list of keys value of parameter keys.
- [#324]: Make the Provider `symkey` argument optional.

### Fixed
- [#369]: The AuthnEvent object is now serialized to JSON for the session.
- [#373]: Made the standard way the default when dealing with signed JWTs without 'kid'. Added the possibility to override this behavior if necessary.

### Security
- [#363]: Fixed IV reuse for CookieDealer class. Replaced the encrypt-then-mac construction with a proper AEAD (AES-SIV).

[#324]: https://github.com/OpenIDC/pyoidc/pull/324
[#369]: https://github.com/OpenIDC/pyoidc/pull/369
[#363]: https://github.com/OpenIDC/pyoidc/issue/363

## 0.10.0.1 [UNRELEASED]

### Fixed
- [#362]: Fix bad package settings URL
- [#358]: Fixed claims_match
Expand Down
74 changes: 59 additions & 15 deletions src/oic/oauth2/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,28 +473,63 @@ def to_jwt(self, key=None, algorithm="", lev=0):
_jws = JWS(self.to_json(lev), alg=algorithm)
return _jws.sign_compact(key)

def _add_key(self, keyjar, issuer, key, key_type=''):
def _add_key(self, keyjar, issuer, key, key_type='', kid='',
no_kid_issuer=None):

try:
logger.debug('Key set summary for {}: {}'.format(
issuer, key_summary(keyjar, issuer)))
except KeyError:
logger.error('Issuer "{}" not in keyjar'.format(issuer))

try:
kl = keyjar.get_verify_key(owner=issuer, key_type=key_type)
except KeyError:
pass
if kid:
_key = keyjar.get_key_by_kid(kid, issuer)
if _key and _key not in key:
key.append(_key)
return
else:
for k in kl:
if k not in key:
key.append(k)
try:
kl = keyjar.get_verify_key(owner=issuer, key_type=key_type)
except KeyError:
pass
else:
if len(kl) == 1:
if kl[0] not in key:
key.append(kl[0])
elif no_kid_issuer:
try:
allowed_kids = no_kid_issuer[issuer]
except KeyError:
return
else:
if allowed_kids:
key.extend([k for k in kl if k.kid in allowed_kids])
else:
key.extend(kl)

def get_verify_keys(self, keyjar, key, jso, header, jwt, **kwargs):
"""
Get keys from a keyjar that can be used to verify a signed JWT
:param keyjar: A KeyJar instance
:param key: List of keys to start with
:param jso: The payload of the JWT, expected to be a dictionary.
:param header: The header of the JWT
:param jwt: A jwkest.jwt.JWT instance
:param kwargs: Other key word arguments
:return: list of usable keys
"""
try:
_kid = header['kid']
except KeyError:
_kid = ''

try:
_iss = jso["iss"]
except KeyError:
pass
else:
# First extend the keyjar if allowed
if "jku" in header:
if not keyjar.find(header["jku"], _iss):
# This is really questionable
Expand All @@ -505,25 +540,34 @@ def get_verify_keys(self, keyjar, key, jso, header, jwt, **kwargs):
except KeyError:
pass

if "kid" in header and header["kid"]:
jwt["kid"] = header["kid"]
# If there is a kid and a key is found with that kid at the issuer
# then I'm done
if _kid:
jwt["kid"] = _kid
try:
_key = keyjar.get_key_by_kid(header["kid"], _iss)
_key = keyjar.get_key_by_kid(_kid, _iss)
if _key:
key.append(_key)
return key
except KeyError:
pass

try:
self._add_key(keyjar, kwargs["opponent_id"], key)
nki = kwargs['no_kid_issuer']
except KeyError:
pass
nki = {}

try:
_key_type = alg2keytype(header['alg'])
except KeyError:
_key_type = ''

try:
self._add_key(keyjar, kwargs["opponent_id"], key, _key_type, _kid,
nki)
except KeyError:
pass

for ent in ["iss", "aud", "client_id"]:
if ent not in jso:
continue
Expand All @@ -534,9 +578,9 @@ def get_verify_keys(self, keyjar, key, jso, header, jwt, **kwargs):
else:
_aud = jso["aud"]
for _e in _aud:
self._add_key(keyjar, _e, key, _key_type)
self._add_key(keyjar, _e, key, _key_type, _kid, nki)
else:
self._add_key(keyjar, jso[ent], key, _key_type)
self._add_key(keyjar, jso[ent], key, _key_type, _kid, nki)
return key

def from_jwt(self, txt, key=None, verify=True, keyjar=None, **kwargs):
Expand Down
104 changes: 104 additions & 0 deletions tests/test_oauth2_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,23 @@
{"type": "EC", "crv": "P-256", "use": ["enc"]},
]

keym = [
{"type": "RSA", "use": ["sig"]},
{"type": "RSA", "use": ["sig"]},
{"type": "RSA", "use": ["sig"]},
]

KEYJAR = build_keyjar(keys)[1]
IKEYJAR = build_keyjar(keys)[1]
IKEYJAR.issuer_keys['issuer'] = IKEYJAR.issuer_keys['']
del IKEYJAR.issuer_keys['']

KEYJARS = {}
for iss in ['A','B','C']:
_kj = build_keyjar(keym)[1]
_kj.issuer_keys[iss] = _kj.issuer_keys['']
del _kj.issuer_keys['']
KEYJARS[iss] = _kj


def url_compare(url1, url2):
Expand Down Expand Up @@ -606,6 +622,94 @@ def test_missing_required(self):
with pytest.raises(MissingRequiredAttribute):
err.to_urlencoded()

@pytest.mark.parametrize("keytype,alg",[
('RSA', 'RS256'),
('EC', 'ES256')
])
def test_to_jwt(keytype, alg):
msg = Message(a='foo', b='bar', c='tjoho')
_jwt = msg.to_jwt(KEYJAR.get_signing_key(keytype, ''), alg)
msg1 = Message().from_jwt(_jwt, KEYJAR.get_signing_key(keytype, ''))
assert msg1 == msg


@pytest.mark.parametrize("keytype,alg,enc",[
('RSA', 'RSA1_5', 'A128CBC-HS256'),
('EC', 'ECDH-ES', 'A128GCM'),
])
def test_to_jwe(keytype, alg, enc):
msg = Message(a='foo', b='bar', c='tjoho')
_jwe = msg.to_jwe(KEYJAR.get_encrypt_key(keytype, ''), alg=alg, enc=enc)
msg1 = Message().from_jwe(_jwe, KEYJAR.get_encrypt_key(keytype, ''))
assert msg1 == msg


def test_get_verify_keys_no_kid_multiple_keys():
msg = Message()
header = {'alg': 'RS256'}
keys = []
msg.get_verify_keys(KEYJARS['A'], keys, {'iss': 'A'}, header, {})
assert keys == []


def test_get_verify_keys_no_kid_single_key():
msg = Message()
header = {'alg': 'RS256'}
keys = []
msg.get_verify_keys(IKEYJAR, keys, {'iss': 'issuer'}, header, {})
assert len(keys) == 1


def test_get_verify_keys_no_kid_multiple_keys_no_kid_issuer():
msg = Message()
header = {'alg': 'RS256'}
keys = []

a_kids = [k.kid for k in
KEYJARS['A'].get_verify_key(owner='A', key_type='RSA')]
no_kid_issuer = {'A': a_kids}

msg.get_verify_keys(KEYJARS['A'], keys, {'iss': 'A'}, header, {},
no_kid_issuer=no_kid_issuer)
assert len(keys) == 3
assert set([k.kid for k in keys]) == set(a_kids)


def test_get_verify_keys_no_kid_multiple_keys_no_kid_issuer_lim():
msg = Message()
header = {'alg': 'RS256'}
keys = []

a_kids = [k.kid for k in
KEYJARS['A'].get_verify_key(owner='A', key_type='RSA')]
# get rid of one kid
a_kids = a_kids[:-1]
no_kid_issuer = {'A': a_kids}

msg.get_verify_keys(KEYJARS['A'], keys, {'iss': 'A'}, header, {},
no_kid_issuer=no_kid_issuer)
assert len(keys) == 2
assert set([k.kid for k in keys]) == set(a_kids)


def test_get_verify_keys_matching_kid():
msg = Message()
a_kids = [k.kid for k in
KEYJARS['A'].get_verify_key(owner='A', key_type='RSA')]
header = {'alg': 'RS256', 'kid': a_kids[0]}
keys = []
msg.get_verify_keys(KEYJARS['A'], keys, {'iss': 'A'}, header, {})
assert len(keys) == 1
assert keys[0].kid == a_kids[0]


def test_get_verify_keys_no_matching_kid():
msg = Message()
header = {'alg': 'RS256', 'kid': 'aaaaaaa'}
keys = []
msg.get_verify_keys(KEYJARS['A'], keys, {'iss': 'A'}, header, {})
assert keys == []
=======

def test_to_jwt_rsa():
msg = Message(a='foo', b='bar', c='tjoho')
Expand Down

0 comments on commit d974e5a

Please sign in to comment.