forked from ArjanCodes/examples
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
192 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
from fastapi import FastAPI, Request | ||
from fastapi.testclient import TestClient | ||
from slowapi import Limiter | ||
from slowapi.util import get_remote_address | ||
|
||
RATE_LIMITING_ENABLED = True | ||
|
||
app = FastAPI() | ||
limiter = Limiter( | ||
key_func=get_remote_address, | ||
strategy="fixed-window", | ||
storage_uri="memory://", | ||
enabled=RATE_LIMITING_ENABLED, | ||
) | ||
|
||
|
||
@app.get("/limited") | ||
@limiter.limit("2/second", per_method=True) | ||
async def limited_route(request: Request) -> dict[str, str]: | ||
return {"message": "This is a limited route"} | ||
|
||
|
||
@app.get("/unlimited") | ||
async def unlimited_route(request: Request) -> dict[str, str]: | ||
return {"message": "This is an unlimited route"} | ||
|
||
|
||
def main() -> None: | ||
client = TestClient(app) | ||
for _ in range(5): | ||
response = client.get("/unlimited") | ||
print(response.status_code, response.json()) | ||
|
||
for _ in range(5): | ||
response = client.get("/limited") | ||
print(response.status_code, response.json()) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
[tool.poetry] | ||
name = "rate-limiting" | ||
version = "0.1.0" | ||
description = "" | ||
authors = ["ArjanCodes"] | ||
|
||
[tool.poetry.dependencies] | ||
python = "^3.12" | ||
fastapi = "^0.110.2" | ||
slowapi = "^0.1.9" | ||
httpx = "^0.27.0" | ||
uvicorn = "^0.29.0" | ||
|
||
|
||
[build-system] | ||
requires = ["poetry-core"] | ||
build-backend = "poetry.core.masonry.api" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import hashlib | ||
import time | ||
from dataclasses import dataclass | ||
from functools import wraps | ||
from typing import Any, Callable | ||
|
||
from fastapi import FastAPI, HTTPException, Request | ||
|
||
app = FastAPI() | ||
|
||
|
||
# Mock database of API keys and their respective limits | ||
@dataclass | ||
class RateLimit: | ||
max_calls: int | ||
period: int | ||
|
||
|
||
api_key_limits = { | ||
"api_key_1": RateLimit(max_calls=5, period=60), | ||
"api_key_2": RateLimit(max_calls=10, period=60), | ||
} | ||
|
||
|
||
def rate_limit(): | ||
def decorator(func: Callable[[Request], Any]) -> Callable[[Request], Any]: | ||
usage: dict[str, list[float]] = {} | ||
|
||
@wraps(func) | ||
async def wrapper(request: Request) -> Any: | ||
# get the API key | ||
api_key = request.headers.get("X-API-KEY") | ||
if not api_key: | ||
raise HTTPException(status_code=400, detail="API key missing") | ||
|
||
# check if the API key is valid | ||
if api_key not in api_key_limits: | ||
raise HTTPException(status_code=403, detail="Invalid API key") | ||
|
||
# get the rate limits for the API key | ||
limits = api_key_limits[api_key] | ||
|
||
# get the client's IP address | ||
if not request.client: | ||
raise ValueError("Request has no client information") | ||
ip_address: str = request.client.host | ||
|
||
# create a unique identifier for the client | ||
unique_id: str = hashlib.sha256((api_key + ip_address).encode()).hexdigest() | ||
|
||
# update the timestamps | ||
now = time.time() | ||
if unique_id not in usage: | ||
usage[unique_id] = [] | ||
timestamps = usage[unique_id] | ||
timestamps[:] = [t for t in timestamps if now - t < limits.period] | ||
|
||
if len(timestamps) < limits.max_calls: | ||
timestamps.append(now) | ||
return await func(request) | ||
|
||
# calculate the time to wait before the next request | ||
wait = limits.period - (now - timestamps[0]) | ||
raise HTTPException( | ||
status_code=429, | ||
detail=f"Rate limit exceeded. Retry after {wait:.2f} seconds", | ||
) | ||
|
||
return wrapper | ||
|
||
return decorator | ||
|
||
|
||
@app.get("/") | ||
@rate_limit() | ||
async def read_root(request: Request): | ||
return {"message": "Hello, World!"} | ||
|
||
|
||
# Run the server using `uvicorn script_name:app --reload` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import hashlib | ||
import time | ||
from dataclasses import dataclass | ||
from functools import wraps | ||
from typing import Any, Callable | ||
|
||
from fastapi import FastAPI, HTTPException, Request | ||
|
||
app = FastAPI() | ||
|
||
|
||
def rate_limit(max_calls: int, period: int): | ||
def decorator(func: Callable[[Request], Any]) -> Callable[[Request], Any]: | ||
usage: dict[str, list[float]] = {} | ||
|
||
@wraps(func) | ||
async def wrapper(request: Request) -> Any: | ||
# get the client's IP address | ||
if not request.client: | ||
raise ValueError("Request has no client information") | ||
ip_address: str = request.client.host | ||
|
||
# create a unique identifier for the client | ||
unique_id: str = hashlib.sha256((ip_address).encode()).hexdigest() | ||
|
||
# update the timestamps | ||
now = time.time() | ||
if unique_id not in usage: | ||
usage[unique_id] = [] | ||
timestamps = usage[unique_id] | ||
timestamps[:] = [t for t in timestamps if now - t < period] | ||
|
||
if len(timestamps) < max_calls: | ||
timestamps.append(now) | ||
return await func(request) | ||
|
||
# calculate the time to wait before the next request | ||
wait = period - (now - timestamps[0]) | ||
raise HTTPException( | ||
status_code=429, | ||
detail=f"Rate limit exceeded. Retry after {wait:.2f} seconds", | ||
) | ||
|
||
return wrapper | ||
|
||
return decorator | ||
|
||
|
||
@app.get("/") | ||
@rate_limit(max_calls=5, period=60) | ||
async def read_root(request: Request): | ||
return {"message": "Hello, World!"} | ||
|
||
|
||
# Run the server using `uvicorn script_name:app --reload` |