Skip to content

Commit

Permalink
refactor: Update display_article_list function to show article previews
Browse files Browse the repository at this point in the history
  • Loading branch information
jaigouk committed Aug 5, 2024
1 parent e612ed5 commit 480bec6
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 67 deletions.
28 changes: 25 additions & 3 deletions pages_util/CreateNewArticle.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
from pages_util.Settings import (
load_search_options,
save_search_options,
SEARCH_ENGINES,
load_llm_settings,
save_llm_settings,
list_downloaded_models,
SEARCH_ENGINES,
LLM_MODELS,
)
from .Settings import list_downloaded_models
Expand Down Expand Up @@ -47,6 +48,7 @@ def initialize_session_state():

def display_article_form():
st.header("Create New Article")
categories = FileIOHelper.load_categories()
with st.form(key="search_form"):
selected_category = st.selectbox("Select category", categories, index=0)
st.text_input(
Expand Down Expand Up @@ -97,6 +99,25 @@ def display_sidebar_options():
llm_settings = load_llm_settings()
primary_model, fallback_model, model_settings = display_llm_options(llm_settings)

# Save search options if they have changed
if (
primary_engine != search_options["primary_engine"]
or fallback_engine != search_options["fallback_engine"]
or search_top_k != search_options["search_top_k"]
or retrieve_top_k != search_options["retrieve_top_k"]
):
save_search_options(
primary_engine, fallback_engine, search_top_k, retrieve_top_k
)

# Save LLM settings if they have changed
if (
primary_model != llm_settings["primary_model"]
or fallback_model != llm_settings["fallback_model"]
or model_settings != llm_settings["model_settings"]
):
save_llm_settings(primary_model, fallback_model, model_settings)

return {
"search_options": {
"primary_engine": primary_engine,
Expand Down Expand Up @@ -156,7 +177,7 @@ def display_llm_options(llm_settings):
key="primary_model",
)
if primary_model == "ollama":
display_ollama_options(llm_settings)
llm_settings["model_settings"] = display_ollama_options(llm_settings)

fallback_model_options = [None] + [
model for model in LLM_MODELS.keys() if model != primary_model
Expand Down Expand Up @@ -190,6 +211,7 @@ def display_ollama_options(llm_settings):
)
llm_settings["model_settings"]["ollama"]["model"] = selected_ollama_model
llm_settings["model_settings"]["ollama"]["max_tokens"] = max_tokens
return llm_settings["model_settings"]


def run_storm_process(status, progress_bar, progress_text):
Expand Down Expand Up @@ -337,7 +359,7 @@ def cleanup_folder(current_working_dir):
def create_new_article_page():
load_and_apply_theme()
initialize_session_state()
categories = FileIOHelper.load_categories()

if st.session_state["page3_write_article_state"] == "not started":
submit_button, selected_category = display_article_form()
handle_form_submission(submit_button, selected_category)
Expand Down
86 changes: 51 additions & 35 deletions pages_util/MyArticles.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,21 @@
import logging
import streamlit as st
from util.file_io import FileIOHelper
from util.ui_components import UIComponents
from util.ui_components import UIComponents, StreamlitCallbackHandler
from util.theme_manager import load_and_apply_theme, get_my_articles_css
from pages_util.Settings import (
load_search_options,
save_search_options,
load_llm_settings,
save_llm_settings,
load_categories,
save_categories,
load_general_settings,
save_general_settings,
save_categories,
)

from util.storm_runner import run_storm_with_config

logging.basicConfig(level=logging.DEBUG)


Expand Down Expand Up @@ -139,41 +145,51 @@ def my_articles_page():
st.title("My Articles")

if "page2_selected_my_article" in st.session_state:
category, article_key = st.session_state.page2_selected_my_article
selected_article_file_path_dict = st.session_state.user_articles[category][
article_key
]
UIComponents.display_article_page(
article_key,
selected_article_file_path_dict,
show_title=True,
show_main_article=True,
show_feedback_form=False,
show_qa_panel=False,
)
if st.button("Back to Article List"):
del st.session_state.page2_selected_my_article
st.rerun()
display_selected_article()
else:
# Sidebar controls - only show when listing articles
with st.sidebar:
category_options = ["All Categories"] + list(
st.session_state.user_articles.keys()
)
st.session_state.selected_category = st.selectbox(
"Select Category", category_options
)
st.session_state.page_size = st.selectbox(
"Items per page", [12, 24, 48, 96], index=1
)
st.session_state.num_columns = st.number_input(
"Number of columns",
min_value=1,
max_value=4,
value=st.session_state.num_columns,
)
display_article_list_and_controls()


def display_selected_article():
category, article_key = st.session_state.page2_selected_my_article
selected_article_file_path_dict = st.session_state.user_articles[category][
article_key
]
UIComponents.display_article_page(
article_key,
selected_article_file_path_dict,
show_title=True,
show_main_article=True,
show_feedback_form=False,
show_qa_panel=False,
)
if st.button("Back to Article List"):
del st.session_state.page2_selected_my_article
st.rerun()


def display_article_list_and_controls():
st.subheader("Article List")

# Sidebar controls
with st.sidebar:
category_options = ["All Categories"] + list(
st.session_state.user_articles.keys()
)
st.session_state.selected_category = st.selectbox(
"Select Category", category_options
)
st.session_state.page_size = st.selectbox(
"Items per page", [12, 24, 48, 96], index=1
)
st.session_state.num_columns = st.number_input(
"Number of columns",
min_value=1,
max_value=4,
value=st.session_state.num_columns,
)

display_article_list(st.session_state.page_size, st.session_state.num_columns)
display_article_list(st.session_state.page_size, st.session_state.num_columns)

# Save the number of columns setting
general_settings = load_general_settings()
Expand Down
1 change: 0 additions & 1 deletion pages_util/Settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,6 @@ def load_search_options():

if result:
stored_options = json.loads(result[0])
# Update default options with stored values, keeping default for missing keys
default_options.update(stored_options)

return default_options
Expand Down
50 changes: 27 additions & 23 deletions util/storm_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,19 +79,20 @@ def add_examples_to_runner(runner):
]


def log_progress(callback_handler, message: str):
st.info(message)
logger.info(message)
if callback_handler:
callback_handler.on_information_gathering_start(message=message)


def run_storm_with_fallback(
topic: str,
current_working_dir: str,
callback_handler=None,
runner=None,
):
def log_progress(message: str):
st.info(message)
logger.info(message)
if callback_handler:
callback_handler.on_information_gathering_start(message=message)

log_progress("Starting STORM process...")
log_progress(callback_handler, "Starting STORM process...")

if runner is None:
raise ValueError("Runner is not initialized")
Expand Down Expand Up @@ -210,24 +211,24 @@ def run_storm_with_config(
current_working_dir: str,
callback_handler=None,
):
progress_placeholder = st.empty()

def update_progress(message):
progress_placeholder.write(message)
if callback_handler:
callback_handler.on_information_gathering_start(message=message)
logger.info(message)

update_progress("Loading configurations...")
llm_settings = load_llm_settings()
search_options = load_search_options()

update_progress("Setting up LLM...")
primary_model = llm_settings["primary_model"]
fallback_model = llm_settings["fallback_model"]
model_settings = llm_settings["model_settings"]

search_options = load_search_options()
search_top_k = search_options["search_top_k"]
retrieve_top_k = search_options["retrieve_top_k"]
if primary_model is None or fallback_model is None or model_settings is None:
llm_settings = load_llm_settings()
primary_model = llm_settings["primary_model"]
fallback_model = llm_settings["fallback_model"]
model_settings = llm_settings["model_settings"]

if search_top_k is None or retrieve_top_k is None:
search_options = load_search_options()
search_top_k = search_options["search_top_k"]
retrieve_top_k = search_options["retrieve_top_k"]

llm_configs = STORMWikiLMConfigs()

Expand All @@ -247,6 +248,7 @@ def run_storm_with_config(
]:
getattr(llm_configs, f"set_{lm_type}_lm")(primary_lm)

update_progress("Setting up search engine...")
engine_args = STORMWikiRunnerArguments(
output_dir=current_working_dir,
max_conv_turn=3,
Expand All @@ -255,19 +257,21 @@ def run_storm_with_config(
retrieve_top_k=retrieve_top_k,
)

# Set up the search engine with only max_results
rm = CombinedSearchAPI(max_results=engine_args.search_top_k)

update_progress("Initializing STORM runner...")
runner = STORMWikiRunner(engine_args, llm_configs, rm)

# Add this line to ensure engine_args is accessible
runner.engine_args = engine_args

add_examples_to_runner(runner)
return run_storm_with_fallback(

result = run_storm_with_fallback(
topic, current_working_dir, callback_handler, runner=runner
)

update_progress("STORM process completed.")
return result


def set_storm_runner():
current_working_dir = os.getenv("STREAMLIT_OUTPUT_DIR")
Expand Down
10 changes: 5 additions & 5 deletions util/ui_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,9 @@ class StreamlitCallbackHandler(BaseCallbackHandler):
def __init__(self, status_container):
self.status_container = status_container

def on_information_gathering_start(self, message, **kwargs):
self.status_container.info(message)

def on_identify_perspective_start(self, **kwargs):
self.status_container.info(
"Start identifying different perspectives for researching the topic."
Expand All @@ -272,9 +275,6 @@ def on_identify_perspective_end(self, perspectives: list[str], **kwargs):
f" from the following perspectives:\n- {perspective_list}"
)

def on_information_gathering_start(self, **kwargs):
self.status_container.info("Start browsing the Internet.")

def on_dialogue_turn_end(self, dlg_turn, **kwargs):
urls = list(set([r.url for r in dlg_turn.search_results]))
for url in urls:
Expand Down Expand Up @@ -302,8 +302,8 @@ def on_information_organization_start(self, **kwargs):

def on_direct_outline_generation_end(self, outline: str, **kwargs):
self.status_container.success(
f"Finish leveraging the internal knowledge of the large language model."
"Finish leveraging the internal knowledge of the large language model."
)

def on_outline_refinement_end(self, outline: str, **kwargs):
self.status_container.success(f"Finish leveraging the collected information.")
self.status_container.success("Finish leveraging the collected information.")

0 comments on commit 480bec6

Please sign in to comment.