diff --git a/rasa/core/run.py b/rasa/core/run.py index 9222287e2817..e6830637b6d1 100644 --- a/rasa/core/run.py +++ b/rasa/core/run.py @@ -66,6 +66,7 @@ def configure_app( jwt_method: Optional[Text] = None, route: Optional[Text] = "/webhooks/", port: int = constants.DEFAULT_SERVER_PORT, + endpoints: Optional[AvailableEndpoints] = None, log_file: Optional[Text] = None, ): """Run the agent.""" @@ -77,6 +78,7 @@ def configure_app( auth_token=auth_token, jwt_secret=jwt_secret, jwt_method=jwt_method, + endpoints=endpoints, ) else: app = Sanic(__name__, configure_logging=False) @@ -143,6 +145,7 @@ def serve_application( jwt_secret, jwt_method, port=port, + endpoints=endpoints, log_file=log_file, ) diff --git a/rasa/server.py b/rasa/server.py index f3f882200f43..cf6019ee3023 100644 --- a/rasa/server.py +++ b/rasa/server.py @@ -23,16 +23,17 @@ DEFAULT_DOMAIN_PATH, DOCS_BASE_URL, ) +from rasa.core import broker from rasa.core.agent import load_agent, Agent from rasa.core.channels import UserMessage, CollectingOutputChannel from rasa.core.events import Event from rasa.core.test import test from rasa.core.trackers import DialogueStateTracker, EventVerbosity -from rasa.core.utils import dump_obj_as_str_to_file +from rasa.core.utils import dump_obj_as_str_to_file, AvailableEndpoints from rasa.model import get_model_subdirectories, fingerprint_from_path from rasa.nlu.emulators.no_emulator import NoEmulator from rasa.nlu.test import run_evaluation - +from rasa.core.tracker_store import TrackerStore logger = logging.getLogger(__name__) @@ -233,9 +234,29 @@ async def _load_agent( model_path: Optional[Text] = None, model_server: Optional[EndpointConfig] = None, remote_storage: Optional[Text] = None, + endpoints: Optional[AvailableEndpoints] = None, ) -> Agent: try: - loaded_agent = await load_agent(model_path, model_server, remote_storage) + tracker_store = None + generator = None + action_endpoint = None + + if endpoints: + _broker = broker.from_endpoint_config(endpoints.event_broker) + tracker_store = TrackerStore.find_tracker_store( + None, endpoints.tracker_store, _broker + ) + generator = endpoints.nlg + action_endpoint = endpoints.action + + loaded_agent = await load_agent( + model_path, + model_server, + remote_storage, + generator=generator, + tracker_store=tracker_store, + action_endpoint=action_endpoint, + ) except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse( @@ -259,6 +280,7 @@ def create_app( auth_token: Optional[Text] = None, jwt_secret: Optional[Text] = None, jwt_method: Text = "HS256", + endpoints: Optional[AvailableEndpoints] = None, ): """Class representing a Rasa HTTP server.""" @@ -796,7 +818,9 @@ async def load_model(request: Request): model_server = request.json.get("model_server", None) remote_storage = request.json.get("remote_storage", None) - app.agent = await _load_agent(model_path, model_server, remote_storage) + app.agent = await _load_agent( + model_path, model_server, remote_storage, endpoints + ) logger.debug("Successfully loaded model '{}'.".format(model_path)) return response.json(None, status=204)