Skip to content

Commit

Permalink
Add PEP8 corrections
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicolas Olivier committed Jan 8, 2020
1 parent 6095a19 commit 86cf2f3
Showing 1 changed file with 58 additions and 57 deletions.
115 changes: 58 additions & 57 deletions warrant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,26 @@ def cognito_to_dict(attr_list, attr_map=None):
value = a.get('Value')
if value in ['true', 'false']:
value = ast.literal_eval(value.capitalize())
name = attr_map.get(name,name)
name = attr_map.get(name, name)
attr_dict[name] = value
return attr_dict


def dict_to_cognito(attributes, attr_map=None):
"""
:param attributes: Dictionary of User Pool attribute names/values
:param attr_map: Dictonnary with attributes mapping
:return: list of User Pool attribute formatted dicts: {'Name': <attr_name>, 'Value': <attr_value>}
"""
if attr_map is None:
attr_map = {}
for k,v in attr_map.items():
for k, v in attr_map.items():
if v in attributes.keys():
attributes[k] = attributes.pop(v)

return [{'Name': key, 'Value': value} for key, value in attributes.items()]


def camel_to_snake(camel_str):
"""
:param camel_str: string
Expand All @@ -45,6 +48,7 @@ def camel_to_snake(camel_str):
s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', camel_str)
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()


def snake_to_camel(snake_str):
"""
:param snake_str: string
Expand All @@ -66,10 +70,10 @@ def __init__(self, username, attribute_list, cognito_obj, metadata=None, attr_ma
self.pk = username
self._cognito = cognito_obj
self._attr_map = {} if attr_map is None else attr_map
self._data = cognito_to_dict(attribute_list,self._attr_map)
self.sub = self._data.pop('sub',None)
self.email_verified = self._data.pop('email_verified',None)
self.phone_number_verified = self._data.pop('phone_number_verified',None)
self._data = cognito_to_dict(attribute_list, self._attr_map)
self.sub = self._data.pop('sub', None)
self.email_verified = self._data.pop('email_verified', None)
self.phone_number_verified = self._data.pop('phone_number_verified', None)
self._metadata = {} if metadata is None else metadata

def __repr__(self):
Expand All @@ -80,24 +84,24 @@ def __unicode__(self):
return self.username

def __getattr__(self, name):
if name in list(self.__dict__.get('_data',{}).keys()):
if name in list(self.__dict__.get('_data', {}).keys()):
return self._data.get(name)
if name in list(self.__dict__.get('_metadata',{}).keys()):
if name in list(self.__dict__.get('_metadata', {}).keys()):
return self._metadata.get(name)

def __setattr__(self, name, value):
if name in list(self.__dict__.get('_data',{}).keys()):
if name in list(self.__dict__.get('_data', {}).keys()):
self._data[name] = value
else:
super(UserObj, self).__setattr__(name, value)

def save(self,admin=False):
def save(self, admin=False):
if admin:
self._cognito.admin_update_profile(self._data, self._attr_map)
return
self._cognito.update_profile(self._data,self._attr_map)
self._cognito.update_profile(self._data, self._attr_map)

def delete(self,admin=False):
def delete(self, admin=False):
if admin:
self._cognito.admin_delete_user()
return
Expand Down Expand Up @@ -129,16 +133,15 @@ def __repr__(self):


class Cognito(object):

user_class = UserObj
group_class = GroupObj

def __init__(
self, user_pool_id, client_id,user_pool_region=None,
self, user_pool_id, client_id, user_pool_region=None,
username=None, id_token=None, refresh_token=None,
access_token=None, client_secret=None,
access_key=None, secret_key=None,
):
):
"""
:param user_pool_id: Cognito User Pool ID
:param client_id: Cognito User Pool Application client ID
Expand All @@ -161,6 +164,7 @@ def __init__(
self.token_type = None
self.custom_attributes = None
self.base_attributes = None
self.pool_jwk = None

boto3_client_kwargs = {}
if access_key and secret_key:
Expand All @@ -176,37 +180,37 @@ def get_keys(self):
try:
return self.pool_jwk
except AttributeError:
#Check for the dictionary in environment variables.
pool_jwk_env = env('COGNITO_JWKS', {},var_type='dict')
# Check for the dictionary in environment variables.
pool_jwk_env = env('COGNITO_JWKS', {}, var_type='dict')
if len(pool_jwk_env.keys()) > 0:
self.pool_jwk = pool_jwk_env
return self.pool_jwk
#If it is not there use the requests library to get it
# If it is not there use the requests library to get it
self.pool_jwk = requests.get(
'https://cognito-idp.{}.amazonaws.com/{}/.well-known/jwks.json'.format(
self.user_pool_region,self.user_pool_id
self.user_pool_region, self.user_pool_id
)).json()
return self.pool_jwk

def get_key(self,kid):
def get_key(self, kid):
keys = self.get_keys().get('keys')
key = list(filter(lambda x:x.get('kid') == kid,keys))
key = list(filter(lambda x: x.get('kid') == kid, keys))
return key[0]

def verify_token(self,token,id_name,token_use):
def verify_token(self, token, id_name, token_use):
kid = jwt.get_unverified_header(token).get('kid')
unverified_claims = jwt.get_unverified_claims(token)
token_use_verified = unverified_claims.get('token_use') == token_use
if not token_use_verified:
raise TokenVerificationException('Your {} token use could not be verified.')
hmac_key = self.get_key(kid)
try:
verified = jwt.decode(token,hmac_key,algorithms=['RS256'],
audience=unverified_claims.get('aud'),
issuer=unverified_claims.get('iss'))
verified = jwt.decode(token, hmac_key, algorithms=['RS256'],
audience=unverified_claims.get('aud'),
issuer=unverified_claims.get('iss'))
except JWTError:
raise TokenVerificationException('Your {} token could not be verified.')
setattr(self,id_name,token)
setattr(self, id_name, token)
return verified

def get_user_obj(self, username=None, attribute_list=None, metadata=None,
Expand All @@ -221,9 +225,9 @@ def get_user_obj(self, username=None, attribute_list=None, metadata=None,
what we'd like to display to the users
:return:
"""
return self.user_class(username=username,attribute_list=attribute_list,
return self.user_class(username=username, attribute_list=attribute_list,
cognito_obj=self,
metadata=metadata,attr_map=attr_map)
metadata=metadata, attr_map=attr_map)

