Skip to content

Commit

Permalink
Merge pull request #1 from Scylidose/streamlit
Browse files Browse the repository at this point in the history
Streamlit  and GPT models integration
  • Loading branch information
Scylidose authored Sep 21, 2023
2 parents d5adbc1 + de39fec commit 58a916b
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 83 deletions.
11 changes: 11 additions & 0 deletions .streamlit/config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[logger]

# Level of logging: 'error', 'warning', 'info', or 'debug'.
# Default: 'info'
level = "info"

[server]

# If false, will attempt to open a browser window on start.
# Default: false unless (1) we are on a Linux box where DISPLAY is unset, or (2) we are running in the Streamlit Atom plugin.
headless = true
62 changes: 43 additions & 19 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,57 @@
import os
import streamlit as st
import openai

from src import extract, preprocess, scraping, answer, model, export_doc

def main():
website = "https://nomanssky.fandom.com/"
openai.api_key = os.environ['OPENAI_API_KEY']

st.title("NMS Data Extraction and Question Answering")

website = st.text_input("Enter the website URL:", "https://nomanssky.fandom.com/")
websites_file = "data/websites.json"
documents_dir = "data/documents"
output_file = "data/links.csv"
db_dir = "data/whoosh"
child_depth = 1
child_depth = st.slider("Select child depth:", 1, 10, 1)

if st.button("Start Data Extraction"):
st.text("Starting data extraction...")
scraping.check_websites(website, websites_file, child_depth)
st.text("Data extraction completed.")

if st.button("Export to CSV"):
st.text("Exporting data to CSV...")
extract.json_to_csv(websites_file, output_file)
st.text("Data exported to CSV.")

print("\n-----------------------\n")
scraping.check_websites(website, websites_file, child_depth)
print("\n-----------------------\n")
extract.json_to_csv(websites_file, output_file)
print("\n-----------------------\n")
extract.extract_html_text(output_file)
print("\n-----------------------\n")
preprocess.add_preprocessed_text_website(output_file)
if st.button("Extract HTML Text"):
st.text("Extracting HTML text...")
extract.extract_html_text(output_file)
st.text("HTML text extraction completed.")

query_text = "What is the release date of No Man\'s sky?"
model_choice = "DeepPavlov"
if st.button("Preprocess Text"):
st.text("Preprocessing text...")
preprocess.add_preprocessed_text_website(output_file)
st.text("Text preprocessing completed.")

if model_choice == "DeepPavlov":
model_object = model.configure_deeppavlov()
elif model_choice == "Haystack":
export_doc.export_documents(output_file, documents_dir)
model_object = model.configure_haystack(documents_dir)
query_text = st.text_input("Enter your question:", "What is the release date of No Man's Sky?")
model_choice = st.selectbox("Select a model:", ["DeepPavlov", "GPT 3.5 - 4k token", "GPT 3.5 - 16k token"])

answer.answer_question(model_choice, model_object, query_text, output_file, db_dir)
if st.button("Answer Question"):
st.text("Answering the question...")
if model_choice == "DeepPavlov":
model_object = model.configure_deeppavlov()
elif model_choice == "GPT 3.5 - 4k token":
model_object = "gpt-3.5-4k-tokens"
model_choice = "gpt-3.5-4k-tokens"
elif model_choice == "GPT 3.5 - 16k token":
model_object = "gpt-3.5-4k-tokens"
model_choice = "gpt-3.5-16k-tokens"

answered_question=answer.answer_question(model_choice, model_object, query_text, output_file, db_dir)
st.text_area("Answer:", answered_question)

if __name__ == '__main__':
main()
main()
48 changes: 36 additions & 12 deletions src/answer.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,31 @@
import index, model
from . import index, model

models = {
"gpt-3.5-4k-tokens": {
"name": "gpt-3.5-turbo",
"token_length": 4096,
"input_cost": 0.0015,
"output_cost": 0.002,
"available": True
},
"gpt-3.5-16k-tokens": {
"name": "gpt-3.5-turbo-16k",
"token_length": 16384,
"input_cost": 0.003,
"output_cost": 0.004,
"available": True
}
}

