Skip to content

Commit

Permalink
Merge pull request RasaHQ#3642 from tdzienniak/master
Browse files Browse the repository at this point in the history
Use existing endpoints configuration when loading new model via PUT request
  • Loading branch information
tmbo authored Jun 12, 2019
2 parents a068663 + 0a0e67b commit 1f5e1e1
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 4 deletions.
3 changes: 3 additions & 0 deletions rasa/core/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def configure_app(
jwt_method: Optional[Text] = None,
route: Optional[Text] = "/webhooks/",
port: int = constants.DEFAULT_SERVER_PORT,
endpoints: Optional[AvailableEndpoints] = None,
log_file: Optional[Text] = None,
):
"""Run the agent."""
Expand All @@ -77,6 +78,7 @@ def configure_app(
auth_token=auth_token,
jwt_secret=jwt_secret,
jwt_method=jwt_method,
endpoints=endpoints,
)
else:
app = Sanic(__name__, configure_logging=False)
Expand Down Expand Up @@ -143,6 +145,7 @@ def serve_application(
jwt_secret,
jwt_method,
port=port,
endpoints=endpoints,
log_file=log_file,
)

Expand Down
32 changes: 28 additions & 4 deletions rasa/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,17 @@
DEFAULT_DOMAIN_PATH,
DOCS_BASE_URL,
)
from rasa.core import broker
from rasa.core.agent import load_agent, Agent
from rasa.core.channels import UserMessage, CollectingOutputChannel
from rasa.core.events import Event
from rasa.core.test import test
from rasa.core.trackers import DialogueStateTracker, EventVerbosity
from rasa.core.utils import dump_obj_as_str_to_file
from rasa.core.utils import dump_obj_as_str_to_file, AvailableEndpoints
from rasa.model import get_model_subdirectories, fingerprint_from_path
from rasa.nlu.emulators.no_emulator import NoEmulator
from rasa.nlu.test import run_evaluation

from rasa.core.tracker_store import TrackerStore

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -233,9 +234,29 @@ async def _load_agent(
model_path: Optional[Text] = None,
model_server: Optional[EndpointConfig] = None,
remote_storage: Optional[Text] = None,
endpoints: Optional[AvailableEndpoints] = None,
) -> Agent:
try:
loaded_agent = await load_agent(model_path, model_server, remote_storage)
tracker_store = None
generator = None
action_endpoint = None

if endpoints:
_broker = broker.from_endpoint_config(endpoints.event_broker)
tracker_store = TrackerStore.find_tracker_store(
None, endpoints.tracker_store, _broker
)
generator = endpoints.nlg
action_endpoint = endpoints.action

loaded_agent = await load_agent(
model_path,
model_server,
remote_storage,
generator=generator,
tracker_store=tracker_store,
action_endpoint=action_endpoint,
)
except Exception as e:
logger.debug(traceback.format_exc())
raise ErrorResponse(
Expand All @@ -259,6 +280,7 @@ def create_app(
auth_token: Optional[Text] = None,
jwt_secret: Optional[Text] = None,
jwt_method: Text = "HS256",
endpoints: Optional[AvailableEndpoints] = None,
):
"""Class representing a Rasa HTTP server."""

Expand Down Expand Up @@ -796,7 +818,9 @@ async def load_model(request: Request):
model_server = request.json.get("model_server", None)
remote_storage = request.json.get("remote_storage", None)

app.agent = await _load_agent(model_path, model_server, remote_storage)
app.agent = await _load_agent(
model_path, model_server, remote_storage, endpoints
)

logger.debug("Successfully loaded model '{}'.".format(model_path))
return response.json(None, status=204)
Expand Down

0 comments on commit 1f5e1e1

Please sign in to comment.