Skip to content

Commit

Permalink
Back server.py to 0.7.5.post0
Browse files Browse the repository at this point in the history
  • Loading branch information
makseq committed Oct 6, 2020
1 parent 7230d1b commit d0ddb68
Showing 1 changed file with 19 additions and 53 deletions.
72 changes: 19 additions & 53 deletions label_studio/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from label_studio.utils.misc import (
exception_treatment, exception_treatment_page,
config_line_stripped, get_config_templates, convert_string_to_hash, serialize_class,
DirectionSwitch, check_port_in_use, timestamp_to_local_datetime
DirectionSwitch, check_port_in_use
)
from label_studio.utils.argparser import parse_input_args
from label_studio.utils.uri_resolver import resolve_task_data_uri
Expand Down Expand Up @@ -143,7 +143,7 @@ def send_upload(path):
logger.warning('Task path starting with "/upload/" is deprecated and will be removed in next releases, '
'replace "/upload/" => "/data/upload/" in your tasks.json files')
project = project_get_or_create()
project_dir = os.path.join(project.path, 'upload')
project_dir = os.path.join(project.name, 'upload')
return open(os.path.join(project_dir, path), 'rb').read()


Expand Down Expand Up @@ -241,7 +241,7 @@ def setup_page():
"""
project = project_get_or_create()

templates = get_config_templates(project.config)
templates = get_config_templates()
input_values = {}
project.analytics.send(getframeinfo(currentframe()).function)
return flask.render_template(
Expand Down Expand Up @@ -303,16 +303,9 @@ def model_page():
ml_backend.training_in_progress = training_status['is_training']
ml_backend.model_version = training_status['model_version']
ml_backend.is_connected = True
ml_backend.is_error = False
except Exception as exc:
logger.error(str(exc), exc_info=True)
ml_backend.is_error = True
try:
# try to parse json as the result of @exception_treatment
ml_backend.error = json.loads(str(exc))
except ValueError:
ml_backend.error = {'detail': "Can't parse exception message from ML Backend"}

else:
ml_backend.is_connected = False
ml_backends.append(ml_backend)
Expand Down Expand Up @@ -567,12 +560,12 @@ def api_generate_next_task():

task = resolve_task_data_uri(task)

project.analytics.send(getframeinfo(currentframe()).function)
#project.analytics.send(getframeinfo(currentframe()).function)

# collect prediction from multiple ml backends
if project.ml_backends_connected:
task = project.make_predictions(task)
logger.debug('Next task:\n' + str(task.get('id', None)))
logger.debug('Next task:\n' + json.dumps(task))
return make_response(jsonify(task), 200)


Expand All @@ -593,7 +586,6 @@ def api_project():

output = project.serialize()
output['multi_session_mode'] = input_args.command != 'start-multi-session'
project.analytics.send(getframeinfo(currentframe()).function, method=request.method)
return make_response(jsonify(output), code)


Expand Down Expand Up @@ -712,8 +704,6 @@ def api_all_tasks():
# get tasks with completions
tasks = []
for item in paginated:
if item['completed_at'] != 'undefined' and item['completed_at'] is not None:
item['completed_at'] = timestamp_to_local_datetime(item['completed_at']).strftime('%Y-%m-%d %H:%M:%S')
i = item['id']
task = project.get_task_with_completions(i)
if task is None: # no completion at task
Expand All @@ -739,15 +729,11 @@ def api_tasks(task_id):
if request.method == 'GET':
task_data = project.get_task_with_completions(task_id) or project.source_storage.get(task_id)
task_data = resolve_task_data_uri(task_data)

if project.ml_backends_connected:
task_data = project.make_predictions(task_data)

project.analytics.send(getframeinfo(currentframe()).function, method=request.method)
project.analytics.send(getframeinfo(currentframe()).function)
return make_response(jsonify(task_data), 200)
elif request.method == 'DELETE':
project.remove_task(task_id)
project.analytics.send(getframeinfo(currentframe()).function, method=request.method)
project.analytics.send(getframeinfo(currentframe()).function)
return make_response(jsonify('Task deleted.'), 204)


Expand Down Expand Up @@ -788,33 +774,14 @@ def api_completions(task_id):
completion.pop('skipped', None)
completion.pop('was_cancelled', None)
completion_id = project.save_completion(int(task_id), completion)
project.analytics.send(getframeinfo(currentframe()).function, method=request.method)
project.analytics.send(getframeinfo(currentframe()).function)
return make_response(json.dumps({'id': completion_id}), 201)

else:
project.analytics.send(getframeinfo(currentframe()).function, error=500, method=request.method)
project.analytics.send(getframeinfo(currentframe()).function, error=500)
return make_response('Incorrect request method', 500)


@app.route('/api/project/completions/', methods=['DELETE'])
@requires_auth
@exception_treatment
def api_all_completions():
""" Delete all completions
"""
project = project_get_or_create()

if request.method == 'DELETE':
project.delete_all_completions()
project.analytics.send(getframeinfo(currentframe()).function, method=request.method)
return make_response('done', 201)

else:
project.analytics.send(getframeinfo(currentframe()).function, error=500, method=request.method)
return make_response('Incorrect request method', 500)



@app.route('/api/tasks/<task_id>/cancel', methods=['POST'])
@requires_auth
@exception_treatment
Expand Down Expand Up @@ -842,13 +809,13 @@ def api_completion_by_id(task_id, completion_id):
if request.method == 'DELETE':
if project.config.get('allow_delete_completions', False):
project.delete_completion(int(task_id))
project.analytics.send(getframeinfo(currentframe()).function, method=request.method)
project.analytics.send(getframeinfo(currentframe()).function)
return make_response('deleted', 204)
else:
project.analytics.send(getframeinfo(currentframe()).function, error=422, method=request.method)
project.analytics.send(getframeinfo(currentframe()).function, error=422)
return make_response('Completion removing is not allowed in server config', 422)
else:
project.analytics.send(getframeinfo(currentframe()).function, error=500, method=request.method)
project.analytics.send(getframeinfo(currentframe()).function, error=500)
return make_response('Incorrect request method', 500)


Expand Down Expand Up @@ -1052,9 +1019,8 @@ def main():
label_studio.utils.auth.PASSWORD = input_args.password or config.get('password', '')

# set host name
host = input_args.host or config.get('host', 'localhost') # name for external links generation
host = input_args.host or config.get('host', 'localhost')
port = input_args.port or config.get('port', 8080)
server_host = 'localhost' if host == 'localhost' else '0.0.0.0' # web server host

# ssl certificate and key
cert_file = input_args.cert_file or config.get('cert')
Expand All @@ -1072,29 +1038,29 @@ def main():
'* Trying to start at ' + str(port) +
'\n****************\n')

set_web_protocol(input_args.protocol or config.get('protocol', 'http://'))
set_web_protocol(config.get('protocol', 'http://'))
set_full_hostname(get_web_protocol() + host.replace('0.0.0.0', 'localhost') + ':' + str(port))

start_browser('http://localhost:' + str(port), input_args.no_browser)
if input_args.use_gevent:
app.debug = input_args.debug
ssl_args = {'keyfile': key_file, 'certfile': cert_file} if ssl_context else {}
http_server = WSGIServer((server_host, port), app, log=app.logger, **ssl_args)
http_server = WSGIServer((host, port), app, log=app.logger, **ssl_args)
http_server.serve_forever()
else:
app.run(host=server_host, port=port, debug=input_args.debug, ssl_context=ssl_context)
app.run(host=host, port=port, debug=input_args.debug, ssl_context=ssl_context)

# On `start-multi-session` command, server creates one project per each browser sessions
elif input_args.command == 'start-multi-session':
server_host = input_args.host or '0.0.0.0'
host = input_args.host or '0.0.0.0'
port = input_args.port or 8080

if input_args.use_gevent:
app.debug = input_args.debug
http_server = WSGIServer((server_host, port), app, log=app.logger)
http_server = WSGIServer((host, port), app, log=app.logger)
http_server.serve_forever()
else:
app.run(host=server_host, port=port, debug=input_args.debug)
app.run(host=host, port=port, debug=input_args.debug)


if __name__ == "__main__":
Expand Down

0 comments on commit d0ddb68

Please sign in to comment.