Skip to content

Commit

Permalink
Merge pull request jmcarp#40 from Tackitt/master
Browse files Browse the repository at this point in the history
Defer registering views until after init_app
  • Loading branch information
sloria authored Apr 1, 2017
2 parents 62b8a09 + c36b8ec commit d9dd1e2
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 7 deletions.
49 changes: 42 additions & 7 deletions flask_apispec/extension.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# -*- coding: utf-8 -*-

import types

import flask
import functools
import types
from apispec import APISpec

from flask_apispec import ResourceMeta
from flask_apispec.apidoc import ViewConverter, ResourceConverter


class FlaskApiSpec(object):
"""Flask-apispec extension.
Expand Down Expand Up @@ -35,17 +35,35 @@ def get_pet(pet_id):
:param Flask app: App associated with API documentation
:param APISpec spec: apispec specification associated with API documentation
"""

def __init__(self, app=None):
self._deferred = []
self.app = app
self.view_converter = None
self.resource_converter = None
self.spec = None

if app:
self.init_app(app)

def init_app(self, app):
self.app = app
self.view_converter = ViewConverter(self.app)
self.resource_converter = ResourceConverter(self.app)
self.spec = self.app.config.get('APISPEC_SPEC') or make_apispec()
self.spec = self.app.config.get('APISPEC_SPEC') or \
make_apispec(self.app.config.get('APISPEC_TITLE', 'flask-apispec'),
self.app.config.get('APISPEC_VERSION', 'v1'))
self.add_routes()

for deferred in self._deferred:
deferred()

def _defer(self, callable, *args, **kwargs):
bound = functools.partial(callable, *args, **kwargs)
self._deferred.append(bound)
if self.app:
bound()

def add_routes(self):
blueprint = flask.Blueprint(
'flask-apispec',
Expand Down Expand Up @@ -75,6 +93,22 @@ def register(self, target, endpoint=None, blueprint=None,
resource_class_args=None, resource_class_kwargs=None):
"""Register a view.
:param target: view function or view class.
:param endpoint: (optional) endpoint name.
:param blueprint: (optional) blueprint name.
:param tuple resource_class_args: (optional) args to be forwarded to the
view class constructor.
:param dict resource_class_kwargs: (optional) kwargs to be forwarded to
the view class constructor.
"""

self._defer(self._register, target, endpoint, blueprint,
resource_class_args, resource_class_kwargs)

def _register(self, target, endpoint=None, blueprint=None,
resource_class_args=None, resource_class_kwargs=None):
"""Register a view.
:param target: view function or view class.
:param endpoint: (optional) endpoint name.
:param blueprint: (optional) blueprint name.
Expand All @@ -98,9 +132,10 @@ def register(self, target, endpoint=None, blueprint=None,
for path in paths:
self.spec.add_path(**path)

def make_apispec():

def make_apispec(title='flask-apispec', version='v1'):
return APISpec(
title='flask-apispec',
version='v1',
title=title,
version=version,
plugins=['apispec.ext.marshmallow'],
)
28 changes: 28 additions & 0 deletions tests/test_extension.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,34 @@
# -*- coding: utf-8 -*-

import pytest
from flask import Blueprint

from flask_apispec import doc
from flask_apispec.extension import FlaskApiSpec
from flask_apispec.views import MethodResource


@pytest.fixture
def docs(app):
return FlaskApiSpec(app)

class TestExtension:
def test_deferred_register(self, app):
blueprint = Blueprint('test', __name__)
docs = FlaskApiSpec()

@doc(tags=['band'])
class BandResource(MethodResource):
def get(self, **kwargs):
return 'slowdive'

blueprint.add_url_rule('/bands/<band_id>/', view_func=BandResource.as_view('band'))
docs.register(BandResource, endpoint='band', blueprint=blueprint.name)

app.register_blueprint(blueprint)
docs.init_app(app)

assert '/bands/{band_id}/' in docs.spec._paths

def test_register_function(self, app, docs):
@app.route('/bands/<int:band_id>/')
Expand Down Expand Up @@ -62,3 +80,13 @@ def test_serve_swagger_ui_custom_url(self, app, client):
app.config['APISPEC_SWAGGER_UI_URL'] = '/swagger-ui.html'
FlaskApiSpec(app)
client.get('/swagger-ui.html')

def test_apispec_config(self, app):
app.config['APISPEC_TITLE'] = 'test-extension'
app.config['APISPEC_VERSION'] = '2.1'
docs = FlaskApiSpec(app)

assert docs.spec.info == {
'title': 'test-extension',
'version': '2.1',
}

0 comments on commit d9dd1e2

Please sign in to comment.