Skip to content

Commit

Permalink
add show description method
Browse files Browse the repository at this point in the history
  • Loading branch information
snova-jorgep committed Dec 13, 2024
1 parent 3d1a69d commit b8d9c07
Showing 1 changed file with 20 additions and 4 deletions.
24 changes: 20 additions & 4 deletions multimodal_knowledge_retriever/streamlit/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import time
import uuid
from threading import Thread
from typing import Optional
from typing import Any,Optional

import yaml

Expand All @@ -27,6 +27,7 @@
logging.info('URL: http://localhost:8501')

CONFIG_PATH = os.path.join(kit_dir, 'config.yaml')
APP_DESCRIPTION_PATH = os.path.join(kit_dir, 'streamlit', 'app_description.yaml')
ADDITIONAL_ENV_VARS: list[str] = []
# Available models in dropdown menu
LVLM_MODELS = [
Expand All @@ -38,6 +39,13 @@
# Minutes for scheduled cache deletion
EXIT_TIME_DELTA = 30

def load_config() -> Any:
with open(CONFIG_PATH, 'r') as yaml_file:
return yaml.safe_load(yaml_file)

def load_app_description() -> Any:
with open(APP_DESCRIPTION_PATH, 'r') as yaml_file:
return yaml.safe_load(yaml_file)

def delete_temp_dir(temp_dir: str) -> None:
"""Delete the temporary directory and its contents."""
Expand Down Expand Up @@ -120,6 +128,14 @@ def handle_user_input(user_question: str) -> None:
with c2.expander('Images'):
for image in image_source:
st.image(image)

# show overview message when chat history is empty
if len(st.session_state.chat_history) == 0:
with st.chat_message(
'ai',
avatar='https://sambanova.ai/hubfs/logotype_sambanova_orange.png',
):
st.write(load_app_description().get('app_overview'))


def initialize_multimodal_retrieval() -> Optional[MultimodalRetrieval]:
Expand All @@ -133,8 +149,7 @@ def initialize_multimodal_retrieval() -> Optional[MultimodalRetrieval]:


def main() -> None:
with open(CONFIG_PATH, 'r') as yaml_file:
config = yaml.safe_load(yaml_file)
config = load_config()

prod_mode = config.get('prod_mode', False)
llm_type = 'SambaStudio' if config.get('llm', {}).get('type') == 'sambastudio' else 'SambaNova Cloud'
Expand Down Expand Up @@ -180,7 +195,7 @@ def main() -> None:
user_question = st.chat_input('Ask questions about your data', disabled=st.session_state.input_disabled)
if user_question is not None:
st.session_state.mp_events.input_submitted('chat_input')
handle_user_input(user_question)
handle_user_input(user_question)

with st.sidebar:
st.title('Setup')
Expand Down Expand Up @@ -267,6 +282,7 @@ def main() -> None:
st.session_state.image_sources_history = []
st.session_state.multimodal_retriever.init_memory()
st.toast('Conversation reset. The next response will clear the history on the screen')
st.rerun()


if __name__ == '__main__':
Expand Down

0 comments on commit b8d9c07

Please sign in to comment.