Skip to content

Commit

Permalink
Add ApiKeyPlugin to extract X-API-Key header (tomwojcik#26)
Browse files Browse the repository at this point in the history
`X-API-Key` is the header specified by
[AWS API Gateway](https://docs.aws.amazon.com/apigateway/latest/developerguide/api-gateway-api-key-source.html),
and also detailed in [Swagger documentation](https://swagger.io/docs/specification/authentication/api-keys/),
to send an API key.
  • Loading branch information
adamantike authored Oct 17, 2020
1 parent a935fbf commit 0022215
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 2 deletions.
10 changes: 8 additions & 2 deletions docs/source/plugins.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ You add as many plugins as you want to your middleware. You pass them to the mid
app = Starlette(middleware=middleware)
*******
API Key
*******

Extracts header "X-API-Key" and keeps it in context.

**************
Correlation ID
**************
Expand All @@ -50,9 +56,9 @@ Forwarded For

Extracts header "X-Forwarded-For" and keeps it in context.

************
**********
Request ID
************
**********

Extracts header "X-Request-ID" and keeps it in context.
You can pass `force_new_uuid=True` to enforce the creation of a new UUID.
Expand Down
1 change: 1 addition & 0 deletions starlette_context/header_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@


class HeaderKeys(str, Enum):
api_key = "X-API-Key"
correlation_id = "X-Correlation-ID"
request_id = "X-Request-ID"
date = "Date"
Expand Down
1 change: 1 addition & 0 deletions starlette_context/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .api_key import ApiKeyPlugin
from .base import Plugin
from .correlation_id import CorrelationIdPlugin
from .date_header import DateHeaderPlugin
Expand Down
6 changes: 6 additions & 0 deletions starlette_context/plugins/api_key.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from starlette_context.header_keys import HeaderKeys
from starlette_context.plugins.base import Plugin


class ApiKeyPlugin(Plugin):
key = HeaderKeys.api_key
2 changes: 2 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from starlette_context.header_keys import HeaderKeys
from starlette_context.middleware import ContextMiddleware

dummy_api_key = "abcdef12345"
dummy_correlation_id = uuid.uuid4().hex
dummy_request_id = uuid.uuid4().hex
dummy_user_agent = "dummy_user_agent"
Expand All @@ -21,6 +22,7 @@ def headers():
h = MutableHeaders()
h.update(
{
HeaderKeys.api_key: dummy_api_key,
HeaderKeys.correlation_id: dummy_correlation_id,
HeaderKeys.request_id: dummy_request_id,
HeaderKeys.date: dummy_date,
Expand Down
42 changes: 42 additions & 0 deletions tests/test_plugins/test_api_key.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from starlette import status
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from starlette.testclient import TestClient

from starlette_context import plugins
from starlette_context.header_keys import HeaderKeys
from starlette_context.middleware import ContextMiddleware
from tests.conftest import dummy_api_key

middleware = [
Middleware(
ContextMiddleware,
plugins=(plugins.ApiKeyPlugin(),),
)
]
app = Starlette(middleware=middleware)
client = TestClient(app)
headers = {HeaderKeys.api_key: dummy_api_key}


@app.route("/")
async def index(request: Request) -> Response:
return JSONResponse(
{"headers": str(request.headers.get(HeaderKeys.api_key))}
)


def test_valid_request_returns_proper_response():
response = client.get("/", headers=headers)
assert response.status_code == status.HTTP_200_OK
assert dummy_api_key in response.text
assert HeaderKeys.api_key not in response.text


def test_missing_forwarded_for_header():
response = client.get("/", headers={})
assert response.status_code == status.HTTP_200_OK
assert dummy_api_key not in response.text
assert HeaderKeys.api_key not in response.headers

0 comments on commit 0022215

Please sign in to comment.