Skip to content

Commit

Permalink
Merge pull request NVIDIA#1906 from IsaacYangSLA/dev/publish
Browse files Browse the repository at this point in the history
Add publish to inference server
  • Loading branch information
IsaacYangSLA authored Dec 6, 2017
2 parents 49fae45 + f952607 commit 4508e88
Show file tree
Hide file tree
Showing 9 changed files with 228 additions and 19 deletions.
8 changes: 5 additions & 3 deletions digits/model/images/classification/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@ def job_type(self):
return 'Image Classification Model'

@override
def download_files(self, epoch=-1):
def download_files(self, epoch=-1, frozen_file=False):
task = self.train_task()

snapshot_filenames = task.get_snapshot(epoch, download=True)
if frozen_file:
snapshot_filenames = task.get_snapshot(epoch, frozen_file=True)
else:
snapshot_filenames = task.get_snapshot(epoch, download=True)

# get model files
model_files = task.get_model_files()
Expand Down
3 changes: 2 additions & 1 deletion digits/model/tasks/caffe_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -1155,7 +1155,8 @@ def get_task_stats(self, epoch=-1):
"solver file": self.solver_file,
"train_val file": self.train_val_file,
"deploy file": self.deploy_file,
"framework": "caffe"
"framework": "caffe",
"mean subtraction": self.use_mean
}

# These attributes only available in more recent jobs:
Expand Down
7 changes: 5 additions & 2 deletions digits/model/tasks/tensorflow_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def before_run(self):
return True

@override
def get_snapshot(self, epoch=-1, download=False):
def get_snapshot(self, epoch=-1, download=False, frozen_file=False):
"""
return snapshot file for specified epoch
"""
Expand All @@ -164,6 +164,8 @@ def get_snapshot(self, epoch=-1, download=False):
meta_file = snapshot_pre + ".meta"
index_file = snapshot_pre + ".index"
snapshot_files = [snapshot_file, meta_file, index_file]
elif frozen_file:
snapshot_files = os.path.join(os.path.dirname(snapshot_pre), "frozen_model.pb")
else:
snapshot_files = snapshot_pre

Expand Down Expand Up @@ -968,7 +970,8 @@ def get_task_stats(self, epoch=-1):
"mean file": mean_file,
"snapshot file": self.get_snapshot_filename(epoch),
"model file": self.model_file,
"framework": "tensorflow"
"framework": "tensorflow",
"mean subtraction": self.use_mean
}

if hasattr(self, "digits_version"):
Expand Down
69 changes: 69 additions & 0 deletions digits/model/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import zipfile

import flask
from flask import flash
import requests
import werkzeug.exceptions

from . import images as model_images
Expand Down Expand Up @@ -257,6 +259,73 @@ def to_pretrained(job_id):
return flask.redirect(flask.url_for('digits.views.home', tab=3)), 302


@blueprint.route('/<job_id>/publish_inference', methods=['POST'])
def publish_inference(job_id):
"""
Publish model to inference server
"""
rie_url = os.environ.get('RIE_URL', "http://localhost:5055")

publish_endpoint = rie_url+'/models'

# Get data from the modal form
description = flask.request.form.get('description')
modality = flask.request.form.getlist('modality')
output_layer = flask.request.form.get('output_layer')
input_layer = flask.request.form.get('input_layer')
input_shape = flask.request.form.get('input_shape')
output_shape = flask.request.form.get('output_shape')

job = scheduler.get_job(job_id)

if job is None:
raise werkzeug.exceptions.NotFound('Job not found')

epoch = -1
# GET ?epoch=n
if 'epoch' in flask.request.args:
epoch = float(flask.request.args['epoch'])

# POST ?snapshot_epoch=n (from form)
elif 'snapshot_epoch' in flask.request.form:
epoch = float(flask.request.form['snapshot_epoch'])

# Write the stats of the job to json,
# and store in tempfile (for archive)
job_dict = job.json_dict(verbose=False, epoch=epoch)
job_dict.update({"output_layer": output_layer,
"description": description,
"input_layer": input_layer,
"input_shape": input_shape,
"output_shape": output_shape,
"modality": modality})
info = json.dumps(job_dict, sort_keys=True, indent=4, separators=(',', ': '))
info_io = io.BytesIO()
info_io.write(info)

b = io.BytesIO()
mode = ''
with tarfile.open(fileobj=b, mode='w:%s' % mode) as tar:
for path, name in job.download_files(epoch, frozen_file=(job_dict['framework'] == 'tensorflow')):
tar.add(path, arcname=name)
tar_info = tarfile.TarInfo("info.json")
tar_info.size = len(info_io.getvalue())
info_io.seek(0)
tar.addfile(tar_info, info_io)

temp_buffer = b.getvalue()
files = {'model': ('tmp.tgz', temp_buffer)}
try:
r = requests.post(publish_endpoint, files=files)
except Exception as e:
return flask.make_response(e)
if r.status_code != requests.codes.ok:
raise werkzeug.exceptions.BadRequest("Bad Request")
end_point = json.loads(r.text)["location"]
flash('Model successfully published to RIE.<p>New endpoint at {}'.format(end_point))
return flask.redirect(flask.request.referrer), 302


