Skip to content

Commit

Permalink
Refactor/improvements (#11)
Browse files Browse the repository at this point in the history
* Database

1. AsyncAttrs declared in base class instead of children.
2. expire_on_commit=False, no need to refresh session

* UUID replaces str for Base id

For authorize we log the user_id and not the token.

* Created dependencies logic for auth
  • Loading branch information
tomasemilio authored Sep 25, 2024
1 parent cc38dad commit 2feb78e
Show file tree
Hide file tree
Showing 16 changed files with 103 additions and 65 deletions.
8 changes: 5 additions & 3 deletions app/api/v1/routes/admin.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from uuid import UUID

from fastapi import APIRouter, BackgroundTasks, Security

from app.api.v1.routes.user import request_reset_password
Expand Down Expand Up @@ -37,17 +39,17 @@ async def get_users(async_session: sessDep):


@router.get("/user/{id}", response_model=UserOut, status_code=200)
async def get_user(async_session: sessDep, id: str):
async def get_user(async_session: sessDep, id: UUID):
return await User.get(async_session, id)


@router.delete("/user/{id}", status_code=204)
async def delete_user(async_session: sessDep, id: str):
async def delete_user(async_session: sessDep, id: UUID):
user = await User.get(async_session, id)
await user.delete(async_session)


@router.put("/user/{id}", response_model=UserOut, status_code=200)
async def update_user(async_session: sessDep, id: str, user_in: UserIn):
async def update_user(async_session: sessDep, id: UUID, user_in: UserIn):
user = await User.get(async_session, id)
return await user.update(async_session, **user_in.model_dump(exclude_unset=True))
10 changes: 5 additions & 5 deletions app/api/v1/routes/auth.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from fastapi import APIRouter, Depends, status
from fastapi import APIRouter, status

from app.models.auth.functions import authenticate
from app.models.auth.token import Token, TokenDecode, TokenEncode
from app.models.user import User
from app.models.auth.dependencies import userDep
from app.models.auth.schemas import TokenDecode, TokenEncode
from app.models.auth.token import Token

router = APIRouter(prefix="/auth", tags=["Auth"])


@router.post("/token", response_model=TokenEncode, status_code=status.HTTP_200_OK)
async def token(user: User = Depends(authenticate)):
async def token(user: userDep):
return Token(id=user.id, scope=user.scope).encode()


Expand Down
30 changes: 10 additions & 20 deletions app/api/v1/routes/post.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,36 @@
from fastapi import APIRouter, Depends
from uuid import UUID

from fastapi import APIRouter

from app.database.dependencies import sessDep
from app.models.auth.functions import authorize, authorize_and_load, authorize_limited
from app.models.auth.token import TokenDecode
from app.models.auth.dependencies import authDep, authDepLimit, authLoadDep
from app.models.post import Post
from app.models.post.schemas import PostIn, PostOut
from app.models.user import User

router = APIRouter(prefix="/post", tags=["Post"])


@router.post("", response_model=PostOut, status_code=201)
async def create_post(
post_in: PostIn,
async_session: sessDep,
token: TokenDecode = Depends(authorize),
):
async def create_post(post_in: PostIn, async_session: sessDep, token: authDep):
return await Post(**post_in.model_dump(), user_id=token.id).save(async_session)


@router.get("", response_model=list[PostOut], status_code=200)
async def get_posts(user: User = Depends(authorize_and_load)):
async def get_posts(user: authLoadDep):
return await user.awaitable_attrs.posts


@router.get("/{id}", response_model=PostOut, status_code=200)
async def get_post(
async_session: sessDep, id: str, token: TokenDecode = Depends(authorize)
):
async def get_post(async_session: sessDep, id: UUID, token: authDep):
return await Post.find(async_session, id=id, user_id=token.id, raise_=True)


@router.get("/rate-limited/{id}", response_model=PostOut, status_code=200)
async def get_post_rate_limited(
async_session: sessDep, id: str, token: TokenDecode = Depends(authorize_limited)
):
async def get_post_rate_limited(async_session: sessDep, id: UUID, token: authDepLimit):
return await Post.find(async_session, id=id, user_id=token.id, raise_=True)


@router.delete("/{id}", status_code=204)
async def delete_post(
async_session: sessDep, id: str, token: TokenDecode = Depends(authorize)
):
async def delete_post(async_session: sessDep, id: UUID, token: authDep):
post = await Post.find(async_session, id=id, user_id=token.id, raise_=True)
await post.delete(async_session) # type: ignore
await post.delete(async_session)
8 changes: 4 additions & 4 deletions app/api/v1/routes/user.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from fastapi import APIRouter, BackgroundTasks, Depends, Security
from fastapi import APIRouter, BackgroundTasks
from pydantic import EmailStr

from app.config import config
from app.database.dependencies import sessDep
from app.functions.emailer import send_email
from app.models.auth.functions import authorize_and_load
from app.models.auth.dependencies import authLoadDep, resetLoadDep
from app.models.auth.role import Role
from app.models.auth.token import Token
from app.models.user import User
Expand All @@ -14,7 +14,7 @@


@router.get("/me", response_model=UserOut, status_code=200)
async def me(user: User = Depends(authorize_and_load)):
async def me(user: authLoadDep):
return user


Expand Down Expand Up @@ -44,7 +44,7 @@ async def request_reset_password(
async def reset_password(
async_session: sessDep,
passwords: PasswordsIn,
user: User = Security(authorize_and_load, scopes=[Role.RESET]),
user: resetLoadDep,
token: str | None = None, # noqa: F841
):
user = await User.get(
Expand Down
3 changes: 2 additions & 1 deletion app/database/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@
config.DB_URL, connect_args=connect_args, pool_pre_ping=True
)

local_session = async_sessionmaker(engine)
local_session = async_sessionmaker(engine, expire_on_commit=False)

38 changes: 24 additions & 14 deletions app/database/base.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
from datetime import UTC, datetime
from typing import Self, Sequence
from uuid import uuid4
from typing import Literal, Self, Sequence, overload
from uuid import UUID, uuid4

from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from sqlalchemy.types import TIMESTAMP

from app.database.dependencies import sessDep
from app.functions.exceptions import not_found


class Base(DeclarativeBase):
class Base(AsyncAttrs, DeclarativeBase):
__abstract__ = True
id: Mapped[str] = mapped_column(
primary_key=True,
default=lambda: uuid4().hex,
sort_order=-3,
id: Mapped[UUID] = mapped_column(
primary_key=True, default=lambda: uuid4(), sort_order=-3
)
created_at: Mapped[datetime] = mapped_column(
default=lambda: datetime.now(UTC),
Expand All @@ -38,7 +37,7 @@ async def save(self, async_session: sessDep) -> Self:
return self

@classmethod
async def get(cls, async_session: sessDep, id: str) -> Self:
async def get(cls, async_session: sessDep, id: UUID) -> Self:
result = await async_session.get(cls, id)
if not result:
raise not_found(msg=f"{cls.__name__} not found")
Expand All @@ -57,15 +56,26 @@ async def update(self, async_session: sessDep, **kwargs) -> Self:
for key, value in kwargs.items():
setattr(self, key, value)
await async_session.commit()
await async_session.refresh(self)
return self

@overload
@classmethod
async def find(
cls, async_session: sessDep, raise_: Literal[True], **kwargs
) -> Self: ...

@overload
@classmethod
async def find(
cls, async_session: sessDep, raise_: bool = False, **kwargs
cls, async_session: sessDep, raise_: Literal[False], **kwargs
) -> Self | None: ...

@classmethod
async def find(
cls, async_session: sessDep, raise_: bool = True, **kwargs
) -> Self | None:
result = await async_session.execute(select(cls).filter_by(**kwargs))
result = result.scalars().first()
if not result and raise_:
stmt = select(cls).filter_by(**kwargs)
resp = await async_session.scalar(stmt)
if not resp and raise_:
raise not_found(msg=f"{cls.__name__} not found")
return result
return resp
2 changes: 1 addition & 1 deletion app/functions/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
async def lifespan(_: FastAPI):
setup_logger()
logger.warning(f"Starting the application: ENV={config.ENV_STATE}")
await drop_all()
await create_all()
await create_admin_user()
yield
await drop_all()
logger.warning("Shutting down the application")


Expand Down
26 changes: 26 additions & 0 deletions app/models/auth/dependencies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import Annotated

from fastapi import Depends, Security

from app.models.auth.functions import (
authenticate,
authenticate_and_token,
authorize,
authorize_and_load,
authorize_limited,
)
from app.models.auth.role import Role
from app.models.auth.schemas import TokenDecode
from app.models.user import User

authDep = Annotated[TokenDecode, Depends(authorize)]

userDep = Annotated[User, Depends(authenticate)]

authLoadDep = Annotated[User, Depends(authorize_and_load)]

authTokenDep = Annotated[User, Depends(authenticate_and_token)]

resetLoadDep = Annotated[User, Security(authorize_and_load, scopes=[Role.RESET])]

authDepLimit = Annotated[TokenDecode, Depends(authorize_limited)]
13 changes: 9 additions & 4 deletions app/models/auth/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
async def authenticate(
async_session: sessDep, credentials: OAuth2PasswordRequestForm = Depends()
) -> User:
user = await User.find(async_session, email=credentials.username)
user = await User.find(async_session, email=credentials.username, raise_=False)
if not user or not user.check_password(credentials.password):
raise unauthorized_basic()
elif user.verified is False:
Expand All @@ -37,12 +37,17 @@ def authorize(
token: Annotated[str, Depends(oauth2_scheme)],
security_scopes: Annotated[SecurityScopes, Depends],
) -> TokenDecode:
logger.info(f"Authorizing token:{token}")
return Token.decode(token=token, scope=[Role(i) for i in security_scopes.scopes])
decoded_token = Token.decode(
token=token, scope=[Role(i) for i in security_scopes.scopes]
)
logger.info(
f"Authorizing user ID: {decoded_token.id} with scopes: {decoded_token.scope}"
)
return decoded_token


def authorize_limited(token: Annotated[TokenDecode, Depends(authorize)]) -> TokenDecode:
rate_limiter(token.id)
rate_limiter(token.id.hex)
return token


Expand Down
3 changes: 2 additions & 1 deletion app/models/auth/schemas.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datetime import UTC, datetime
from uuid import UUID

from pydantic import BaseModel, computed_field

Expand All @@ -13,7 +14,7 @@ class TokenEncode(BaseModel):


class TokenDecode(BaseModel):
id: str
id: UUID
iat: datetime
exp: datetime
scope: list[Role]
Expand Down
6 changes: 4 additions & 2 deletions app/models/auth/token.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from datetime import UTC, datetime, timedelta
from typing import Annotated
from uuid import UUID

from pydantic import BaseModel, Field, computed_field
from pydantic import AfterValidator, BaseModel, Field, computed_field

from app.config import config
from app.functions.exceptions import forbidden
Expand All @@ -10,7 +12,7 @@


class Token(BaseModel):
id: str
id: Annotated[UUID, AfterValidator(lambda x: x.hex)]
scope: list[Role] = [Role.USER]
expires_in: int = config.TOKEN_EXPIRE_SECONDS
iat: datetime = Field(default_factory=lambda: datetime.now(UTC))
Expand Down
6 changes: 3 additions & 3 deletions app/models/post/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import TYPE_CHECKING
from uuid import UUID

from sqlalchemy import ForeignKey
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlalchemy.orm import Mapped, mapped_column, relationship

from app.database.base import Base
Expand All @@ -10,9 +10,9 @@
from app.models.user import User


class Post(AsyncAttrs, Base):
class Post(Base):
__tablename__ = "post"
title: Mapped[str] = mapped_column()
content: Mapped[str] = mapped_column()
user_id: Mapped[str] = mapped_column(ForeignKey("user.id"))
user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"))
user: Mapped["User"] = relationship(back_populates="posts", lazy="select")
6 changes: 4 additions & 2 deletions app/models/post/schemas.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from uuid import UUID

from pydantic import BaseModel


Expand All @@ -7,5 +9,5 @@ class PostIn(BaseModel):


class PostOut(PostIn):
id: str
user_id: str
id: UUID
user_id: UUID
2 changes: 1 addition & 1 deletion app/models/user/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

async def create_admin_user() -> User:
async with local_session() as async_session:
admin = await User.find(async_session, email=config.ADMIN_EMAIL)
admin = await User.find(async_session, email=config.ADMIN_EMAIL, raise_=False)
if not admin:
logger.info("Admin user not found. Creating one.")
admin = await User(
Expand Down
3 changes: 1 addition & 2 deletions app/models/user/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from sqlalchemy import JSON
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlalchemy.orm import Mapped, mapped_column, relationship

from app.database.base import Base
Expand All @@ -8,7 +7,7 @@
from app.models.post import Post


class User(AsyncAttrs, Base):
class User(Base):
__tablename__ = "user"
name: Mapped[str] = mapped_column()
email: Mapped[str] = mapped_column(unique=True)
Expand Down
4 changes: 2 additions & 2 deletions app/models/user/schemas.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from datetime import datetime
from typing import Annotated, ClassVar
from uuid import uuid4
from uuid import UUID, uuid4

from pydantic import BaseModel, ConfigDict, EmailStr, Field, SecretStr, model_validator

Expand All @@ -26,7 +26,7 @@ class UserIn(BaseModel):

class UserOut(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str
id: UUID
created_at: datetime
updated_at: datetime
name: str
Expand Down

0 comments on commit 2feb78e

Please sign in to comment.