Skip to content

Commit

Permalink
Section 7.4 - Prediction Endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristopherGS committed Mar 29, 2020
1 parent 17948c1 commit b02a156
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 1 deletion.
19 changes: 18 additions & 1 deletion packages/ml_api/api/controller.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from flask import Blueprint, request
from flask import Blueprint, request, jsonify
from regression_model.predict import make_prediction

from api.config import get_logger

Expand All @@ -13,3 +14,19 @@ def health():
if request.method == 'GET':
_logger.info('health status OK')
return 'ok'


@prediction_app.route('/v1/predict/regression', methods=['POST'])
def predict():
if request.method == 'POST':
json_data = request.get_json()
_logger.info(f'Inputs: {json_data}')

result = make_prediction(input_data=json_data)
_logger.info(f'Outputs: {result}')

predictions = result.get('predictions')[0]
version = result.get('version')

return jsonify({'predictions': predictions,
'version': version})
30 changes: 30 additions & 0 deletions packages/ml_api/tests/test_controller.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,36 @@
from regression_model.config import config as model_config
from regression_model.processing.data_management import load_dataset
from regression_model import __version__ as _version

import json
import math


def test_health_endpoint_returns_200(flask_test_client):
# When
response = flask_test_client.get('/health')

# Then
assert response.status_code == 200


def test_prediction_endpoint_returns_prediction(flask_test_client):
# Given
# Load the test data from the regression_model package
# This is important as it makes it harder for the test
# data versions to get confused by not spreading it
# across packages.
test_data = load_dataset(file_name=model_config.TESTING_DATA_FILE)
post_json = test_data[0:1].to_json(orient='records')

# When
response = flask_test_client.post('/v1/predict/regression',
json=post_json)

# Then
assert response.status_code == 200
response_json = json.loads(response.data)
prediction = response_json['predictions']
response_version = response_json['version']
assert math.ceil(prediction) == 112476
assert response_version == _version

0 comments on commit b02a156

Please sign in to comment.