Skip to content

Commit

Permalink
Merge pull request CZ-NIC#667 from OpenIDC/deprecate-refresh-token-db
Browse files Browse the repository at this point in the history
Deprecated RefreshDB
  • Loading branch information
tpazderka authored Jun 17, 2019
2 parents 52a71c1 + 5814180 commit d397bd0
Show file tree
Hide file tree
Showing 7 changed files with 123 additions and 18 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ The format is based on the [KeepAChangeLog] project.
- [#629] Duplicated methods in oic.oic classes were removed.
- [#642] Deprecated `bearer_auth` method.
- [#631] Refactored message type handling in Client/Provider.
- [#644] refresh_db kwarg in SessionDB has been deprecated

### Added
- [#655] Host can be forced on webfinger discovery
Expand Down Expand Up @@ -58,6 +59,7 @@ The format is based on the [KeepAChangeLog] project.
[#638]: https://github.com/OpenIDC/pyoidc/issues/638
[#146]: https://github.com/OpenIDC/pyoidc/issues/146
[#664]: https://github.com/OpenIDC/pyoidc/pull/664
[#644]: https://github.com/OpenIDC/pyoidc/pull/644

## 0.15.1 [2019-01-31]

Expand Down
3 changes: 2 additions & 1 deletion src/oic/extension/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ class JWTToken(Token, JWT):

def __init__(self, typ, keyjar, lt_pattern=None, extra_claims=None, **kwargs):
self.type = typ
JWT.__init__(self, keyjar, msgtype=TokenAssertion, **kwargs)
Token.__init__(self, typ, **kwargs)
kwargs.pop("token_storage", None)
JWT.__init__(self, keyjar, msgtype=TokenAssertion, **kwargs)
self.lt_pattern = lt_pattern or {}
self.db = {} # type: Dict[str,str]
self.session_info = {"": 600}
Expand Down
88 changes: 76 additions & 12 deletions src/oic/utils/sdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
import warnings
from abc import ABCMeta
from abc import abstractmethod
from binascii import Error
from typing import Dict # noqa
from typing import List # noqa
from typing import Optional # noqa

from cryptography.fernet import Fernet
from cryptography.fernet import InvalidToken

from oic import rndstr
from oic.exception import ImproperlyConfigured
Expand Down Expand Up @@ -84,10 +86,16 @@ def decrypt(self, ciphertext):


class Token(object):
def __init__(self, typ, lifetime=0, **kwargs):
def __init__(self, typ, lifetime=0, token_storage=None, **kwargs):
self.type = typ
self.lifetime = lifetime
self.args = kwargs
if typ == "R":
if token_storage is None:
raise ImproperlyConfigured(
"token_storage kwarg must be passed in for refresh token."
)
self.token_storage = token_storage

def __call__(self, sid, *args, **kwargs):
"""
Expand Down Expand Up @@ -148,11 +156,23 @@ def is_expired(self, token, when=None):
return bool(now > eat)

def invalidate(self, token):
pass
"""Mark the refresh token as invalidated."""
if self.get_type(token) != "R":
return
sid = self.get_key(token)
self.token_storage[sid]["revoked"] = True

def valid(self, token):
self.type_and_key(token)
return True
try:
typ, key = self.type_and_key(token)
except (Error, InvalidToken):
raise WrongTokenType()
if typ != self.type:
raise WrongTokenType()
if typ == "R":
return not self.token_storage[key].get("revoked", False)
else:
return True


class DefaultToken(Token):
Expand Down Expand Up @@ -180,6 +200,9 @@ def __call__(self, sid="", ttype="", **kwargs):
rnd = rndstr(32) # Ultimate length multiple of 16

issued_at = "{}".format(utc_time_sans_frac())
if ttype == "R":
# kwargs["sinfo"] is a dictionary and we do not want updates...
self.token_storage[sid] = copy.deepcopy(kwargs["sinfo"])

return base64.b64encode(
self.crypt.encrypt(lv_pack(rnd, ttype, sid, issued_at).encode())
Expand Down Expand Up @@ -293,6 +316,13 @@ def from_json(cls, json_struct):
class RefreshDB(object):
"""Database for refresh token storage."""

def __init__(self):
warnings.warn(
"Using `RefreshDB` is deprecated, please use `Token` and `refresh_token_factory` instead.",
DeprecationWarning,
stacklevel=2,
)

def get(self, refresh_token):
"""
Retrieve info about the authentication proces from the refresh token.
Expand Down Expand Up @@ -365,6 +395,11 @@ class DictRefreshDB(RefreshDB):

def __init__(self):
super(DictRefreshDB, self).__init__()
warnings.warn(
"Using `DictRefreshDB` is deprecated, please use `Token` and `refresh_token_factory` instead.",
DeprecationWarning,
stacklevel=2,
)
self._db = {} # type: Dict[str, Dict[str, str]]

def get(self, refresh_token):
Expand Down Expand Up @@ -408,15 +443,17 @@ def create_session_db(
code_factory = DefaultToken(secret, password, typ="A", lifetime=grant_expires_in)
token_factory = DefaultToken(secret, password, typ="T", lifetime=token_expires_in)
db = DictSessionBackend() if db is None else db
refresh_token_factory = DefaultToken(
secret, password, typ="R", lifetime=refresh_token_expires_in, token_storage={}
)

return SessionDB(
base_url,
db,
refresh_db=None,
code_factory=code_factory,
token_factory=token_factory,
refresh_token_expires_in=refresh_token_expires_in,
refresh_token_factory=None,
refresh_token_factory=refresh_token_factory,
)


Expand Down Expand Up @@ -523,7 +560,7 @@ def __init__(
base_url,
db,
refresh_db=None,
refresh_token_expires_in=86400,
refresh_token_expires_in=None,
token_factory=None,
code_factory=None,
refresh_token_factory=None,
Expand All @@ -533,6 +570,12 @@ def __init__(
:param db: Database for storing the session information.
"""
if refresh_token_expires_in is not None:
warnings.warn(
"Setting a `refresh_token_expires_in` has no effect, please set the expiration on "
"`refresh_token_factory`.",
DeprecationWarning,
)
self.base_url = base_url
if not isinstance(db, SessionBackend):
warnings.warn(
Expand All @@ -557,9 +600,16 @@ def __init__(
self.token_factory["refresh_token"] = refresh_token_factory
self.token_factory_order.append("refresh_token")
elif refresh_db:
warnings.warn(
"Using `refresh_db` is deprecated, please use `refresh_token_factory`",
DeprecationWarning,
stacklevel=2,
)
self._refresh_db = refresh_db
else:
self._refresh_db = DictRefreshDB()
# Not configured
self._refresh_db = None
self.token_factory["refresh_token"] = None

self.access_token = self.token_factory["access_token"]
self.token = self.access_token
Expand Down Expand Up @@ -797,9 +847,10 @@ def upgrade_to_token(
dic["authzreq"],
key,
)
else:
dic["refresh_token"] = refresh_token
elif self.token_factory["refresh_token"] is not None:
refresh_token = self.token_factory["refresh_token"](key, sinfo=dic)
dic["refresh_token"] = refresh_token
dic["refresh_token"] = refresh_token
self._db[key] = dic
return dic

Expand Down Expand Up @@ -841,9 +892,17 @@ def refresh_token(self, rtoken, client_id):
self.access_token.invalidate(at)
else:
raise ExpiredToken()
elif self.token_factory["refresh_token"] is None:
raise WrongTokenType()
elif self.token_factory["refresh_token"].valid(rtoken):
if self.token_factory["refresh_token"].is_expired(rtoken):
raise ExpiredToken()
sid = self.token_factory["refresh_token"].get_key(rtoken)
dic = self._db[sid]
try:
dic = self._db[sid]
except KeyError:
# Session is cleared, use the storage in token factory
dic = self.token_factory["refresh_token"].token_storage[sid]
access_token = self.access_token(sid=sid, sinfo=dic)

try:
Expand Down Expand Up @@ -898,6 +957,11 @@ def is_valid(self, token, client_id=None):
if not self.access_token.valid(token):
return False

elif typ == "R":
if self.token_factory["refresh_token"] is None:
return False
if not self.token_factory["refresh_token"].valid(token):
return False
return True

def is_revoked(self, sid):
Expand Down Expand Up @@ -925,7 +989,7 @@ def revoke_refresh_token(self, rtoken):
"""
if self._refresh_db:
self._refresh_db.revoke_token(rtoken)
else:
elif self.token_factory["refresh_token"] is not None:
self.token_factory["refresh_token"].invalidate(rtoken)

return True
Expand Down
6 changes: 5 additions & 1 deletion src/oic/utils/token_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@ def __init__(

if refresh_token_factory is None:
self.refresh_token_factory = JWTToken(
"R", keyjar=keyjar, iss="https://example.com/as", sign_alg=sign_alg
"R",
keyjar=keyjar,
iss="https://example.com/as",
sign_alg=sign_alg,
token_storage={},
)
else:
self.refresh_token_factory = refresh_token_factory
Expand Down
30 changes: 28 additions & 2 deletions tests/test_oic_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -1884,10 +1884,36 @@ def test_refresh_token_grant_type_wrong_token(self):
assert atr["error_description"] == "Not a refresh token"

def test_refresh_token_grant_type_expired(self):
# Missing refresh_token also raises Expired
authreq = AuthorizationRequest(
state="state",
redirect_uri="http://example.com/authz",
client_id=CLIENT_ID,
response_type="code",
scope=["openid", "offline_access"],
prompt="consent",
)

_sdb = self.provider.sdb
sid = _sdb.access_token.key(user="sub", areq=authreq)
access_grant = _sdb.access_token(sid=sid)
ae = AuthnEvent("user", "salt")
_sdb[sid] = {
"oauth_state": "authz",
"authn_event": ae.to_json(),
"authzreq": authreq.to_json(),
"client_id": CLIENT_ID,
"code": access_grant,
"code_used": False,
"scope": ["openid", "offline_access"],
"redirect_uri": "http://example.com/authz",
}
_sdb.do_sub(sid, "client_salt")
with freeze_time("2000-01-01"):
info = _sdb.upgrade_to_token(access_grant, issue_refresh=True)

rareq = RefreshAccessTokenRequest(
grant_type="refresh_token",
refresh_token="Refresh_some_other_refresh_token",
refresh_token=info["refresh_token"],
client_id=CLIENT_ID,
client_secret=CLIENT_SECRET,
scope=["openid"],
Expand Down
6 changes: 5 additions & 1 deletion tests/test_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,11 @@ def create_sdb(self):
sign_alg="RS256",
),
refresh_token_factory=JWTToken(
"R", keyjar=kj, lt_pattern={"": 24 * 3600}, iss="https://example.com/as"
"R",
keyjar=kj,
lt_pattern={"": 24 * 3600},
iss="https://example.com/as",
token_storage={},
),
)

Expand Down
6 changes: 5 additions & 1 deletion tests/test_x_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,11 @@ def create_provider(self):
sign_alg="RS256",
),
refresh_token_factory=JWTToken(
"R", keyjar=kj, lt_pattern={"": 24 * 3600}, iss="https://example.com/as"
"R",
keyjar=kj,
lt_pattern={"": 24 * 3600},
iss="https://example.com/as",
token_storage={},
),
)
# name, sdb, cdb, authn_broker, authz, client_authn,
Expand Down

0 comments on commit d397bd0

Please sign in to comment.