Skip to content

Commit

Permalink
Bug fixes for MongoAlchemyDbAdapter.
Browse files Browse the repository at this point in the history
  • Loading branch information
lingthio committed Aug 27, 2017
1 parent 121599f commit 7ea2e1e
Show file tree
Hide file tree
Showing 10 changed files with 91 additions and 59 deletions.
10 changes: 4 additions & 6 deletions example_apps/quickstart_mongodb_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@ class ConfigClass(object):
USER_ENABLE_USERNAME = False # Disable username authentication

# For debugging purposes
# USER_SEND_PASSWORD_CHANGED_EMAIL=False
# USER_SEND_REGISTERED_EMAIL=False
# USER_SEND_USERNAME_CHANGED_EMAIL=False
USER_SEND_PASSWORD_CHANGED_EMAIL=False
USER_SEND_REGISTERED_EMAIL=False
USER_SEND_USERNAME_CHANGED_EMAIL=False
USER_ENABLE_CONFIRM_EMAIL=False


def create_app():
Expand All @@ -50,9 +51,6 @@ def create_app():
# Initialize MongoDB
db = MongoAlchemy(app)

# Drop existing table
db.session.db.connection.drop_database(app.config.get('MONGOALCHEMY_DATABASE', ''))

# Define the User data model.
# NB: Make sure to add flask_user UserMixin !!!
class User(db.Document, UserMixin):
Expand Down
10 changes: 8 additions & 2 deletions flask_user/db_adapters/alchemy_db_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class BaseAlchemyDbAdapter(DbAdapter):
""" This class the base class for SQLAlchemyAdapter and MongoAlchemyAdapter."""

def get_object(self, ObjectClass, id):
""" Retrieve one object specified by the primary key 'pk' """
""" Retrieve object by ID """
return ObjectClass.query.get(id)

def find_objects(self, ObjectClass, **kwargs):
Expand Down Expand Up @@ -98,6 +98,12 @@ def commit(self):
class MongoAlchemyDbAdapter(BaseAlchemyDbAdapter):
""" This class shields the code from MongoAlchemy specifics."""

def get_object(self, ObjectClass, id):
""" Retrieve object by ID """
# Translate Flask-User integer ID to MongoDB ObjectID
hex_id = format(id, 'x')
return ObjectClass.query.get(hex_id)

def ifind_first_object(self, ObjectClass, **kwargs):
""" Retrieve the first object matching the case insensitive filters in 'kwargs'. """

Expand All @@ -118,7 +124,7 @@ def ifind_first_object(self, ObjectClass, **kwargs):

def update_object(self, object, **kwargs):
# Update object
super(MongoAlchemyDbAdapter, self).__init__(object, **kwargs)
super(MongoAlchemyDbAdapter, self).update_object(object, **kwargs)
# Save changes to DB
object.save()

Expand Down
2 changes: 1 addition & 1 deletion flask_user/templates/base.html
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
<div class="pull-left"><a href="/"><h1 class="no-margins">{{ user_manager.app_name }}</h1></a></div>
<div class="pull-right">
{% if call_or_get(current_user.is_authenticated) %}
<a href="{{ url_for('user.profile') }}">{{ current_user.username }}</a>
<a href="{{ url_for('user.profile') }}">{{ current_user.username or current_user.email }}</a>
&nbsp; | &nbsp;
<a href="{{ url_for('user.logout') }}">Sign out</a>
{% else %}
Expand Down
3 changes: 2 additions & 1 deletion flask_user/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def teardown():
def db(app, request):
"""Session-wide test database."""
def teardown():
app.db.drop_all()
if hasattr(app.db, 'drop_all'):
app.db.drop_all()

if hasattr(app.db, 'create_all'):
app.db.create_all()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,28 +74,25 @@ def test_init(db):

hashed_password = um.hash_password('Password1')
User = um.UserModel
add_object = um.db_adapter.add_object

# Create user1 with username and email
user1 = User(username='user1', email='[email protected]', password=hashed_password)
user1 = add_object(User, username='user1', email='[email protected]', password=hashed_password)
assert user1
db.session.add(user1)

# Create user1 with email only
user2 = User(email='[email protected]', password=hashed_password,)
user2 = add_object(User, email='[email protected]', password=hashed_password,)
assert user2
db.session.add(user2)

# Create user3 with username and email
user3 = User(username='user3', email='[email protected]', password=hashed_password)
user3 = add_object(User, username='user3', email='[email protected]', password=hashed_password)
assert user3
db.session.add(user3)

# Create user4 with email only
user4 = User(email='[email protected]', password=hashed_password)
user4 = add_object(User, email='[email protected]', password=hashed_password)
assert user4
db.session.add(user4)

db.session.commit()
um.db_adapter.commit()


