Skip to content

Commit

Permalink
API Update
Browse files Browse the repository at this point in the history
- Updated the API code to use the prompt template.
- Removed unused code from run_local_API.py
- Standardized the API endpoint in the localGPTUI
  • Loading branch information
PromtEngineer committed Sep 26, 2023
1 parent db1b36e commit 15e9648
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 35 deletions.
25 changes: 15 additions & 10 deletions localGPTUI/localGPTUI.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
app = Flask(__name__)
app.secret_key = "LeafmanZSecretKey"

API_HOST = "http://localhost:5110/api"


# PAGES #
@app.route("/", methods=["GET", "POST"])
Expand All @@ -21,19 +23,19 @@ def home_page():
user_prompt = request.form["user_prompt"]
print(f"User Prompt: {user_prompt}")

main_prompt_url = "http://localhost:5110/api/prompt_route"
main_prompt_url = f"{API_HOST}/prompt_route"
response = requests.post(main_prompt_url, data={"user_prompt": user_prompt})
print(response.status_code) # print HTTP response status code for debugging
if response.status_code == 200:
# print(response.json()) # Print the JSON data from the response
return render_template("home.html", show_response_modal=True, response_dict=response.json())
elif "documents" in request.files:
delete_source_url = "http://localhost:5110/api/delete_source" # URL of the /api/delete_source endpoint
delete_source_url = f"{API_HOST}/delete_source" # URL of the /api/delete_source endpoint
if request.form.get("action") == "reset":
response = requests.get(delete_source_url)

save_document_url = "http://localhost:5110/api/save_document"
run_ingest_url = "http://localhost:5110/api/run_ingest" # URL of the /api/run_ingest endpoint
save_document_url = f"{API_HOST}/save_document"
run_ingest_url = f"{API_HOST}/run_ingest" # URL of the /api/run_ingest endpoint
files = request.files.getlist("documents")
for file in files:
print(file.filename)
Expand All @@ -57,11 +59,14 @@ def home_page():

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=5111,
help="Port to run the UI on. Defaults to 5111.")
parser.add_argument("--host", type=str, default="127.0.0.1",
help="Host to run the UI on. Defaults to 127.0.0.1. "
"Set to 0.0.0.0 to make the UI externally "
"accessible from other devices.")
parser.add_argument("--port", type=int, default=5111, help="Port to run the UI on. Defaults to 5111.")
parser.add_argument(
"--host",
type=str,
default="127.0.0.1",
help="Host to run the UI on. Defaults to 127.0.0.1. "
"Set to 0.0.0.0 to make the UI externally "
"accessible from other devices.",
)
args = parser.parse_args()
app.run(debug=False, host=args.host, port=args.port)
32 changes: 20 additions & 12 deletions prompt_template_utils.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,21 @@
'''
"""
This file implements prompt template for llama based models.
Modify the prompt template based on the model you select.
This seems to have significant impact on the output of the LLM.
'''
"""

from langchain.memory import ConversationBufferMemory
from langchain.prompts import PromptTemplate

# this is specific to Llama-2.
# this is specific to Llama-2.

system_prompt = """You are a helpful assistant, you will use the provided context to answer user questions.
Read the given context before answering questions and think step by step. If you can not answer a user question based on
the provided context, inform the user. Do not use any other information for answering user"""
the provided context, inform the user. Do not use any other information for answering user. Provide a detailed answer to the question."""


def get_prompt_template(system_prompt=system_prompt, promptTemplate_type=None, history=False):

if promptTemplate_type=="llama":
if promptTemplate_type == "llama":
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
SYSTEM_PROMPT = B_SYS + system_prompt + E_SYS
Expand All @@ -25,33 +24,42 @@ def get_prompt_template(system_prompt=system_prompt, promptTemplate_type=None, h
Context: {history} \n {context}
User: {question}"""

prompt_template = B_INST + SYSTEM_PROMPT + instruction + E_INST
prompt_template = B_INST + SYSTEM_PROMPT + instruction + E_INST
prompt = PromptTemplate(input_variables=["history", "context", "question"], template=prompt_template)
else:
instruction = """
Context: {context}
User: {question}"""

prompt_template = B_INST + SYSTEM_PROMPT + instruction + E_INST
prompt_template = B_INST + SYSTEM_PROMPT + instruction + E_INST
prompt = PromptTemplate(input_variables=["context", "question"], template=prompt_template)

else:
# change this based on the model you have selected.
# change this based on the model you have selected.
if history:
prompt_template = system_prompt + """
prompt_template = (
system_prompt
+ """
Context: {history} \n {context}
User: {question}
Answer:"""
)
prompt = PromptTemplate(input_variables=["history", "context", "question"], template=prompt_template)
else:
prompt_template = system_prompt + """
prompt_template = (
system_prompt
+ """
Context: {context}
User: {question}
Answer:"""
)
prompt = PromptTemplate(input_variables=["context", "question"], template=prompt_template)

memory = ConversationBufferMemory(input_key="question", memory_key="history")

return prompt, memory,
return (
prompt,
memory,
)
31 changes: 18 additions & 13 deletions run_localGPT_API.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,16 @@
import subprocess

import torch
from auto_gptq import AutoGPTQForCausalLM
from flask import Flask, jsonify, request
from langchain.chains import RetrievalQA
from langchain.embeddings import HuggingFaceInstructEmbeddings

# from langchain.embeddings import HuggingFaceEmbeddings
from langchain.llms import HuggingFacePipeline
from run_localGPT import load_model
from prompt_template_utils import get_prompt_template

# from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.vectorstores import Chroma
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
GenerationConfig,
LlamaForCausalLM,
LlamaTokenizer,
pipeline,
)
from werkzeug.utils import secure_filename

from constants import CHROMA_SETTINGS, EMBEDDING_MODEL_NAME, PERSIST_DIRECTORY, MODEL_ID, MODEL_BASENAME
Expand Down Expand Up @@ -65,9 +56,16 @@
RETRIEVER = DB.as_retriever()

LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME)
prompt, memory = get_prompt_template(promptTemplate_type="llama", history=False)

QA = RetrievalQA.from_chain_type(
llm=LLM, chain_type="stuff", retriever=RETRIEVER, return_source_documents=SHOW_SOURCES
llm=LLM,
chain_type="stuff",
retriever=RETRIEVER,
return_source_documents=SHOW_SOURCES,
chain_type_kwargs={
"prompt": prompt,
},
)

app = Flask(__name__)
Expand Down Expand Up @@ -120,7 +118,7 @@ def run_ingest_route():
if DEVICE_TYPE == "cpu":
run_langest_commands.append("--device_type")
run_langest_commands.append(DEVICE_TYPE)

result = subprocess.run(run_langest_commands, capture_output=True)
if result.returncode != 0:
return "Script execution failed: {}".format(result.stderr.decode("utf-8")), 500
Expand All @@ -131,9 +129,16 @@ def run_ingest_route():
client_settings=CHROMA_SETTINGS,
)
RETRIEVER = DB.as_retriever()
prompt, memory = get_prompt_template(promptTemplate_type="llama", history=False)

QA = RetrievalQA.from_chain_type(
llm=LLM, chain_type="stuff", retriever=RETRIEVER, return_source_documents=SHOW_SOURCES
llm=LLM,
chain_type="stuff",
retriever=RETRIEVER,
return_source_documents=SHOW_SOURCES,
chain_type_kwargs={
"prompt": prompt,
},
)
return "Script executed successfully: {}".format(result.stdout.decode("utf-8")), 200
except Exception as e:
Expand Down

0 comments on commit 15e9648

Please sign in to comment.