@blueprint.route('/<job_id>/download',
methods=['GET', 'POST'],
defaults={'extension': 'tar.gz'})
Expand Down
2 changes: 1 addition & 1 deletion digits/templates/helper.html
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
{% if messages %}
{% for category, message in messages %}
<div class="alert alert-{{ category }}">
{{ message }}
{{ message | safe}}
</div>
{% endfor %}
{% endif %}
Expand Down
102 changes: 100 additions & 2 deletions digits/templates/models/images/classification/show.html
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,93 @@ <h4 class='text-center'>Dataset</h4>
</div>
</div>

<div id="nuance_modal" class="modal fade" role="dialog">
<div class="modal-dialog">
<!-- Modal content-->
<div class="modal-content">
<form method="post" id="nuance_form"
action="{{ url_for('digits.model.views.publish_inference', job_id=job.id()) }}">
<div class="modal-header">
<button type="button" class="close" data-dismiss="modal">
&times;
</button>
<h4 class="modal-title">Publish to inference server</h4>
</div>
<div class="modal-body">
<p>
<div class="row">
<div class="col-sm-6" align="right">
<label for="description">Description:</label>
</div>
<div class="col-sm-6">
<input type="text" name="description" id="description"
class="form-control" placeholder="Description">
</div>
</div>
<div class="row">
<div class="col-sm-6" align="right">
<label for="output_layer">Output layer:</label>
</div>
<div class="col-sm-6">
<input type="text" name="output_layer" id="output_layer"
class="form-control" placeholder="softmax">
</div>
</div>
<div class="row">
<div class="col-sm-6" align="right">
<label for="input_layer">Input layer:</label>
</div>
<div class="col-sm-6">
<input type="text" name="input_layer" id="input_layer"
class="form-control" placeholder="in">
</div>
</div>
<div class="row">
<div class="col-sm-6" align="right">
<label for="input_shape">Input shape:</label>
</div>
<div class="col-sm-6">
<input type="text" name="input_shape" id="input_shape"
class="form-control" placeholder="224, 224">
</div>
</div>
<div class="row">
<div class="col-sm-6" align="right">
<label for="output_shape">Output shape:</label>
</div>
<div class="col-sm-6">
<input type="text" name="output_shape" id="output_shape"
class="form-control" placeholder="10">
</div>
</div>
<div class="row">
<div class="col-sm-6" align="right">
<label>Modality:</label>
</div>
<div class="col-sm-6">
<div class="checkbox-inline">
<label><input type="checkbox" name="modality" value="CT">CT</label>
</div>
<div class="checkbox-inline">
<label><input type="checkbox" name="modality" value="XA">X-ray</label>
</div>
<div class="checkbox-inline">
<label><input type="checkbox" name="modality" value="MR">MR</label>
</div>
</div>
</div>
</p>
</div>
<div class="modal-footer">
<button type="submit" class="btn btn-info">Submit
</button>
</div>
</form>
</div>
</div>
</div>


<div class="row">
<div class="col-sm-12">
<div class="well">
Expand Down Expand Up @@ -114,7 +201,7 @@ <h2>Trained Models</h2>
</div>
</div>
<div class="row">
<div class="col-sm-6">
<div class="col-sm-3">
<div class="form-group">
<select id="snapshot_epoch" name="snapshot_epoch" class="form-control">
</select>
Expand Down Expand Up @@ -146,7 +233,7 @@ <h2>Trained Models</h2>
</script>
</div>
</div>
<div class="col-sm-6">
<div class="col-sm-9">
<button
formaction="{{url_for('digits.model.views.download', job_id=job.id())}}"
formmethod="post"
Expand All @@ -161,6 +248,11 @@ <h2>Trained Models</h2>
class="btn btn-success">
Make Pretrained Model
</button>
<!-- Trigger the modal with a button -->
<button type="button" class="btn btn-success" data-toggle="modal"
data-target="#nuance_modal">
Publish to inference server
</button>
</div>
</div>
{% if task.get_framework_id() in framework_ids %}
Expand Down Expand Up @@ -280,5 +372,11 @@ <h3>Test a list of images</h3>
</div>
</div>

<script>
$('#nuance_form').submit(function() {
$('#nuance_modal').modal('toggle');
});
</script>

{% endblock %}

4 changes: 2 additions & 2 deletions digits/templates/models/images/generic/show.html
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ <h2>Trained Models</h2>
</div>
</div>
<div class="row">
<div class="col-sm-8">
<div class="col-sm-4">
<div class="form-group">
<select id="snapshot_epoch" name="snapshot_epoch" class="form-control">
</select>
Expand Down Expand Up @@ -142,7 +142,7 @@ <h2>Trained Models</h2>
</script>
</div>
</div>
<div class="col-sm-4">
<div class="col-sm-8">
<button
formaction="{{url_for('digits.model.views.download', job_id=job.id())}}"
formmethod="post"
Expand Down
Loading

0 comments on commit 4508e88

Please sign in to comment.