Skip to content

Commit

Permalink
Merge pull request RasaHQ#1081 from dcalvom/concurrent-training-proce…
Browse files Browse the repository at this point in the history
…sses-for-project

Concurrent training processes for project
  • Loading branch information
tmbo authored Jul 9, 2018
2 parents cc5318a + fe5f299 commit 523fe27
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 11 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,14 @@ Added
- intent_featurizer_count_vectors adds features to text_features instead of overwriting them
- add basic OOV support to intent_featurizer_count_vectors (make previously trained models impossible to load)
- add a feature for each regex in the training set for crf_entity_extractor
- Current training processes count for server and projects.

Changed
-------
- L1 and L2 regularisation defaults in ``ner_crf`` both set to 0.1
- ``whitespace_tokenizer`` ignores punctuation ``.,!?`` before whitespace or end of string
- Allow multiple training processes per project
- Changed AlreadyTrainingError to MaxTrainingError. The first one was used to indicate that the project was already training. The latest will show an error when the server isn't able to training more models.

Removed
-------
Expand Down
32 changes: 24 additions & 8 deletions rasa_nlu/data_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,16 @@
DEFERRED_RUN_IN_REACTOR_THREAD = True


class AlreadyTrainingError(Exception):
"""Raised when a training is requested for a project that is
already training.
class MaxTrainingError(Exception):
"""Raised when a training is requested and the server has
reached the max count of training processes.
Attributes:
message -- explanation of why the request is invalid
"""

def __init__(self):
self.message = 'The project is already being trained!'
self.message = 'The server can\'t train more models right now!'

def __str__(self):
return self.message
Expand Down Expand Up @@ -94,6 +94,7 @@ def __init__(self,
remote_storage=None,
component_builder=None):
self._training_processes = max(max_training_processes, 1)
self._current_training_processes = 0
self.responses = self._create_query_logger(response_log)
self.project_dir = config.make_path_absolute(project_dir)
self.emulator = self._create_emulator(emulation_mode)
Expand Down Expand Up @@ -276,6 +277,8 @@ def get_status(self):
# be other trainings run in different processes we don't know about.

return {
"max_training_processes": self._training_processes,
"current_training_processes": self._current_training_processes,
"available_projects": {
name: project.as_dict()
for name, project in self.project_store.items()
Expand All @@ -295,8 +298,8 @@ def start_train_process(self,
raise InvalidProjectError("Missing project name to train")

if project in self.project_store:
if self.project_store[project].status == 1:
raise AlreadyTrainingError
if self._training_processes <= self._current_training_processes:
raise MaxTrainingError
else:
self.project_store[project].status = 1
elif project not in self.project_store:
Expand All @@ -308,18 +311,31 @@ def start_train_process(self,
def training_callback(model_path):
model_dir = os.path.basename(os.path.normpath(model_path))
self.project_store[project].update(model_dir)
self._current_training_processes -= 1
self.project_store[project].current_training_processes -= 1
if (self.project_store[project].status == 1 and
self.project_store[project].current_training_processes ==
0):
self.project_store[project].status = 0
return model_dir

def training_errback(failure):
logger.warn(failure)
target_project = self.project_store.get(
failure.value.failed_target_project)
if target_project:
self._current_training_processes -= 1
self.project_store[project].current_training_processes -= 1
if (target_project and
self.project_store[project].current_training_processes ==
0):
target_project.status = 0
return failure

logger.debug("New training queued")

self._current_training_processes += 1
self.project_store[project].current_training_processes += 1

result = self.pool.submit(do_train_in_worker,
train_config,
data_file,
Expand Down Expand Up @@ -381,7 +397,7 @@ def unload_model(self, project, model):
"""Unload a model from server memory."""

if project is None:
raise InvalidProjectError("No project specified".format(project))
raise InvalidProjectError("No project specified")
elif project not in self.project_store:
raise InvalidProjectError("Project {} could not "
"be found".format(project))
Expand Down
3 changes: 2 additions & 1 deletion rasa_nlu/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(self,
self._component_builder = component_builder
self._models = {}
self.status = 0
self.current_training_processes = 0
self._reader_lock = Lock()
self._loader_lock = Lock()
self._writer_lock = Lock()
Expand Down Expand Up @@ -151,7 +152,6 @@ def update(self, model_name):
self._writer_lock.acquire()
self._models[model_name] = None
self._writer_lock.release()
self.status = 0

def unload(self, model_name):
self._writer_lock.acquire()
Expand Down Expand Up @@ -215,6 +215,7 @@ def _read_model_metadata(self, model_name):

def as_dict(self):
return {'status': 'training' if self.status else 'ready',
'current_training_processes': self.current_training_processes,
'available_models': list(self._models.keys()),
'loaded_models': self._list_loaded_models()}

Expand Down
4 changes: 2 additions & 2 deletions rasa_nlu/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from rasa_nlu.config import RasaNLUModelConfig
from rasa_nlu.data_router import (
DataRouter, InvalidProjectError,
AlreadyTrainingError)
MaxTrainingError)
from rasa_nlu.train import TrainingException
from rasa_nlu.utils import json_to_string
from rasa_nlu.version import __version__
Expand Down Expand Up @@ -352,7 +352,7 @@ def train(self, request):

returnValue(json_to_string({'info': 'new model trained: {}'
''.format(response)}))
except AlreadyTrainingError as e:
except MaxTrainingError as e:
request.setResponseCode(403)
returnValue(json_to_string({"error": "{}".format(e)}))
except InvalidProjectError as e:
Expand Down
2 changes: 2 additions & 0 deletions tests/base/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def test_status(app):
response = yield app.get("http://dummy-uri/status")
rjs = yield response.json()
assert response.code == 200 and "available_projects" in rjs
assert "current_training_processes" in rjs
assert "max_training_processes" in rjs
assert "default" in rjs["available_projects"]


Expand Down

0 comments on commit 523fe27

Please sign in to comment.