diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 9963b1c21a08..1c4c3d8e7e21 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -30,11 +30,14 @@ Added - intent_featurizer_count_vectors adds features to text_features instead of overwriting them - add basic OOV support to intent_featurizer_count_vectors (make previously trained models impossible to load) - add a feature for each regex in the training set for crf_entity_extractor +- Current training processes count for server and projects. Changed ------- - L1 and L2 regularisation defaults in ``ner_crf`` both set to 0.1 - ``whitespace_tokenizer`` ignores punctuation ``.,!?`` before whitespace or end of string +- Allow multiple training processes per project +- Changed AlreadyTrainingError to MaxTrainingError. The first one was used to indicate that the project was already training. The latest will show an error when the server isn't able to training more models. Removed ------- diff --git a/rasa_nlu/data_router.py b/rasa_nlu/data_router.py index 976a65c9b141..898f7af4039c 100644 --- a/rasa_nlu/data_router.py +++ b/rasa_nlu/data_router.py @@ -43,16 +43,16 @@ DEFERRED_RUN_IN_REACTOR_THREAD = True -class AlreadyTrainingError(Exception): - """Raised when a training is requested for a project that is - already training. +class MaxTrainingError(Exception): + """Raised when a training is requested and the server has + reached the max count of training processes. Attributes: message -- explanation of why the request is invalid """ def __init__(self): - self.message = 'The project is already being trained!' + self.message = 'The server can\'t train more models right now!' def __str__(self): return self.message @@ -94,6 +94,7 @@ def __init__(self, remote_storage=None, component_builder=None): self._training_processes = max(max_training_processes, 1) + self._current_training_processes = 0 self.responses = self._create_query_logger(response_log) self.project_dir = config.make_path_absolute(project_dir) self.emulator = self._create_emulator(emulation_mode) @@ -276,6 +277,8 @@ def get_status(self): # be other trainings run in different processes we don't know about. return { + "max_training_processes": self._training_processes, + "current_training_processes": self._current_training_processes, "available_projects": { name: project.as_dict() for name, project in self.project_store.items() @@ -295,8 +298,8 @@ def start_train_process(self, raise InvalidProjectError("Missing project name to train") if project in self.project_store: - if self.project_store[project].status == 1: - raise AlreadyTrainingError + if self._training_processes <= self._current_training_processes: + raise MaxTrainingError else: self.project_store[project].status = 1 elif project not in self.project_store: @@ -308,18 +311,31 @@ def start_train_process(self, def training_callback(model_path): model_dir = os.path.basename(os.path.normpath(model_path)) self.project_store[project].update(model_dir) + self._current_training_processes -= 1 + self.project_store[project].current_training_processes -= 1 + if (self.project_store[project].status == 1 and + self.project_store[project].current_training_processes == + 0): + self.project_store[project].status = 0 return model_dir def training_errback(failure): logger.warn(failure) target_project = self.project_store.get( failure.value.failed_target_project) - if target_project: + self._current_training_processes -= 1 + self.project_store[project].current_training_processes -= 1 + if (target_project and + self.project_store[project].current_training_processes == + 0): target_project.status = 0 return failure logger.debug("New training queued") + self._current_training_processes += 1 + self.project_store[project].current_training_processes += 1 + result = self.pool.submit(do_train_in_worker, train_config, data_file, @@ -381,7 +397,7 @@ def unload_model(self, project, model): """Unload a model from server memory.""" if project is None: - raise InvalidProjectError("No project specified".format(project)) + raise InvalidProjectError("No project specified") elif project not in self.project_store: raise InvalidProjectError("Project {} could not " "be found".format(project)) diff --git a/rasa_nlu/project.py b/rasa_nlu/project.py index 27391177a5fa..bb8e073cb146 100644 --- a/rasa_nlu/project.py +++ b/rasa_nlu/project.py @@ -35,6 +35,7 @@ def __init__(self, self._component_builder = component_builder self._models = {} self.status = 0 + self.current_training_processes = 0 self._reader_lock = Lock() self._loader_lock = Lock() self._writer_lock = Lock() @@ -151,7 +152,6 @@ def update(self, model_name): self._writer_lock.acquire() self._models[model_name] = None self._writer_lock.release() - self.status = 0 def unload(self, model_name): self._writer_lock.acquire() @@ -215,6 +215,7 @@ def _read_model_metadata(self, model_name): def as_dict(self): return {'status': 'training' if self.status else 'ready', + 'current_training_processes': self.current_training_processes, 'available_models': list(self._models.keys()), 'loaded_models': self._list_loaded_models()} diff --git a/rasa_nlu/server.py b/rasa_nlu/server.py index 2127520fea97..8457663a7887 100644 --- a/rasa_nlu/server.py +++ b/rasa_nlu/server.py @@ -18,7 +18,7 @@ from rasa_nlu.config import RasaNLUModelConfig from rasa_nlu.data_router import ( DataRouter, InvalidProjectError, - AlreadyTrainingError) + MaxTrainingError) from rasa_nlu.train import TrainingException from rasa_nlu.utils import json_to_string from rasa_nlu.version import __version__ @@ -352,7 +352,7 @@ def train(self, request): returnValue(json_to_string({'info': 'new model trained: {}' ''.format(response)})) - except AlreadyTrainingError as e: + except MaxTrainingError as e: request.setResponseCode(403) returnValue(json_to_string({"error": "{}".format(e)})) except InvalidProjectError as e: diff --git a/tests/base/test_server.py b/tests/base/test_server.py index c754437a3d51..35024fa50265 100644 --- a/tests/base/test_server.py +++ b/tests/base/test_server.py @@ -57,6 +57,8 @@ def test_status(app): response = yield app.get("http://dummy-uri/status") rjs = yield response.json() assert response.code == 200 and "available_projects" in rjs + assert "current_training_processes" in rjs + assert "max_training_processes" in rjs assert "default" in rjs["available_projects"]