def get_group_obj(self, group_data):
"""
Expand All @@ -233,7 +237,7 @@ def get_group_obj(self, group_data):
"""
return self.group_class(group_data=group_data, cognito_obj=self)

def switch_session(self,session):
def switch_session(self, session):
"""
Primarily used for unit testing so we can take advantage of the
placebo library (https://githhub.com/garnaat/placebo)
Expand Down Expand Up @@ -268,11 +272,11 @@ def add_base_attributes(self, **kwargs):
def add_custom_attributes(self, **kwargs):
custom_key = 'custom'
custom_attributes = {}

for old_key, value in kwargs.items():
new_key = custom_key + ':' + old_key
custom_attributes[new_key] = value

self.custom_attributes = custom_attributes

def register(self, username, password, attr_map=None):
Expand Down Expand Up @@ -330,7 +334,7 @@ def admin_confirm_sign_up(self, username=None):
Username=username,
)

def confirm_sign_up(self,confirmation_code,username=None):
def confirm_sign_up(self, confirmation_code, username=None):
"""
Using the confirmation code that is either sent via email or text
message.
Expand All @@ -353,9 +357,9 @@ def admin_authenticate(self, password):
:return:
"""
auth_params = {
'USERNAME': self.username,
'PASSWORD': password
}
'USERNAME': self.username,
'PASSWORD': password
}
self._add_secret_hash(auth_params, 'SECRET_HASH')
tokens = self.client.admin_initiate_auth(
UserPoolId=self.user_pool_id,
Expand All @@ -365,9 +369,9 @@ def admin_authenticate(self, password):
AuthParameters=auth_params,
)

self.verify_token(tokens['AuthenticationResult']['IdToken'], 'id_token','id')
self.verify_token(tokens['AuthenticationResult']['IdToken'], 'id_token', 'id')
self.refresh_token = tokens['AuthenticationResult']['RefreshToken']
self.verify_token(tokens['AuthenticationResult']['AccessToken'], 'access_token','access')
self.verify_token(tokens['AuthenticationResult']['AccessToken'], 'access_token', 'access')
self.token_type = tokens['AuthenticationResult']['TokenType']

def authenticate(self, password):
Expand All @@ -380,16 +384,16 @@ def authenticate(self, password):
client_id=self.client_id, client=self.client,
client_secret=self.client_secret)
tokens = aws.authenticate_user()
self.verify_token(tokens['AuthenticationResult']['IdToken'],'id_token','id')
self.verify_token(tokens['AuthenticationResult']['IdToken'], 'id_token', 'id')
self.refresh_token = tokens['AuthenticationResult']['RefreshToken']
self.verify_token(tokens['AuthenticationResult']['AccessToken'], 'access_token','access')
self.verify_token(tokens['AuthenticationResult']['AccessToken'], 'access_token', 'access')
self.token_type = tokens['AuthenticationResult']['TokenType']

def new_password_challenge(self, password, new_password):
"""
Respond to the new password challenge using the SRP protocol
:param password: The user's current passsword
:param password: The user's new passsword
:param new_password: The user's new passsword
"""
aws = AWSSRP(username=self.username, password=password, pool_id=self.user_pool_id,
client_id=self.client_id, client=self.client,
Expand Down Expand Up @@ -419,9 +423,9 @@ def logout(self):
def admin_update_profile(self, attrs, attr_map=None):
user_attrs = dict_to_cognito(attrs, attr_map)
self.client.admin_update_user_attributes(
UserPoolId = self.user_pool_id,
Username = self.username,
UserAttributes = user_attrs
UserPoolId=self.user_pool_id,
Username=self.username,
UserAttributes=user_attrs
)

def update_profile(self, attrs, attr_map=None):
Expand All @@ -431,7 +435,7 @@ def update_profile(self, attrs, attr_map=None):
:param attr_map: Dictionary map from Cognito attributes to attribute
names we would like to show to our users
"""
user_attrs = dict_to_cognito(attrs,attr_map)
user_attrs = dict_to_cognito(attrs, attr_map)
self.client.update_user_attributes(
UserAttributes=user_attrs,
AccessToken=self.access_token
Expand All @@ -446,8 +450,8 @@ def get_user(self, attr_map=None):
:return:
"""
user = self.client.get_user(
AccessToken=self.access_token
)
AccessToken=self.access_token
)

user_metadata = {
'username': user.get('Username'),
Expand All @@ -457,7 +461,7 @@ def get_user(self, attr_map=None):
}
return self.get_user_obj(username=self.username,
attribute_list=user.get('UserAttributes'),
metadata=user_metadata,attr_map=attr_map)
metadata=user_metadata, attr_map=attr_map)

def get_users(self, attr_map=None):
"""
Expand All @@ -466,12 +470,12 @@ def get_users(self, attr_map=None):
:param attr_map:
:return:
"""
kwargs = {"UserPoolId":self.user_pool_id}
kwargs = {"UserPoolId": self.user_pool_id}

response = self.client.list_users(**kwargs)
return [self.get_user_obj(user.get('Username'),
attribute_list=user.get('Attributes'),
metadata={'username':user.get('Username')},
metadata={'username': user.get('Username')},
attr_map=attr_map)
for user in response.get('Users')]

Expand All @@ -483,19 +487,19 @@ def admin_get_user(self, attr_map=None):
:return: UserObj object
"""
user = self.client.admin_get_user(
UserPoolId=self.user_pool_id,
Username=self.username)
UserPoolId=self.user_pool_id,
Username=self.username)
user_metadata = {
'enabled': user.get('Enabled'),
'user_status':user.get('UserStatus'),
'username':user.get('Username'),
'user_status': user.get('UserStatus'),
'username': user.get('Username'),
'id_token': self.id_token,
'access_token': self.access_token,
'refresh_token': self.refresh_token
}
return self.get_user_obj(username=self.username,
attribute_list=user.get('UserAttributes'),
metadata=user_metadata,attr_map=attr_map)
metadata=user_metadata, attr_map=attr_map)

def admin_create_user(self, username, temporary_password='', attr_map=None, **kwargs):
"""
Expand Down Expand Up @@ -575,14 +579,12 @@ def initiate_forgot_password(self):
self._add_secret_hash(params, 'SecretHash')
self.client.forgot_password(**params)


def delete_user(self):

self.client.delete_user(
AccessToken=self.access_token
)


def admin_delete_user(self):
self.client.admin_delete_user(
UserPoolId=self.user_pool_id,
Expand Down Expand Up @@ -661,4 +663,3 @@ def get_groups(self):
response = self.client.list_groups(UserPoolId=self.user_pool_id)
return [self.get_group_obj(group_data)
for group_data in response.get('Groups')]

0 comments on commit 86cf2f3

Please sign in to comment.