def answer_question(model_choice, model_object, query_text, output_file, db_dir):
"""
Use the specified model to answer a question based on the given query text.
Args:
model_choice (str): The name of the model to use for answering the question. Must be either "DeepPavlov"
or "Haystack".
or "ChatHPT".
model_object: An instance of the model to use. For DeepPavlov, this should be an instance of the
`deeppavlov.models.bertqa.BertQA` class. For Haystack, this should be an instance of the `ExtractiveQAPipeline`
class.
`deeppavlov.models.bertqa.BertQA` class.
query_text (str): The text of the question to answer.
output_file (str): The path to the file where the search index should be saved (only used for DeepPavlov).
db_dir (str): The path to the directory where the search index should be stored (only used for DeepPavlov).
Expand All @@ -18,17 +34,25 @@ def answer_question(model_choice, model_object, query_text, output_file, db_dir)
The answer to the question, as determined by the specified model. If the model is unable to find an
answer, returns an empty string.
"""
if model_choice == "DeepPavlov":
index_db = index.create_database(output_file, db_dir)
if query_text is None:
query_text = ''

index_db = index.create_database(output_file, db_dir)
if query_text is None:
query_text = ''

documents = index.search_query(query_text, index_db)
documents = index.search_query(query_text, index_db)

if model_choice == "DeepPavlov":
if len(documents) > 0:
return model.deeppavlov_answer(model_object, documents[0]['content'], query_text)
return model.deeppavlov_answer(model_object, body, query_text)
else:
body = ' '.join(item['content'] for item in documents)
negResponse = "I'm unable to answer the question based on the information I have."

prompt = f"Answer this question: {query_text}\nUsing only the information from this documentation: {body}\nIf the answer is not contained in the supplied doc reply '{negResponse}' and nothing else"

truncated_prompt = model.configure_gpt(prompt, models[model_choice], 1024)

elif model_choice == "Haystack":
return model.haystack_answer(model_object, query_text)
if model_choice.startswith("gpt"):
return model.chat_with_gpt(truncated_prompt, models[model_choice])

return ''
77 changes: 28 additions & 49 deletions src/model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
from deeppavlov import build_model, configs

from haystack.document_stores import InMemoryDocumentStore
from haystack.pipelines.standard_pipelines import TextIndexingPipeline
from haystack.nodes import BM25Retriever, FARMReader
from haystack.pipelines import ExtractiveQAPipeline
from haystack.utils import print_answers
import openai
from . import preprocess

import os

Expand All @@ -20,33 +16,6 @@ def configure_deeppavlov():
"""
return build_model(configs.squad.squad_bert)

def configure_haystack(doc_dir):
"""
Configure a Haystack Extractive Question Answering (QA) Pipeline.
The function uses a directory containing text documents to create an in-memory document store, index the
documents using BM25, load a pre-trained FARM Reader model for extractive QA, and create a pipeline that combines
the reader and the retriever.
Args:
doc_dir (str): The path to the directory containing the text documents to be indexed and searched.
Returns:
A Haystack ExtractiveQAPipeline object that can be used to perform extractive QA.
"""
document_store = InMemoryDocumentStore(use_bm25=True)

files_to_index = [doc_dir + "/" + f for f in os.listdir(doc_dir)]
indexing_pipeline = TextIndexingPipeline(document_store)
indexing_pipeline.run_batch(file_paths=files_to_index)

retriever = BM25Retriever(document_store=document_store)

reader = FARMReader(model_name_or_path="deepset/roberta-base-squad2", use_gpu=True)

return ExtractiveQAPipeline(reader, retriever)


def deeppavlov_answer(model, documents, query):
"""
Answer a question based on a list of documents and a query using DeepPavol model.
Expand All @@ -69,30 +38,40 @@ def deeppavlov_answer(model, documents, query):

return answer

model = build_model(configs.squad.squad_bert)
def configure_gpt(message, model, max_tokens):
"""
Configure a GPT-based language model for generating responses to a message.
Args:
message (str): The input message or prompt.
model (dict): A dictionary representing the GPT model, including its properties.
max_tokens (int): The maximum number of tokens allowed in the generated response.
def haystack_answer(pipe, query):
Returns:
str: A truncated input message suitable for GPT-based model generation.
"""
Use a Haystack Extractive Question Answering (QA) Pipeline to answer a question.
safety_margin = int(model['token_length']*0.25)
truncated_prompt, word_count = preprocess.truncate_text(message, model['token_length'] - max_tokens - safety_margin)