def test_invalid_register_with_username_form(client):
Expand Down Expand Up @@ -469,11 +466,12 @@ def test_cleanup(db):
Delete user1 and user2
"""
global user1, user2, user3, user4
db.session.delete(user1)
db.session.delete(user2)
db.session.delete(user3)
db.session.delete(user4)
db.session.commit()
um = current_app.user_manager
um.db_adapter.delete_object(user1)
um.db_adapter.delete_object(user2)
um.db_adapter.delete_object(user3)
um.db_adapter.delete_object(user4)
um.db_adapter.commit()
user1 = None
user2 = None
user3 = None
Expand Down
File renamed without changes.
22 changes: 16 additions & 6 deletions flask_user/tests/test_valid_forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def check_valid_register_form(um, client, db):
# Create User
valid_user = User(confirmed_at=datetime.datetime.utcnow(), **kwargs)
db.session.add(valid_user)
db.session.commit()
um.db_adapter.commit()
assert valid_user

def check_valid_resend_confirm_email_form(um, client):
Expand All @@ -185,6 +185,7 @@ def check_valid_confirm_email_page(um, client):
if not um.enable_confirm_email: return

print("test_valid_confirm_email_page")
global valid_user

# Generate confirmation token for user 1
confirmation_token = um.generate_token(valid_user.id)
Expand All @@ -193,6 +194,7 @@ def check_valid_confirm_email_page(um, client):
client.get_valid_page(url_for('user.confirm_email', token=confirmation_token))

# Verify operations
valid_user = um.db_adapter.get_object(um.UserModel, valid_user.id)
assert valid_user.confirmed_at != None

def check_valid_login_form(um, client):
Expand All @@ -217,6 +219,7 @@ def check_valid_change_password_form(um, client):
if not um.enable_change_password: return

print("test_valid_change_password_form")
global valid_user

# Define defaults
new_password = 'Password9'
Expand All @@ -233,23 +236,26 @@ def check_valid_change_password_form(um, client):
client.post_valid_form(url_for('user.change_password'), **kwargs)

# Verify operations
valid_user = um.db_adapter.get_object(um.UserModel, valid_user.id)
assert um.verify_password(valid_user, new_password)

# Change password back to old password for subsequent tests
valid_user.password = old_hashed_password
um.db_adapter.update_object(valid_user, password=old_hashed_password)

def check_valid_change_username_form(um, client):
# Skip test for certain config combinations
if not um.enable_change_username: return

print("test_valid_change_username_form")
global valid_user

new_username = 'user9'

# Submit form and verify that response has no errors
client.post_valid_form(url_for('user.change_username'), new_username=new_username, old_password=VALID_PASSWORD)

# Verify operations
valid_user = um.db_adapter.get_object(um.UserModel, valid_user.id)
assert valid_user.username == new_username

# Change username back to old password for subsequent tests
Expand All @@ -276,6 +282,7 @@ def check_valid_reset_password_page(um, client):
if not um.enable_forgot_password: return

print("test_valid_reset_password_page")
global valid_user

# Simulate a valid forgot password form
token = um.generate_token(valid_user.id)
Expand All @@ -296,6 +303,7 @@ def check_valid_reset_password_page(um, client):
client.post_valid_form(url, **kwargs)

# Verify operations
valid_user = um.db_adapter.get_object(um.UserModel, valid_user.id)
assert um.verify_password(valid_user, new_password)

# Change password back to old password for subsequent tests
Expand All @@ -319,8 +327,9 @@ def delete_valid_user(db):

if valid_user:
# Delete valid_user
db.session.delete(valid_user)
db.session.commit()
um = current_app.user_manager
um.db_adapter.delete_object(valid_user)
um.db_adapter.commit()
valid_user = None

def delete_valid_user_invite(db):
Expand All @@ -329,6 +338,7 @@ def delete_valid_user_invite(db):

if valid_user_invite:
# Delete valid_user_invite
db.session.delete(valid_user_invite)
db.session.commit()
um = current_app.user_manager
um.db_adapter.delete_object(valid_user_invite)
um.db_adapter.commit()
valid_user_invite = None
39 changes: 30 additions & 9 deletions flask_user/tests/tst_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from flask_user import login_required, UserManager, UserMixin
from flask_user import roles_required, confirm_email_required

ORM_type = 'SQLAlchemy'
ORM_type = 'MongoAlchemy' # SQLAlchemy or MongoAlchemy

app = Flask(__name__)

Expand All @@ -31,6 +31,11 @@ class ConfigClass(object):
MAIL_PORT = int(os.getenv('MAIL_PORT', '465'))
MAIL_USE_SSL = os.getenv('MAIL_USE_SSL', True)

# Disable email sending
USER_SEND_PASSWORD_CHANGED_EMAIL=False
USER_SEND_REGISTERED_EMAIL=False
USER_SEND_USERNAME_CHANGED_EMAIL=False

# Read config from ConfigClass defined above
app.config.from_object(__name__+'.ConfigClass')

Expand Down Expand Up @@ -97,6 +102,12 @@ class UserRoles(db.Model):
from flask_mongoalchemy import MongoAlchemy
db = MongoAlchemy(app)


class Role(db.Document):
name = db.StringField()
label = db.StringField(default='')


# Define the User data model.
# NB: Make sure to add flask_user UserMixin !!!
class User(db.Document, UserMixin):
Expand All @@ -114,18 +125,17 @@ def id(self):
# self._id = format(value, 'x')

# User authentication information
username = db.StringField(required=False)
email = db.StringField(required=False)
username = db.StringField(default='')
email = db.StringField(default='')
password = db.StringField()
confirmed_at = db.DateTimeField(required=False)
confirmed_at = db.DateTimeField(default=None)

# User information
first_name = db.StringField(required=False)
last_name = db.StringField(required=False)



first_name = db.StringField(default='')
last_name = db.StringField(default='')

# Relationships
roles = db.ListField(db.DocumentField(Role), required=False, default=[])


# Define custom UserManager class
Expand Down Expand Up @@ -155,12 +165,23 @@ def init_app(app, test_config=None): # For automated tests
if ORM_type == 'SQLAlchemy':
db.create_all()

if ORM_type == 'MongoAlchemy':
# Drop existing table
db.session.db.connection.drop_database(app.config.get('MONGOALCHEMY_DATABASE', ''))


# Setup Flask-User
if ORM_type == 'SQLAlchemy':
user_manager = CustomUserManager(app, db, User, UserInvitationClass=UserInvitation)
else:
user_manager = CustomUserManager(app, db, User)

# For debugging purposes
# id = int('59a2258f9ebea4e67d20596f', 16)
# encrypted_id = user_manager._encrypt_id(id)
# decrypted_id = user_manager._decrypt_id(encrypted_id)
# assert(decrypted_id==id)

# Create regular 'member' user
if not User.query.filter(User.username=='member').first():
user = User(username='member', email='[email protected]',
Expand Down
31 changes: 14 additions & 17 deletions flask_user/token_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,15 @@ def _encrypt_id(self, id):
hex_bytes = hex_str.encode() # Convert to bytes
padded_bytes = pad(hex_bytes, 16) # Pad to multiples of 16
encrypted_bytes = self.cipher.encrypt(padded_bytes) # Encrypt
url_safe_str = base64.urlsafe_b64encode(encrypted_bytes) # Convert to URL-safe string
encrypted_id = url_safe_str[0:-2] # Remove trailing base64 '=='
encrypted_id = base64.urlsafe_b64encode(encrypted_bytes) # Convert to URL-safe string

# For debug purposes
print('TokenMixin._encrypt_id()')
print('hex_str', hex_str)
print('hex_bytes', hex_bytes)
print('padded_bytes', padded_bytes)
print('encrypted_bytes', encrypted_bytes)
print('url_safe_str', url_safe_str)
print('encrypted_id', encrypted_id)
# print('TokenMixin._encrypt_id()')
# print('hex_str', hex_str)
# print('hex_bytes', hex_bytes)
# print('padded_bytes', padded_bytes)
# print('encrypted_bytes', encrypted_bytes)
# print('encrypted_id', encrypted_id)

return encrypted_id

Expand All @@ -86,19 +84,18 @@ def _decrypt_id(self, encrypted_id):
encrypted_id = encrypted_id.encode('ascii', 'ignore')

try:
url_safe_str = encrypted_id + b'==' # Add trailing base64 '=='
encrypted_bytes = base64.urlsafe_b64decode(url_safe_str) # Convert to bytes
encrypted_bytes = base64.urlsafe_b64decode(encrypted_id) # Convert to bytes
padded_bytes = self.cipher.decrypt(encrypted_bytes) # Decrypt
hex_bytes = unpad(padded_bytes, 16) # Remove padding
id = int(hex_bytes, 16) # Convert hex to integer

# For debug purposes
print('TokenMixin._decrypt_id()')
print('url_safe_str', url_safe_str)
print('encrypted_bytes', encrypted_bytes)
print('padded_bytes', padded_bytes)
print('hex_bytes', hex_bytes)
print('id', id)
# print('TokenMixin._decrypt_id()')
# print('encrypted_id', encrypted_id)
# print('encrypted_bytes', encrypted_bytes)
# print('padded_bytes', padded_bytes)
# print('hex_bytes', hex_bytes)
# print('id', id)

return id
except Exception as e: # pragma: no cover
Expand Down
7 changes: 4 additions & 3 deletions flask_user/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,23 +55,24 @@ def confirm_email(token):
if user_manager.UserEmailModel:
user_email = user_manager.get_user_email_by_id(object_id)
if user_email:
user_email.confirmed_at = datetime.utcnow()
db_adapter.update_object(user_email, confirmed_at=datetime.utcnow())
user = user_email.user
else:
user_email = None
user = user_manager.get_user_by_id(object_id)
if user:
user.confirmed_at = datetime.utcnow()
db_adapter.update_object(user, confirmed_at=datetime.utcnow())

if user:
# If User.active exists: activate User
if hasattr(user, 'active'):
db_adapter.update_object(user, active=True)
db_adapter.commit()
else: # pragma: no cover
flash(_('Invalid confirmation token.'), 'error')
return redirect(url_for('user.login'))

db_adapter.commit()

# Send email_confirmed signal
signals.user_confirmed_email.send(current_app._get_current_object(), user=user)

Expand Down

0 comments on commit 7ea2e1e

Please sign in to comment.