return truncated_prompt

The function takes a query (a question) and uses the Haystack pipeline to retrieve the top-k most relevant
documents from the in-memory document store using BM25, and then use a pre-trained FARM Reader model for extractive
QA to extract the answer from the retrieved documents.
def chat_with_gpt(message, model):
"""
Generate a response using GPT-based chat capabilities.
Args:
pipe (ExtractiveQAPipeline): A Haystack ExtractiveQAPipeline object configured with a retriever and a reader.
query (str): The question to be answered.
message (list): A list of message objects with 'role' and 'content'.
model (dict): A dictionary containing the GPT-based chat model information.
Returns:
The extracted answer to the question.
str: A generated response based on the input messages.
"""
prediction = pipe.run(
query=query,
params={
"Retriever": {"top_k": 10},
"Reader": {"top_k": 5}
}
response = openai.ChatCompletion.create(
model=model['name'],
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": message}
]
)

return prediction['answers'][0].answer
return response.choices[0].message.content
28 changes: 27 additions & 1 deletion src/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,17 @@ def preprocess_text(text):

return text


def remove_common_text(output_file, texts_to_remove):
"""
Remove common text values from a specified column in a CSV file and save the updated data.
Args:
output_file (str): The path to the CSV file to be processed and updated.
texts_to_remove (list): A list of text values to be removed from the specified column.
Returns:
None
"""
# Load the CSV file into a pandas DataFrame
df = pd.read_csv(output_file)

Expand Down Expand Up @@ -106,3 +115,20 @@ def add_preprocessed_text_website(output_file):
df.to_csv(output_file, index=False)

remove_common_text(output_file, ["add category cancel save", "community content available cc by-nc-sa unless otherwise noted advertisement fan feed no man sky wiki starship freighter starbirth explore wikis universal conquest wiki let go luna wiki club wiki explore property fandom muthead futhead fanatical follow overview fandom career press contact term use privacy policy global sitemap local sitemap community community central support help sell info advertise medium kit fandomatic contact fandom apps take favorite fandom never miss beat no man sky wiki fandom game community view mobile site follow ig tiktok join fan lab", "no man sky wiki no man sky wiki explore main page page interactive map navigation main page community portal recent change random page admin noticeboard portal official site community site reddit playstation steam universe galaxy star system planet space station specie resource sentinel technology crafting freighter starship exocraft exosuit multi-tool base building blueprint visual catalogue creativity story mission industrial mining refining cooking tech tree currency additional journal civilized space galactic hub company faction portal lore gamepedia gamepedia support report bad ad help wiki contact fandom home fan central beta game anime movie tv video wikis explore wikis community central start wiki account register sign advertisement no man sky wiki page explore main page page interactive map navigation main page community portal recent change random page admin noticeboard portal official site community site reddit playstation steam universe galaxy star system planet space station specie resource sentinel technology crafting freighter starship exocraft exosuit multi-tool base building blueprint visual catalogue creativity story mission industrial mining refining cooking tech tree currency additional journal civilized space galactic hub company faction portal lore gamepedia gamepedia support report bad ad help wiki contact"])

def truncate_text(text, max_tokens):
"""
Truncate a given text to a specified maximum number of tokens.
Args:
text (str): The input text to be truncated.
max_tokens (int): The maximum number of tokens to retain in the truncated text.
Returns:
tuple: A tuple containing the truncated text and the count of tokens in the original text.
"""
tokens = text.split()
if len(tokens) <= max_tokens:
return text, len(tokens)

return ' '.join(tokens[:max_tokens]), len(tokens)
4 changes: 2 additions & 2 deletions website/atlasfind/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
db_dir = "data/whoosh"

export_doc.export_documents(output_file, documents_dir)
model_object = model.configure_haystack(documents_dir)
model_object = model.configure_deeppavlov()

# Create your views here.
def home(request):
Expand All @@ -22,7 +22,7 @@ def search(request):
search_query = request.GET.get('search_query', None)

# perform search and retrieve results
results = answer.answer_question("Haystack", model_object, search_query, output_file, db_dir)
results = answer.answer_question("DeepPavlov", model_object, search_query, output_file, db_dir)

# prepare the data to be sent back to the client
if not results:
Expand Down

0 comments on commit 58a916b

Please sign in to comment.